Chapter 3
18 min read
Section 17 of 76

Parameterizing the Reverse Process

The Reverse Diffusion Process

Learning Objectives

By the end of this section, you will be able to:

  1. Explain why we need to parameterize the reverse process mean
  2. Compare epsilon-prediction, x0-prediction, and v-prediction parameterizations
  3. Derive the relationships between different parameterizations
  4. Choose the appropriate parameterization for different use cases

The Parameterization Problem

From Section 3.2, we know the tractable posterior has mean:

μ~t(xt,x0)=αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt\tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} \mathbf{x}_0 + \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t

The problem: During generation, we don't have x0\mathbf{x}_0! We need to learn a neural network to predict something that lets us compute the mean.

Key Insight: We have three equivalent ways to compute the posterior mean - by predicting the noise, the clean image, or a combination. Each has different training dynamics and advantages.

The Learned Reverse Process

We parameterize the reverse process as:

pθ(xt1xt)=N(xt1;μθ(xt,t),σt2I)p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \sigma_t^2\mathbf{I})

The question is: how should we parameterize μθ\boldsymbol{\mu}_\theta?


Epsilon-Prediction

The most common parameterization, introduced in DDPM (Ho et al., 2020), is to predict the noise that was added.

The Key Relationship

From the forward process, we know:

xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}

Rearranging for x0\mathbf{x}_0:

x0=xt1αˉtϵαˉt\mathbf{x}_0 = \frac{\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}}{\sqrt{\bar{\alpha}_t}}

Substituting into the Posterior Mean

Substituting this into the posterior mean formula and simplifying:

μ~t=1αt(xtβt1αˉtϵ)\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}\right)

Epsilon-Prediction: Train the network ϵθ(xt,t)\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) to predict the noise ϵ\boldsymbol{\epsilon}. Then compute:

μθ(xt,t)=1αt(xtβt1αˉtϵθ(xt,t))\boldsymbol{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\right)

Why Predict Noise?

  • Consistent magnitude: The noise ϵN(0,I)\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})has unit variance regardless of timestep
  • Denoising interpretation: The network learns to "see" what noise was added
  • Connection to score: Predicting noise is equivalent to predicting the score function (Section 3.5)

x0-Prediction

An alternative is to directly predict the clean image x0\mathbf{x}_0:

x^0=fθ(xt,t)\hat{\mathbf{x}}_0 = f_\theta(\mathbf{x}_t, t)

Computing the Mean

Using the posterior mean formula directly:

μθ(xt,t)=αˉt1βt1αˉtx^0+αt(1αˉt1)1αˉtxt\boldsymbol{\mu}_\theta(\mathbf{x}_t, t) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} \hat{\mathbf{x}}_0 + \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t

Advantages of x0-Prediction

  • Interpretable output: The network output is directly the denoised image
  • Better for low-step sampling: With few denoising steps, x0-prediction can produce cleaner results
  • Useful for guidance: Can apply constraints directly to the predicted x^0\hat{\mathbf{x}}_0

Disadvantages

  • Varying target magnitude: The clean image statistics depend on the data distribution
  • Harder at high noise: Predicting x0\mathbf{x}_0 from pure noise is very difficult

Equivalence

x0-prediction and epsilon-prediction are mathematically equivalent. Given one, you can compute the other:

x^0=xt1αˉtϵθαˉt\hat{\mathbf{x}}_0 = \frac{\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_\theta}{\sqrt{\bar{\alpha}_t}}

ϵθ=xtαˉtx^01αˉt\boldsymbol{\epsilon}_\theta = \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\hat{\mathbf{x}}_0}{\sqrt{1-\bar{\alpha}_t}}

v-Prediction (Velocity)

Introduced by Salimans & Ho (2022), v-prediction uses a combination of x0\mathbf{x}_0 and ϵ\boldsymbol{\epsilon}:

