Chapter 6
25 min read
Section 29 of 76

The Complete DDPM Class

Building the Diffusion Model

Learning Objectives

By the end of this section, you will:

  1. Understand the complete DDPM algorithm and how forward/reverse processes work together
  2. Implement the noise schedule with precomputed alpha and beta values
  3. Build the forward process that adds noise to training data
  4. Implement the reverse process that removes noise step by step
  5. Assemble a complete DDPM class ready for training and sampling

From Architecture to Algorithm

In Chapter 5, we built the U-Net that predicts noise. Now we build the diffusion model itself: the algorithms that use the U-Net to transform data into noise (forward) and noise into data (reverse). The DDPM class we build here is the complete system for training and generating images.

DDPM: The Big Picture

Denoising Diffusion Probabilistic Models (DDPM) work by:

  1. Forward process (training): Gradually add Gaussian noise to real images until they become pure noise
  2. Reverse process (generation): Learn to reverse this process, gradually removing noise to create images from scratch

The key insight is that both processes are Markov chains: each step only depends on the previous step. The forward process has a closed-form solution, while the reverse process is learned by the U-Net.

The Two Distributions

DDPM defines two probability distributions:

  • q(xtxt1)q(x_t | x_{t-1}): The forward process (fixed, adds noise)
  • pθ(xt1xt)p_\theta(x_{t-1} | x_t): The reverse process (learned, removes noise)

Both are Gaussian distributions:

📝text
1Forward: q(x_t | x_{t-1}) = N(x_t; sqrt(1-beta_t) * x_{t-1}, beta_t * I)
2
3Reverse: p_theta(x_{t-1} | x_t) = N(x_{t-1}; mu_theta(x_t, t), sigma_t^2 * I)

Why Gaussian?

Gaussian distributions are closed under addition: the sum of two Gaussians is still Gaussian. This mathematical property makes it possible to derive closed-form expressions for the forward process and the training objective. It also means we can chain many small noise steps into one large step.

The Noise Schedule

The noise schedule defines how much noise is added at each timestep. It's controlled by βt\beta_t, the variance of noise added at step tt.

Key Schedule Parameters

ParameterDefinitionRole
beta_tNoise variance at step tHow much noise to add at each step
alpha_t1 - beta_tHow much signal is preserved at each step
alpha_bar_tProduct of alpha_1...alpha_tTotal signal preserved from x_0 to x_t
sqrt(alpha_bar_t)sqrt(cumulative product)Coefficient for x_0 in forward process
sqrt(1-alpha_bar_t)sqrt(1 - cumulative product)Coefficient for noise in forward process

The genius of DDPM is that we can compute xtx_t directly from x0x_0without iterating through all intermediate steps:

xt=αˉtx0+1αˉtϵ,ϵN(0,I)x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)

Noise Schedule Implementation
🐍noise_schedule.py
1Noise Schedule Setup

The noise schedule defines how noise is added over time. We precompute all schedule parameters for efficient training and sampling.

8Linear Beta Schedule

The original DDPM uses a linear schedule from beta_start=0.0001 to beta_end=0.02. These values were found empirically to work well.

12Alpha Values

alpha_t = 1 - beta_t represents how much signal is preserved at each step. Higher alpha means less noise added.

15Cumulative Alpha

alpha_bar_t is the cumulative product of all alphas up to t. This lets us jump directly to any timestep without iterating.

18Posterior Variance

The posterior variance is used in the reverse process. It defines the variance of p(x_{t-1}|x_t, x_0).

