Chapter 6
18 min read
Section 32 of 76

Practical Training Tips

Building the Diffusion Model

Learning Objectives

By the end of this section, you will:

  1. Choose optimal hyperparameters for training diffusion models
  2. Tune learning rate and batch size for your hardware and dataset
  3. Understand hardware trade-offs and optimize for your setup
  4. Scale training to larger datasets efficiently
  5. Avoid common pitfalls that cause training to fail

Practical Experience Matters

Training diffusion models is as much art as science. The tips in this section come from hands-on experience with real training runs. These practical insights can save you weeks of debugging and wasted compute.

Key Hyperparameters

Here are the most important hyperparameters for diffusion training and their typical values:

HyperparameterTypical ValueRangeNotes
Learning rate2e-41e-4 to 3e-4AdamW with warmup
Batch size64-25632-512Larger = smoother gradients
EMA decay0.99990.999-0.99999Higher for longer training
Timesteps T1000250-40001000 is standard
Dropout0.0-0.10-0.20.1 for regularization
Warmup steps50001000-100005% of total training
Gradient clip1.00.5-2.0Essential for stability
Beta schedulecosinelinear/cosineCosine usually better

Start with Defaults

The default hyperparameters from the DDPM and Improved DDPM papers work well for most cases. Start with these values and only tune if you have specific issues or requirements.

Learning Rate Tuning

The learning rate is the most critical hyperparameter. Here's a practical guide:

Signs Your Learning Rate is Too High

  • Loss oscillates wildly or doesn't decrease
  • NaN values appear in loss or gradients
  • Model diverges (loss goes to infinity)
  • Samples look completely noisy even after training

Signs Your Learning Rate is Too Low

  • Loss decreases very slowly
  • Model takes much longer to converge
  • Samples are blurry after expected training time
🐍python
1# Learning rate finder for diffusion models
2def find_lr(
3    model: nn.Module,
4    ddpm: 'DDPM',
5    dataloader: DataLoader,
6    device: str = "cuda",
7    start_lr: float = 1e-7,
8    end_lr: float = 1e-2,
9    num_steps: int = 200,
10):
11    """
12    Simple learning rate range test.
13    Plot loss vs LR to find optimal range.
14    """
15    model = model.to(device)
16    optimizer = AdamW(model.parameters(), lr=start_lr)
17
18    # Exponential LR growth
19    gamma = (end_lr / start_lr) ** (1 / num_steps)
20    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
21
22    lrs, losses = [], []
23
24    for step, images in enumerate(dataloader):
25        if step >= num_steps:
26            break
27
28        images = images.to(device)
29        loss = ddpm.training_loss(images)
30
31        loss.backward()
32        optimizer.step()
33        optimizer.zero_grad()
34        scheduler.step()
35
36        lrs.append(scheduler.get_last_lr()[0])
37        losses.append(loss.item())
38
39        if loss.item() > 10 * losses[0]:  # Diverging
40            break
41
42    return lrs, losses
43
44# Use: Find where loss is lowest before it starts increasing
45# Set LR to ~1/10 of that value

Recommended approach: Start with 2×1042 \times 10^{-4} for AdamW. If training is unstable, reduce to 1×1041 \times 10^{-4}. If training is too slow, try 3×1043 \times 10^{-4}.


Batch Size Considerations

Batch size affects both training dynamics and memory usage:

Batch SizeEffectMemoryRecommendation
Small (32-64)More noise, faster iterationsLowFor debugging, small GPUs
Medium (128-256)Good balanceModerateStandard choice
Large (512+)Smoother gradients, slower convergenceHighMulti-GPU training

Scaling Learning Rate with Batch Size

When increasing batch size, scale the learning rate proportionally (linear scaling rule):

🐍python
1# Linear scaling rule
2base_batch_size = 64
3base_lr = 2e-4
4
5actual_batch_size = 256
6scaled_lr = base_lr * (actual_batch_size / base_batch_size)
7# scaled_lr = 8e-4
8
9# With gradient accumulation
10micro_batch_size = 64  # Fits in GPU
11accumulation_steps = 4  # Effective batch = 256
12effective_batch_size = micro_batch_size * accumulation_steps
13scaled_lr = base_lr * (effective_batch_size / base_batch_size)

Don't Scale Indefinitely

Linear scaling works up to a point. For very large batch sizes (>1024), the scaling becomes sublinear. Also, very large batch sizes can hurt generalization. Stick to 128-512 for most cases.

Number of Timesteps

The number of diffusion timesteps TT affects both training and sampling:

TTraining EffectSampling EffectUse Case
250-500Coarser discretizationFaster but lower qualityQuick experiments
1000Standard discretizationGood quality, 1000 stepsDefault choice
2000-4000Finer discretizationHigher quality potentialHigh-quality generation

Key insight: You can train with 1000 timesteps and sample with fewer using DDIM (Chapter 7). Training timesteps affect learning quality; sampling timesteps affect generation speed/quality trade-off.


Hardware Considerations

GPU Memory Requirements

🐍python
1# Approximate memory usage for 64x64 diffusion training
2# (batch_size=64, model ~35M params)
3
4# Model parameters: ~140 MB (FP32)
5# Optimizer states (Adam): ~280 MB
6# Gradients: ~140 MB
7# Activations: ~2-4 GB (depends on depth)
8# Total: ~3-5 GB minimum
9
10# For larger images (256x256), activations grow 16x!
11# May need gradient checkpointing or smaller batch

