Chapter 2
15 min read
Section 14 of 76

Implementing the Forward Process

The Forward Diffusion Process

Learning Objectives

By the end of this section, you will be able to:

  1. Build a production-ready forward diffusion class with proper PyTorch integration
  2. Understand how to use register_buffer for efficient GPU computation
  3. Implement multiple noise schedules in a unified framework
  4. Use the forward process for training data generation

Complete Implementation

Here is a production-ready implementation that incorporates all the concepts from this chapter: noise schedules, closed-form sampling, and pre-computed coefficients for the reverse process.

Complete Forward Diffusion Implementation
🐍gaussian_diffusion.py
7Dataclass Config

Using a dataclass for configuration keeps all hyperparameters organized and type-checked.

15GaussianDiffusion Class

Inherits from nn.Module so it can be moved to GPU with .to(device) and integrates with PyTorch ecosystem.

28Register Buffers

Buffers are tensors that should be saved/loaded with the model but are not trainable parameters. They move to the correct device automatically.

38Reciprocal Terms

Pre-compute sqrt(1/alpha_bar) for efficiently recovering x_0 from x_t and predicted noise.

43Posterior Quantities

Pre-compute posterior mean coefficients and variance for the reverse process. These are used during sampling.

58Schedule Computation

Factory method that computes betas based on the selected schedule type.

62Linear Schedule

Simply use torch.linspace to get evenly spaced beta values.

65Cosine Schedule

Compute alpha_bar directly from cosine function, then derive betas. Use float64 for numerical precision.

77q_sample Method

The main forward sampling method. Takes clean data and timesteps, returns noised data and the noise that was added.

93Extract Method

Helper to gather values from a 1D tensor at batch indices and reshape for broadcasting with image tensors.

99SNR Method

Compute signal-to-noise ratio at given timesteps. Useful for analysis and advanced loss weighting.

104Predict x_0

Given x_t and predicted noise, recover x_0 using the reparameterization formula solved for x_0.

110Posterior Distribution

Compute the true posterior q(x_{t-1}|x_t, x_0). This is what the model learns to approximate.