63 lines without explanation
1import torch
2import torch.nn as nn
3import math
4
5class NoiseSchedule:
6    """Precompute and store all noise schedule parameters."""
7
8    def __init__(
9        self,
10        timesteps: int = 1000,
11        beta_start: float = 0.0001,
12        beta_end: float = 0.02,
13        schedule_type: str = "linear",
14    ):
15        self.timesteps = timesteps
16
17        # Compute beta schedule
18        if schedule_type == "linear":
19            self.betas = torch.linspace(beta_start, beta_end, timesteps)
20        elif schedule_type == "cosine":
21            self.betas = self._cosine_schedule(timesteps)
22        else:
23            raise ValueError(f"Unknown schedule: {schedule_type}")
24
25        # Compute alpha values
26        self.alphas = 1.0 - self.betas
27
28        # Cumulative product of alphas
29        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
30
31        # Previous cumulative product (for t=0, use 1.0)
32        self.alphas_cumprod_prev = torch.cat([
33            torch.tensor([1.0]),
34            self.alphas_cumprod[:-1]
35        ])
36
37        # Precompute values for forward process
38        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
39        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
40
41        # Precompute values for reverse process
42        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
43
44        # Posterior variance (for reverse process)
45        self.posterior_variance = (
46            self.betas * (1.0 - self.alphas_cumprod_prev) /
47            (1.0 - self.alphas_cumprod)
48        )
49
50    def _cosine_schedule(self, timesteps: int, s: float = 0.008):
51        """
52        Cosine schedule from 'Improved DDPM' paper.
53        Provides smoother noise addition than linear.
54        """
55        steps = timesteps + 1
56        t = torch.linspace(0, timesteps, steps)
57        alphas_cumprod = torch.cos(((t / timesteps) + s) / (1 + s) * math.pi / 2) ** 2
58        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
59        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
60        return torch.clamp(betas, 0.0001, 0.9999)
61
62    def get(self, values: torch.Tensor, t: torch.Tensor, x_shape):
63        """
64        Extract values at timestep t and reshape for broadcasting.
65        """
66        batch_size = t.shape[0]
67        out = values.gather(-1, t)
68        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

Linear vs Cosine Schedule

The original DDPM used a linear schedule, but the Improved DDPM paper found that a cosine schedule works better:

ScheduleProsCons
LinearSimple, well-studiedDestroys information too quickly at start
CosineSmoother noise addition, better imagesSlightly more complex to compute

Choosing a Schedule

For most applications, cosine schedule is recommended. It preserves more image structure in early timesteps, making it easier for the model to learn. The linear schedule is fine for experimentation and matches the original DDPM paper.

Forward Diffusion Process

The forward process q(xtx0)q(x_t | x_0) adds noise to clean images. During training, we sample random timesteps and compute xtx_t directly:

Forward Process: Adding Noise
🐍forward_process.py
1q_sample Method

This implements q(x_t|x_0), the forward process that adds noise to data. It's the core of training.

8Extract Schedule Values

We gather the alpha_bar values for the specific timesteps in the batch. This vectorized operation handles different t for each sample.

12Sample Gaussian Noise

We sample standard Gaussian noise epsilon ~ N(0, I). This is the noise we'll train the network to predict.

15Reparameterization

x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * epsilon. This is the closed-form forward process.

19Return Noisy Image and Noise

We return both x_t (for input to the model) and epsilon (the target for the loss).

32 lines without explanation
1def q_sample(
2    self,
3    x_0: torch.Tensor,
4    t: torch.Tensor,
5    noise: torch.Tensor = None,
6) -> tuple[torch.Tensor, torch.Tensor]:
7    """
8    Forward diffusion process: q(x_t | x_0)
9
10    Given clean images x_0 and timesteps t, compute noisy images x_t.
11
12    Args:
13        x_0: Clean images [B, C, H, W], scaled to [-1, 1]
14        t: Timesteps [B], integers from 0 to T-1
15        noise: Optional pre-sampled noise [B, C, H, W]
16
17    Returns:
18        x_t: Noisy images [B, C, H, W]
19        noise: The noise that was added (for computing loss)
20    """
21    # Sample noise if not provided
22    if noise is None:
23        noise = torch.randn_like(x_0)
24
25    # Get schedule values for timestep t
26    sqrt_alpha_bar = self.schedule.get(
27        self.schedule.sqrt_alphas_cumprod, t, x_0.shape
28    )
29    sqrt_one_minus_alpha_bar = self.schedule.get(
30        self.schedule.sqrt_one_minus_alphas_cumprod, t, x_0.shape
31    )
32
33    # Compute x_t using reparameterization
34    # x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * noise
35    x_t = sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * noise
36
37    return x_t, noise

