Chapter 5
18 min read
Section 26 of 76

Time Conditioning

U-Net Architecture for Diffusion

Learning Objectives

By the end of this section, you will:

  1. Understand why time conditioning is essential for diffusion models
  2. Implement sinusoidal positional embeddings for timestep encoding
  3. Learn the mathematical properties that make sinusoidal embeddings effective
  4. Build the complete time embedding pipeline with MLP projection
  5. Compare injection methods: addition vs FiLM vs adaptive normalization

Why This Matters

Time conditioning is what allows a single neural network to denoise images at all noise levels. Without it, you would need a separate network for each timestep! The sinusoidal embedding enables smooth interpolation and extrapolation across timesteps, while the injection method determines how strongly time information influences each layer.

Why Time Conditioning?

The denoising task changes dramatically as tt varies:

TimestepNoise LevelTaskNetwork Behavior
t = T (1000)~100%Generate from noisePredict large-scale structure
t ~ T/2 (500)~50%Refine coarse detailsBalance structure and texture
t ~ 100~10%Fine-tune detailsPredict high-frequency features
t ~ 1~0.1%Final cleanupRemove residual noise artifacts

A network without time conditioning would have to "guess" the noise level from the image alone—an ambiguous and unreliable task. Time conditioning provides explicit information about where we are in the denoising process.

The Noise Schedule Connection

Recall from Chapter 3 that at timestep tt, the noisy image follows:xt=αˉtx0+1αˉtϵx_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilonThe network needs to know αˉt\bar{\alpha}_t to correctly scale its prediction. Time conditioning provides this information.

Sinusoidal Embeddings

The standard approach for encoding timesteps is sinusoidal positional embedding, borrowed from the Transformer architecture. The key idea: represent a scalar timestep as a high-dimensional vector using sine and cosine functions at different frequencies.

The Formula

For timestep tt and embedding dimension dd, the embedding at position ii is:

PE(t,2i)=sin(t100002i/d)\text{PE}(t, 2i) = \sin\left(\frac{t}{10000^{2i/d}}\right)

PE(t,2i+1)=cos(t100002i/d)\text{PE}(t, 2i+1) = \cos\left(\frac{t}{10000^{2i/d}}\right)

The denominator 100002i/d10000^{2i/d} creates a geometric progression of frequencies from 1 to 1/10000.

Sinusoidal Time Embedding Implementation
🐍sinusoidal_embedding.py
1Import Math and Torch

We use Python math for the logarithm and PyTorch for tensor operations in the model.

4Function Signature

The function takes a tensor of timesteps (shape [B]) and the desired embedding dimension. Returns embeddings of shape [B, dim].

18Compute Half Dimension

We split the embedding into sin and cos components, so we need half the dimension for each. This gives us the full dimension when concatenated.

21Frequency Calculation

The key insight: different dimensions use different frequencies. Low dimensions have high frequency, high dimensions have low frequency. This is the same idea as Transformer positional encoding.

EXAMPLE
freq_i = 1 / 10000^(2i/d)
28Compute Arguments

Multiply timesteps by frequencies. Broadcasting: timesteps [B, 1] * freqs [1, half_dim] = args [B, half_dim].

31Apply Sin and Cos

Apply sin to even indices and cos to odd indices. The interleaving creates a unique encoding for each timestep that varies smoothly.

40 lines without explanation
1import torch
2import math
3
4def sinusoidal_embedding(
5    timesteps: torch.Tensor,
6    dim: int,
7    max_period: float = 10000.0,
8) -> torch.Tensor:
9    """
10    Create sinusoidal timestep embeddings.
11
12    Args:
13        timesteps: Tensor of timesteps [B]
14        dim: Embedding dimension (must be even)
15        max_period: Controls frequency range (default 10000)
16
17    Returns:
18        Embeddings of shape [B, dim]
19    """
20    half_dim = dim // 2
21
22    # Compute frequency for each dimension
23    # freqs[i] = 1 / (max_period ^ (2i / dim))
24    freqs = torch.exp(
25        -math.log(max_period)
26        * torch.arange(half_dim, device=timesteps.device)
27        / half_dim
28    )
29
30    # Compute arguments for sin and cos
31    # args[b, i] = timesteps[b] * freqs[i]
32    args = timesteps[:, None].float() * freqs[None, :]
33
34    # Apply sin and cos and concatenate
35    embedding = torch.cat([
36        torch.sin(args),
37        torch.cos(args),
38    ], dim=-1)
39
40    return embedding
41
42
43# Example usage
44t = torch.tensor([0, 100, 500, 999])  # Batch of timesteps
45emb = sinusoidal_embedding(t, dim=256)
46print(f"Embedding shape: {emb.shape}")  # [4, 256]

Embedding Properties

Let's visualize how the sinusoidal embedding works. Explore different timesteps and observe how the embedding changes:

Key Properties