148 lines without explanation
1import torch
2import torch.nn as nn
3import math
4from typing import Optional, Tuple, Dict
5from dataclasses import dataclass
6
7@dataclass
8class DiffusionConfig:
9    """Configuration for the diffusion process."""
10    num_timesteps: int = 1000
11    schedule: str = "cosine"  # "linear", "cosine", "quadratic"
12    beta_start: float = 0.0001
13    beta_end: float = 0.02
14
15
16class GaussianDiffusion(nn.Module):
17    """
18    Complete forward diffusion process implementation.
19
20    This class handles:
21    - Noise schedule computation (linear, cosine, quadratic)
22    - Efficient closed-form sampling at any timestep
23    - Pre-computation of all derived quantities
24    - GPU-compatible tensor operations
25    """
26
27    def __init__(self, config: DiffusionConfig):
28        super().__init__()
29        self.config = config
30        self.T = config.num_timesteps
31
32        # Compute the noise schedule
33        betas = self._compute_schedule(config)
34
35        # Register as buffers (not parameters, but should be on same device)
36        self.register_buffer("betas", betas)
37        self.register_buffer("alphas", 1.0 - betas)
38        self.register_buffer("alpha_bars", torch.cumprod(self.alphas, dim=0))
39
40        # Pre-compute useful quantities
41        self.register_buffer("sqrt_alpha_bars", torch.sqrt(self.alpha_bars))
42        self.register_buffer("sqrt_one_minus_alpha_bars", torch.sqrt(1.0 - self.alpha_bars))
43        self.register_buffer("sqrt_recip_alpha_bars", torch.sqrt(1.0 / self.alpha_bars))
44        self.register_buffer("sqrt_recip_alpha_bars_minus_one",
45                           torch.sqrt(1.0 / self.alpha_bars - 1.0))
46
47        # For posterior computation (used in reverse process)
48        alpha_bars_prev = torch.cat([torch.tensor([1.0]), self.alpha_bars[:-1]])
49        self.register_buffer("alpha_bars_prev", alpha_bars_prev)
50
51        # Posterior variance: (1 - alpha_bar_{t-1}) / (1 - alpha_bar_t) * beta_t
52        posterior_variance = betas * (1.0 - alpha_bars_prev) / (1.0 - self.alpha_bars)
53        self.register_buffer("posterior_variance", posterior_variance)
54        self.register_buffer("posterior_log_variance_clipped",
55                           torch.log(torch.clamp(posterior_variance, min=1e-20)))
56
57        # Posterior mean coefficients
58        self.register_buffer("posterior_mean_coef1",
59            betas * torch.sqrt(alpha_bars_prev) / (1.0 - self.alpha_bars))
60        self.register_buffer("posterior_mean_coef2",
61            (1.0 - alpha_bars_prev) * torch.sqrt(self.alphas) / (1.0 - self.alpha_bars))
62
63    def _compute_schedule(self, config: DiffusionConfig) -> torch.Tensor:
64        """Compute the beta schedule based on configuration."""
65        T = config.num_timesteps
66
67        if config.schedule == "linear":
68            return torch.linspace(config.beta_start, config.beta_end, T)
69
70        elif config.schedule == "cosine":
71            s = 0.008
72            steps = torch.arange(T + 1, dtype=torch.float64) / T
73            alpha_bars = torch.cos((steps + s) / (1 + s) * math.pi / 2) ** 2
74            alpha_bars = alpha_bars / alpha_bars[0]
75            betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
76            return torch.clamp(betas, min=0.0001, max=0.999).float()
77
78        elif config.schedule == "quadratic":
79            return torch.linspace(config.beta_start**0.5, config.beta_end**0.5, T) ** 2
80
81        else:
82            raise ValueError(f"Unknown schedule: {config.schedule}")
83
84    def q_sample(
85        self,
86        x_0: torch.Tensor,
87        t: torch.Tensor,
88        noise: Optional[torch.Tensor] = None
89    ) -> Tuple[torch.Tensor, torch.Tensor]:
90        """
91        Sample from q(x_t | x_0) using the closed-form formula.
92
93        x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
94
95        Args:
96            x_0: Clean data [B, C, H, W]
97            t: Timesteps [B]
98            noise: Optional pre-sampled noise [B, C, H, W]
99
100        Returns:
101            x_t: Noised data
102            noise: The noise that was added
103        """
104        if noise is None:
105            noise = torch.randn_like(x_0)
106
107        sqrt_alpha_bar = self._extract(self.sqrt_alpha_bars, t, x_0.shape)
108        sqrt_one_minus_alpha_bar = self._extract(self.sqrt_one_minus_alpha_bars, t, x_0.shape)
109
110        x_t = sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * noise
111        return x_t, noise
112
113    def _extract(self, a: torch.Tensor, t: torch.Tensor, x_shape: tuple) -> torch.Tensor:
114        """Extract values from 'a' at indices 't' and reshape for broadcasting."""
115        batch_size = t.shape[0]
116        out = a.gather(-1, t)
117        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
118
119    def get_snr(self, t: torch.Tensor) -> torch.Tensor:
120        """Get signal-to-noise ratio at timestep t."""
121        alpha_bar = self._extract(self.alpha_bars, t, t.shape)
122        return alpha_bar / (1.0 - alpha_bar)
123
124    def predict_x0_from_noise(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
125        """Recover x_0 from x_t and predicted noise."""
126        sqrt_recip_alpha_bar = self._extract(self.sqrt_recip_alpha_bars, t, x_t.shape)
127        sqrt_recip_alpha_bar_minus_one = self._extract(self.sqrt_recip_alpha_bars_minus_one, t, x_t.shape)
128        return sqrt_recip_alpha_bar * x_t - sqrt_recip_alpha_bar_minus_one * noise
129
130    def q_posterior_mean_variance(
131        self,
132        x_0: torch.Tensor,
133        x_t: torch.Tensor,
134        t: torch.Tensor
135    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
136        """
137        Compute the mean and variance of q(x_{t-1} | x_t, x_0).
138
139        This is the true posterior, used as the target for training.
140        """
141        posterior_mean = (
142            self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 +
143            self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
144        )
145        posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
146        posterior_log_variance = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
147        return posterior_mean, posterior_variance, posterior_log_variance
148
149
150# Example instantiation
151config = DiffusionConfig(num_timesteps=1000, schedule="cosine")
152diffusion = GaussianDiffusion(config)
153
154# Test on a batch
155x_0 = torch.randn(4, 3, 64, 64)
156t = torch.randint(0, 1000, (4,))
157x_t, noise = diffusion.q_sample(x_0, t)
158
159print(f"Config: T={config.num_timesteps}, schedule={config.schedule}")
160print(f"x_0 shape: {x_0.shape}, x_t shape: {x_t.shape}")
161print(f"At t={t[0].item()}: alpha_bar={diffusion.alpha_bars[t[0]]:.4f}")

Usage Examples

Basic Sampling

The most common operation is sampling noisy versions of clean data for training:

  • Sample timesteps uniformly: tUniform{0,1,,T1}t \sim \text{Uniform}\{0, 1, \ldots, T-1\}
  • Add noise: Use q_sample(x_0, t) to get xt\mathbf{x}_t and the noise ϵ\boldsymbol{\epsilon}
  • Train: Network predicts ϵ\boldsymbol{\epsilon}from xt\mathbf{x}_t

GPU Compatibility

Because we inherit from nn.Module and use register_buffer, moving to GPU is seamless:

Device Management

All registered buffers automatically move when you call diffusion.to(device). The q_sample method works correctly as long as input tensors are on the same device as the module.

Integration with Training

Here is how the forward process integrates into a training loop:

Training Step Pseudocode:
  1. Load batch of clean images x0\mathbf{x}_0
  2. Sample random timesteps tUniform(0,T)t \sim \text{Uniform}(0, T)
  3. Sample noise ϵN(0,I)\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})
  4. Compute xt\mathbf{x}_t using diffusion.q_sample(x_0, t, noise)
  5. Predict noise: ϵ^=model(xt,t)\hat{\boldsymbol{\epsilon}} = \text{model}(\mathbf{x}_t, t)
  6. Compute loss: L=ϵϵ^2\mathcal{L} = \|\boldsymbol{\epsilon} - \hat{\boldsymbol{\epsilon}}\|^2
  7. Backpropagate and update model weights

Notice that the forward process is not trained - it is a fixed mathematical operation. Only the neural network that predicts the noise is trained.


Key Takeaways

  1. Use nn.Module: Enables GPU compatibility and PyTorch ecosystem integration
  2. Register buffers: For tensors that are not parameters but should move with the model
  3. Pre-compute everything: All derived quantities should be computed once at initialization
  4. Unified interface: Support multiple schedules through configuration
  5. The extract pattern: Use gather + reshape for batch-indexed coefficient lookup
Chapter Complete! You now understand the forward diffusion process from theory to implementation. In Chapter 3, we'll tackle the reverse process - how to generate new samples by learning to denoise.