Learning Objectives
By the end of this section, you will be able to:
- Understand how the complex ELBO objective simplifies to a simple MSE loss on predicted noise
- Explain why dropping the weighting terms improves sample quality despite being theoretically suboptimal
- Compare uniform vs importance-weighted timestep sampling and their practical implications
- Connect the simplified loss to maximum likelihood estimation and understand the trade-offs
From ELBO to Simple MSE
In the previous chapter, we derived the variational lower bound (ELBO) for diffusion models. The full objective contains time-dependent weighting terms that emerge from the KL divergence decomposition:
where the weights depend on the noise schedule:
The DDPM Insight: Ho et al. (2020) made a surprising discovery: simply ignoring these weights and using uniform weighting leads to better sample quality, even though it deviates from the variational objective.
The Simplified Loss
The simplified loss that is actually used in practice:
This is remarkably simple: sample a random timestep, add noise to the data, predict the noise, and minimize the MSE. No complicated weighting needed.
| Objective | Weighting | Optimizes |
|---|---|---|
| L_VLB (Full ELBO) | w_t from KL derivation | Log-likelihood bound |
| L_simple (DDPM) | Uniform (w_t = 1) | Sample quality |
| L_hybrid | lambda * L_VLB + L_simple | Both objectives |
Why Simple MSE Works
The success of the simplified loss seems counterintuitive - we are deliberately ignoring the mathematically-derived weighting. Let's understand why this works.
The Signal-to-Noise Perspective
The key insight comes from analyzing what the VLB weights actually do:
where SNR is the signal-to-noise ratio . These weights become very large for small timesteps (where the signal dominates) and small for large timesteps (where noise dominates).
The Problem: High weights on small timesteps means the network focuses on nearly-clean data where prediction is easy. Meanwhile, the challenging high-noise regime gets less attention, hurting generation quality.
What Uniform Weighting Does
Uniform weighting redistributes learning effort more evenly across all timesteps:
- Early timesteps (low noise): Get relatively less attention. These are easy to predict anyway - almost like predicting zero noise.
- Middle timesteps: Get more balanced attention. This is where the interesting structure-preserving denoising happens.
- Late timesteps (high noise): Get more attention relative to VLB. Critical for generation quality since we start from pure noise.
Interactive Visualization: Loss Weighting
Explore how different weighting strategies affect the emphasis placed on each timestep. Notice how uniform weighting is more balanced compared to VLB weighting.
Empirical Observation
Timestep Sampling Strategies
Beyond the loss weighting, how we sample timesteps during training also affects model performance:
Uniform Sampling
The standard approach: sample with equal probability for each timestep.
1def sample_timesteps_uniform(batch_size: int, T: int) -> torch.Tensor:
2 """Sample timesteps uniformly from 1 to T."""
3 return torch.randint(1, T + 1, (batch_size,))Importance Sampling
Sample timesteps proportionally to their loss contribution to reduce variance:
This requires maintaining running estimates of per-timestep loss magnitudes:
1class ImportanceSampler:
2 """Importance sampling for timesteps based on loss history."""
3
4 def __init__(self, T: int, history_size: int = 10):
5 self.T = T
6 self.loss_history = torch.ones(T, history_size) # (T, history_size)
7 self.history_idx = 0
8
9 def update(self, timesteps: torch.Tensor, losses: torch.Tensor):
10 """Update loss history with new observations."""
11 for t, loss in zip(timesteps, losses):
12 self.loss_history[t - 1, self.history_idx] = loss.item()
13 self.history_idx = (self.history_idx + 1) % self.loss_history.shape[1]
14
15 def sample(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
16 """Sample timesteps and return importance weights."""
17 # Compute sampling probabilities from loss history
18 loss_means = self.loss_history.mean(dim=1)
19 probs = loss_means / loss_means.sum()
20
21 # Sample timesteps
22 timesteps = torch.multinomial(probs, batch_size, replacement=True) + 1
23
24 # Compute importance weights for unbiased gradients
25 weights = 1.0 / (self.T * probs[timesteps - 1])
26 weights = weights / weights.mean() # Normalize
27
28 return timesteps, weightsStratified Sampling
Ensure each batch covers the full range of timesteps more evenly:
1def sample_timesteps_stratified(batch_size: int, T: int) -> torch.Tensor:
2 """
3 Stratified sampling: divide [1, T] into batch_size strata
4 and sample one timestep from each stratum.
5 """
6 # Create strata boundaries
7 strata_size = T / batch_size
8
9 # Sample within each stratum
10 u = torch.rand(batch_size) # Uniform in [0, 1)
11 timesteps = (torch.arange(batch_size) * strata_size + u * strata_size).long() + 1
12 timesteps = timesteps.clamp(1, T)
13
14 # Shuffle to avoid ordering bias in batch
15 timesteps = timesteps[torch.randperm(batch_size)]
16
17 return timesteps| Strategy | Pros | Cons |
|---|---|---|
| Uniform | Simple, no overhead | May undersample important regions |
| Importance | Lower gradient variance | Requires loss tracking, overhead |
| Stratified | Better coverage per batch | Slight implementation complexity |
Connection to Maximum Likelihood
Understanding how the simplified loss relates to maximum likelihood training provides deeper insight into what we are optimizing.
The VLB-Likelihood Connection
The ELBO provides a lower bound on the log-likelihood:
Maximizing the ELBO is equivalent to maximizing a lower bound on the log-likelihood. This is the principled variational inference approach.
What Does Simple Loss Optimize?
The simplified loss can be seen as optimizing a reweightedlikelihood objective. Define the reweighted ELBO:
where for all (uniform) instead of the VLB-derived weights. This is no longer a valid bound on the likelihood, but it optimizes for a different objective that empirically produces better samples.
The Quality-Likelihood Trade-off: Models trained with VLB weighting achieve better bits-per-dimension (likelihood) but often produce samples that look worse to humans. Models trained with simple loss produce better-looking samples but have worse likelihood scores.
Hybrid Approaches
Some works combine both objectives:
This allows trading off between sample quality and likelihood. The parameter controls the balance.
1def hybrid_loss(
2 model: nn.Module,
3 x_0: torch.Tensor,
4 noise_schedule: dict,
5 vlb_weight: float = 0.001
6) -> torch.Tensor:
7 """
8 Compute hybrid loss combining simple and VLB objectives.
9
10 Args:
11 model: Noise prediction network
12 x_0: Clean data
13 noise_schedule: Dict with alpha_bar, beta, etc.
14 vlb_weight: Weight for VLB term (typically small)
15
16 Returns:
17 Combined loss value
18 """
19 batch_size = x_0.shape[0]
20 T = len(noise_schedule["beta"])
21
22 # Sample timesteps and noise
23 t = torch.randint(1, T + 1, (batch_size,), device=x_0.device)
24 epsilon = torch.randn_like(x_0)
25
26 # Get schedule values
27 alpha_bar_t = noise_schedule["alpha_bar"][t - 1]
28 beta_t = noise_schedule["beta"][t - 1]
29
30 # Reshape for broadcasting
31 alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1)
32
33 # Create noisy input
34 x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
35
36 # Predict noise
37 epsilon_pred = model(x_t, t)
38
39 # Simple loss (uniform weighting)
40 simple_loss = F.mse_loss(epsilon_pred, epsilon)
41
42 # VLB weighting
43 # w_t = beta_t^2 / (2 * sigma_t^2 * alpha_t * (1 - alpha_bar_t))
44 alpha_t = 1 - beta_t
45 sigma_t_sq = beta_t # Using beta_t as variance
46 vlb_weights = beta_t ** 2 / (2 * sigma_t_sq * alpha_t * (1 - alpha_bar_t.squeeze()))
47
48 # Per-sample MSE with VLB weighting
49 per_sample_mse = ((epsilon_pred - epsilon) ** 2).mean(dim=(1, 2, 3))
50 vlb_loss = (vlb_weights * per_sample_mse).mean()
51
52 return simple_loss + vlb_weight * vlb_lossImplementation
Let's put together a complete, production-ready training loss implementation that supports multiple modes:
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Literal, Optional
5from dataclasses import dataclass
6
7@dataclass
8class LossConfig:
9 """Configuration for diffusion loss computation."""
10 loss_type: Literal["simple", "vlb", "hybrid"] = "simple"
11 vlb_weight: float = 0.001 # Only used for hybrid
12 prediction_type: Literal["epsilon", "x0", "v"] = "epsilon"
13
14class DiffusionLoss(nn.Module):
15 """
16 Flexible diffusion loss supporting multiple objectives.
17
18 This implementation covers:
19 - Simple MSE loss (DDPM default)
20 - VLB-weighted loss (theoretically optimal)
21 - Hybrid loss combining both
22 - Multiple prediction targets (epsilon, x0, v)
23 """
24
25 def __init__(
26 self,
27 config: LossConfig,
28 noise_schedule: dict,
29 ):
30 super().__init__()
31 self.config = config
32
33 # Register schedule buffers
34 self.register_buffer("alpha_bar", noise_schedule["alpha_bar"])
35 self.register_buffer("beta", noise_schedule["beta"])
36 self.register_buffer("alpha", 1 - noise_schedule["beta"])
37
38 # Precompute VLB weights
39 sigma_sq = self.beta
40 vlb_weights = self.beta ** 2 / (
41 2 * sigma_sq * self.alpha * (1 - self.alpha_bar)
42 )
43 # Clamp to avoid numerical issues at t=0
44 vlb_weights = vlb_weights.clamp(max=100.0)
45 self.register_buffer("vlb_weights", vlb_weights)
46
47 def get_target(
48 self,
49 x_0: torch.Tensor,
50 epsilon: torch.Tensor,
51 alpha_bar_t: torch.Tensor,
52 ) -> torch.Tensor:
53 """Get the prediction target based on prediction_type."""
54 if self.config.prediction_type == "epsilon":
55 return epsilon
56 elif self.config.prediction_type == "x0":
57 return x_0
58 elif self.config.prediction_type == "v":
59 # v = sqrt(alpha_bar) * epsilon - sqrt(1 - alpha_bar) * x_0
60 return (
61 torch.sqrt(alpha_bar_t) * epsilon
62 - torch.sqrt(1 - alpha_bar_t) * x_0
63 )
64 else:
65 raise ValueError(f"Unknown prediction type: {self.config.prediction_type}")
66
67 def forward(
68 self,
69 model_output: torch.Tensor,
70 x_0: torch.Tensor,
71 epsilon: torch.Tensor,
72 t: torch.Tensor,
73 ) -> dict[str, torch.Tensor]:
74 """
75 Compute diffusion loss.
76
77 Args:
78 model_output: Network prediction
79 x_0: Clean data
80 epsilon: Added noise
81 t: Timesteps (1-indexed)
82
83 Returns:
84 Dict with 'loss' and optional diagnostic values
85 """
86 # Get target
87 alpha_bar_t = self.alpha_bar[t - 1].view(-1, 1, 1, 1)
88 target = self.get_target(x_0, epsilon, alpha_bar_t)
89
90 # Per-sample MSE
91 per_sample_mse = ((model_output - target) ** 2).mean(dim=(1, 2, 3))
92
93 # Compute losses based on type
94 if self.config.loss_type == "simple":
95 loss = per_sample_mse.mean()
96 return {"loss": loss, "mse": loss}
97
98 elif self.config.loss_type == "vlb":
99 weights = self.vlb_weights[t - 1]
100 loss = (weights * per_sample_mse).mean()
101 return {
102 "loss": loss,
103 "mse": per_sample_mse.mean(),
104 "weighted_mse": loss,
105 }
106
107 elif self.config.loss_type == "hybrid":
108 simple_loss = per_sample_mse.mean()
109 weights = self.vlb_weights[t - 1]
110 vlb_loss = (weights * per_sample_mse).mean()
111 loss = simple_loss + self.config.vlb_weight * vlb_loss
112 return {
113 "loss": loss,
114 "simple_loss": simple_loss,
115 "vlb_loss": vlb_loss,
116 }
117 else:
118 raise ValueError(f"Unknown loss type: {self.config.loss_type}")Training Loop Integration
Here's how to use this loss in a training loop:
1def train_step(
2 model: nn.Module,
3 loss_fn: DiffusionLoss,
4 optimizer: torch.optim.Optimizer,
5 x_0: torch.Tensor,
6 T: int,
7) -> dict:
8 """Single training step for diffusion model."""
9 batch_size = x_0.shape[0]
10 device = x_0.device
11
12 # Sample timesteps uniformly
13 t = torch.randint(1, T + 1, (batch_size,), device=device)
14
15 # Sample noise
16 epsilon = torch.randn_like(x_0)
17
18 # Create noisy input: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
19 alpha_bar_t = loss_fn.alpha_bar[t - 1].view(-1, 1, 1, 1)
20 x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
21
22 # Forward pass
23 model_output = model(x_t, t)
24
25 # Compute loss
26 loss_dict = loss_fn(model_output, x_0, epsilon, t)
27
28 # Backward pass
29 optimizer.zero_grad()
30 loss_dict["loss"].backward()
31 optimizer.step()
32
33 return {k: v.item() for k, v in loss_dict.items()}
34
35
36# Example usage
37if __name__ == "__main__":
38 # Setup
39 T = 1000
40 betas = torch.linspace(1e-4, 0.02, T)
41 alpha_bar = torch.cumprod(1 - betas, dim=0)
42
43 noise_schedule = {
44 "beta": betas,
45 "alpha_bar": alpha_bar,
46 }
47
48 config = LossConfig(loss_type="simple", prediction_type="epsilon")
49 loss_fn = DiffusionLoss(config, noise_schedule)
50
51 # Mock training step
52 model = nn.Identity() # Placeholder
53 x_0 = torch.randn(4, 3, 32, 32) # Batch of images
54
55 print("Loss function configured successfully!")Key Takeaways
- Simple beats complex: The simplified uniform-weighted MSE loss produces better samples than the theoretically-derived VLB weighting
- VLB overweights easy timesteps: The variational weights focus too much on low-noise (easy) timesteps, neglecting the high-noise regime critical for generation
- Timestep sampling matters: Uniform sampling is standard, but importance and stratified sampling can reduce variance
- Quality vs likelihood trade-off: Simple loss optimizes for sample quality rather than likelihood; hybrid approaches can balance both
- Implementation flexibility: A good loss implementation should support multiple objectives and prediction types for experimentation
Looking Ahead: In the next section, we'll explore more sophisticated loss weighting strategies that attempt to get the best of both worlds - better sample quality while retaining some of the theoretical benefits of proper likelihood optimization.