Chapter 3
25 min read
Section 18 of 76

The Training Objective Derivation

The Reverse Diffusion Process

Learning Objectives

By the end of this section, you will be able to:

  1. Derive the Evidence Lower Bound (ELBO) for diffusion models from first principles
  2. Decompose the ELBO into reconstruction, prior, and denoising terms
  3. Simplify to the practical loss: Lsimple=E[ϵϵθ(xt,t)2]L_{\text{simple}} = \mathbb{E}[\|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\|^2]
  4. Understand every step of the derivation

Starting with the Variational Bound

Our goal is to maximize the log-likelihood of the data:

logpθ(x0)\log p_\theta(\mathbf{x}_0)

Since this is intractable (it requires marginalizing over all latent variables x1:T\mathbf{x}_{1:T}), we derive a variational lower bound (ELBO).

Introducing the Variational Distribution

We introduce the forward process q(x1:Tx0)q(\mathbf{x}_{1:T}|\mathbf{x}_0) as our variational distribution:

logpθ(x0)=logpθ(x0:T)dx1:T\log p_\theta(\mathbf{x}_0) = \log \int p_\theta(\mathbf{x}_{0:T}) d\mathbf{x}_{1:T}

Multiplying and dividing by qq:

=logpθ(x0:T)q(x1:Tx0)q(x1:Tx0)dx1:T= \log \int \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} q(\mathbf{x}_{1:T}|\mathbf{x}_0) d\mathbf{x}_{1:T}

By Jensen's inequality (since log is concave):

q(x1:Tx0)logpθ(x0:T)q(x1:Tx0)dx1:T\geq \int q(\mathbf{x}_{1:T}|\mathbf{x}_0) \log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)} d\mathbf{x}_{1:T}

=Eq(x1:Tx0)[logpθ(x0:T)q(x1:Tx0)]=:ELBO= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\left[\log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\right] =: \text{ELBO}

The ELBO: Maximizing this lower bound pushes up the true log-likelihood. The gap between them is the KL divergence between qq and the true posterior.

ELBO Decomposition

Now we expand the joint distributions:

pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)p_\theta(\mathbf{x}_{0:T}) = p(\mathbf{x}_T) \prod_{t=1}^{T} p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)

q(x1:Tx0)=t=1Tq(xtxt1)q(\mathbf{x}_{1:T}|\mathbf{x}_0) = \prod_{t=1}^{T} q(\mathbf{x}_t|\mathbf{x}_{t-1})

Substituting into the ELBO:

ELBO=Eq[logp(xT)t=1Tpθ(xt1xt)t=1Tq(xtxt1)]\text{ELBO} = \mathbb{E}_q\left[\log \frac{p(\mathbf{x}_T) \prod_{t=1}^{T} p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{\prod_{t=1}^{T} q(\mathbf{x}_t|\mathbf{x}_{t-1})}\right]

Rearranging the Terms

After careful algebraic manipulation (using Bayes' rule on some terms), the ELBO decomposes into:

ELBO=Eq[logpθ(x0x1)]L0:ReconstructionDKL(q(xTx0)p(xT))LT:Prior matching\text{ELBO} = \underbrace{\mathbb{E}_q[\log p_\theta(\mathbf{x}_0|\mathbf{x}_1)]}_{L_0: \text{Reconstruction}} - \underbrace{D_{\text{KL}}(q(\mathbf{x}_T|\mathbf{x}_0) \| p(\mathbf{x}_T))}_{L_T: \text{Prior matching}}

t=2TEq[DKL(q(xt1xt,x0)pθ(xt1xt))]Lt1:Denoising matching\quad - \sum_{t=2}^{T} \underbrace{\mathbb{E}_q\left[D_{\text{KL}}(q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) \| p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t))\right]}_{L_{t-1}: \text{Denoising matching}}

TermNameInterpretation
L_0ReconstructionHow well can we reconstruct x_0 from x_1?
L_TPrior matchingHow close is q(x_T|x_0) to pure Gaussian p(x_T)?
L_{t-1}Denoising matchingHow well does p_theta match the true reverse q?

The L_T Term

