Learning Objectives
By the end of this section, you will:
- Understand the DDPM sampling algorithm and its mathematical foundation
- Implement efficient sampling with progress tracking and batching
- Apply classifier-free guidance for controllable generation
- Optimize sampling speed with various techniques
- Handle memory efficiently when generating many samples
From Training to Generation
Sampling Overview
The DDPM sampling process works by reversing the forward diffusion. Given a trained noise prediction model , we:
- Start with pure Gaussian noise
- For each timestep from T-1 down to 0:
- Predict the noise in the current sample
- Use this prediction to compute a less noisy version
- Optionally add a small amount of noise (stochastic sampling)
The key equation for each step is:
where and is the posterior standard deviation.
DDPM Sampling Algorithm
Here's the complete DDPM sampling algorithm:
Stochastic vs Deterministic
Step-by-Step Visualization
The following visualization shows the denoising process in action. Watch how the signal gradually emerges from noise as we iterate through timesteps:
Watch the diffusion model progressively denoise a 1D signal from pure noise to clean data
What's Happening at Each Step?
1. Predict Noise (or x0)
The neural network takes the current noisy signal x_t and timestep t, then predicts either the noise epsilon or the clean signal x_0.
2. Compute Mean
Using the prediction, we compute the mean of p(x_{t-1}|x_t), which tells us where to move in the next step.
3. Add Noise (DDPM) or Not (DDIM)
DDPM adds scaled Gaussian noise to maintain the Markov property. DDIM skips this for deterministic, faster sampling.
4. Update x
We update x_{t-1} = mean + sigma * z (DDPM) or x_{t-1} = mean (DDIM), moving one step closer to the clean data.
Key Equations
Key observations from the visualization:
- Early steps (high t): The model makes large corrections to gross structure
- Middle steps: Medium-scale features emerge
- Late steps (low t): Fine details are refined
- t=0: The final clean image (or in this case, signal)
Sampling Optimizations
DDPM sampling is slow because it requires ~1000 forward passes through the U-Net. Here are techniques to speed it up:
1. Reduced Steps
1def ddpm_sample_strided(
2 model: nn.Module,
3 schedule: 'NoiseSchedule',
4 shape: tuple,
5 num_inference_steps: int = 250, # Reduced from 1000
6 device: str = "cuda",
7) -> torch.Tensor:
8 """
9 Sample using a subset of timesteps for faster generation.
10
11 Warning: This is a simple approximation. For proper reduced-step
12 sampling, use DDIM (Chapter 7) instead.
13 """
14 # Create strided timestep sequence
15 step_size = schedule.timesteps // num_inference_steps
16 timesteps = list(range(0, schedule.timesteps, step_size))[::-1]
17
18 x = torch.randn(shape, device=device)
19
20 for t in tqdm(timesteps, desc="Sampling"):
21 t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
22 predicted_noise = model(x, t_batch)
23
24 # Use larger step
25 alpha_bar = schedule.alphas_cumprod[t]
26 next_t = max(0, t - step_size)
27 alpha_bar_next = schedule.alphas_cumprod[next_t]
28
29 # Approximate step
30 pred_x0 = (x - torch.sqrt(1 - alpha_bar) * predicted_noise) / torch.sqrt(alpha_bar)
31 pred_x0 = pred_x0.clamp(-1, 1)
32
33 x = torch.sqrt(alpha_bar_next) * pred_x0 + \
34 torch.sqrt(1 - alpha_bar_next) * torch.randn_like(x)
35
36 return xStrided Sampling Quality
2. Batched Sampling
1def sample_batch_efficient(
2 model: nn.Module,
3 ddpm: 'DDPM',
4 total_samples: int,
5 batch_size: int,
6 image_size: int,
7 device: str = "cuda",
8) -> list:
9 """
10 Generate many samples efficiently by batching.
11 """
12 all_samples = []
13 num_batches = (total_samples + batch_size - 1) // batch_size
14
15 for batch_idx in tqdm(range(num_batches), desc="Generating batches"):
16 # Adjust last batch size
17 current_batch_size = min(batch_size, total_samples - batch_idx * batch_size)
18
19 samples = ddpm.sample(
20 batch_size=current_batch_size,
21 image_size=image_size,
22 device=device,
23 show_progress=False, # Disable inner progress bar
24 )
25
26 # Move to CPU to free GPU memory
27 all_samples.append(samples.cpu())
28
29 # Clear CUDA cache periodically
30 if batch_idx % 10 == 0:
31 torch.cuda.empty_cache()
32
33 return torch.cat(all_samples, dim=0)3. Model Compilation
1import torch
2
3# PyTorch 2.0+ model compilation for faster inference
4@torch.no_grad()
5def create_compiled_sampler(model: nn.Module, device: str = "cuda"):
6 """
7 Create a compiled model for faster sampling.
8 """
9 model = model.to(device)
10 model.eval()
11
12 # Compile the model (PyTorch 2.0+)
13 compiled_model = torch.compile(model, mode="reduce-overhead")
14
15 return compiled_model
16
17# Usage
18compiled_model = create_compiled_sampler(unet)
19# First call is slow (compilation), subsequent calls are faster
20samples = ddpm_sample(compiled_model, schedule, (16, 3, 64, 64))| Technique | Speedup | Quality Impact | When to Use |
|---|---|---|---|
| Full 1000 steps | 1x (baseline) | Best | Final quality samples |
| 250 strided steps | ~4x | Noticeable degradation | Quick preview only |
| DDIM 50 steps | ~20x | Minimal | Production (Chapter 7) |
| torch.compile | 1.2-2x | None | Always for inference |
| Half precision | 1.5-2x | Minimal | With careful implementation |
Classifier-Free Guidance
Classifier-Free Guidance (CFG) improves sample quality by amplifying the conditioning signal. It's essential for conditional generation (text-to-image, class-conditional):
where is the guidance scale (typically 3-15), is the condition, and is the null condition.
1@torch.no_grad()
2def sample_with_cfg(
3 model: nn.Module,
4 schedule: 'NoiseSchedule',
5 shape: tuple,
6 condition: torch.Tensor,
7 guidance_scale: float = 7.5,
8 device: str = "cuda",
9) -> torch.Tensor:
10 """
11 Sample with classifier-free guidance.
12
13 Args:
14 model: Conditional noise prediction model
15 schedule: Noise schedule
16 shape: Output shape
17 condition: Conditioning input (e.g., class label, text embedding)
18 guidance_scale: CFG weight (higher = stronger conditioning)
19 device: Device
20
21 Returns:
22 Guided samples
23 """
24 batch_size = shape[0]
25 x = torch.randn(shape, device=device)
26
27 # Null condition (trained with dropout)
28 null_condition = torch.zeros_like(condition)
29
30 for t in tqdm(range(schedule.timesteps - 1, -1, -1)):
31 t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
32
33 # Run model twice: with and without condition
34 # Can be batched for efficiency
35 noise_cond = model(x, t_batch, condition)
36 noise_uncond = model(x, t_batch, null_condition)
37
38 # Apply guidance
39 predicted_noise = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
40
41 # Standard denoising step
42 alpha = schedule.alphas[t]
43 alpha_bar = schedule.alphas_cumprod[t]
44 beta = schedule.betas[t]
45
46 mean = (1 / torch.sqrt(alpha)) * (
47 x - (beta / torch.sqrt(1 - alpha_bar)) * predicted_noise
48 )
49
50 if t > 0:
51 x = mean + torch.sqrt(beta) * torch.randn_like(x)
52 else:
53 x = mean
54
55 return x
56
57# Optimized: Batch both predictions together
58@torch.no_grad()
59def sample_with_cfg_batched(
60 model: nn.Module,
61 schedule: 'NoiseSchedule',
62 shape: tuple,
63 condition: torch.Tensor,
64 guidance_scale: float = 7.5,
65 device: str = "cuda",
66) -> torch.Tensor:
67 """CFG with batched conditional/unconditional predictions."""
68 batch_size = shape[0]
69 x = torch.randn(shape, device=device)
70 null_condition = torch.zeros_like(condition)
71
72 for t in tqdm(range(schedule.timesteps - 1, -1, -1)):
73 # Batch both predictions
74 x_doubled = torch.cat([x, x], dim=0)
75 t_doubled = torch.full((batch_size * 2,), t, device=device, dtype=torch.long)
76 cond_doubled = torch.cat([condition, null_condition], dim=0)
77
78 # Single forward pass
79 noise_pred = model(x_doubled, t_doubled, cond_doubled)
80
81 # Split and apply guidance
82 noise_cond, noise_uncond = noise_pred.chunk(2)
83 predicted_noise = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
84
85 # Denoise step (same as before)
86 alpha = schedule.alphas[t]
87 alpha_bar = schedule.alphas_cumprod[t]
88 beta = schedule.betas[t]
89
90 mean = (1 / torch.sqrt(alpha)) * (
91 x - (beta / torch.sqrt(1 - alpha_bar)) * predicted_noise
92 )
93
94 if t > 0:
95 x = mean + torch.sqrt(beta) * torch.randn_like(x)
96 else:
97 x = mean
98
99 return xGuidance Scale Selection
Batch Sampling Strategies
When generating many samples (e.g., for evaluation), efficient batching is important:
1import torch
2from pathlib import Path
3from torchvision.utils import save_image
4
5def generate_samples_for_evaluation(
6 trainer: 'DiffusionTrainer',
7 output_dir: str,
8 num_samples: int = 10000,
9 batch_size: int = 64,
10 image_size: int = 64,
11) -> None:
12 """
13 Generate many samples for FID/IS evaluation.
14 """
15 output_path = Path(output_dir)
16 output_path.mkdir(parents=True, exist_ok=True)
17
18 samples_generated = 0
19
20 while samples_generated < num_samples:
21 current_batch = min(batch_size, num_samples - samples_generated)
22
23 # Generate batch
24 samples = trainer.sample(
25 num_samples=current_batch,
26 image_size=image_size,
27 use_ema=True,
28 )
29
30 # Save individual images
31 for i, sample in enumerate(samples):
32 idx = samples_generated + i
33 save_image(sample, output_path / f"sample_{idx:05d}.png")
34
35 samples_generated += current_batch
36 print(f"Generated {samples_generated}/{num_samples} samples")
37
38 # Memory management
39 del samples
40 torch.cuda.empty_cache()
41
42 print(f"Saved {num_samples} samples to {output_dir}")
43
44
45def generate_grid(
46 trainer: 'DiffusionTrainer',
47 nrow: int = 8,
48 ncol: int = 8,
49 image_size: int = 64,
50) -> torch.Tensor:
51 """Generate a grid of samples for visualization."""
52 num_samples = nrow * ncol
53
54 samples = trainer.sample(
55 num_samples=num_samples,
56 image_size=image_size,
57 use_ema=True,
58 )
59
60 # Arrange in grid
61 grid = samples.view(nrow, ncol, 3, image_size, image_size)
62 grid = grid.permute(0, 3, 1, 4, 2) # (nrow, H, ncol, W, C)
63 grid = grid.reshape(nrow * image_size, ncol * image_size, 3)
64
65 return grid
66
67
68def sample_with_seed(
69 trainer: 'DiffusionTrainer',
70 seed: int,
71 num_samples: int = 16,
72 image_size: int = 64,
73) -> torch.Tensor:
74 """Generate reproducible samples with a fixed seed."""
75 generator = torch.Generator(device=trainer.device)
76 generator.manual_seed(seed)
77
78 # Start from seeded noise
79 x = torch.randn(
80 num_samples, 3, image_size, image_size,
81 device=trainer.device,
82 generator=generator,
83 )
84
85 # Sample from fixed starting point
86 with trainer.ema.average_parameters():
87 for t in range(trainer.ddpm.timesteps - 1, -1, -1):
88 t_batch = torch.full((num_samples,), t, device=trainer.device, dtype=torch.long)
89 x = trainer.ddpm.p_sample(x, t_batch)
90
91 return (x + 1) / 2 # Scale to [0, 1]Summary
In this section, we covered the complete DDPM sampling algorithm:
- Basic algorithm: Iterative denoising from to
- Implementation details: Clipping, progress tracking, intermediate storage
- Optimizations: Strided steps, batching, model compilation
- Classifier-free guidance: Amplifying conditioning for better quality
- Batch generation: Efficient sampling for evaluation and visualization
Coming Up Next
The sampling algorithm is where diffusion models create their magic. While 1000 steps can be slow, the quality is remarkable. In Chapter 7, we'll explore faster samplers like DDIM that achieve similar quality in 50-100 steps.