The sinusoidal embedding has several important properties that make it ideal for time conditioning:

  1. Unique representation: Each timestep produces a distinct embedding vector. The high-dimensional space ensures no two timesteps collide.
  2. Smooth interpolation: Adjacent timesteps have similar embeddings. This allows the network to generalize between training timesteps.
  3. Multi-scale information: Low-frequency components capture coarse time structure; high-frequency components capture fine distinctions.
  4. Bounded values: Outputs are always in [-1, 1], which is numerically stable.
  5. No learned parameters: The embedding is deterministic, reducing the parameter count and training complexity.

Relative Position Information

A key property from the original Transformer paper: for any fixed offset kk,PE(t+k)\text{PE}(t+k) can be expressed as a linear function ofPE(t)\text{PE}(t). This allows the network to easily learn relative time relationships, not just absolute positions.

MLP Projection

The raw sinusoidal embedding is passed through an MLP to create the final time embedding. This serves several purposes:

  • Learn task-specific representations: The MLP can transform the generic positional encoding into features useful for denoising
  • Increase capacity: Two linear layers with nonlinearity provide more expressive power
  • Match dimensions: Project to the dimension needed by the ResBlocks
Complete Time Embedding Module
🐍time_embedding.py
1Time Embedding Module

The full time embedding pipeline: sinusoidal encoding followed by MLP projection. This is what gets injected into each ResBlock.

12Sinusoidal Encoding Dimension

We typically use model_channels as the sinusoidal dimension, then project to 4x for the MLP output.

19MLP Architecture

Two linear layers with SiLU activation. This allows the network to learn complex transformations of the time information.

40Forward Pass

First compute sinusoidal encoding, then pass through MLP. The output is ready to be injected into ResBlocks.

65 lines without explanation
1import torch
2import torch.nn as nn
3import math
4
5class TimeEmbedding(nn.Module):
6    """
7    Complete time embedding: sinusoidal encoding + MLP projection.
8
9    Used in diffusion U-Net to condition on timestep.
10    """
11
12    def __init__(
13        self,
14        model_channels: int,
15        time_embed_dim: int = None,
16        max_period: float = 10000.0,
17    ):
18        super().__init__()
19
20        # Time embedding dimension (typically 4x model channels)
21        time_embed_dim = time_embed_dim or model_channels * 4
22
23        self.model_channels = model_channels
24        self.max_period = max_period
25
26        # MLP: project sinusoidal to time_embed_dim
27        self.mlp = nn.Sequential(
28            nn.Linear(model_channels, time_embed_dim),
29            nn.SiLU(),
30            nn.Linear(time_embed_dim, time_embed_dim),
31        )
32
33    def sinusoidal_embedding(self, timesteps: torch.Tensor) -> torch.Tensor:
34        """Compute sinusoidal embedding for timesteps."""
35        half_dim = self.model_channels // 2
36
37        freqs = torch.exp(
38            -math.log(self.max_period)
39            * torch.arange(half_dim, device=timesteps.device, dtype=torch.float32)
40            / half_dim
41        )
42
43        args = timesteps[:, None].float() * freqs[None, :]
44        embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
45
46        return embedding
47
48    def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
49        """
50        Args:
51            timesteps: Integer timesteps [B]
52
53        Returns:
54            Time embeddings [B, time_embed_dim]
55        """
56        # Sinusoidal encoding
57        emb = self.sinusoidal_embedding(timesteps)
58
59        # MLP projection
60        emb = self.mlp(emb)
61
62        return emb
63
64
65# Example usage
66time_emb_module = TimeEmbedding(model_channels=128)
67timesteps = torch.randint(0, 1000, (8,))
68time_emb = time_emb_module(timesteps)
69print(f"Output embedding shape: {time_emb.shape}")  # [8, 512]

Dimension Choices

Common choices: if model_channels=128\text{model\_channels} = 128, then:
  • Sinusoidal dimension: 128 (same as model_channels)
  • Time embedding dimension: 512 (4x model_channels)
  • Per-ResBlock projection: 512 to channels of that block
The 4x multiplier gives the MLP sufficient capacity to learn rich time representations.

Injection Methods

Once we have the time embedding, how do we incorporate it into each ResBlock? There are three main approaches:

Method 1: Simple Addition

The simplest approach: project time embedding to match channel dimension and add to features:

🐍python
1# Simple addition (used in DDPM)
2time_proj = nn.Linear(time_embed_dim, channels)
3
4def forward(features, time_emb):
5    # Project and add
6    time_emb = time_proj(time_emb)[:, :, None, None]
7    return features + time_emb

Method 2: FiLM Conditioning

Feature-wise Linear Modulation learns both scale and shift from time:

FiLM: Scale and Shift Conditioning
🐍film_conditioning.py
1FiLM Conditioning

Feature-wise Linear Modulation: learn scale (gamma) and shift (beta) from time embedding, then apply to features.

12Scale and Shift Projection

Project time embedding to 2x channels: first half is scale (gamma), second half is shift (beta).

24Apply Modulation

Multiply features by (1 + gamma) and add beta. The (1 + gamma) ensures identity initialization when gamma=0.