Key properties of the forward process:

  • Closed-form: We can compute xtx_t directly without iterating through x1,x2,...,xt1x_1, x_2, ..., x_{t-1}
  • Gaussian: xtx_t is always Gaussian distributed given x0x_0
  • Signal decay: As tTt \to T, αˉt0\bar{\alpha}_t \to 0, so xTN(0,I)x_T \approx \mathcal{N}(0, I)

Input Scaling

DDPM assumes inputs are scaled to [-1, 1], not [0, 1]. This centering around zero is important because the noise ϵ\epsilon has mean zero. Always scale your images: x = 2 * x - 1 before training.

Reverse Diffusion Process

The reverse process pθ(xt1xt)p_\theta(x_{t-1} | x_t) learns to remove noise. The U-Net predicts the noise ϵθ(xt,t)\epsilon_\theta(x_t, t), and we use this to compute xt1x_{t-1}:

Reverse Process: Removing Noise
🐍reverse_process.py
1p_sample Method

This implements one step of p(x_{t-1}|x_t), the learned reverse process. It denoises from x_t to x_{t-1}.

8Predict Noise

The U-Net takes the noisy image and timestep, outputting its prediction of the noise that was added.

12Compute Predicted x_0

From the predicted noise, we can estimate x_0 using the reparameterization formula solved for x_0.

18Clip x_0

Clipping the predicted x_0 to [-1, 1] improves sample quality. This is optional but commonly used.

22Compute Posterior Mean

The mean of p(x_{t-1}|x_t, x_0) depends on both x_t and the predicted x_0. This is derived from Bayes' theorem.

28Add Noise (except t=0)

For t > 0, we add Gaussian noise scaled by the posterior variance. At t=0, we return the mean directly.

55 lines without explanation
1@torch.no_grad()
2def p_sample(
3    self,
4    x_t: torch.Tensor,
5    t: torch.Tensor,
6    clip_denoised: bool = True,
7) -> torch.Tensor:
8    """
9    Reverse diffusion step: p_theta(x_{t-1} | x_t)
10
11    Given noisy images x_t and timesteps t, compute less noisy x_{t-1}.
12
13    Args:
14        x_t: Noisy images [B, C, H, W]
15        t: Current timesteps [B]
16        clip_denoised: Whether to clip predicted x_0 to [-1, 1]
17
18    Returns:
19        x_{t-1}: Slightly denoised images [B, C, H, W]
20    """
21    # Predict noise using the U-Net
22    predicted_noise = self.model(x_t, t)
23
24    # Get schedule values
25    alpha = self.schedule.get(self.schedule.alphas, t, x_t.shape)
26    alpha_bar = self.schedule.get(self.schedule.alphas_cumprod, t, x_t.shape)
27    alpha_bar_prev = self.schedule.get(self.schedule.alphas_cumprod_prev, t, x_t.shape)
28    beta = self.schedule.get(self.schedule.betas, t, x_t.shape)
29
30    # Predict x_0 from x_t and predicted noise
31    # x_0 = (x_t - sqrt(1-alpha_bar) * noise) / sqrt(alpha_bar)
32    predicted_x0 = (
33        x_t - torch.sqrt(1 - alpha_bar) * predicted_noise
34    ) / torch.sqrt(alpha_bar)
35
36    # Optionally clip predicted x_0
37    if clip_denoised:
38        predicted_x0 = torch.clamp(predicted_x0, -1, 1)
39
40    # Compute posterior mean
41    # mu = (sqrt(alpha_bar_prev) * beta * x_0 + sqrt(alpha) * (1-alpha_bar_prev) * x_t)
42    #      / (1 - alpha_bar)
43    posterior_mean = (
44        torch.sqrt(alpha_bar_prev) * beta * predicted_x0 +
45        torch.sqrt(alpha) * (1 - alpha_bar_prev) * x_t
46    ) / (1 - alpha_bar)
47
48    # Get posterior variance
49    posterior_var = self.schedule.get(self.schedule.posterior_variance, t, x_t.shape)
50
51    # Sample x_{t-1}
52    # For t > 0, add noise; for t = 0, just return the mean
53    noise = torch.randn_like(x_t)
54
55    # Create mask for t > 0
56    nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))
57
58    # x_{t-1} = mean + sqrt(variance) * noise (only for t > 0)
59    x_prev = posterior_mean + nonzero_mask * torch.sqrt(posterior_var) * noise
60
61    return x_prev