Multi-GPU Training

🐍python
1import torch
2import torch.distributed as dist
3from torch.nn.parallel import DistributedDataParallel as DDP
4
5def setup_distributed(rank, world_size):
6    """Initialize distributed training."""
7    dist.init_process_group(
8        backend="nccl",
9        rank=rank,
10        world_size=world_size,
11    )
12    torch.cuda.set_device(rank)
13
14def train_distributed(rank, world_size, args):
15    setup_distributed(rank, world_size)
16
17    # Create model and wrap in DDP
18    model = UNet(...).to(rank)
19    model = DDP(model, device_ids=[rank])
20
21    ddpm = DDPM(model=model, ...).to(rank)
22
23    # Distributed sampler for data
24    sampler = torch.utils.data.distributed.DistributedSampler(
25        dataset, num_replicas=world_size, rank=rank, shuffle=True
26    )
27    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_per_gpu)
28
29    # Training loop (same as single GPU)
30    for epoch in range(num_epochs):
31        sampler.set_epoch(epoch)
32        for batch in dataloader:
33            # ... training step ...
34            pass
35
36# Launch with: python -m torch.distributed.launch --nproc_per_node=4 train.py

Start Simple

Unless you need to, start with single-GPU training. Multi-GPU adds complexity (synchronization, debugging, reproducibility). Scale up only when single-GPU is genuinely limiting.

Scaling to Larger Datasets

As datasets grow, consider these optimizations:

Data Loading

🐍python
1# Efficient data loading for large datasets
2dataloader = DataLoader(
3    dataset,
4    batch_size=batch_size,
5    shuffle=True,
6    num_workers=8,  # Increase for large datasets
7    pin_memory=True,  # Faster GPU transfer
8    prefetch_factor=2,  # Prefetch batches
9    persistent_workers=True,  # Keep workers alive
10    drop_last=True,  # Consistent batch sizes
11)
12
13# For very large datasets, use webdataset or other streaming formats
14import webdataset as wds
15
16dataset = (
17    wds.WebDataset(shards_pattern)
18    .shuffle(1000)
19    .decode("pil")
20    .to_tuple("image.jpg")
21    .map(transform)
22)

Training Duration Guidelines

Dataset SizeImage ResolutionTraining TimeSteps
10K images64x64~1 day (V100)~200K steps
50K images64x64~3-5 days~500K steps
1M images64x64~2-3 weeks~1M steps
50K images256x256~1-2 weeks~500K steps

Common Pitfalls

1. Wrong Image Scaling

🐍python
1# WRONG: Images in [0, 1]
2images = images  # Still 0-1!
3loss = ddpm.training_loss(images)  # Model expects [-1, 1]
4
5# CORRECT: Scale to [-1, 1]
6images = images * 2 - 1  # Now in [-1, 1]
7loss = ddpm.training_loss(images)

2. Forgetting EMA for Sampling

🐍python
1# WRONG: Using training weights
2samples = ddpm.sample(batch_size=16, ...)  # Low quality!
3
4# CORRECT: Using EMA weights
5with ema.average_parameters():
6    samples = ddpm.sample(batch_size=16, ...)  # Much better!

3. Not Saving Optimizer State

🐍python
1# WRONG: Only saving model weights
2torch.save(model.state_dict(), "checkpoint.pt")
3# When resuming, Adam momentum is lost!
4
5# CORRECT: Save everything
6checkpoint = {
7    'model': model.state_dict(),
8    'optimizer': optimizer.state_dict(),
9    'ema': ema.state_dict(),
10    'step': global_step,
11}
12torch.save(checkpoint, "checkpoint.pt")

4. Incorrect Timestep Handling

🐍python
1# WRONG: Timesteps as floats
2t = torch.rand(batch_size) * 1000  # Float!
3noise_pred = model(x_t, t)  # May cause issues
4
5# CORRECT: Timesteps as integers
6t = torch.randint(0, 1000, (batch_size,))  # Integer
7noise_pred = model(x_t, t)
8
9# WRONG: t starts at 1
10t = torch.randint(1, 1001, (batch_size,))  # Skips t=0!
11
12# CORRECT: t starts at 0
13t = torch.randint(0, 1000, (batch_size,))  # Includes t=0

5. Memory Leaks

🐍python
1# WRONG: Storing tensors with gradients
2all_losses = []
3for batch in dataloader:
4    loss = ddpm.training_loss(batch)
5    all_losses.append(loss)  # Keeps entire computation graph!
6
7# CORRECT: Detach or use .item()
8all_losses = []
9for batch in dataloader:
10    loss = ddpm.training_loss(batch)
11    all_losses.append(loss.item())  # Just the number

Summary

Key takeaways for practical diffusion training:

  1. Start with defaults: LR=2e-4, batch_size=64-128, T=1000, EMA=0.9999
  2. Scale LR with batch size: Linear scaling up to ~512, then sublinear
  3. Use cosine schedule: Almost always better than linear
  4. Always use EMA: Essential for good sample quality
  5. Scale images to [-1, 1]: The most common source of training failures
  6. Save complete checkpoints: Model, optimizer, EMA, step count

Coming Up Next

In the final section of this chapter, we'll cover debugging and visualization: how to diagnose training issues, visualize the diffusion process, and evaluate model quality.

Training diffusion models is an iterative process. These tips will help you avoid common issues and get to good results faster. When in doubt, start simple and add complexity only when needed.