Chapter 6
15 min read
Section 33 of 76

Debugging and Visualization

Building the Diffusion Model

Learning Objectives

By the end of this section, you will:

  1. Diagnose common training issues from loss curves and samples
  2. Visualize the diffusion process to understand model behavior
  3. Evaluate model quality using FID, IS, and other metrics
  4. Build debugging tools for faster iteration

Diagnosing Training Issues

Reading Loss Curves

The training loss tells you a lot about what's happening:

Loss BehaviorLikely CauseSolution
Loss stays highWrong scaling, bad architectureCheck [-1,1] scaling, model output shape
Loss decreases then plateausNormal convergenceContinue training or increase capacity
Loss spikes randomlyLearning rate too highReduce LR, add gradient clipping
Loss becomes NaNNumerical instabilityLower LR, check for division by zero
Loss oscillatesLR too high or batch too smallReduce LR, increase batch size

Analyzing Sample Quality

🐍python
1import torch
2import matplotlib.pyplot as plt
3
4def analyze_samples(trainer, num_samples=16):
5    """Generate and analyze sample quality at different stages."""
6
7    # Generate samples
8    samples = trainer.sample(num_samples, use_ema=True)
9
10    # Basic statistics
11    print(f"Sample range: [{samples.min():.3f}, {samples.max():.3f}]")
12    print(f"Sample mean: {samples.mean():.3f}")
13    print(f"Sample std: {samples.std():.3f}")
14
15    # Visualize
16    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
17    for i, ax in enumerate(axes.flat):
18        if i < len(samples):
19            img = samples[i].permute(1, 2, 0).cpu().numpy()
20            ax.imshow(img)
21        ax.axis('off')
22    plt.tight_layout()
23    plt.savefig('samples.png')
24
25    return samples
26
27# What to look for:
28# - Range should be [0, 1] after scaling
29# - Mean should be ~0.5
30# - Images should have recognizable structure
31# - No repeated patterns (mode collapse)

Sample Quality Issues

IssueAppearanceLikely Cause
Blurry samplesSoft, lacking detailUndertrained, low capacity, or bad EMA
Noisy samplesGrainy, visible noiseUndertrained or wrong timestep handling
Color artifactsUnnatural colorsWrong normalization or data augmentation
Repeated patternsSame structure everywhereMode collapse, check diversity
Blank/uniformNo structure at allTraining not working, check loss

Visualizing the Diffusion Process

Forward Process Visualization

