Chapter 3
25 min read
Section 16 of 76

Deriving the True Reverse Distribution

The Reverse Diffusion Process

Learning Objectives

By the end of this section, you will be able to:

  1. Derive the tractable posterior q(xt1xt,x0)q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) using Bayes' theorem
  2. Compute the posterior mean μ~t(xt,x0)\tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) and variance β~t\tilde{\beta}_t
  3. Understand why conditioning on x0\mathbf{x}_0 transforms an intractable problem into a tractable one
  4. Implement the posterior formulas in PyTorch

The Bayesian Setup

In Section 3.1, we saw that the true reverse q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_t) is intractable because it requires marginalizing over the unknown data distribution. However, if we condition on knowing the original data x0\mathbf{x}_0, the situation changes dramatically.

The Key Insight: Given x0\mathbf{x}_0, both the forward marginal q(xtx0)q(\mathbf{x}_t|\mathbf{x}_0) and the transition q(xtxt1)q(\mathbf{x}_t|\mathbf{x}_{t-1}) are Gaussian. Bayes' theorem applied to Gaussians yields another Gaussian!

We want to compute:

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) = \frac{q(\mathbf{x}_t|\mathbf{x}_{t-1}, \mathbf{x}_0) \cdot q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)}

What We Know

From Chapter 2, we have closed-form expressions for all terms:

DistributionFormulaMeanVariance
q(x_t|x_{t-1})N(sqrt(1-beta_t) x_{t-1}, beta_t I)sqrt(1-beta_t) x_{t-1}beta_t
q(x_{t-1}|x_0)N(sqrt(alpha-bar_{t-1}) x_0, (1-alpha-bar_{t-1}) I)sqrt(alpha-bar_{t-1}) x_01 - alpha-bar_{t-1}
q(x_t|x_0)N(sqrt(alpha-bar_t) x_0, (1-alpha-bar_t) I)sqrt(alpha-bar_t) x_01 - alpha-bar_t

Markov Property

Note that q(xtxt1,x0)=q(xtxt1)q(\mathbf{x}_t|\mathbf{x}_{t-1}, \mathbf{x}_0) = q(\mathbf{x}_t|\mathbf{x}_{t-1}) due to the Markov property - the forward process only depends on the previous step.

Deriving the Posterior

Let's work through the derivation step by step. For clarity, we'll work with a single dimension; the multivariate case follows identically due to the diagonal covariance structure.

Step 1: Write Out the Gaussians

Using Bayes' theorem:

q(xt1xt,x0)q(xtxt1)q(xt1x0)q(x_{t-1}|x_t, x_0) \propto q(x_t|x_{t-1}) \cdot q(x_{t-1}|x_0)

The Gaussian PDFs are:

q(xtxt1)exp((xt1βtxt1)22βt)q(x_t|x_{t-1}) \propto \exp\left(-\frac{(x_t - \sqrt{1-\beta_t}x_{t-1})^2}{2\beta_t}\right)

q(xt1x0)exp((xt1αˉt1x0)22(1αˉt1))q(x_{t-1}|x_0) \propto \exp\left(-\frac{(x_{t-1} - \sqrt{\bar{\alpha}_{t-1}}x_0)^2}{2(1-\bar{\alpha}_{t-1})}\right)

Step 2: Combine the Exponents

Since we multiply PDFs, we add the exponents:

logq(xt1xt,x0)=(xt1βtxt1)22βt(xt1αˉt1x0)22(1αˉt1)+C\log q(x_{t-1}|x_t, x_0) = -\frac{(x_t - \sqrt{1-\beta_t}x_{t-1})^2}{2\beta_t} - \frac{(x_{t-1} - \sqrt{\bar{\alpha}_{t-1}}x_0)^2}{2(1-\bar{\alpha}_{t-1})} + C

Step 3: Complete the Square

The key technique is to identify this as a quadratic in xt1x_{t-1}. After expanding and collecting terms, we get a form:

exp((xt1μ~t)22β~t)\propto \exp\left(-\frac{(x_{t-1} - \tilde{\mu}_t)^2}{2\tilde{\beta}_t}\right)

This is the PDF of a Gaussian with mean μ~t\tilde{\mu}_t and variance β~t\tilde{\beta}_t.

