Chapter 6
20 min read
Section 31 of 76

Sampling Algorithm

Building the Diffusion Model

Learning Objectives

By the end of this section, you will:

  1. Understand the DDPM sampling algorithm and its mathematical foundation
  2. Implement efficient sampling with progress tracking and batching
  3. Apply classifier-free guidance for controllable generation
  4. Optimize sampling speed with various techniques
  5. Handle memory efficiently when generating many samples

From Training to Generation

Training teaches the model to predict noise. Sampling uses this ability to generate new images by starting from noise and iteratively denoising. This section covers the practical details of turning a trained model into a working image generator.

Sampling Overview

The DDPM sampling process works by reversing the forward diffusion. Given a trained noise prediction model ϵθ\epsilon_\theta, we:

  1. Start with pure Gaussian noise xTN(0,I)x_T \sim \mathcal{N}(0, I)
  2. For each timestep from T-1 down to 0:
  3. Predict the noise in the current sample
  4. Use this prediction to compute a less noisy version
  5. Optionally add a small amount of noise (stochastic sampling)

The key equation for each step is:

xt1=1αt(xtβt1αˉtϵθ(xt,t))+σtzx_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z

where zN(0,I)z \sim \mathcal{N}(0, I) and σt\sigma_t is the posterior standard deviation.


DDPM Sampling Algorithm

Here's the complete DDPM sampling algorithm:

Complete DDPM Sampling
🐍sampling.py
1Sampling Function

The main sampling function generates images from pure noise by iteratively applying the reverse process.

8Initialize from Noise

We start from pure Gaussian noise x_T ~ N(0, I). This is our starting point in the diffusion process.

12Reverse Iteration

We iterate from t=T-1 down to t=0, applying one denoising step at each timestep. Order matters: we go backwards.

18Predict Noise

The U-Net predicts the noise present in x_t given the current timestep t.

24Compute x_{t-1}

Using the predicted noise and schedule parameters, we compute the next (less noisy) sample.

32Stochastic Noise

For DDPM, we add noise at each step (except the last). This maintains the Markovian structure of the reverse process.