Understanding the Reverse Step

The reverse step involves three key computations:

  1. Predict noise: The U-Net outputs ϵθ(xt,t)\epsilon_\theta(x_t, t)
  2. Estimate x_0: Using the predicted noise, estimate what the clean image would look like
  3. Compute posterior: Combine xtx_t and predicted x0x_0to get the distribution of xt1x_{t-1}

The posterior mean formula comes from Bayes' theorem applied to Gaussian distributions. It's a weighted combination of where we are (xtx_t) and where we think we're going (x^0\hat{x}_0).

Why Predict Noise?

We could train the model to predict x0x_0 directly or predict xt1x_{t-1}. Ho et al. found that predicting noise works best empirically. It also has a nice interpretation: the model learns what "doesn't belong" in the image at each noise level.

Complete DDPM Class

Now let's assemble everything into a complete DDPM class:

Complete DDPM Implementation
🐍ddpm.py
1DDPM Class

The DDPM class encapsulates the entire diffusion model: noise schedule, forward process, reverse process, and sampling.

10Constructor Parameters

Key parameters: model (the U-Net), timesteps (typically 1000), beta_start/end (noise schedule bounds), schedule_type (linear or cosine).

20Register Buffers

We register schedule parameters as buffers so they're automatically moved to the correct device with the model.

35Cosine Schedule Option

The cosine schedule (from Improved DDPM) provides smoother noise addition, which can improve image quality.

50Training Loss

The training loss is simple: MSE between predicted noise and actual noise. This is the simplified ELBO objective.

60Sampling Loop

To generate images, we start from pure noise and iteratively apply p_sample for each timestep from T to 1.

