Learning Objectives
By the end of this section, you will be able to:
- Derive the Evidence Lower Bound (ELBO) for diffusion models from first principles
- Decompose the ELBO into reconstruction, prior, and denoising terms
- Simplify to the practical loss:
- Understand every step of the derivation
Starting with the Variational Bound
Our goal is to maximize the log-likelihood of the data:
Since this is intractable (it requires marginalizing over all latent variables ), we derive a variational lower bound (ELBO).
Introducing the Variational Distribution
We introduce the forward process as our variational distribution:
Multiplying and dividing by :
By Jensen's inequality (since log is concave):
The ELBO: Maximizing this lower bound pushes up the true log-likelihood. The gap between them is the KL divergence between and the true posterior.
ELBO Decomposition
Now we expand the joint distributions:
Substituting into the ELBO:
Rearranging the Terms
After careful algebraic manipulation (using Bayes' rule on some terms), the ELBO decomposes into:
| Term | Name | Interpretation |
|---|---|---|
| L_0 | Reconstruction | How well can we reconstruct x_0 from x_1? |
| L_T | Prior matching | How close is q(x_T|x_0) to pure Gaussian p(x_T)? |
| L_{t-1} | Denoising matching | How well does p_theta match the true reverse q? |
The L_T Term
The Denoising Term
The key term is , which measures how well our learned reverse matches the true reverse .
KL Between Gaussians
Both distributions are Gaussian:
- True posterior:
- Learned reverse:
The KL divergence between two Gaussians with the same variance simplifies to:
where C is a constant independent of .
Substituting the Parameterization
Using the epsilon-prediction parameterization from Section 3.3:
The difference is:
Therefore:
The Key Result: Minimizing the KL divergence between the true and learned reverse is equivalent to minimizing the MSE between the true noise and the predicted noise, up to a time-dependent weighting.
The Simplified Loss
Ho et al. (2020) found that ignoring the time-dependent weighting and using a simple unweighted loss works better in practice:
The Training Algorithm
- Sample a training image
- Sample a random timestep
- Sample noise
- Create noisy sample:
- Compute loss:
- Take gradient step on
Sampling t Uniformly
Why the Simple Loss Works
The simple loss ignores the weighting . Why does this work?
Empirical Observations
- Sample quality: Unweighted loss produces better perceptual quality (FID scores)
- Stability: Avoids numerical issues at extreme timesteps
- Simplicity: One less hyperparameter to tune
Theoretical Interpretation
The simple loss can be viewed as:
- Denoising score matching: Learning the score function at all noise levels equally
- Re-weighted VLB: A different weighting that emphasizes mid-level noise
- Multi-scale denoising: Training a hierarchy of denoisers
| Loss Type | Weighting | Emphasis |
|---|---|---|
| VLB (L_{t-1}) | beta_t^2 / (sigma_t^2 alpha_t (1-alpha-bar_t)) | Varies with schedule |
| Simple (L_simple) | 1 (uniform) | All timesteps equally |
| SNR-weighted | SNR(t) = alpha-bar_t / (1-alpha-bar_t) | High SNR timesteps |
Practical Insight: While the VLB is theoretically motivated, the simple loss works better in practice for perceptual quality. The VLB optimizes likelihood, but likelihood doesn't always correlate with visual quality.
PyTorch Implementation
Let's implement the complete training objective:
1import torch
2import torch.nn as nn
3from typing import Optional, Dict
4
5class DiffusionLoss:
6 """Computes the diffusion training loss."""
7
8 def __init__(
9 self,
10 betas: torch.Tensor,
11 loss_type: str = "simple", # "simple", "vlb", or "hybrid"
12 vlb_weight: float = 0.001
13 ):
14 """
15 Args:
16 betas: Noise schedule (T,)
17 loss_type: Type of loss to compute
18 vlb_weight: Weight for VLB term in hybrid loss
19 """
20 self.loss_type = loss_type
21 self.vlb_weight = vlb_weight
22
23 # Precompute schedule values
24 self.betas = betas
25 self.alphas = 1.0 - betas
26 self.alphas_bar = torch.cumprod(self.alphas, dim=0)
27 self.alphas_bar_prev = torch.cat([
28 torch.tensor([1.0]),
29 self.alphas_bar[:-1]
30 ])
31
32 # sqrt values for forward process
33 self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
34 self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alphas_bar)
35
36 # Posterior variance
37 self.posterior_variance = (
38 (1.0 - self.alphas_bar_prev) /
39 (1.0 - self.alphas_bar) *
40 self.betas
41 )
42
43 # VLB weights: beta_t^2 / (2 * sigma_t^2 * alpha_t * (1 - alpha_bar_t))
44 # Using sigma_t^2 = beta_t (learned variance often uses beta_t)
45 self.vlb_weights = (
46 self.betas /
47 (2.0 * self.betas * self.alphas * (1.0 - self.alphas_bar))
48 )
49 # Handle edge cases
50 self.vlb_weights[0] = self.vlb_weights[1]
51
52 def _extract(self, tensor: torch.Tensor, t: torch.Tensor, x: torch.Tensor):
53 """Extract values and reshape for broadcasting."""
54 values = tensor.to(x.device)[t]
55 while values.dim() < x.dim():
56 values = values.unsqueeze(-1)
57 return values
58
59 def q_sample(
60 self,
61 x_start: torch.Tensor,
62 t: torch.Tensor,
63 noise: Optional[torch.Tensor] = None
64 ) -> torch.Tensor:
65 """
66 Forward diffusion: sample x_t from q(x_t | x_0).
67
68 x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
69 """
70 if noise is None:
71 noise = torch.randn_like(x_start)
72
73 sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_start)
74 sqrt_one_minus_alpha_bar = self._extract(
75 self.sqrt_one_minus_alphas_bar, t, x_start
76 )
77
78 return sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise
79
80 def compute_loss(
81 self,
82 model: nn.Module,
83 x_start: torch.Tensor,
84 t: Optional[torch.Tensor] = None,
85 noise: Optional[torch.Tensor] = None
86 ) -> Dict[str, torch.Tensor]:
87 """
88 Compute the training loss.
89
90 Args:
91 model: The noise prediction network
92 x_start: Clean images (B, C, H, W)
93 t: Timesteps (B,). If None, sampled uniformly.
94 noise: Noise to add (B, C, H, W). If None, sampled.
95
96 Returns:
97 Dictionary with loss terms
98 """
99 B = x_start.shape[0]
100 device = x_start.device
101 T = len(self.betas)
102
103 # Sample timesteps if not provided
104 if t is None:
105 t = torch.randint(0, T, (B,), device=device)
106
107 # Sample noise if not provided
108 if noise is None:
109 noise = torch.randn_like(x_start)
110
111 # Create noisy samples
112 x_t = self.q_sample(x_start, t, noise)
113
114 # Get model prediction
115 pred_noise = model(x_t, t)
116
117 # Simple loss: unweighted MSE
118 mse_loss = (noise - pred_noise).pow(2).mean(dim=[1, 2, 3])
119
120 if self.loss_type == "simple":
121 loss = mse_loss.mean()
122 return {"loss": loss, "mse": mse_loss.mean()}
123
124 elif self.loss_type == "vlb":
125 # Weighted by VLB coefficients
126 weights = self._extract(self.vlb_weights, t, x_start).squeeze()
127 weighted_loss = (weights * mse_loss).mean()
128 return {"loss": weighted_loss, "mse": mse_loss.mean()}
129
130 elif self.loss_type == "hybrid":
131 # Combination of simple and VLB
132 simple_loss = mse_loss.mean()
133 weights = self._extract(self.vlb_weights, t, x_start).squeeze()
134 vlb_loss = (weights * mse_loss).mean()
135 loss = simple_loss + self.vlb_weight * vlb_loss
136 return {
137 "loss": loss,
138 "simple": simple_loss,
139 "vlb": vlb_loss,
140 "mse": mse_loss.mean()
141 }
142
143 else:
144 raise ValueError(f"Unknown loss type: {self.loss_type}")
145
146
147class SimpleDiffusionTrainer:
148 """Simple training loop for diffusion models."""
149
150 def __init__(
151 self,
152 model: nn.Module,
153 betas: torch.Tensor,
154 lr: float = 2e-4,
155 loss_type: str = "simple"
156 ):
157 self.model = model
158 self.loss_fn = DiffusionLoss(betas, loss_type)
159 self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
160
161 def train_step(self, x_start: torch.Tensor) -> Dict[str, float]:
162 """Single training step."""
163 self.model.train()
164 self.optimizer.zero_grad()
165
166 losses = self.loss_fn.compute_loss(self.model, x_start)
167 losses["loss"].backward()
168 self.optimizer.step()
169
170 return {k: v.item() for k, v in losses.items()}
171
172
173# Example usage
174if __name__ == "__main__":
175 # Simple test network
176 class SimpleNoisePredictor(nn.Module):
177 def __init__(self, channels=3, dim=64):
178 super().__init__()
179 self.time_embed = nn.Embedding(1000, dim)
180 self.net = nn.Sequential(
181 nn.Conv2d(channels, dim, 3, padding=1),
182 nn.ReLU(),
183 nn.Conv2d(dim, dim, 3, padding=1),
184 nn.ReLU(),
185 nn.Conv2d(dim, channels, 3, padding=1),
186 )
187
188 def forward(self, x, t):
189 # Very simple - real U-Net would be much more complex
190 time_emb = self.time_embed(t)[:, :, None, None]
191 return self.net(x) + time_emb.expand(-1, -1, x.shape[2], x.shape[3])[:, :3]
192
193 # Setup
194 T = 1000
195 betas = torch.linspace(0.0001, 0.02, T)
196 model = SimpleNoisePredictor()
197
198 # Test loss computation
199 loss_fn = DiffusionLoss(betas, loss_type="simple")
200 x_start = torch.randn(4, 3, 32, 32)
201 losses = loss_fn.compute_loss(model, x_start)
202
203 print(f"Simple loss: {losses['loss'].item():.4f}")
204 print(f"MSE: {losses['mse'].item():.4f}")
205
206 # Test hybrid loss
207 loss_fn_hybrid = DiffusionLoss(betas, loss_type="hybrid", vlb_weight=0.001)
208 losses_hybrid = loss_fn_hybrid.compute_loss(model, x_start)
209 print(f"Hybrid loss: {losses_hybrid['loss'].item():.4f}")Key Takeaways
- ELBO derivation: We maximize a lower bound on log p(x_0) using the forward process as variational distribution
- Three terms: Reconstruction (L_0), prior matching (L_T, typically ignored), and denoising matching (L_{t-1})
- KL to MSE: The KL divergence between Gaussians reduces to MSE between means
- Simple loss: works better than weighted VLB in practice
- Training is simple: Sample x_0, t, noise; predict noise; compute MSE
Looking Ahead: In the next section, we'll connect the epsilon-prediction formulation to score matching, revealing that the network is learning the gradient of the log probability density.