Learning Objectives
By the end of this section, you will:
- Understand why time conditioning is essential for diffusion models
- Implement sinusoidal positional embeddings for timestep encoding
- Learn the mathematical properties that make sinusoidal embeddings effective
- Build the complete time embedding pipeline with MLP projection
- Compare injection methods: addition vs FiLM vs adaptive normalization
Why This Matters
Why Time Conditioning?
The denoising task changes dramatically as varies:
| Timestep | Noise Level | Task | Network Behavior |
|---|---|---|---|
| t = T (1000) | ~100% | Generate from noise | Predict large-scale structure |
| t ~ T/2 (500) | ~50% | Refine coarse details | Balance structure and texture |
| t ~ 100 | ~10% | Fine-tune details | Predict high-frequency features |
| t ~ 1 | ~0.1% | Final cleanup | Remove residual noise artifacts |
A network without time conditioning would have to "guess" the noise level from the image alone—an ambiguous and unreliable task. Time conditioning provides explicit information about where we are in the denoising process.
The Noise Schedule Connection
Sinusoidal Embeddings
The standard approach for encoding timesteps is sinusoidal positional embedding, borrowed from the Transformer architecture. The key idea: represent a scalar timestep as a high-dimensional vector using sine and cosine functions at different frequencies.
The Formula
For timestep and embedding dimension , the embedding at position is:
The denominator creates a geometric progression of frequencies from 1 to 1/10000.
Embedding Properties
Let's visualize how the sinusoidal embedding works. Explore different timesteps and observe how the embedding changes:
Key Properties
The sinusoidal embedding has several important properties that make it ideal for time conditioning:
- Unique representation: Each timestep produces a distinct embedding vector. The high-dimensional space ensures no two timesteps collide.
- Smooth interpolation: Adjacent timesteps have similar embeddings. This allows the network to generalize between training timesteps.
- Multi-scale information: Low-frequency components capture coarse time structure; high-frequency components capture fine distinctions.
- Bounded values: Outputs are always in [-1, 1], which is numerically stable.
- No learned parameters: The embedding is deterministic, reducing the parameter count and training complexity.
Relative Position Information
MLP Projection
The raw sinusoidal embedding is passed through an MLP to create the final time embedding. This serves several purposes:
- Learn task-specific representations: The MLP can transform the generic positional encoding into features useful for denoising
- Increase capacity: Two linear layers with nonlinearity provide more expressive power
- Match dimensions: Project to the dimension needed by the ResBlocks
Dimension Choices
- Sinusoidal dimension: 128 (same as model_channels)
- Time embedding dimension: 512 (4x model_channels)
- Per-ResBlock projection: 512 to channels of that block
Injection Methods
Once we have the time embedding, how do we incorporate it into each ResBlock? There are three main approaches:
Method 1: Simple Addition
The simplest approach: project time embedding to match channel dimension and add to features:
1# Simple addition (used in DDPM)
2time_proj = nn.Linear(time_embed_dim, channels)
3
4def forward(features, time_emb):
5 # Project and add
6 time_emb = time_proj(time_emb)[:, :, None, None]
7 return features + time_embMethod 2: FiLM Conditioning
Feature-wise Linear Modulation learns both scale and shift from time:
Method 3: Adaptive Group Normalization (AdaGN)
Used in modern diffusion models: replace GroupNorm's learned scale/shift with time-dependent parameters:
1class AdaptiveGroupNorm(nn.Module):
2 """
3 GroupNorm with time-dependent scale and shift.
4
5 Instead of learning fixed gamma/beta, predict them from time embedding.
6 """
7
8 def __init__(self, num_groups: int, channels: int, time_embed_dim: int):
9 super().__init__()
10
11 self.norm = nn.GroupNorm(num_groups, channels, affine=False)
12 self.proj = nn.Linear(time_embed_dim, channels * 2)
13
14 def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
15 # Normalize without affine transform
16 x = self.norm(x)
17
18 # Get time-dependent scale and shift
19 params = self.proj(time_emb)
20 gamma, beta = params.chunk(2, dim=-1)
21 gamma = gamma[:, :, None, None]
22 beta = beta[:, :, None, None]
23
24 # Apply time-dependent affine transform
25 return x * (1 + gamma) + beta| Method | Parameters | Expressiveness | Used In |
|---|---|---|---|
| Addition | Linear(D, C) | Low (additive only) | DDPM, early models |
| FiLM | Linear(D, 2C) | Medium (scale + shift) | Many modern models |
| AdaGN | Linear(D, 2C) | High (replaces norm) | DiT, advanced models |
Complete Implementation
Here's how time conditioning integrates with the ResBlock from the previous section:
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class TimedResBlock(nn.Module):
6 """ResBlock with time conditioning via FiLM."""
7
8 def __init__(
9 self,
10 in_channels: int,
11 out_channels: int,
12 time_embed_dim: int,
13 dropout: float = 0.0,
14 ):
15 super().__init__()
16
17 # First conv block
18 self.norm1 = nn.GroupNorm(32, in_channels)
19 self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
20
21 # Time conditioning (FiLM style)
22 self.time_proj = nn.Sequential(
23 nn.SiLU(),
24 nn.Linear(time_embed_dim, out_channels * 2),
25 )
26
27 # Second conv block
28 self.norm2 = nn.GroupNorm(32, out_channels)
29 self.dropout = nn.Dropout(dropout)
30 self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
31
32 # Skip projection
33 self.skip = (
34 nn.Conv2d(in_channels, out_channels, 1)
35 if in_channels != out_channels
36 else nn.Identity()
37 )
38
39 def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
40 # First block
41 h = F.silu(self.norm1(x))
42 h = self.conv1(h)
43
44 # Time conditioning (FiLM)
45 time_params = self.time_proj(time_emb)
46 gamma, beta = time_params.chunk(2, dim=-1)
47 h = h * (1 + gamma[:, :, None, None]) + beta[:, :, None, None]
48
49 # Second block
50 h = F.silu(self.norm2(h))
51 h = self.dropout(h)
52 h = self.conv2(h)
53
54 # Residual
55 return h + self.skip(x)Summary
In this section, we implemented time conditioning for diffusion U-Nets:
- Sinusoidal embeddings: Encode scalar timesteps as high-dimensional vectors using sine and cosine at multiple frequencies
- Multi-scale information: Different frequencies capture both coarse and fine time structure, enabling smooth interpolation
- MLP projection: Transform raw embeddings into task-specific representations with increased capacity
- Injection methods: Addition (simple), FiLM (scale + shift), or AdaGN (replaces normalization parameters)