110 lines without explanation
1import torch
2import torch.nn as nn
3from tqdm import tqdm
4from typing import Optional, Callable
5
6@torch.no_grad()
7def ddpm_sample(
8    model: nn.Module,
9    schedule: 'NoiseSchedule',
10    shape: tuple,
11    device: str = "cuda",
12    clip_denoised: bool = True,
13    progress_callback: Optional[Callable] = None,
14) -> torch.Tensor:
15    """
16    Generate samples using DDPM reverse process.
17
18    Args:
19        model: Trained noise prediction U-Net
20        schedule: Noise schedule with precomputed values
21        shape: Output shape (batch_size, channels, height, width)
22        device: Device to generate on
23        clip_denoised: Whether to clip intermediate predictions
24        progress_callback: Optional callback for progress updates
25
26    Returns:
27        samples: Generated images in [-1, 1]
28    """
29    batch_size = shape[0]
30
31    # Start from pure noise
32    x = torch.randn(shape, device=device)
33
34    # Iterate from T-1 to 0
35    timesteps = list(range(schedule.timesteps))[::-1]
36
37    for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
38        # Create batch of timesteps
39        t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
40
41        # Predict noise
42        predicted_noise = model(x, t_batch)
43
44        # Get schedule values for this timestep
45        alpha = schedule.alphas[t]
46        alpha_bar = schedule.alphas_cumprod[t]
47        beta = schedule.betas[t]
48
49        # Compute predicted x_0 (for optional clipping)
50        if clip_denoised:
51            # x_0 = (x_t - sqrt(1-alpha_bar) * eps) / sqrt(alpha_bar)
52            pred_x0 = (x - torch.sqrt(1 - alpha_bar) * predicted_noise) / torch.sqrt(alpha_bar)
53            pred_x0 = pred_x0.clamp(-1, 1)
54
55            # Recompute noise from clipped x_0
56            predicted_noise = (x - torch.sqrt(alpha_bar) * pred_x0) / torch.sqrt(1 - alpha_bar)
57
58        # Compute mean of p(x_{t-1} | x_t)
59        # mu = (1/sqrt(alpha)) * (x_t - beta/sqrt(1-alpha_bar) * eps)
60        mean = (1 / torch.sqrt(alpha)) * (
61            x - (beta / torch.sqrt(1 - alpha_bar)) * predicted_noise
62        )
63
64        # Add noise for t > 0
65        if t > 0:
66            noise = torch.randn_like(x)
67            sigma = torch.sqrt(beta)  # Simplified variance
68            x = mean + sigma * noise
69        else:
70            x = mean  # No noise at t=0
71
72        # Optional progress callback
73        if progress_callback:
74            progress_callback(i, len(timesteps), x)
75
76    return x
77
78
79# Alternative: Return intermediate samples for visualization
80@torch.no_grad()
81def ddpm_sample_with_intermediates(
82    model: nn.Module,
83    schedule: 'NoiseSchedule',
84    shape: tuple,
85    device: str = "cuda",
86    save_every: int = 100,
87) -> tuple[torch.Tensor, list]:
88    """
89    Sample with intermediate results for visualization.
90    """
91    intermediates = []
92    x = torch.randn(shape, device=device)
93
94    for t in tqdm(range(schedule.timesteps - 1, -1, -1), desc="Sampling"):
95        t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
96
97        # ... same denoising logic as above ...
98        predicted_noise = model(x, t_batch)
99        alpha = schedule.alphas[t]
100        alpha_bar = schedule.alphas_cumprod[t]
101        beta = schedule.betas[t]
102
103        mean = (1 / torch.sqrt(alpha)) * (
104            x - (beta / torch.sqrt(1 - alpha_bar)) * predicted_noise
105        )
106
107        if t > 0:
108            x = mean + torch.sqrt(beta) * torch.randn_like(x)
109        else:
110            x = mean
111
112        # Save intermediate
113        if t % save_every == 0 or t == 0:
114            intermediates.append((t, x.clone()))
115
116    return x, intermediates

Stochastic vs Deterministic

The noise added at each step (σtz\sigma_t z) makes DDPM sampling stochastic: the same starting noise can produce different outputs. DDIM (covered in Chapter 7) removes this noise for deterministic sampling.

Step-by-Step Visualization

The following visualization shows the denoising process in action. Watch how the signal gradually emerges from noise as we iterate through timesteps:

Step-by-Step Denoising Animation

Watch the diffusion model progressively denoise a 1D signal from pure noise to clean data

Noise level: 100%
+10-1Current x_tTarget x_0Predicted x_0
Current Timestep
t = 20
Pure noise
Noise Schedule
alpha_bar = 0.000
Signal: 0% | Noise: 100%
Sampling Mode
DDPM
Stochastic reverse process

What's Happening at Each Step?

1. Predict Noise (or x0)

The neural network takes the current noisy signal x_t and timestep t, then predicts either the noise epsilon or the clean signal x_0.

2. Compute Mean

Using the prediction, we compute the mean of p(x_{t-1}|x_t), which tells us where to move in the next step.

3. Add Noise (DDPM) or Not (DDIM)

DDPM adds scaled Gaussian noise to maintain the Markov property. DDIM skips this for deterministic, faster sampling.

4. Update x

We update x_{t-1} = mean + sigma * z (DDPM) or x_{t-1} = mean (DDIM), moving one step closer to the clean data.

Key Equations

DDPM Reverse Step:
x_{t-1} = (1/sqrt(alpha_t)) * (x_t - (beta_t/sqrt(1-alpha_bar_t)) * epsilon_theta) + sigma_t * z
DDIM Reverse Step:
x_{t-1} = sqrt(alpha_bar_{t-1}) * pred_x0 + sqrt(1-alpha_bar_{t-1}) * direction_pointing_to_x_t

