Learning Objectives
By the end of this section, you will be able to:
- Derive the tractable posterior using Bayes' theorem
- Compute the posterior mean and variance
- Understand why conditioning on transforms an intractable problem into a tractable one
- Implement the posterior formulas in PyTorch
The Bayesian Setup
In Section 3.1, we saw that the true reverse is intractable because it requires marginalizing over the unknown data distribution. However, if we condition on knowing the original data , the situation changes dramatically.
The Key Insight: Given , both the forward marginal and the transition are Gaussian. Bayes' theorem applied to Gaussians yields another Gaussian!
We want to compute:
What We Know
From Chapter 2, we have closed-form expressions for all terms:
| Distribution | Formula | Mean | Variance |
|---|---|---|---|
| q(x_t|x_{t-1}) | N(sqrt(1-beta_t) x_{t-1}, beta_t I) | sqrt(1-beta_t) x_{t-1} | beta_t |
| q(x_{t-1}|x_0) | N(sqrt(alpha-bar_{t-1}) x_0, (1-alpha-bar_{t-1}) I) | sqrt(alpha-bar_{t-1}) x_0 | 1 - alpha-bar_{t-1} |
| q(x_t|x_0) | N(sqrt(alpha-bar_t) x_0, (1-alpha-bar_t) I) | sqrt(alpha-bar_t) x_0 | 1 - alpha-bar_t |
Markov Property
Deriving the Posterior
Let's work through the derivation step by step. For clarity, we'll work with a single dimension; the multivariate case follows identically due to the diagonal covariance structure.
Step 1: Write Out the Gaussians
Using Bayes' theorem:
The Gaussian PDFs are:
Step 2: Combine the Exponents
Since we multiply PDFs, we add the exponents:
Step 3: Complete the Square
The key technique is to identify this as a quadratic in . After expanding and collecting terms, we get a form:
This is the PDF of a Gaussian with mean and variance .
The Product of Two Gaussians: When you multiply two Gaussian PDFs (as functions of the same variable), the result is proportional to another Gaussian. This is the magic that makes the posterior tractable.
The Posterior Mean
After completing the square, the posterior mean is:
This is a weighted average of (the clean data) and (the noisy observation).
Breaking Down the Coefficients
| Coefficient | Expression | Interpretation |
|---|---|---|
| Weight of x_0 | sqrt(alpha-bar_{t-1}) beta_t / (1 - alpha-bar_t) | Increases as t decreases (cleaner signals get more weight) |
| Weight of x_t | sqrt(alpha_t)(1 - alpha-bar_{t-1}) / (1 - alpha-bar_t) | Decreases as t decreases |
Alternative Form
where is the noise that was added. This form motivates predicting the noise!
The Posterior Variance
The posterior variance is:
Properties of the Posterior Variance
- Always less than : The conditioning on reduces uncertainty
- Depends only on the schedule: It doesn't depend on the actual values of or
- Approaches for large t: When
Special Cases
| Timestep | Posterior Variance | Interpretation |
|---|---|---|
| t = 1 | beta-tilde_1 = 0 | No variance - x_0 is determined exactly |
| t near 0 | Small beta-tilde_t | High confidence in prediction |
| t near T | beta-tilde_t near beta_t | Similar to forward variance |
Building Intuition
The posterior distribution tells us: given where we started () and where we are now (), what is the distribution of the previous state ()?
Why Is This Useful for Training?
During training, we do have access to - it's the training data! This means we can:
- Compute the exact posterior
- Train our network to match this posterior
- At generation time, use the learned network since we don't have
The Training-Generation Gap: During training, we know and can compute the exact target. At generation time, we don't know , so we use the learned model to approximate the posterior mean.
PyTorch Implementation
Let's implement the posterior calculations in PyTorch:
1import torch
2import torch.nn as nn
3
4class DiffusionPosterior:
5 """Computes the tractable posterior q(x_{t-1} | x_t, x_0)."""
6
7 def __init__(self, betas: torch.Tensor):
8 """
9 Args:
10 betas: Noise schedule (T,) tensor
11 """
12 self.betas = betas
13 self.alphas = 1.0 - betas
14
15 # Cumulative products
16 self.alphas_bar = torch.cumprod(self.alphas, dim=0)
17 self.alphas_bar_prev = torch.cat([
18 torch.tensor([1.0]),
19 self.alphas_bar[:-1]
20 ])
21
22 # Precompute posterior variance
23 # beta_tilde_t = (1 - alpha_bar_{t-1}) / (1 - alpha_bar_t) * beta_t
24 self.posterior_variance = (
25 (1.0 - self.alphas_bar_prev) /
26 (1.0 - self.alphas_bar) *
27 self.betas
28 )
29 # Handle t=0 case
30 self.posterior_variance[0] = self.betas[0]
31
32 # Coefficients for posterior mean
33 # coef1 * x_0 + coef2 * x_t = posterior_mean
34 self.posterior_mean_coef1 = (
35 torch.sqrt(self.alphas_bar_prev) * self.betas /
36 (1.0 - self.alphas_bar)
37 )
38 self.posterior_mean_coef2 = (
39 torch.sqrt(self.alphas) * (1.0 - self.alphas_bar_prev) /
40 (1.0 - self.alphas_bar)
41 )
42
43 def get_posterior_mean(
44 self,
45 x_0: torch.Tensor,
46 x_t: torch.Tensor,
47 t: torch.Tensor
48 ) -> torch.Tensor:
49 """
50 Compute posterior mean mu_tilde_t(x_t, x_0).
51
52 Args:
53 x_0: Clean data (B, C, H, W)
54 x_t: Noisy data at timestep t (B, C, H, W)
55 t: Timestep indices (B,)
56
57 Returns:
58 Posterior mean (B, C, H, W)
59 """
60 # Extract coefficients for batch
61 coef1 = self.posterior_mean_coef1[t] # (B,)
62 coef2 = self.posterior_mean_coef2[t] # (B,)
63
64 # Reshape for broadcasting: (B,) -> (B, 1, 1, 1)
65 while coef1.dim() < x_0.dim():
66 coef1 = coef1.unsqueeze(-1)
67 coef2 = coef2.unsqueeze(-1)
68
69 return coef1 * x_0 + coef2 * x_t
70
71 def get_posterior_variance(self, t: torch.Tensor) -> torch.Tensor:
72 """
73 Get posterior variance beta_tilde_t.
74
75 Args:
76 t: Timestep indices (B,)
77
78 Returns:
79 Posterior variance (B,)
80 """
81 return self.posterior_variance[t]
82
83 def sample_posterior(
84 self,
85 x_0: torch.Tensor,
86 x_t: torch.Tensor,
87 t: torch.Tensor
88 ) -> torch.Tensor:
89 """
90 Sample from q(x_{t-1} | x_t, x_0).
91
92 Args:
93 x_0: Clean data (B, C, H, W)
94 x_t: Noisy data at timestep t (B, C, H, W)
95 t: Timestep indices (B,)
96
97 Returns:
98 Sample x_{t-1} (B, C, H, W)
99 """
100 mean = self.get_posterior_mean(x_0, x_t, t)
101 variance = self.get_posterior_variance(t)
102
103 # Reshape variance for broadcasting
104 while variance.dim() < x_0.dim():
105 variance = variance.unsqueeze(-1)
106
107 # Sample: x_{t-1} = mean + sqrt(variance) * noise
108 noise = torch.randn_like(x_0)
109
110 # Don't add noise at t=0
111 nonzero_mask = (t > 0).float()
112 while nonzero_mask.dim() < x_0.dim():
113 nonzero_mask = nonzero_mask.unsqueeze(-1)
114
115 return mean + nonzero_mask * torch.sqrt(variance) * noise
116
117
118# Example usage
119if __name__ == "__main__":
120 T = 1000
121 # Linear schedule
122 betas = torch.linspace(0.0001, 0.02, T)
123
124 posterior = DiffusionPosterior(betas)
125
126 # Simulate some data
127 x_0 = torch.randn(4, 3, 32, 32) # Clean images
128 t = torch.randint(1, T, (4,)) # Random timesteps
129
130 # Get noisy version
131 alphas_bar = posterior.alphas_bar[t]
132 while alphas_bar.dim() < x_0.dim():
133 alphas_bar = alphas_bar.unsqueeze(-1)
134
135 epsilon = torch.randn_like(x_0)
136 x_t = torch.sqrt(alphas_bar) * x_0 + torch.sqrt(1 - alphas_bar) * epsilon
137
138 # Compute posterior
139 mean = posterior.get_posterior_mean(x_0, x_t, t)
140 variance = posterior.get_posterior_variance(t)
141 x_t_minus_1 = posterior.sample_posterior(x_0, x_t, t)
142
143 print(f"x_0 shape: {x_0.shape}")
144 print(f"Posterior mean shape: {mean.shape}")
145 print(f"Posterior variance: {variance}")
146 print(f"Sampled x_(t-1) shape: {x_t_minus_1.shape}")Key Takeaways
- Conditioning makes it tractable: The posterior is Gaussian because all involved distributions are Gaussian
- Posterior mean formula:
- Posterior variance formula:
- Weighted average interpretation: The posterior mean interpolates between and
- Training target: This posterior serves as the target our network learns to approximate
Looking Ahead: In the next section, we'll see how to parameterize the learned reverse process and the different choices for what the network should predict.