EXAMPLE
output = features * (1 + gamma) + beta
36 lines without explanation
1import torch
2import torch.nn as nn
3
4class FiLMConditioner(nn.Module):
5    """
6    Feature-wise Linear Modulation for time conditioning.
7
8    More expressive than simple addition: learns both
9    scale (multiplicative) and shift (additive) modulation.
10    """
11
12    def __init__(self, time_embed_dim: int, channels: int):
13        super().__init__()
14
15        # Project to 2x channels: [gamma, beta]
16        self.proj = nn.Linear(time_embed_dim, channels * 2)
17
18    def forward(
19        self, features: torch.Tensor, time_emb: torch.Tensor
20    ) -> torch.Tensor:
21        """
22        Args:
23            features: [B, C, H, W]
24            time_emb: [B, time_embed_dim]
25        Returns:
26            Modulated features [B, C, H, W]
27        """
28        # Project time embedding
29        params = self.proj(time_emb)  # [B, 2*C]
30
31        # Split into scale and shift
32        gamma, beta = params.chunk(2, dim=-1)  # Each [B, C]
33
34        # Reshape for broadcasting
35        gamma = gamma[:, :, None, None]  # [B, C, 1, 1]
36        beta = beta[:, :, None, None]    # [B, C, 1, 1]
37
38        # Apply modulation: (1 + gamma) * features + beta
39        return features * (1 + gamma) + beta

Method 3: Adaptive Group Normalization (AdaGN)

Used in modern diffusion models: replace GroupNorm's learned scale/shift with time-dependent parameters:

🐍python
1class AdaptiveGroupNorm(nn.Module):
2    """
3    GroupNorm with time-dependent scale and shift.
4
5    Instead of learning fixed gamma/beta, predict them from time embedding.
6    """
7
8    def __init__(self, num_groups: int, channels: int, time_embed_dim: int):
9        super().__init__()
10
11        self.norm = nn.GroupNorm(num_groups, channels, affine=False)
12        self.proj = nn.Linear(time_embed_dim, channels * 2)
13
14    def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
15        # Normalize without affine transform
16        x = self.norm(x)
17
18        # Get time-dependent scale and shift
19        params = self.proj(time_emb)
20        gamma, beta = params.chunk(2, dim=-1)
21        gamma = gamma[:, :, None, None]
22        beta = beta[:, :, None, None]
23
24        # Apply time-dependent affine transform
25        return x * (1 + gamma) + beta
MethodParametersExpressivenessUsed In
AdditionLinear(D, C)Low (additive only)DDPM, early models
FiLMLinear(D, 2C)Medium (scale + shift)Many modern models
AdaGNLinear(D, 2C)High (replaces norm)DiT, advanced models

Complete Implementation

Here's how time conditioning integrates with the ResBlock from the previous section:

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class TimedResBlock(nn.Module):
6    """ResBlock with time conditioning via FiLM."""
7
8    def __init__(
9        self,
10        in_channels: int,
11        out_channels: int,
12        time_embed_dim: int,
13        dropout: float = 0.0,
14    ):
15        super().__init__()
16
17        # First conv block
18        self.norm1 = nn.GroupNorm(32, in_channels)
19        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
20
21        # Time conditioning (FiLM style)
22        self.time_proj = nn.Sequential(
23            nn.SiLU(),
24            nn.Linear(time_embed_dim, out_channels * 2),
25        )
26
27        # Second conv block
28        self.norm2 = nn.GroupNorm(32, out_channels)
29        self.dropout = nn.Dropout(dropout)
30        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
31
32        # Skip projection
33        self.skip = (
34            nn.Conv2d(in_channels, out_channels, 1)
35            if in_channels != out_channels
36            else nn.Identity()
37        )
38
39    def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
40        # First block
41        h = F.silu(self.norm1(x))
42        h = self.conv1(h)
43
44        # Time conditioning (FiLM)
45        time_params = self.time_proj(time_emb)
46        gamma, beta = time_params.chunk(2, dim=-1)
47        h = h * (1 + gamma[:, :, None, None]) + beta[:, :, None, None]
48
49        # Second block
50        h = F.silu(self.norm2(h))
51        h = self.dropout(h)
52        h = self.conv2(h)
53
54        # Residual
55        return h + self.skip(x)

Summary

In this section, we implemented time conditioning for diffusion U-Nets:

  1. Sinusoidal embeddings: Encode scalar timesteps as high-dimensional vectors using sine and cosine at multiple frequencies
  2. Multi-scale information: Different frequencies capture both coarse and fine time structure, enabling smooth interpolation
  3. MLP projection: Transform raw embeddings into task-specific representations with increased capacity
  4. Injection methods: Addition (simple), FiLM (scale + shift), or AdaGN (replaces normalization parameters)

Coming Up Next

In the next section, we'll add attention layers to our U-Net. Self-attention enables global context aggregation, allowing the network to model long-range dependencies that convolutions alone cannot capture efficiently.