Key observations from the visualization:

  • Early steps (high t): The model makes large corrections to gross structure
  • Middle steps: Medium-scale features emerge
  • Late steps (low t): Fine details are refined
  • t=0: The final clean image (or in this case, signal)

Sampling Optimizations

DDPM sampling is slow because it requires ~1000 forward passes through the U-Net. Here are techniques to speed it up:

1. Reduced Steps

🐍python
1def ddpm_sample_strided(
2    model: nn.Module,
3    schedule: 'NoiseSchedule',
4    shape: tuple,
5    num_inference_steps: int = 250,  # Reduced from 1000
6    device: str = "cuda",
7) -> torch.Tensor:
8    """
9    Sample using a subset of timesteps for faster generation.
10
11    Warning: This is a simple approximation. For proper reduced-step
12    sampling, use DDIM (Chapter 7) instead.
13    """
14    # Create strided timestep sequence
15    step_size = schedule.timesteps // num_inference_steps
16    timesteps = list(range(0, schedule.timesteps, step_size))[::-1]
17
18    x = torch.randn(shape, device=device)
19
20    for t in tqdm(timesteps, desc="Sampling"):
21        t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
22        predicted_noise = model(x, t_batch)
23
24        # Use larger step
25        alpha_bar = schedule.alphas_cumprod[t]
26        next_t = max(0, t - step_size)
27        alpha_bar_next = schedule.alphas_cumprod[next_t]
28
29        # Approximate step
30        pred_x0 = (x - torch.sqrt(1 - alpha_bar) * predicted_noise) / torch.sqrt(alpha_bar)
31        pred_x0 = pred_x0.clamp(-1, 1)
32
33        x = torch.sqrt(alpha_bar_next) * pred_x0 + \
34            torch.sqrt(1 - alpha_bar_next) * torch.randn_like(x)
35
36    return x

Strided Sampling Quality

Simple strided sampling degrades quality significantly. For proper fast sampling, use DDIM or other advanced samplers (Chapter 7). The above is only for understanding the concept.

2. Batched Sampling

🐍python
1def sample_batch_efficient(
2    model: nn.Module,
3    ddpm: 'DDPM',
4    total_samples: int,
5    batch_size: int,
6    image_size: int,
7    device: str = "cuda",
8) -> list:
9    """
10    Generate many samples efficiently by batching.
11    """
12    all_samples = []
13    num_batches = (total_samples + batch_size - 1) // batch_size
14
15    for batch_idx in tqdm(range(num_batches), desc="Generating batches"):
16        # Adjust last batch size
17        current_batch_size = min(batch_size, total_samples - batch_idx * batch_size)
18
19        samples = ddpm.sample(
20            batch_size=current_batch_size,
21            image_size=image_size,
22            device=device,
23            show_progress=False,  # Disable inner progress bar
24        )
25
26        # Move to CPU to free GPU memory
27        all_samples.append(samples.cpu())
28
29        # Clear CUDA cache periodically
30        if batch_idx % 10 == 0:
31            torch.cuda.empty_cache()
32
33    return torch.cat(all_samples, dim=0)

3. Model Compilation

🐍python
1import torch
2
3# PyTorch 2.0+ model compilation for faster inference
4@torch.no_grad()
5def create_compiled_sampler(model: nn.Module, device: str = "cuda"):
6    """
7    Create a compiled model for faster sampling.
8    """
9    model = model.to(device)
10    model.eval()
11
12    # Compile the model (PyTorch 2.0+)
13    compiled_model = torch.compile(model, mode="reduce-overhead")
14
15    return compiled_model
16
17# Usage
18compiled_model = create_compiled_sampler(unet)
19# First call is slow (compilation), subsequent calls are faster
20samples = ddpm_sample(compiled_model, schedule, (16, 3, 64, 64))
TechniqueSpeedupQuality ImpactWhen to Use
Full 1000 steps1x (baseline)BestFinal quality samples
250 strided steps~4xNoticeable degradationQuick preview only
DDIM 50 steps~20xMinimalProduction (Chapter 7)
torch.compile1.2-2xNoneAlways for inference
Half precision1.5-2xMinimalWith careful implementation