178 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional
5from tqdm import tqdm
6
7class DDPM(nn.Module):
8    """
9    Denoising Diffusion Probabilistic Model.
10
11    Combines:
12    - Noise schedule (defines the forward process)
13    - U-Net model (predicts noise for reverse process)
14    - Forward process (q_sample)
15    - Reverse process (p_sample)
16    - Training loss
17    - Sampling procedure
18    """
19
20    def __init__(
21        self,
22        model: nn.Module,
23        timesteps: int = 1000,
24        beta_start: float = 0.0001,
25        beta_end: float = 0.02,
26        schedule_type: str = "linear",
27    ):
28        super().__init__()
29
30        self.model = model
31        self.timesteps = timesteps
32
33        # Initialize noise schedule
34        self.schedule = NoiseSchedule(
35            timesteps=timesteps,
36            beta_start=beta_start,
37            beta_end=beta_end,
38            schedule_type=schedule_type,
39        )
40
41        # Register schedule tensors as buffers
42        self.register_buffer('betas', self.schedule.betas)
43        self.register_buffer('alphas', self.schedule.alphas)
44        self.register_buffer('alphas_cumprod', self.schedule.alphas_cumprod)
45        self.register_buffer('alphas_cumprod_prev', self.schedule.alphas_cumprod_prev)
46        self.register_buffer('sqrt_alphas_cumprod', self.schedule.sqrt_alphas_cumprod)
47        self.register_buffer('sqrt_one_minus_alphas_cumprod',
48                            self.schedule.sqrt_one_minus_alphas_cumprod)
49        self.register_buffer('posterior_variance', self.schedule.posterior_variance)
50
51    def q_sample(
52        self,
53        x_0: torch.Tensor,
54        t: torch.Tensor,
55        noise: Optional[torch.Tensor] = None,
56    ) -> tuple[torch.Tensor, torch.Tensor]:
57        """Forward process: add noise to x_0 to get x_t."""
58        if noise is None:
59            noise = torch.randn_like(x_0)
60
61        sqrt_alpha_bar = self._extract(self.sqrt_alphas_cumprod, t, x_0.shape)
62        sqrt_one_minus_alpha_bar = self._extract(
63            self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
64        )
65
66        x_t = sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * noise
67        return x_t, noise
68
69    @torch.no_grad()
70    def p_sample(
71        self,
72        x_t: torch.Tensor,
73        t: torch.Tensor,
74        clip_denoised: bool = True,
75    ) -> torch.Tensor:
76        """Reverse process: denoise x_t to get x_{t-1}."""
77        # Predict noise
78        pred_noise = self.model(x_t, t)
79
80        # Get schedule values
81        alpha = self._extract(self.alphas, t, x_t.shape)
82        alpha_bar = self._extract(self.alphas_cumprod, t, x_t.shape)
83        alpha_bar_prev = self._extract(self.alphas_cumprod_prev, t, x_t.shape)
84        beta = self._extract(self.betas, t, x_t.shape)
85
86        # Predict x_0
87        pred_x0 = (x_t - torch.sqrt(1 - alpha_bar) * pred_noise) / torch.sqrt(alpha_bar)
88        if clip_denoised:
89            pred_x0 = pred_x0.clamp(-1, 1)
90
91        # Posterior mean
92        posterior_mean = (
93            torch.sqrt(alpha_bar_prev) * beta * pred_x0 +
94            torch.sqrt(alpha) * (1 - alpha_bar_prev) * x_t
95        ) / (1 - alpha_bar)
96
97        # Posterior variance
98        posterior_var = self._extract(self.posterior_variance, t, x_t.shape)
99
100        # Sample
101        noise = torch.randn_like(x_t)
102        nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))
103        x_prev = posterior_mean + nonzero_mask * torch.sqrt(posterior_var) * noise
104
105        return x_prev
106
107    def training_loss(
108        self,
109        x_0: torch.Tensor,
110        noise: Optional[torch.Tensor] = None,
111    ) -> torch.Tensor:
112        """
113        Compute the training loss (simplified ELBO).
114
115        Args:
116            x_0: Clean images [B, C, H, W]
117            noise: Optional pre-sampled noise
118
119        Returns:
120            loss: Scalar loss value
121        """
122        batch_size = x_0.shape[0]
123        device = x_0.device
124
125        # Sample random timesteps
126        t = torch.randint(0, self.timesteps, (batch_size,), device=device)
127
128        # Add noise to get x_t
129        x_t, noise = self.q_sample(x_0, t, noise)
130
131        # Predict noise
132        pred_noise = self.model(x_t, t)
133
134        # MSE loss between predicted and actual noise
135        loss = F.mse_loss(pred_noise, noise)
136
137        return loss
138
139    @torch.no_grad()
140    def sample(
141        self,
142        batch_size: int,
143        image_size: int,
144        channels: int = 3,
145        device: str = "cuda",
146        show_progress: bool = True,
147    ) -> torch.Tensor:
148        """
149        Generate samples by running the reverse process.
150
151        Args:
152            batch_size: Number of images to generate
153            image_size: Size of generated images (assumes square)
154            channels: Number of image channels
155            device: Device to generate on
156            show_progress: Whether to show progress bar
157
158        Returns:
159            samples: Generated images [B, C, H, W] in [-1, 1]
160        """
161        # Start from pure noise
162        x = torch.randn(batch_size, channels, image_size, image_size, device=device)
163
164        # Iteratively denoise
165        timesteps = list(range(self.timesteps))[::-1]  # T-1, T-2, ..., 0
166        if show_progress:
167            timesteps = tqdm(timesteps, desc="Sampling")
168
169        for t in timesteps:
170            t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
171            x = self.p_sample(x, t_batch)
172
173        return x
174
175    def _extract(
176        self,
177        values: torch.Tensor,
178        t: torch.Tensor,
179        x_shape: tuple,
180    ) -> torch.Tensor:
181        """Extract values at timestep t and reshape for broadcasting."""
182        batch_size = t.shape[0]
183        out = values.gather(-1, t)
184        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