🐍python
1def visualize_forward_process(ddpm, image, num_steps=10):
2    """
3    Show how an image gets progressively noisier.
4    """
5    fig, axes = plt.subplots(1, num_steps + 1, figsize=(20, 3))
6
7    # Original image
8    axes[0].imshow(image.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
9    axes[0].set_title('t=0')
10    axes[0].axis('off')
11
12    # Progressive noising
13    timesteps = torch.linspace(0, ddpm.timesteps - 1, num_steps).long()
14
15    for i, t in enumerate(timesteps):
16        t_batch = t.unsqueeze(0).to(image.device)
17        x_t, _ = ddpm.q_sample(image.unsqueeze(0), t_batch)
18
19        img = x_t[0].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5
20        axes[i + 1].imshow(img.clip(0, 1))
21        axes[i + 1].set_title(f't={t.item()}')
22        axes[i + 1].axis('off')
23
24    plt.tight_layout()
25    plt.savefig('forward_process.png')
26    return fig

Reverse Process Visualization

🐍python
1@torch.no_grad()
2def visualize_reverse_process(ddpm, image_size=64, num_snapshots=10):
3    """
4    Show how noise transforms into an image during sampling.
5    """
6    device = next(ddpm.model.parameters()).device
7
8    # Start from noise
9    x = torch.randn(1, 3, image_size, image_size, device=device)
10
11    # Collect snapshots
12    snapshot_times = torch.linspace(ddpm.timesteps - 1, 0, num_snapshots).long()
13    snapshots = []
14
15    for t in range(ddpm.timesteps - 1, -1, -1):
16        t_batch = torch.tensor([t], device=device)
17        x = ddpm.p_sample(x, t_batch)
18
19        if t in snapshot_times:
20            snapshots.append((t, x.clone()))
21
22    # Visualize
23    fig, axes = plt.subplots(1, len(snapshots), figsize=(20, 3))
24    for i, (t, sample) in enumerate(snapshots):
25        img = sample[0].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5
26        axes[i].imshow(img.clip(0, 1))
27        axes[i].set_title(f't={t}')
28        axes[i].axis('off')
29
30    plt.tight_layout()
31    plt.savefig('reverse_process.png')
32    return fig

Noise Prediction Visualization

🐍python
1@torch.no_grad()
2def visualize_noise_prediction(ddpm, image, timesteps=[0, 250, 500, 750, 999]):
3    """
4    Show what noise the model predicts at different timesteps.
5    """
6    device = next(ddpm.model.parameters()).device
7    image = image.unsqueeze(0).to(device)
8
9    fig, axes = plt.subplots(len(timesteps), 3, figsize=(12, 4 * len(timesteps)))
10
11    for i, t_val in enumerate(timesteps):
12        t = torch.tensor([t_val], device=device)
13
14        # Get noisy image and true noise
15        x_t, true_noise = ddpm.q_sample(image, t)
16
17        # Predict noise
18        pred_noise = ddpm.model(x_t, t)
19
20        # Visualize
21        # Noisy image
22        axes[i, 0].imshow(x_t[0].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
23        axes[i, 0].set_title(f't={t_val}: Noisy Image')
24
25        # True noise
26        noise_vis = true_noise[0].permute(1, 2, 0).cpu().numpy() * 0.25 + 0.5
27        axes[i, 1].imshow(noise_vis.clip(0, 1))
28        axes[i, 1].set_title('True Noise')
29
30        # Predicted noise
31        pred_vis = pred_noise[0].permute(1, 2, 0).cpu().numpy() * 0.25 + 0.5
32        axes[i, 2].imshow(pred_vis.clip(0, 1))
33        axes[i, 2].set_title('Predicted Noise')
34
35        for ax in axes[i]:
36            ax.axis('off')
37
38    plt.tight_layout()
39    plt.savefig('noise_prediction.png')
40    return fig

What Good Predictions Look Like

At high timesteps (t~1000), predicted and true noise should be very similar (mostly random). At low timesteps (t~0), the predicted noise should capture the fine structure that distinguishes the image from the data distribution.

Evaluation Metrics

Common metrics for evaluating diffusion models:

Frechet Inception Distance (FID)

FID measures the similarity between generated and real image distributions:

🐍python
1# Using pytorch-fid or clean-fid
2from pytorch_fid import fid_score
3
4def compute_fid(real_path, generated_path, device='cuda'):
5    """
6    Compute FID between two image directories.
7
8    Lower FID = better quality (more similar to real data)
9    FID < 10: Excellent
10    FID 10-50: Good
11    FID 50-100: Moderate
12    FID > 100: Poor
13    """
14    fid = fid_score.calculate_fid_given_paths(
15        [real_path, generated_path],
16        batch_size=50,
17        device=device,
18        dims=2048,  # InceptionV3 features
19    )
20    return fid
21
22# Generate samples for FID
23def generate_for_fid(trainer, output_dir, num_samples=10000, batch_size=64):
24    """Generate samples for FID evaluation."""
25    import os
26    from torchvision.utils import save_image
27
28    os.makedirs(output_dir, exist_ok=True)
29
30    generated = 0
31    while generated < num_samples:
32        batch = min(batch_size, num_samples - generated)
33        samples = trainer.sample(batch, use_ema=True)
34
35        for i, sample in enumerate(samples):
36            save_image(sample, f"{output_dir}/{generated + i:05d}.png")
37
38        generated += batch
39
40    print(f"Generated {num_samples} samples to {output_dir}")

Inception Score (IS)

Inception Score measures both quality (confident predictions) and diversity:

🐍python
1from torchmetrics.image.inception import InceptionScore
2
3def compute_is(generated_images, device='cuda'):
4    """
5    Compute Inception Score.
6
7    Higher IS = better (more confident, diverse predictions)
8    Typical good values: IS > 10 for ImageNet-like data
9    """
10    is_metric = InceptionScore(normalize=True).to(device)
11
12    # Add images in batches
13    for i in range(0, len(generated_images), 64):
14        batch = generated_images[i:i+64].to(device)
15        is_metric.update(batch)
16
17    mean, std = is_metric.compute()
18    return mean.item(), std.item()

Sample Diversity

🐍python
1def compute_lpips_diversity(samples, device='cuda'):
2    """
3    Measure diversity using LPIPS distance between samples.
4    Higher = more diverse.
5    """
6    import lpips
7
8    lpips_fn = lpips.LPIPS(net='alex').to(device)
9
10    n = len(samples)
11    distances = []
12
13    for i in range(n):
14        for j in range(i + 1, n):
15            d = lpips_fn(samples[i:i+1], samples[j:j+1])
16            distances.append(d.item())
17
18    mean_distance = sum(distances) / len(distances)
19    return mean_distance  # Higher = more diverse
MetricMeasuresBetter IsTypical Good Values
FIDDistribution similarityLower< 50 (< 10 excellent)
ISQuality + diversityHigher> 10
LPIPS diversitySample variationHigher> 0.3
PrecisionSample qualityHigher> 0.7
RecallMode coverageHigher> 0.5

Debugging Tools

Quick Sanity Checks

🐍python
1def sanity_check(model, ddpm, dataloader, device='cuda'):
2    """
3    Quick checks to verify training setup is correct.
4    """
5    model.to(device)
6
7    # Get a batch
8    images = next(iter(dataloader)).to(device)
9    images = images * 2 - 1  # Scale to [-1, 1]
10
11    print("=== Input Check ===")
12    print(f"  Shape: {images.shape}")
13    print(f"  Range: [{images.min():.3f}, {images.max():.3f}]")
14    print(f"  Mean: {images.mean():.3f}")
15
16    # Check forward process
17    t = torch.randint(0, 1000, (images.shape[0],), device=device)
18    x_t, noise = ddpm.q_sample(images, t)
19
20    print("\n=== Forward Process ===")
21    print(f"  x_t range: [{x_t.min():.3f}, {x_t.max():.3f}]")
22    print(f"  noise range: [{noise.min():.3f}, {noise.max():.3f}]")
23
24    # Check model output
25    with torch.no_grad():
26        pred_noise = model(x_t, t)
27
28    print("\n=== Model Output ===")
29    print(f"  Shape: {pred_noise.shape} (should match input)")
30    print(f"  Range: [{pred_noise.min():.3f}, {pred_noise.max():.3f}]")
31    print(f"  Mean: {pred_noise.mean():.3f}")
32
33    # Check loss
34    loss = ddpm.training_loss(images)
35    print(f"\n=== Loss ===")
36    print(f"  Value: {loss.item():.4f}")
37    print(f"  Is NaN: {torch.isnan(loss).item()}")
38
39    # Check gradients
40    loss.backward()
41    grad_norms = []
42    for name, p in model.named_parameters():
43        if p.grad is not None:
44            grad_norms.append(p.grad.norm().item())
45
46    print(f"\n=== Gradients ===")
47    print(f"  Max norm: {max(grad_norms):.4f}")
48    print(f"  Min norm: {min(grad_norms):.6f}")
49    print(f"  Any NaN: {any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None)}")
50
51    return True

Training Dashboard

🐍python
1class TrainingMonitor:
2    """Simple training monitor with periodic checks."""
3
4    def __init__(self, log_every=100, sample_every=1000):
5        self.log_every = log_every
6        self.sample_every = sample_every
7        self.losses = []
8        self.step = 0
9
10    def log(self, loss, lr=None):
11        self.losses.append(loss)
12        self.step += 1
13
14        if self.step % self.log_every == 0:
15            recent = self.losses[-self.log_every:]
16            print(f"Step {self.step}: loss={sum(recent)/len(recent):.4f}")
17
18    def should_sample(self):
19        return self.step % self.sample_every == 0
20
21    def plot_loss(self, save_path='loss_curve.png'):
22        import matplotlib.pyplot as plt
23
24        # Smooth the loss curve
25        window = min(100, len(self.losses) // 10)
26        if window > 1:
27            smoothed = [
28                sum(self.losses[max(0, i-window):i+1]) / min(i+1, window)
29                for i in range(len(self.losses))
30            ]
31        else:
32            smoothed = self.losses
33
34        plt.figure(figsize=(10, 5))
35        plt.plot(smoothed)
36        plt.xlabel('Step')
37        plt.ylabel('Loss')
38        plt.title('Training Loss')
39        plt.savefig(save_path)
40        plt.close()

Chapter Summary

Congratulations! You've completed Chapter 6: Building the Diffusion Model. You now have:

  1. A complete DDPM implementation: Forward process, reverse process, noise schedule
  2. A production training loop: EMA, mixed precision, checkpointing, logging
  3. Sampling algorithms: Full DDPM sampling with classifier-free guidance
  4. Practical training tips: Hyperparameters, scaling, common pitfalls
  5. Debugging and evaluation tools: Visualizations, FID, sanity checks

Coming Up: Chapter 7

In Chapter 7, we'll explore improved sampling methods: DDIM for fast deterministic sampling, DPM-Solver for even faster generation, and advanced techniques like ancestral sampling and guidance interpolation. These methods can reduce sampling from 1000 steps to just 20-50 with minimal quality loss!

With the complete diffusion model built and trained, you're ready to generate high-quality images. The foundation you've built here applies to all modern diffusion models, from unconditional generation to text-to-image systems like Stable Diffusion.