Classifier-Free Guidance

Classifier-Free Guidance (CFG) improves sample quality by amplifying the conditioning signal. It's essential for conditional generation (text-to-image, class-conditional):

ϵ~θ(xt,t,c)=ϵθ(xt,t,)+w(ϵθ(xt,t,c)ϵθ(xt,t,))\tilde{\epsilon}_\theta(x_t, t, c) = \epsilon_\theta(x_t, t, \emptyset) + w \cdot (\epsilon_\theta(x_t, t, c) - \epsilon_\theta(x_t, t, \emptyset))

where ww is the guidance scale (typically 3-15), cc is the condition, and \emptyset is the null condition.

🐍python
1@torch.no_grad()
2def sample_with_cfg(
3    model: nn.Module,
4    schedule: 'NoiseSchedule',
5    shape: tuple,
6    condition: torch.Tensor,
7    guidance_scale: float = 7.5,
8    device: str = "cuda",
9) -> torch.Tensor:
10    """
11    Sample with classifier-free guidance.
12
13    Args:
14        model: Conditional noise prediction model
15        schedule: Noise schedule
16        shape: Output shape
17        condition: Conditioning input (e.g., class label, text embedding)
18        guidance_scale: CFG weight (higher = stronger conditioning)
19        device: Device
20
21    Returns:
22        Guided samples
23    """
24    batch_size = shape[0]
25    x = torch.randn(shape, device=device)
26
27    # Null condition (trained with dropout)
28    null_condition = torch.zeros_like(condition)
29
30    for t in tqdm(range(schedule.timesteps - 1, -1, -1)):
31        t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
32
33        # Run model twice: with and without condition
34        # Can be batched for efficiency
35        noise_cond = model(x, t_batch, condition)
36        noise_uncond = model(x, t_batch, null_condition)
37
38        # Apply guidance
39        predicted_noise = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
40
41        # Standard denoising step
42        alpha = schedule.alphas[t]
43        alpha_bar = schedule.alphas_cumprod[t]
44        beta = schedule.betas[t]
45
46        mean = (1 / torch.sqrt(alpha)) * (
47            x - (beta / torch.sqrt(1 - alpha_bar)) * predicted_noise
48        )
49
50        if t > 0:
51            x = mean + torch.sqrt(beta) * torch.randn_like(x)
52        else:
53            x = mean
54
55    return x
56
57# Optimized: Batch both predictions together
58@torch.no_grad()
59def sample_with_cfg_batched(
60    model: nn.Module,
61    schedule: 'NoiseSchedule',
62    shape: tuple,
63    condition: torch.Tensor,
64    guidance_scale: float = 7.5,
65    device: str = "cuda",
66) -> torch.Tensor:
67    """CFG with batched conditional/unconditional predictions."""
68    batch_size = shape[0]
69    x = torch.randn(shape, device=device)
70    null_condition = torch.zeros_like(condition)
71
72    for t in tqdm(range(schedule.timesteps - 1, -1, -1)):
73        # Batch both predictions
74        x_doubled = torch.cat([x, x], dim=0)
75        t_doubled = torch.full((batch_size * 2,), t, device=device, dtype=torch.long)
76        cond_doubled = torch.cat([condition, null_condition], dim=0)
77
78        # Single forward pass
79        noise_pred = model(x_doubled, t_doubled, cond_doubled)
80
81        # Split and apply guidance
82        noise_cond, noise_uncond = noise_pred.chunk(2)
83        predicted_noise = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
84
85        # Denoise step (same as before)
86        alpha = schedule.alphas[t]
87        alpha_bar = schedule.alphas_cumprod[t]
88        beta = schedule.betas[t]
89
90        mean = (1 / torch.sqrt(alpha)) * (
91            x - (beta / torch.sqrt(1 - alpha_bar)) * predicted_noise
92        )
93
94        if t > 0:
95            x = mean + torch.sqrt(beta) * torch.randn_like(x)
96        else:
97            x = mean
98
99    return x