Using the DDPM Class

Here's how to use the DDPM class for training and sampling:

🐍python
1import torch
2from torch.optim import AdamW
3from torch.utils.data import DataLoader
4
5# Create U-Net and DDPM
6unet = UNet(
7    image_size=64,
8    base_channels=128,
9    channel_mults=(1, 2, 2, 4),
10    num_res_blocks=2,
11    attention_resolutions=(16, 8),
12)
13
14ddpm = DDPM(
15    model=unet,
16    timesteps=1000,
17    schedule_type="cosine",  # Use improved schedule
18)
19
20# Move to GPU
21device = "cuda" if torch.cuda.is_available() else "cpu"
22ddpm = ddpm.to(device)
23
24# Optimizer
25optimizer = AdamW(ddpm.parameters(), lr=2e-4)
26
27# Training loop (simplified)
28def train_step(images):
29    """Single training step."""
30    # Scale images to [-1, 1]
31    images = 2 * images - 1  # Assuming images are in [0, 1]
32    images = images.to(device)
33
34    # Compute loss
35    loss = ddpm.training_loss(images)
36
37    # Update model
38    optimizer.zero_grad()
39    loss.backward()
40    optimizer.step()
41
42    return loss.item()
43
44# Sampling
45@torch.no_grad()
46def generate_samples(num_samples=16):
47    """Generate images from the trained model."""
48    ddpm.eval()
49
50    # Generate samples
51    samples = ddpm.sample(
52        batch_size=num_samples,
53        image_size=64,
54        channels=3,
55        device=device,
56    )
57
58    # Scale back to [0, 1]
59    samples = (samples + 1) / 2
60    samples = samples.clamp(0, 1)
61
62    return samples
63
64# Example usage
65# for epoch in range(num_epochs):
66#     for batch in dataloader:
67#         loss = train_step(batch)
68#         print(f"Loss: {loss:.4f}")
69#
70# samples = generate_samples(16)
71# save_images(samples, "generated.png")

Training Tips

AspectRecommendation
Learning rate1e-4 to 2e-4 (AdamW)
Batch size64-256 depending on GPU memory
EMAUse exponential moving average of weights (decay=0.9999)
Gradient clippingClip to 1.0 for stability
Training time~500K-1M steps for good results on 64x64
Image scalingAlways scale to [-1, 1]

EMA is Important

Production diffusion models always use Exponential Moving Average (EMA)of the model weights for sampling. The EMA model produces significantly better samples than the training model. We'll cover this in detail in the training section.

Summary

In this section, we built a complete DDPM implementation:

  1. Noise schedule: Precomputed alpha, beta, and related values for efficient forward and reverse processes
  2. Forward process: q(xtx0)q(x_t|x_0) adds noise using the closed-form reparameterization trick
  3. Reverse process: pθ(xt1xt)p_\theta(x_{t-1}|x_t) removes noise one step at a time using the U-Net's predictions
  4. Training loss: Simple MSE between predicted and actual noise
  5. Sampling: Iterate through all timesteps from T to 1 to generate images

Coming Up Next

In the next section, we'll dive deep into the training loop: how to efficiently train DDPM on real datasets, implement EMA, handle mixed precision, and monitor training progress. We'll also discuss common issues and how to debug them.

The DDPM class we built is the foundation for all diffusion-based generation. The same principles apply to more advanced models like DDIM, Stable Diffusion, and DALL-E, which build upon this core algorithm with various improvements.