Learning Objectives
By the end of this section, you will:
- Diagnose common training issues from loss curves and samples
- Visualize the diffusion process to understand model behavior
- Evaluate model quality using FID, IS, and other metrics
- Build debugging tools for faster iteration
Diagnosing Training Issues
Reading Loss Curves
The training loss tells you a lot about what's happening:
| Loss Behavior | Likely Cause | Solution |
|---|---|---|
| Loss stays high | Wrong scaling, bad architecture | Check [-1,1] scaling, model output shape |
| Loss decreases then plateaus | Normal convergence | Continue training or increase capacity |
| Loss spikes randomly | Learning rate too high | Reduce LR, add gradient clipping |
| Loss becomes NaN | Numerical instability | Lower LR, check for division by zero |
| Loss oscillates | LR too high or batch too small | Reduce LR, increase batch size |
Analyzing Sample Quality
🐍python
1import torch
2import matplotlib.pyplot as plt
3
4def analyze_samples(trainer, num_samples=16):
5 """Generate and analyze sample quality at different stages."""
6
7 # Generate samples
8 samples = trainer.sample(num_samples, use_ema=True)
9
10 # Basic statistics
11 print(f"Sample range: [{samples.min():.3f}, {samples.max():.3f}]")
12 print(f"Sample mean: {samples.mean():.3f}")
13 print(f"Sample std: {samples.std():.3f}")
14
15 # Visualize
16 fig, axes = plt.subplots(4, 4, figsize=(10, 10))
17 for i, ax in enumerate(axes.flat):
18 if i < len(samples):
19 img = samples[i].permute(1, 2, 0).cpu().numpy()
20 ax.imshow(img)
21 ax.axis('off')
22 plt.tight_layout()
23 plt.savefig('samples.png')
24
25 return samples
26
27# What to look for:
28# - Range should be [0, 1] after scaling
29# - Mean should be ~0.5
30# - Images should have recognizable structure
31# - No repeated patterns (mode collapse)Sample Quality Issues
| Issue | Appearance | Likely Cause |
|---|---|---|
| Blurry samples | Soft, lacking detail | Undertrained, low capacity, or bad EMA |
| Noisy samples | Grainy, visible noise | Undertrained or wrong timestep handling |
| Color artifacts | Unnatural colors | Wrong normalization or data augmentation |
| Repeated patterns | Same structure everywhere | Mode collapse, check diversity |
| Blank/uniform | No structure at all | Training not working, check loss |
Visualizing the Diffusion Process
Forward Process Visualization
🐍python
1def visualize_forward_process(ddpm, image, num_steps=10):
2 """
3 Show how an image gets progressively noisier.
4 """
5 fig, axes = plt.subplots(1, num_steps + 1, figsize=(20, 3))
6
7 # Original image
8 axes[0].imshow(image.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
9 axes[0].set_title('t=0')
10 axes[0].axis('off')
11
12 # Progressive noising
13 timesteps = torch.linspace(0, ddpm.timesteps - 1, num_steps).long()
14
15 for i, t in enumerate(timesteps):
16 t_batch = t.unsqueeze(0).to(image.device)
17 x_t, _ = ddpm.q_sample(image.unsqueeze(0), t_batch)
18
19 img = x_t[0].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5
20 axes[i + 1].imshow(img.clip(0, 1))
21 axes[i + 1].set_title(f't={t.item()}')
22 axes[i + 1].axis('off')
23
24 plt.tight_layout()
25 plt.savefig('forward_process.png')
26 return figReverse Process Visualization
🐍python
1@torch.no_grad()
2def visualize_reverse_process(ddpm, image_size=64, num_snapshots=10):
3 """
4 Show how noise transforms into an image during sampling.
5 """
6 device = next(ddpm.model.parameters()).device
7
8 # Start from noise
9 x = torch.randn(1, 3, image_size, image_size, device=device)
10
11 # Collect snapshots
12 snapshot_times = torch.linspace(ddpm.timesteps - 1, 0, num_snapshots).long()
13 snapshots = []
14
15 for t in range(ddpm.timesteps - 1, -1, -1):
16 t_batch = torch.tensor([t], device=device)
17 x = ddpm.p_sample(x, t_batch)
18
19 if t in snapshot_times:
20 snapshots.append((t, x.clone()))
21
22 # Visualize
23 fig, axes = plt.subplots(1, len(snapshots), figsize=(20, 3))
24 for i, (t, sample) in enumerate(snapshots):
25 img = sample[0].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5
26 axes[i].imshow(img.clip(0, 1))
27 axes[i].set_title(f't={t}')
28 axes[i].axis('off')
29
30 plt.tight_layout()
31 plt.savefig('reverse_process.png')
32 return figNoise Prediction Visualization
🐍python
1@torch.no_grad()
2def visualize_noise_prediction(ddpm, image, timesteps=[0, 250, 500, 750, 999]):
3 """
4 Show what noise the model predicts at different timesteps.
5 """
6 device = next(ddpm.model.parameters()).device
7 image = image.unsqueeze(0).to(device)
8
9 fig, axes = plt.subplots(len(timesteps), 3, figsize=(12, 4 * len(timesteps)))
10
11 for i, t_val in enumerate(timesteps):
12 t = torch.tensor([t_val], device=device)
13
14 # Get noisy image and true noise
15 x_t, true_noise = ddpm.q_sample(image, t)
16
17 # Predict noise
18 pred_noise = ddpm.model(x_t, t)
19
20 # Visualize
21 # Noisy image
22 axes[i, 0].imshow(x_t[0].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
23 axes[i, 0].set_title(f't={t_val}: Noisy Image')
24
25 # True noise
26 noise_vis = true_noise[0].permute(1, 2, 0).cpu().numpy() * 0.25 + 0.5
27 axes[i, 1].imshow(noise_vis.clip(0, 1))
28 axes[i, 1].set_title('True Noise')
29
30 # Predicted noise
31 pred_vis = pred_noise[0].permute(1, 2, 0).cpu().numpy() * 0.25 + 0.5
32 axes[i, 2].imshow(pred_vis.clip(0, 1))
33 axes[i, 2].set_title('Predicted Noise')
34
35 for ax in axes[i]:
36 ax.axis('off')
37
38 plt.tight_layout()
39 plt.savefig('noise_prediction.png')
40 return figWhat Good Predictions Look Like
At high timesteps (t~1000), predicted and true noise should be very similar (mostly random). At low timesteps (t~0), the predicted noise should capture the fine structure that distinguishes the image from the data distribution.
Evaluation Metrics
Common metrics for evaluating diffusion models:
Frechet Inception Distance (FID)
FID measures the similarity between generated and real image distributions:
🐍python
1# Using pytorch-fid or clean-fid
2from pytorch_fid import fid_score
3
4def compute_fid(real_path, generated_path, device='cuda'):
5 """
6 Compute FID between two image directories.
7
8 Lower FID = better quality (more similar to real data)
9 FID < 10: Excellent
10 FID 10-50: Good
11 FID 50-100: Moderate
12 FID > 100: Poor
13 """
14 fid = fid_score.calculate_fid_given_paths(
15 [real_path, generated_path],
16 batch_size=50,
17 device=device,
18 dims=2048, # InceptionV3 features
19 )
20 return fid
21
22# Generate samples for FID
23def generate_for_fid(trainer, output_dir, num_samples=10000, batch_size=64):
24 """Generate samples for FID evaluation."""
25 import os
26 from torchvision.utils import save_image
27
28 os.makedirs(output_dir, exist_ok=True)
29
30 generated = 0
31 while generated < num_samples:
32 batch = min(batch_size, num_samples - generated)
33 samples = trainer.sample(batch, use_ema=True)
34
35 for i, sample in enumerate(samples):
36 save_image(sample, f"{output_dir}/{generated + i:05d}.png")
37
38 generated += batch
39
40 print(f"Generated {num_samples} samples to {output_dir}")Inception Score (IS)
Inception Score measures both quality (confident predictions) and diversity:
🐍python
1from torchmetrics.image.inception import InceptionScore
2
3def compute_is(generated_images, device='cuda'):
4 """
5 Compute Inception Score.
6
7 Higher IS = better (more confident, diverse predictions)
8 Typical good values: IS > 10 for ImageNet-like data
9 """
10 is_metric = InceptionScore(normalize=True).to(device)
11
12 # Add images in batches
13 for i in range(0, len(generated_images), 64):
14 batch = generated_images[i:i+64].to(device)
15 is_metric.update(batch)
16
17 mean, std = is_metric.compute()
18 return mean.item(), std.item()Sample Diversity
🐍python
1def compute_lpips_diversity(samples, device='cuda'):
2 """
3 Measure diversity using LPIPS distance between samples.
4 Higher = more diverse.
5 """
6 import lpips
7
8 lpips_fn = lpips.LPIPS(net='alex').to(device)
9
10 n = len(samples)
11 distances = []
12
13 for i in range(n):
14 for j in range(i + 1, n):
15 d = lpips_fn(samples[i:i+1], samples[j:j+1])
16 distances.append(d.item())
17
18 mean_distance = sum(distances) / len(distances)
19 return mean_distance # Higher = more diverse| Metric | Measures | Better Is | Typical Good Values |
|---|---|---|---|
| FID | Distribution similarity | Lower | < 50 (< 10 excellent) |
| IS | Quality + diversity | Higher | > 10 |
| LPIPS diversity | Sample variation | Higher | > 0.3 |
| Precision | Sample quality | Higher | > 0.7 |
| Recall | Mode coverage | Higher | > 0.5 |
Debugging Tools
Quick Sanity Checks
🐍python
1def sanity_check(model, ddpm, dataloader, device='cuda'):
2 """
3 Quick checks to verify training setup is correct.
4 """
5 model.to(device)
6
7 # Get a batch
8 images = next(iter(dataloader)).to(device)
9 images = images * 2 - 1 # Scale to [-1, 1]
10
11 print("=== Input Check ===")
12 print(f" Shape: {images.shape}")
13 print(f" Range: [{images.min():.3f}, {images.max():.3f}]")
14 print(f" Mean: {images.mean():.3f}")
15
16 # Check forward process
17 t = torch.randint(0, 1000, (images.shape[0],), device=device)
18 x_t, noise = ddpm.q_sample(images, t)
19
20 print("\n=== Forward Process ===")
21 print(f" x_t range: [{x_t.min():.3f}, {x_t.max():.3f}]")
22 print(f" noise range: [{noise.min():.3f}, {noise.max():.3f}]")
23
24 # Check model output
25 with torch.no_grad():
26 pred_noise = model(x_t, t)
27
28 print("\n=== Model Output ===")
29 print(f" Shape: {pred_noise.shape} (should match input)")
30 print(f" Range: [{pred_noise.min():.3f}, {pred_noise.max():.3f}]")
31 print(f" Mean: {pred_noise.mean():.3f}")
32
33 # Check loss
34 loss = ddpm.training_loss(images)
35 print(f"\n=== Loss ===")
36 print(f" Value: {loss.item():.4f}")
37 print(f" Is NaN: {torch.isnan(loss).item()}")
38
39 # Check gradients
40 loss.backward()
41 grad_norms = []
42 for name, p in model.named_parameters():
43 if p.grad is not None:
44 grad_norms.append(p.grad.norm().item())
45
46 print(f"\n=== Gradients ===")
47 print(f" Max norm: {max(grad_norms):.4f}")
48 print(f" Min norm: {min(grad_norms):.6f}")
49 print(f" Any NaN: {any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None)}")
50
51 return TrueTraining Dashboard
🐍python
1class TrainingMonitor:
2 """Simple training monitor with periodic checks."""
3
4 def __init__(self, log_every=100, sample_every=1000):
5 self.log_every = log_every
6 self.sample_every = sample_every
7 self.losses = []
8 self.step = 0
9
10 def log(self, loss, lr=None):
11 self.losses.append(loss)
12 self.step += 1
13
14 if self.step % self.log_every == 0:
15 recent = self.losses[-self.log_every:]
16 print(f"Step {self.step}: loss={sum(recent)/len(recent):.4f}")
17
18 def should_sample(self):
19 return self.step % self.sample_every == 0
20
21 def plot_loss(self, save_path='loss_curve.png'):
22 import matplotlib.pyplot as plt
23
24 # Smooth the loss curve
25 window = min(100, len(self.losses) // 10)
26 if window > 1:
27 smoothed = [
28 sum(self.losses[max(0, i-window):i+1]) / min(i+1, window)
29 for i in range(len(self.losses))
30 ]
31 else:
32 smoothed = self.losses
33
34 plt.figure(figsize=(10, 5))
35 plt.plot(smoothed)
36 plt.xlabel('Step')
37 plt.ylabel('Loss')
38 plt.title('Training Loss')
39 plt.savefig(save_path)
40 plt.close()Chapter Summary
Congratulations! You've completed Chapter 6: Building the Diffusion Model. You now have:
- A complete DDPM implementation: Forward process, reverse process, noise schedule
- A production training loop: EMA, mixed precision, checkpointing, logging
- Sampling algorithms: Full DDPM sampling with classifier-free guidance
- Practical training tips: Hyperparameters, scaling, common pitfalls
- Debugging and evaluation tools: Visualizations, FID, sanity checks
Coming Up: Chapter 7
In Chapter 7, we'll explore improved sampling methods: DDIM for fast deterministic sampling, DPM-Solver for even faster generation, and advanced techniques like ancestral sampling and guidance interpolation. These methods can reduce sampling from 1000 steps to just 20-50 with minimal quality loss!
With the complete diffusion model built and trained, you're ready to generate high-quality images. The foundation you've built here applies to all modern diffusion models, from unconditional generation to text-to-image systems like Stable Diffusion.