v=αˉtϵ1αˉtx0\mathbf{v} = \sqrt{\bar{\alpha}_t}\boldsymbol{\epsilon} - \sqrt{1-\bar{\alpha}_t}\mathbf{x}_0

Intuition: The Velocity

Think of the forward process as a rotation in the space spanned by x0\mathbf{x}_0 and ϵ\boldsymbol{\epsilon}:

xt=cos(ϕt)x0+sin(ϕt)ϵ\mathbf{x}_t = \cos(\phi_t)\mathbf{x}_0 + \sin(\phi_t)\boldsymbol{\epsilon}

where cos(ϕt)=αˉt\cos(\phi_t) = \sqrt{\bar{\alpha}_t} and sin(ϕt)=1αˉt\sin(\phi_t) = \sqrt{1-\bar{\alpha}_t}.

The velocity v\mathbf{v} is the derivative of this path:

v=dxtdϕt=sin(ϕt)x0+cos(ϕt)ϵ\mathbf{v} = \frac{d\mathbf{x}_t}{d\phi_t} = -\sin(\phi_t)\mathbf{x}_0 + \cos(\phi_t)\boldsymbol{\epsilon}

Recovering x0 and epsilon

Given the predicted velocity vθ(xt,t)\mathbf{v}_\theta(\mathbf{x}_t, t):

x^0=αˉtxt1αˉtvθ\hat{\mathbf{x}}_0 = \sqrt{\bar{\alpha}_t}\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\mathbf{v}_\theta

ϵ^=1αˉtxt+αˉtvθ\hat{\boldsymbol{\epsilon}} = \sqrt{1-\bar{\alpha}_t}\mathbf{x}_t + \sqrt{\bar{\alpha}_t}\mathbf{v}_\theta

Advantages of v-Prediction

  • Balanced across timesteps: The velocity has similar magnitude at all timesteps
  • Better training stability: Gradients are more consistent
  • Improved FID scores: Empirically produces better samples

Comparing Parameterizations