Since q(xTx0)N(0,I)q(\mathbf{x}_T|\mathbf{x}_0) \approx \mathcal{N}(\mathbf{0}, \mathbf{I}) for large T and p(xT)=N(0,I)p(\mathbf{x}_T) = \mathcal{N}(\mathbf{0}, \mathbf{I}), this term is essentially zero and is typically ignored during training.

The Denoising Term

The key term is Lt1L_{t-1}, which measures how well our learned reverse pθ(xt1xt)p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) matches the true reverse q(xt1xt,x0)q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0).

KL Between Gaussians

Both distributions are Gaussian:

  • True posterior: q(xt1xt,x0)=N(μ~t,β~tI)q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\tilde{\boldsymbol{\mu}}_t, \tilde{\beta}_t\mathbf{I})
  • Learned reverse: pθ(xt1xt)=N(μθ,σt2I)p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) = \mathcal{N}(\boldsymbol{\mu}_\theta, \sigma_t^2\mathbf{I})

The KL divergence between two Gaussians with the same variance simplifies to:

DKL(qpθ)=12σt2μ~tμθ2+CD_{\text{KL}}(q \| p_\theta) = \frac{1}{2\sigma_t^2} \|\tilde{\boldsymbol{\mu}}_t - \boldsymbol{\mu}_\theta\|^2 + C

where C is a constant independent of θ\theta.

Substituting the Parameterization

Using the epsilon-prediction parameterization from Section 3.3:

μ~t=1αt(xtβt1αˉtϵ)\tilde{\boldsymbol{\mu}}_t = \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}\right)

μθ=1αt(xtβt1αˉtϵθ(xt,t))\boldsymbol{\mu}_\theta = \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\right)

The difference is:

μ~tμθ=βtαt1αˉt(ϵθϵ)\tilde{\boldsymbol{\mu}}_t - \boldsymbol{\mu}_\theta = \frac{\beta_t}{\sqrt{\alpha_t}\sqrt{1-\bar{\alpha}_t}}(\boldsymbol{\epsilon}_\theta - \boldsymbol{\epsilon})

Therefore:

Lt1=Ex0,ϵ[βt22σt2αt(1αˉt)ϵϵθ(xt,t)2]L_{t-1} = \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1-\bar{\alpha}_t)} \|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\|^2\right]

The Key Result: Minimizing the KL divergence between the true and learned reverse is equivalent to minimizing the MSE between the true noise and the predicted noise, up to a time-dependent weighting.

The Simplified Loss

Ho et al. (2020) found that ignoring the time-dependent weighting and using a simple unweighted loss works better in practice:

Lsimple=Et,x0,ϵ[ϵϵθ(αˉtx0+1αˉtϵ,t)2]L_{\text{simple}} = \mathbb{E}_{t, \mathbf{x}_0, \boldsymbol{\epsilon}}\left[\|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}, t)\|^2\right]

The Training Algorithm

  1. Sample a training image x0q(x0)\mathbf{x}_0 \sim q(\mathbf{x}_0)
  2. Sample a random timestep tUniform(1,T)t \sim \text{Uniform}(1, T)
  3. Sample noise ϵN(0,I)\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})
  4. Create noisy sample: xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}
  5. Compute loss: ϵϵθ(xt,t)2\|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\|^2
  6. Take gradient step on θ\theta

Sampling t Uniformly

The simple loss samples t uniformly from 1 to T. This is different from importance-weighted sampling, which we'll discuss in Chapter 4. Uniform sampling works well but may under-emphasize certain timesteps.

Why the Simple Loss Works

The simple loss ignores the weighting w(t)=βt22σt2αt(1αˉt)w(t) = \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1-\bar{\alpha}_t)}. Why does this work?

Empirical Observations

  • Sample quality: Unweighted loss produces better perceptual quality (FID scores)
  • Stability: Avoids numerical issues at extreme timesteps
  • Simplicity: One less hyperparameter to tune

Theoretical Interpretation

The simple loss can be viewed as:

  • Denoising score matching: Learning the score function at all noise levels equally
  • Re-weighted VLB: A different weighting that emphasizes mid-level noise
  • Multi-scale denoising: Training a hierarchy of denoisers
