Learning Objectives
By the end of this section, you will be able to:
- Explain why we need to parameterize the reverse process mean
- Compare epsilon-prediction, x0-prediction, and v-prediction parameterizations
- Derive the relationships between different parameterizations
- Choose the appropriate parameterization for different use cases
The Parameterization Problem
From Section 3.2, we know the tractable posterior has mean:
The problem: During generation, we don't have ! We need to learn a neural network to predict something that lets us compute the mean.
Key Insight: We have three equivalent ways to compute the posterior mean - by predicting the noise, the clean image, or a combination. Each has different training dynamics and advantages.
The Learned Reverse Process
We parameterize the reverse process as:
The question is: how should we parameterize ?
Epsilon-Prediction
The most common parameterization, introduced in DDPM (Ho et al., 2020), is to predict the noise that was added.
The Key Relationship
From the forward process, we know:
Rearranging for :
Substituting into the Posterior Mean
Substituting this into the posterior mean formula and simplifying:
Epsilon-Prediction: Train the network to predict the noise . Then compute:
Why Predict Noise?
- Consistent magnitude: The noise has unit variance regardless of timestep
- Denoising interpretation: The network learns to "see" what noise was added
- Connection to score: Predicting noise is equivalent to predicting the score function (Section 3.5)
x0-Prediction
An alternative is to directly predict the clean image :
Computing the Mean
Using the posterior mean formula directly:
Advantages of x0-Prediction
- Interpretable output: The network output is directly the denoised image
- Better for low-step sampling: With few denoising steps, x0-prediction can produce cleaner results
- Useful for guidance: Can apply constraints directly to the predicted
Disadvantages
- Varying target magnitude: The clean image statistics depend on the data distribution
- Harder at high noise: Predicting from pure noise is very difficult
Equivalence
v-Prediction (Velocity)
Introduced by Salimans & Ho (2022), v-prediction uses a combination of and :
Intuition: The Velocity
Think of the forward process as a rotation in the space spanned by and :
where and .
The velocity is the derivative of this path:
Recovering x0 and epsilon
Given the predicted velocity :
Advantages of v-Prediction
- Balanced across timesteps: The velocity has similar magnitude at all timesteps
- Better training stability: Gradients are more consistent
- Improved FID scores: Empirically produces better samples
Comparing Parameterizations
| Aspect | epsilon-prediction | x0-prediction | v-prediction |
|---|---|---|---|
| Output | Noise epsilon | Clean image x_0 | Velocity v |
| Target magnitude | Unit variance (consistent) | Data-dependent (varies) | Approximately unit (consistent) |
| Low noise (t near 0) | Harder (small signal) | Easier (clear target) | Balanced |
| High noise (t near T) | Easier (large signal) | Harder (pure noise input) | Balanced |
| Used in | DDPM, Stable Diffusion | DDIM, Imagen (partial) | Progressive Distillation |
| Loss weighting | Often needs SNR weighting | Often needs SNR weighting | More uniform naturally |
Practical Recommendation: Start with epsilon-prediction (it's the most common). Consider v-prediction for improved stability or when using few sampling steps. Use x0-prediction when you need to apply constraints to the denoised output during sampling.
PyTorch Implementation
Let's implement all three parameterizations:
1import torch
2import torch.nn as nn
3from enum import Enum
4from typing import Tuple
5
6class PredictionType(Enum):
7 EPSILON = "epsilon" # Predict noise
8 X_START = "x_start" # Predict clean image
9 VELOCITY = "velocity" # Predict velocity
10
11class ParameterizationHelper:
12 """Converts between different parameterizations."""
13
14 def __init__(self, alphas_bar: torch.Tensor):
15 """
16 Args:
17 alphas_bar: Cumulative product of alphas (T,)
18 """
19 self.alphas_bar = alphas_bar
20 self.sqrt_alphas_bar = torch.sqrt(alphas_bar)
21 self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - alphas_bar)
22
23 def _extract(self, tensor: torch.Tensor, t: torch.Tensor, x: torch.Tensor):
24 """Extract values for batch and reshape for broadcasting."""
25 values = tensor[t]
26 while values.dim() < x.dim():
27 values = values.unsqueeze(-1)
28 return values
29
30 def predict_x_start_from_eps(
31 self,
32 x_t: torch.Tensor,
33 t: torch.Tensor,
34 eps: torch.Tensor
35 ) -> torch.Tensor:
36 """
37 x_0 = (x_t - sqrt(1-alpha_bar) * eps) / sqrt(alpha_bar)
38 """
39 sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_t)
40 sqrt_one_minus_alpha_bar = self._extract(
41 self.sqrt_one_minus_alphas_bar, t, x_t
42 )
43 return (x_t - sqrt_one_minus_alpha_bar * eps) / sqrt_alpha_bar
44
45 def predict_eps_from_x_start(
46 self,
47 x_t: torch.Tensor,
48 t: torch.Tensor,
49 x_start: torch.Tensor
50 ) -> torch.Tensor:
51 """
52 eps = (x_t - sqrt(alpha_bar) * x_0) / sqrt(1-alpha_bar)
53 """
54 sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_t)
55 sqrt_one_minus_alpha_bar = self._extract(
56 self.sqrt_one_minus_alphas_bar, t, x_t
57 )
58 return (x_t - sqrt_alpha_bar * x_start) / sqrt_one_minus_alpha_bar
59
60 def predict_x_start_from_v(
61 self,
62 x_t: torch.Tensor,
63 t: torch.Tensor,
64 v: torch.Tensor
65 ) -> torch.Tensor:
66 """
67 x_0 = sqrt(alpha_bar) * x_t - sqrt(1-alpha_bar) * v
68 """
69 sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_t)
70 sqrt_one_minus_alpha_bar = self._extract(
71 self.sqrt_one_minus_alphas_bar, t, x_t
72 )
73 return sqrt_alpha_bar * x_t - sqrt_one_minus_alpha_bar * v
74
75 def predict_eps_from_v(
76 self,
77 x_t: torch.Tensor,
78 t: torch.Tensor,
79 v: torch.Tensor
80 ) -> torch.Tensor:
81 """
82 eps = sqrt(1-alpha_bar) * x_t + sqrt(alpha_bar) * v
83 """
84 sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_t)
85 sqrt_one_minus_alpha_bar = self._extract(
86 self.sqrt_one_minus_alphas_bar, t, x_t
87 )
88 return sqrt_one_minus_alpha_bar * x_t + sqrt_alpha_bar * v
89
90 def predict_v_from_x_start_and_eps(
91 self,
92 x_start: torch.Tensor,
93 eps: torch.Tensor,
94 t: torch.Tensor
95 ) -> torch.Tensor:
96 """
97 v = sqrt(alpha_bar) * eps - sqrt(1-alpha_bar) * x_0
98 """
99 sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_start)
100 sqrt_one_minus_alpha_bar = self._extract(
101 self.sqrt_one_minus_alphas_bar, t, x_start
102 )
103 return sqrt_alpha_bar * eps - sqrt_one_minus_alpha_bar * x_start
104
105 def get_x_start_and_eps(
106 self,
107 x_t: torch.Tensor,
108 t: torch.Tensor,
109 model_output: torch.Tensor,
110 prediction_type: PredictionType
111 ) -> Tuple[torch.Tensor, torch.Tensor]:
112 """
113 Convert any prediction type to x_start and eps.
114
115 Returns:
116 (x_start, eps): Both derived from model output
117 """
118 if prediction_type == PredictionType.EPSILON:
119 eps = model_output
120 x_start = self.predict_x_start_from_eps(x_t, t, eps)
121
122 elif prediction_type == PredictionType.X_START:
123 x_start = model_output
124 eps = self.predict_eps_from_x_start(x_t, t, x_start)
125
126 elif prediction_type == PredictionType.VELOCITY:
127 v = model_output
128 x_start = self.predict_x_start_from_v(x_t, t, v)
129 eps = self.predict_eps_from_v(x_t, t, v)
130
131 else:
132 raise ValueError(f"Unknown prediction type: {prediction_type}")
133
134 return x_start, eps
135
136
137class DiffusionModel(nn.Module):
138 """Diffusion model with configurable parameterization."""
139
140 def __init__(
141 self,
142 network: nn.Module,
143 betas: torch.Tensor,
144 prediction_type: PredictionType = PredictionType.EPSILON
145 ):
146 super().__init__()
147 self.network = network
148 self.prediction_type = prediction_type
149
150 # Register buffers
151 alphas = 1.0 - betas
152 alphas_bar = torch.cumprod(alphas, dim=0)
153 self.register_buffer("betas", betas)
154 self.register_buffer("alphas_bar", alphas_bar)
155
156 self.param_helper = ParameterizationHelper(alphas_bar)
157
158 def get_training_target(
159 self,
160 x_start: torch.Tensor,
161 eps: torch.Tensor,
162 t: torch.Tensor
163 ) -> torch.Tensor:
164 """Get the target for the given prediction type."""
165 if self.prediction_type == PredictionType.EPSILON:
166 return eps
167 elif self.prediction_type == PredictionType.X_START:
168 return x_start
169 elif self.prediction_type == PredictionType.VELOCITY:
170 return self.param_helper.predict_v_from_x_start_and_eps(
171 x_start, eps, t
172 )
173
174 def training_step(
175 self,
176 x_start: torch.Tensor,
177 t: torch.Tensor
178 ) -> torch.Tensor:
179 """
180 Compute training loss.
181
182 Args:
183 x_start: Clean images (B, C, H, W)
184 t: Timesteps (B,)
185
186 Returns:
187 MSE loss
188 """
189 # Sample noise
190 eps = torch.randn_like(x_start)
191
192 # Create noisy sample
193 sqrt_alpha_bar = self.param_helper._extract(
194 self.param_helper.sqrt_alphas_bar, t, x_start
195 )
196 sqrt_one_minus_alpha_bar = self.param_helper._extract(
197 self.param_helper.sqrt_one_minus_alphas_bar, t, x_start
198 )
199 x_t = sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * eps
200
201 # Get model prediction
202 model_output = self.network(x_t, t)
203
204 # Get target
205 target = self.get_training_target(x_start, eps, t)
206
207 # Compute loss
208 loss = nn.functional.mse_loss(model_output, target)
209 return loss
210
211
212# Example usage
213if __name__ == "__main__":
214 # Create schedule
215 T = 1000
216 betas = torch.linspace(0.0001, 0.02, T)
217 alphas_bar = torch.cumprod(1 - betas, dim=0)
218
219 helper = ParameterizationHelper(alphas_bar)
220
221 # Test conversions
222 x_t = torch.randn(2, 3, 32, 32)
223 t = torch.tensor([100, 500])
224 eps = torch.randn_like(x_t)
225
226 # Round-trip test: eps -> x_start -> eps
227 x_start = helper.predict_x_start_from_eps(x_t, t, eps)
228 eps_reconstructed = helper.predict_eps_from_x_start(x_t, t, x_start)
229 print(f"Eps reconstruction error: {(eps - eps_reconstructed).abs().max():.6f}")
230
231 # Test velocity parameterization
232 v = helper.predict_v_from_x_start_and_eps(x_start, eps, t)
233 x_start_from_v = helper.predict_x_start_from_v(x_t, t, v)
234 print(f"x_start from v error: {(x_start - x_start_from_v).abs().max():.6f}")Key Takeaways
- Three equivalent parameterizations: epsilon-prediction, x0-prediction, and v-prediction all allow computing the posterior mean
- Epsilon-prediction (DDPM default): Predicts the noise, consistent magnitude across timesteps, connects to score matching
- x0-prediction: Directly predicts the clean image, interpretable but harder at high noise levels
- v-prediction: Balanced approach, better training stability, especially useful for distillation
- Conversion formulas: You can always convert between parameterizations using the closed-form relationships
Looking Ahead: In the next section, we'll derive the training objective that trains the network to match the true posterior, showing why the simple MSE loss works so well.