Aspectepsilon-predictionx0-predictionv-prediction
OutputNoise epsilonClean image x_0Velocity v
Target magnitudeUnit variance (consistent)Data-dependent (varies)Approximately unit (consistent)
Low noise (t near 0)Harder (small signal)Easier (clear target)Balanced
High noise (t near T)Easier (large signal)Harder (pure noise input)Balanced
Used inDDPM, Stable DiffusionDDIM, Imagen (partial)Progressive Distillation
Loss weightingOften needs SNR weightingOften needs SNR weightingMore uniform naturally
Practical Recommendation: Start with epsilon-prediction (it's the most common). Consider v-prediction for improved stability or when using few sampling steps. Use x0-prediction when you need to apply constraints to the denoised output during sampling.

PyTorch Implementation

Let's implement all three parameterizations:

🐍python
1import torch
2import torch.nn as nn
3from enum import Enum
4from typing import Tuple
5
6class PredictionType(Enum):
7    EPSILON = "epsilon"      # Predict noise
8    X_START = "x_start"      # Predict clean image
9    VELOCITY = "velocity"    # Predict velocity
10
11class ParameterizationHelper:
12    """Converts between different parameterizations."""
13
14    def __init__(self, alphas_bar: torch.Tensor):
15        """
16        Args:
17            alphas_bar: Cumulative product of alphas (T,)
18        """
19        self.alphas_bar = alphas_bar
20        self.sqrt_alphas_bar = torch.sqrt(alphas_bar)
21        self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - alphas_bar)
22
23    def _extract(self, tensor: torch.Tensor, t: torch.Tensor, x: torch.Tensor):
24        """Extract values for batch and reshape for broadcasting."""
25        values = tensor[t]
26        while values.dim() < x.dim():
27            values = values.unsqueeze(-1)
28        return values
29
30    def predict_x_start_from_eps(
31        self,
32        x_t: torch.Tensor,
33        t: torch.Tensor,
34        eps: torch.Tensor
35    ) -> torch.Tensor:
36        """
37        x_0 = (x_t - sqrt(1-alpha_bar) * eps) / sqrt(alpha_bar)
38        """
39        sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_t)
40        sqrt_one_minus_alpha_bar = self._extract(
41            self.sqrt_one_minus_alphas_bar, t, x_t
42        )
43        return (x_t - sqrt_one_minus_alpha_bar * eps) / sqrt_alpha_bar
44
45    def predict_eps_from_x_start(
46        self,
47        x_t: torch.Tensor,
48        t: torch.Tensor,
49        x_start: torch.Tensor
50    ) -> torch.Tensor:
51        """
52        eps = (x_t - sqrt(alpha_bar) * x_0) / sqrt(1-alpha_bar)
53        """
54        sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_t)
55        sqrt_one_minus_alpha_bar = self._extract(
56            self.sqrt_one_minus_alphas_bar, t, x_t
57        )
58        return (x_t - sqrt_alpha_bar * x_start) / sqrt_one_minus_alpha_bar
59
60    def predict_x_start_from_v(
61        self,
62        x_t: torch.Tensor,
63        t: torch.Tensor,
64        v: torch.Tensor
65    ) -> torch.Tensor:
66        """
67        x_0 = sqrt(alpha_bar) * x_t - sqrt(1-alpha_bar) * v
68        """
69        sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_t)
70        sqrt_one_minus_alpha_bar = self._extract(
71            self.sqrt_one_minus_alphas_bar, t, x_t
72        )
73        return sqrt_alpha_bar * x_t - sqrt_one_minus_alpha_bar * v
74
75    def predict_eps_from_v(
76        self,
77        x_t: torch.Tensor,
78        t: torch.Tensor,
79        v: torch.Tensor
80    ) -> torch.Tensor:
81        """
82        eps = sqrt(1-alpha_bar) * x_t + sqrt(alpha_bar) * v
83        """
84        sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_t)
85        sqrt_one_minus_alpha_bar = self._extract(
86            self.sqrt_one_minus_alphas_bar, t, x_t
87        )
88        return sqrt_one_minus_alpha_bar * x_t + sqrt_alpha_bar * v
89
90    def predict_v_from_x_start_and_eps(
91        self,
92        x_start: torch.Tensor,
93        eps: torch.Tensor,
94        t: torch.Tensor
95    ) -> torch.Tensor:
96        """
97        v = sqrt(alpha_bar) * eps - sqrt(1-alpha_bar) * x_0
98        """
99        sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_start)
100        sqrt_one_minus_alpha_bar = self._extract(
101            self.sqrt_one_minus_alphas_bar, t, x_start
102        )
103        return sqrt_alpha_bar * eps - sqrt_one_minus_alpha_bar * x_start
104
105    def get_x_start_and_eps(
106        self,
107        x_t: torch.Tensor,
108        t: torch.Tensor,
109        model_output: torch.Tensor,
110        prediction_type: PredictionType
111    ) -> Tuple[torch.Tensor, torch.Tensor]:
112        """
113        Convert any prediction type to x_start and eps.
114
115        Returns:
116            (x_start, eps): Both derived from model output
117        """
118        if prediction_type == PredictionType.EPSILON:
119            eps = model_output
120            x_start = self.predict_x_start_from_eps(x_t, t, eps)
121
122        elif prediction_type == PredictionType.X_START:
123            x_start = model_output
124            eps = self.predict_eps_from_x_start(x_t, t, x_start)
125
126        elif prediction_type == PredictionType.VELOCITY:
127            v = model_output
128            x_start = self.predict_x_start_from_v(x_t, t, v)
129            eps = self.predict_eps_from_v(x_t, t, v)
130
131        else:
132            raise ValueError(f"Unknown prediction type: {prediction_type}")
133
134        return x_start, eps
135
136
137class DiffusionModel(nn.Module):
138    """Diffusion model with configurable parameterization."""
139
140    def __init__(
141        self,
142        network: nn.Module,
143        betas: torch.Tensor,
144        prediction_type: PredictionType = PredictionType.EPSILON
145    ):
146        super().__init__()
147        self.network = network
148        self.prediction_type = prediction_type
149
150        # Register buffers
151        alphas = 1.0 - betas
152        alphas_bar = torch.cumprod(alphas, dim=0)
153        self.register_buffer("betas", betas)
154        self.register_buffer("alphas_bar", alphas_bar)
155
156        self.param_helper = ParameterizationHelper(alphas_bar)
157
158    def get_training_target(
159        self,
160        x_start: torch.Tensor,
161        eps: torch.Tensor,
162        t: torch.Tensor
163    ) -> torch.Tensor:
164        """Get the target for the given prediction type."""
165        if self.prediction_type == PredictionType.EPSILON:
166            return eps
167        elif self.prediction_type == PredictionType.X_START:
168            return x_start
169        elif self.prediction_type == PredictionType.VELOCITY:
170            return self.param_helper.predict_v_from_x_start_and_eps(
171                x_start, eps, t
172            )
173
174    def training_step(
175        self,
176        x_start: torch.Tensor,
177        t: torch.Tensor
178    ) -> torch.Tensor:
179        """
180        Compute training loss.
181
182        Args:
183            x_start: Clean images (B, C, H, W)
184            t: Timesteps (B,)
185
186        Returns:
187            MSE loss
188        """
189        # Sample noise
190        eps = torch.randn_like(x_start)
191
192        # Create noisy sample
193        sqrt_alpha_bar = self.param_helper._extract(
194            self.param_helper.sqrt_alphas_bar, t, x_start
195        )
196        sqrt_one_minus_alpha_bar = self.param_helper._extract(
197            self.param_helper.sqrt_one_minus_alphas_bar, t, x_start
198        )
199        x_t = sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * eps
200
201        # Get model prediction
202        model_output = self.network(x_t, t)
203
204        # Get target
205        target = self.get_training_target(x_start, eps, t)
206
207        # Compute loss
208        loss = nn.functional.mse_loss(model_output, target)
209        return loss
210
211
212# Example usage
213if __name__ == "__main__":
214    # Create schedule
215    T = 1000
216    betas = torch.linspace(0.0001, 0.02, T)
217    alphas_bar = torch.cumprod(1 - betas, dim=0)
218
219    helper = ParameterizationHelper(alphas_bar)
220
221    # Test conversions
222    x_t = torch.randn(2, 3, 32, 32)
223    t = torch.tensor([100, 500])
224    eps = torch.randn_like(x_t)
225
226    # Round-trip test: eps -> x_start -> eps
227    x_start = helper.predict_x_start_from_eps(x_t, t, eps)
228    eps_reconstructed = helper.predict_eps_from_x_start(x_t, t, x_start)
229    print(f"Eps reconstruction error: {(eps - eps_reconstructed).abs().max():.6f}")
230
231    # Test velocity parameterization
232    v = helper.predict_v_from_x_start_and_eps(x_start, eps, t)
233    x_start_from_v = helper.predict_x_start_from_v(x_t, t, v)
234    print(f"x_start from v error: {(x_start - x_start_from_v).abs().max():.6f}")

Key Takeaways

  1. Three equivalent parameterizations: epsilon-prediction, x0-prediction, and v-prediction all allow computing the posterior mean
  2. Epsilon-prediction (DDPM default): Predicts the noise, consistent magnitude across timesteps, connects to score matching
  3. x0-prediction: Directly predicts the clean image, interpretable but harder at high noise levels
  4. v-prediction: Balanced approach, better training stability, especially useful for distillation
  5. Conversion formulas: You can always convert between parameterizations using the closed-form relationships
Looking Ahead: In the next section, we'll derive the training objective that trains the network to match the true posterior, showing why the simple MSE loss works so well.