Loss TypeWeightingEmphasis
VLB (L_{t-1})beta_t^2 / (sigma_t^2 alpha_t (1-alpha-bar_t))Varies with schedule
Simple (L_simple)1 (uniform)All timesteps equally
SNR-weightedSNR(t) = alpha-bar_t / (1-alpha-bar_t)High SNR timesteps
Practical Insight: While the VLB is theoretically motivated, the simple loss works better in practice for perceptual quality. The VLB optimizes likelihood, but likelihood doesn't always correlate with visual quality.

PyTorch Implementation

Let's implement the complete training objective:

🐍python
1import torch
2import torch.nn as nn
3from typing import Optional, Dict
4
5class DiffusionLoss:
6    """Computes the diffusion training loss."""
7
8    def __init__(
9        self,
10        betas: torch.Tensor,
11        loss_type: str = "simple",  # "simple", "vlb", or "hybrid"
12        vlb_weight: float = 0.001
13    ):
14        """
15        Args:
16            betas: Noise schedule (T,)
17            loss_type: Type of loss to compute
18            vlb_weight: Weight for VLB term in hybrid loss
19        """
20        self.loss_type = loss_type
21        self.vlb_weight = vlb_weight
22
23        # Precompute schedule values
24        self.betas = betas
25        self.alphas = 1.0 - betas
26        self.alphas_bar = torch.cumprod(self.alphas, dim=0)
27        self.alphas_bar_prev = torch.cat([
28            torch.tensor([1.0]),
29            self.alphas_bar[:-1]
30        ])
31
32        # sqrt values for forward process
33        self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
34        self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alphas_bar)
35
36        # Posterior variance
37        self.posterior_variance = (
38            (1.0 - self.alphas_bar_prev) /
39            (1.0 - self.alphas_bar) *
40            self.betas
41        )
42
43        # VLB weights: beta_t^2 / (2 * sigma_t^2 * alpha_t * (1 - alpha_bar_t))
44        # Using sigma_t^2 = beta_t (learned variance often uses beta_t)
45        self.vlb_weights = (
46            self.betas /
47            (2.0 * self.betas * self.alphas * (1.0 - self.alphas_bar))
48        )
49        # Handle edge cases
50        self.vlb_weights[0] = self.vlb_weights[1]
51
52    def _extract(self, tensor: torch.Tensor, t: torch.Tensor, x: torch.Tensor):
53        """Extract values and reshape for broadcasting."""
54        values = tensor.to(x.device)[t]
55        while values.dim() < x.dim():
56            values = values.unsqueeze(-1)
57        return values
58
59    def q_sample(
60        self,
61        x_start: torch.Tensor,
62        t: torch.Tensor,
63        noise: Optional[torch.Tensor] = None
64    ) -> torch.Tensor:
65        """
66        Forward diffusion: sample x_t from q(x_t | x_0).
67
68        x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
69        """
70        if noise is None:
71            noise = torch.randn_like(x_start)
72
73        sqrt_alpha_bar = self._extract(self.sqrt_alphas_bar, t, x_start)
74        sqrt_one_minus_alpha_bar = self._extract(
75            self.sqrt_one_minus_alphas_bar, t, x_start
76        )
77
78        return sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise
79
80    def compute_loss(
81        self,
82        model: nn.Module,
83        x_start: torch.Tensor,
84        t: Optional[torch.Tensor] = None,
85        noise: Optional[torch.Tensor] = None
86    ) -> Dict[str, torch.Tensor]:
87        """
88        Compute the training loss.
89
90        Args:
91            model: The noise prediction network
92            x_start: Clean images (B, C, H, W)
93            t: Timesteps (B,). If None, sampled uniformly.
94            noise: Noise to add (B, C, H, W). If None, sampled.
95
96        Returns:
97            Dictionary with loss terms
98        """
99        B = x_start.shape[0]
100        device = x_start.device
101        T = len(self.betas)
102
103        # Sample timesteps if not provided
104        if t is None:
105            t = torch.randint(0, T, (B,), device=device)
106
107        # Sample noise if not provided
108        if noise is None:
109            noise = torch.randn_like(x_start)
110
111        # Create noisy samples
112        x_t = self.q_sample(x_start, t, noise)
113
114        # Get model prediction
115        pred_noise = model(x_t, t)
116
117        # Simple loss: unweighted MSE
118        mse_loss = (noise - pred_noise).pow(2).mean(dim=[1, 2, 3])
119
120        if self.loss_type == "simple":
121            loss = mse_loss.mean()
122            return {"loss": loss, "mse": mse_loss.mean()}
123
124        elif self.loss_type == "vlb":
125            # Weighted by VLB coefficients
126            weights = self._extract(self.vlb_weights, t, x_start).squeeze()
127            weighted_loss = (weights * mse_loss).mean()
128            return {"loss": weighted_loss, "mse": mse_loss.mean()}
129
130        elif self.loss_type == "hybrid":
131            # Combination of simple and VLB
132            simple_loss = mse_loss.mean()
133            weights = self._extract(self.vlb_weights, t, x_start).squeeze()
134            vlb_loss = (weights * mse_loss).mean()
135            loss = simple_loss + self.vlb_weight * vlb_loss
136            return {
137                "loss": loss,
138                "simple": simple_loss,
139                "vlb": vlb_loss,
140                "mse": mse_loss.mean()
141            }
142
143        else:
144            raise ValueError(f"Unknown loss type: {self.loss_type}")
145
146
147class SimpleDiffusionTrainer:
148    """Simple training loop for diffusion models."""
149
150    def __init__(
151        self,
152        model: nn.Module,
153        betas: torch.Tensor,
154        lr: float = 2e-4,
155        loss_type: str = "simple"
156    ):
157        self.model = model
158        self.loss_fn = DiffusionLoss(betas, loss_type)
159        self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
160
161    def train_step(self, x_start: torch.Tensor) -> Dict[str, float]:
162        """Single training step."""
163        self.model.train()
164        self.optimizer.zero_grad()
165
166        losses = self.loss_fn.compute_loss(self.model, x_start)
167        losses["loss"].backward()
168        self.optimizer.step()
169
170        return {k: v.item() for k, v in losses.items()}
171
172
173# Example usage
174if __name__ == "__main__":
175    # Simple test network
176    class SimpleNoisePredictor(nn.Module):
177        def __init__(self, channels=3, dim=64):
178            super().__init__()
179            self.time_embed = nn.Embedding(1000, dim)
180            self.net = nn.Sequential(
181                nn.Conv2d(channels, dim, 3, padding=1),
182                nn.ReLU(),
183                nn.Conv2d(dim, dim, 3, padding=1),
184                nn.ReLU(),
185                nn.Conv2d(dim, channels, 3, padding=1),
186            )
187
188        def forward(self, x, t):
189            # Very simple - real U-Net would be much more complex
190            time_emb = self.time_embed(t)[:, :, None, None]
191            return self.net(x) + time_emb.expand(-1, -1, x.shape[2], x.shape[3])[:, :3]
192
193    # Setup
194    T = 1000
195    betas = torch.linspace(0.0001, 0.02, T)
196    model = SimpleNoisePredictor()
197
198    # Test loss computation
199    loss_fn = DiffusionLoss(betas, loss_type="simple")
200    x_start = torch.randn(4, 3, 32, 32)
201    losses = loss_fn.compute_loss(model, x_start)
202
203    print(f"Simple loss: {losses['loss'].item():.4f}")
204    print(f"MSE: {losses['mse'].item():.4f}")
205
206    # Test hybrid loss
207    loss_fn_hybrid = DiffusionLoss(betas, loss_type="hybrid", vlb_weight=0.001)
208    losses_hybrid = loss_fn_hybrid.compute_loss(model, x_start)
209    print(f"Hybrid loss: {losses_hybrid['loss'].item():.4f}")

Key Takeaways

  1. ELBO derivation: We maximize a lower bound on log p(x_0) using the forward process as variational distribution
  2. Three terms: Reconstruction (L_0), prior matching (L_T, typically ignored), and denoising matching (L_{t-1})
  3. KL to MSE: The KL divergence between Gaussians reduces to MSE between means
  4. Simple loss: Lsimple=E[ϵϵθ2]L_{\text{simple}} = \mathbb{E}[\|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta\|^2] works better than weighted VLB in practice
  5. Training is simple: Sample x_0, t, noise; predict noise; compute MSE
Looking Ahead: In the next section, we'll connect the epsilon-prediction formulation to score matching, revealing that the network is learning the gradient of the log probability density.