The Product of Two Gaussians: When you multiply two Gaussian PDFs (as functions of the same variable), the result is proportional to another Gaussian. This is the magic that makes the posterior tractable.

The Posterior Mean

After completing the square, the posterior mean is:

μ~t(xt,x0)=αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt\tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} \mathbf{x}_0 + \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t

This is a weighted average of x0\mathbf{x}_0 (the clean data) and xt\mathbf{x}_t (the noisy observation).

Breaking Down the Coefficients

CoefficientExpressionInterpretation
Weight of x_0sqrt(alpha-bar_{t-1}) beta_t / (1 - alpha-bar_t)Increases as t decreases (cleaner signals get more weight)
Weight of x_tsqrt(alpha_t)(1 - alpha-bar_{t-1}) / (1 - alpha-bar_t)Decreases as t decreases

Alternative Form

Using αt=1βt\alpha_t = 1 - \beta_t, the posterior mean can also be written as:

μ~t=1αt(xtβt1αˉtϵ)\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}\right)

where ϵ\boldsymbol{\epsilon} is the noise that was added. This form motivates predicting the noise!

The Posterior Variance

The posterior variance is:

β~t=1αˉt11αˉtβt\tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t

Properties of the Posterior Variance

  • Always less than βt\beta_t: The conditioning on x0\mathbf{x}_0 reduces uncertainty
  • Depends only on the schedule: It doesn't depend on the actual values of xt\mathbf{x}_t or x0\mathbf{x}_0
  • Approaches βt\beta_t for large t: When αˉtαˉt10\bar{\alpha}_t \approx \bar{\alpha}_{t-1} \approx 0

Special Cases

TimestepPosterior VarianceInterpretation
t = 1beta-tilde_1 = 0No variance - x_0 is determined exactly
t near 0Small beta-tilde_tHigh confidence in prediction
t near Tbeta-tilde_t near beta_tSimilar to forward variance

Building Intuition

The posterior distribution tells us: given where we started (x0\mathbf{x}_0) and where we are now (xt\mathbf{x}_t), what is the distribution of the previous state (xt1\mathbf{x}_{t-1})?

Why Is This Useful for Training?

During training, we do have access to x0\mathbf{x}_0 - it's the training data! This means we can:

  1. Compute the exact posterior q(xt1xt,x0)q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)
  2. Train our network to match this posterior
  3. At generation time, use the learned network since we don't have x0\mathbf{x}_0
The Training-Generation Gap: During training, we know x0\mathbf{x}_0 and can compute the exact target. At generation time, we don't know x0\mathbf{x}_0, so we use the learned model to approximate the posterior mean.

PyTorch Implementation

Let's implement the posterior calculations in PyTorch:

🐍python
1import torch
2import torch.nn as nn
3
4class DiffusionPosterior:
5    """Computes the tractable posterior q(x_{t-1} | x_t, x_0)."""
6
7    def __init__(self, betas: torch.Tensor):
8        """
9        Args:
10            betas: Noise schedule (T,) tensor
11        """
12        self.betas = betas
13        self.alphas = 1.0 - betas
14
15        # Cumulative products
16        self.alphas_bar = torch.cumprod(self.alphas, dim=0)
17        self.alphas_bar_prev = torch.cat([
18            torch.tensor([1.0]),
19            self.alphas_bar[:-1]
20        ])
21
22        # Precompute posterior variance
23        # beta_tilde_t = (1 - alpha_bar_{t-1}) / (1 - alpha_bar_t) * beta_t
24        self.posterior_variance = (
25            (1.0 - self.alphas_bar_prev) /
26            (1.0 - self.alphas_bar) *
27            self.betas
28        )
29        # Handle t=0 case
30        self.posterior_variance[0] = self.betas[0]
31
32        # Coefficients for posterior mean
33        # coef1 * x_0 + coef2 * x_t = posterior_mean
34        self.posterior_mean_coef1 = (
35            torch.sqrt(self.alphas_bar_prev) * self.betas /
36            (1.0 - self.alphas_bar)
37        )
38        self.posterior_mean_coef2 = (
39            torch.sqrt(self.alphas) * (1.0 - self.alphas_bar_prev) /
40            (1.0 - self.alphas_bar)
41        )
42
43    def get_posterior_mean(
44        self,
45        x_0: torch.Tensor,
46        x_t: torch.Tensor,
47        t: torch.Tensor
48    ) -> torch.Tensor:
49        """
50        Compute posterior mean mu_tilde_t(x_t, x_0).
51
52        Args:
53            x_0: Clean data (B, C, H, W)
54            x_t: Noisy data at timestep t (B, C, H, W)
55            t: Timestep indices (B,)
56
57        Returns:
58            Posterior mean (B, C, H, W)
59        """
60        # Extract coefficients for batch
61        coef1 = self.posterior_mean_coef1[t]  # (B,)
62        coef2 = self.posterior_mean_coef2[t]  # (B,)
63
64        # Reshape for broadcasting: (B,) -> (B, 1, 1, 1)
65        while coef1.dim() < x_0.dim():
66            coef1 = coef1.unsqueeze(-1)
67            coef2 = coef2.unsqueeze(-1)
68
69        return coef1 * x_0 + coef2 * x_t
70
71    def get_posterior_variance(self, t: torch.Tensor) -> torch.Tensor:
72        """
73        Get posterior variance beta_tilde_t.
74
75        Args:
76            t: Timestep indices (B,)
77
78        Returns:
79            Posterior variance (B,)
80        """
81        return self.posterior_variance[t]
82
83    def sample_posterior(
84        self,
85        x_0: torch.Tensor,
86        x_t: torch.Tensor,
87        t: torch.Tensor
88    ) -> torch.Tensor:
89        """
90        Sample from q(x_{t-1} | x_t, x_0).
91
92        Args:
93            x_0: Clean data (B, C, H, W)
94            x_t: Noisy data at timestep t (B, C, H, W)
95            t: Timestep indices (B,)
96
97        Returns:
98            Sample x_{t-1} (B, C, H, W)
99        """
100        mean = self.get_posterior_mean(x_0, x_t, t)
101        variance = self.get_posterior_variance(t)
102
103        # Reshape variance for broadcasting
104        while variance.dim() < x_0.dim():
105            variance = variance.unsqueeze(-1)
106
107        # Sample: x_{t-1} = mean + sqrt(variance) * noise
108        noise = torch.randn_like(x_0)
109
110        # Don't add noise at t=0
111        nonzero_mask = (t > 0).float()
112        while nonzero_mask.dim() < x_0.dim():
113            nonzero_mask = nonzero_mask.unsqueeze(-1)
114
115        return mean + nonzero_mask * torch.sqrt(variance) * noise
116
117
118# Example usage
119if __name__ == "__main__":
120    T = 1000
121    # Linear schedule
122    betas = torch.linspace(0.0001, 0.02, T)
123
124    posterior = DiffusionPosterior(betas)
125
126    # Simulate some data
127    x_0 = torch.randn(4, 3, 32, 32)  # Clean images
128    t = torch.randint(1, T, (4,))    # Random timesteps
129
130    # Get noisy version
131    alphas_bar = posterior.alphas_bar[t]
132    while alphas_bar.dim() < x_0.dim():
133        alphas_bar = alphas_bar.unsqueeze(-1)
134
135    epsilon = torch.randn_like(x_0)
136    x_t = torch.sqrt(alphas_bar) * x_0 + torch.sqrt(1 - alphas_bar) * epsilon
137
138    # Compute posterior
139    mean = posterior.get_posterior_mean(x_0, x_t, t)
140    variance = posterior.get_posterior_variance(t)
141    x_t_minus_1 = posterior.sample_posterior(x_0, x_t, t)
142
143    print(f"x_0 shape: {x_0.shape}")
144    print(f"Posterior mean shape: {mean.shape}")
145    print(f"Posterior variance: {variance}")
146    print(f"Sampled x_(t-1) shape: {x_t_minus_1.shape}")

Key Takeaways

  1. Conditioning makes it tractable: The posterior q(xt1xt,x0)q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)is Gaussian because all involved distributions are Gaussian
  2. Posterior mean formula: μ~t=αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt\tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\mathbf{x}_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{x}_t
  3. Posterior variance formula: β~t=1αˉt11αˉtβt\tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t
  4. Weighted average interpretation: The posterior mean interpolates between x0\mathbf{x}_0 and xt\mathbf{x}_t
  5. Training target: This posterior serves as the target our network learns to approximate
Looking Ahead: In the next section, we'll see how to parameterize the learned reverse process pθ(xt1xt)p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) and the different choices for what the network should predict.