Guidance Scale Selection

guidance_scale=7.5 is a common starting point for text-to-image models. Higher values (10-15) give more adherence to the condition but can reduce diversity and cause artifacts. Lower values (3-5) give more variety but weaker conditioning. Tune based on your use case.

Batch Sampling Strategies

When generating many samples (e.g., for evaluation), efficient batching is important:

🐍python
1import torch
2from pathlib import Path
3from torchvision.utils import save_image
4
5def generate_samples_for_evaluation(
6    trainer: 'DiffusionTrainer',
7    output_dir: str,
8    num_samples: int = 10000,
9    batch_size: int = 64,
10    image_size: int = 64,
11) -> None:
12    """
13    Generate many samples for FID/IS evaluation.
14    """
15    output_path = Path(output_dir)
16    output_path.mkdir(parents=True, exist_ok=True)
17
18    samples_generated = 0
19
20    while samples_generated < num_samples:
21        current_batch = min(batch_size, num_samples - samples_generated)
22
23        # Generate batch
24        samples = trainer.sample(
25            num_samples=current_batch,
26            image_size=image_size,
27            use_ema=True,
28        )
29
30        # Save individual images
31        for i, sample in enumerate(samples):
32            idx = samples_generated + i
33            save_image(sample, output_path / f"sample_{idx:05d}.png")
34
35        samples_generated += current_batch
36        print(f"Generated {samples_generated}/{num_samples} samples")
37
38        # Memory management
39        del samples
40        torch.cuda.empty_cache()
41
42    print(f"Saved {num_samples} samples to {output_dir}")
43
44
45def generate_grid(
46    trainer: 'DiffusionTrainer',
47    nrow: int = 8,
48    ncol: int = 8,
49    image_size: int = 64,
50) -> torch.Tensor:
51    """Generate a grid of samples for visualization."""
52    num_samples = nrow * ncol
53
54    samples = trainer.sample(
55        num_samples=num_samples,
56        image_size=image_size,
57        use_ema=True,
58    )
59
60    # Arrange in grid
61    grid = samples.view(nrow, ncol, 3, image_size, image_size)
62    grid = grid.permute(0, 3, 1, 4, 2)  # (nrow, H, ncol, W, C)
63    grid = grid.reshape(nrow * image_size, ncol * image_size, 3)
64
65    return grid
66
67
68def sample_with_seed(
69    trainer: 'DiffusionTrainer',
70    seed: int,
71    num_samples: int = 16,
72    image_size: int = 64,
73) -> torch.Tensor:
74    """Generate reproducible samples with a fixed seed."""
75    generator = torch.Generator(device=trainer.device)
76    generator.manual_seed(seed)
77
78    # Start from seeded noise
79    x = torch.randn(
80        num_samples, 3, image_size, image_size,
81        device=trainer.device,
82        generator=generator,
83    )
84
85    # Sample from fixed starting point
86    with trainer.ema.average_parameters():
87        for t in range(trainer.ddpm.timesteps - 1, -1, -1):
88            t_batch = torch.full((num_samples,), t, device=trainer.device, dtype=torch.long)
89            x = trainer.ddpm.p_sample(x, t_batch)
90
91    return (x + 1) / 2  # Scale to [0, 1]

Summary

In this section, we covered the complete DDPM sampling algorithm:

  1. Basic algorithm: Iterative denoising from xTx_T to x0x_0
  2. Implementation details: Clipping, progress tracking, intermediate storage
  3. Optimizations: Strided steps, batching, model compilation
  4. Classifier-free guidance: Amplifying conditioning for better quality
  5. Batch generation: Efficient sampling for evaluation and visualization

Coming Up Next

In the next section, we'll cover practical training tips: hyperparameter tuning, scaling to larger datasets, hardware considerations, and common pitfalls to avoid.

The sampling algorithm is where diffusion models create their magic. While 1000 steps can be slow, the quality is remarkable. In Chapter 7, we'll explore faster samplers like DDIM that achieve similar quality in 50-100 steps.