Learning Objectives
By the end of this section, you will:
- Choose optimal hyperparameters for training diffusion models
- Tune learning rate and batch size for your hardware and dataset
- Understand hardware trade-offs and optimize for your setup
- Scale training to larger datasets efficiently
- Avoid common pitfalls that cause training to fail
Practical Experience Matters
Key Hyperparameters
Here are the most important hyperparameters for diffusion training and their typical values:
| Hyperparameter | Typical Value | Range | Notes |
|---|---|---|---|
| Learning rate | 2e-4 | 1e-4 to 3e-4 | AdamW with warmup |
| Batch size | 64-256 | 32-512 | Larger = smoother gradients |
| EMA decay | 0.9999 | 0.999-0.99999 | Higher for longer training |
| Timesteps T | 1000 | 250-4000 | 1000 is standard |
| Dropout | 0.0-0.1 | 0-0.2 | 0.1 for regularization |
| Warmup steps | 5000 | 1000-10000 | 5% of total training |
| Gradient clip | 1.0 | 0.5-2.0 | Essential for stability |
| Beta schedule | cosine | linear/cosine | Cosine usually better |
Start with Defaults
Learning Rate Tuning
The learning rate is the most critical hyperparameter. Here's a practical guide:
Signs Your Learning Rate is Too High
- Loss oscillates wildly or doesn't decrease
- NaN values appear in loss or gradients
- Model diverges (loss goes to infinity)
- Samples look completely noisy even after training
Signs Your Learning Rate is Too Low
- Loss decreases very slowly
- Model takes much longer to converge
- Samples are blurry after expected training time
1# Learning rate finder for diffusion models
2def find_lr(
3 model: nn.Module,
4 ddpm: 'DDPM',
5 dataloader: DataLoader,
6 device: str = "cuda",
7 start_lr: float = 1e-7,
8 end_lr: float = 1e-2,
9 num_steps: int = 200,
10):
11 """
12 Simple learning rate range test.
13 Plot loss vs LR to find optimal range.
14 """
15 model = model.to(device)
16 optimizer = AdamW(model.parameters(), lr=start_lr)
17
18 # Exponential LR growth
19 gamma = (end_lr / start_lr) ** (1 / num_steps)
20 scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
21
22 lrs, losses = [], []
23
24 for step, images in enumerate(dataloader):
25 if step >= num_steps:
26 break
27
28 images = images.to(device)
29 loss = ddpm.training_loss(images)
30
31 loss.backward()
32 optimizer.step()
33 optimizer.zero_grad()
34 scheduler.step()
35
36 lrs.append(scheduler.get_last_lr()[0])
37 losses.append(loss.item())
38
39 if loss.item() > 10 * losses[0]: # Diverging
40 break
41
42 return lrs, losses
43
44# Use: Find where loss is lowest before it starts increasing
45# Set LR to ~1/10 of that valueRecommended approach: Start with for AdamW. If training is unstable, reduce to . If training is too slow, try .
Batch Size Considerations
Batch size affects both training dynamics and memory usage:
| Batch Size | Effect | Memory | Recommendation |
|---|---|---|---|
| Small (32-64) | More noise, faster iterations | Low | For debugging, small GPUs |
| Medium (128-256) | Good balance | Moderate | Standard choice |
| Large (512+) | Smoother gradients, slower convergence | High | Multi-GPU training |
Scaling Learning Rate with Batch Size
When increasing batch size, scale the learning rate proportionally (linear scaling rule):
1# Linear scaling rule
2base_batch_size = 64
3base_lr = 2e-4
4
5actual_batch_size = 256
6scaled_lr = base_lr * (actual_batch_size / base_batch_size)
7# scaled_lr = 8e-4
8
9# With gradient accumulation
10micro_batch_size = 64 # Fits in GPU
11accumulation_steps = 4 # Effective batch = 256
12effective_batch_size = micro_batch_size * accumulation_steps
13scaled_lr = base_lr * (effective_batch_size / base_batch_size)Don't Scale Indefinitely
Number of Timesteps
The number of diffusion timesteps affects both training and sampling:
| T | Training Effect | Sampling Effect | Use Case |
|---|---|---|---|
| 250-500 | Coarser discretization | Faster but lower quality | Quick experiments |
| 1000 | Standard discretization | Good quality, 1000 steps | Default choice |
| 2000-4000 | Finer discretization | Higher quality potential | High-quality generation |
Key insight: You can train with 1000 timesteps and sample with fewer using DDIM (Chapter 7). Training timesteps affect learning quality; sampling timesteps affect generation speed/quality trade-off.
Hardware Considerations
GPU Memory Requirements
1# Approximate memory usage for 64x64 diffusion training
2# (batch_size=64, model ~35M params)
3
4# Model parameters: ~140 MB (FP32)
5# Optimizer states (Adam): ~280 MB
6# Gradients: ~140 MB
7# Activations: ~2-4 GB (depends on depth)
8# Total: ~3-5 GB minimum
9
10# For larger images (256x256), activations grow 16x!
11# May need gradient checkpointing or smaller batchMulti-GPU Training
1import torch
2import torch.distributed as dist
3from torch.nn.parallel import DistributedDataParallel as DDP
4
5def setup_distributed(rank, world_size):
6 """Initialize distributed training."""
7 dist.init_process_group(
8 backend="nccl",
9 rank=rank,
10 world_size=world_size,
11 )
12 torch.cuda.set_device(rank)
13
14def train_distributed(rank, world_size, args):
15 setup_distributed(rank, world_size)
16
17 # Create model and wrap in DDP
18 model = UNet(...).to(rank)
19 model = DDP(model, device_ids=[rank])
20
21 ddpm = DDPM(model=model, ...).to(rank)
22
23 # Distributed sampler for data
24 sampler = torch.utils.data.distributed.DistributedSampler(
25 dataset, num_replicas=world_size, rank=rank, shuffle=True
26 )
27 dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_per_gpu)
28
29 # Training loop (same as single GPU)
30 for epoch in range(num_epochs):
31 sampler.set_epoch(epoch)
32 for batch in dataloader:
33 # ... training step ...
34 pass
35
36# Launch with: python -m torch.distributed.launch --nproc_per_node=4 train.pyStart Simple
Scaling to Larger Datasets
As datasets grow, consider these optimizations:
Data Loading
1# Efficient data loading for large datasets
2dataloader = DataLoader(
3 dataset,
4 batch_size=batch_size,
5 shuffle=True,
6 num_workers=8, # Increase for large datasets
7 pin_memory=True, # Faster GPU transfer
8 prefetch_factor=2, # Prefetch batches
9 persistent_workers=True, # Keep workers alive
10 drop_last=True, # Consistent batch sizes
11)
12
13# For very large datasets, use webdataset or other streaming formats
14import webdataset as wds
15
16dataset = (
17 wds.WebDataset(shards_pattern)
18 .shuffle(1000)
19 .decode("pil")
20 .to_tuple("image.jpg")
21 .map(transform)
22)Training Duration Guidelines
| Dataset Size | Image Resolution | Training Time | Steps |
|---|---|---|---|
| 10K images | 64x64 | ~1 day (V100) | ~200K steps |
| 50K images | 64x64 | ~3-5 days | ~500K steps |
| 1M images | 64x64 | ~2-3 weeks | ~1M steps |
| 50K images | 256x256 | ~1-2 weeks | ~500K steps |
Common Pitfalls
1. Wrong Image Scaling
1# WRONG: Images in [0, 1]
2images = images # Still 0-1!
3loss = ddpm.training_loss(images) # Model expects [-1, 1]
4
5# CORRECT: Scale to [-1, 1]
6images = images * 2 - 1 # Now in [-1, 1]
7loss = ddpm.training_loss(images)2. Forgetting EMA for Sampling
1# WRONG: Using training weights
2samples = ddpm.sample(batch_size=16, ...) # Low quality!
3
4# CORRECT: Using EMA weights
5with ema.average_parameters():
6 samples = ddpm.sample(batch_size=16, ...) # Much better!3. Not Saving Optimizer State
1# WRONG: Only saving model weights
2torch.save(model.state_dict(), "checkpoint.pt")
3# When resuming, Adam momentum is lost!
4
5# CORRECT: Save everything
6checkpoint = {
7 'model': model.state_dict(),
8 'optimizer': optimizer.state_dict(),
9 'ema': ema.state_dict(),
10 'step': global_step,
11}
12torch.save(checkpoint, "checkpoint.pt")4. Incorrect Timestep Handling
1# WRONG: Timesteps as floats
2t = torch.rand(batch_size) * 1000 # Float!
3noise_pred = model(x_t, t) # May cause issues
4
5# CORRECT: Timesteps as integers
6t = torch.randint(0, 1000, (batch_size,)) # Integer
7noise_pred = model(x_t, t)
8
9# WRONG: t starts at 1
10t = torch.randint(1, 1001, (batch_size,)) # Skips t=0!
11
12# CORRECT: t starts at 0
13t = torch.randint(0, 1000, (batch_size,)) # Includes t=05. Memory Leaks
1# WRONG: Storing tensors with gradients
2all_losses = []
3for batch in dataloader:
4 loss = ddpm.training_loss(batch)
5 all_losses.append(loss) # Keeps entire computation graph!
6
7# CORRECT: Detach or use .item()
8all_losses = []
9for batch in dataloader:
10 loss = ddpm.training_loss(batch)
11 all_losses.append(loss.item()) # Just the numberSummary
Key takeaways for practical diffusion training:
- Start with defaults: LR=2e-4, batch_size=64-128, T=1000, EMA=0.9999
- Scale LR with batch size: Linear scaling up to ~512, then sublinear
- Use cosine schedule: Almost always better than linear
- Always use EMA: Essential for good sample quality
- Scale images to [-1, 1]: The most common source of training failures
- Save complete checkpoints: Model, optimizer, EMA, step count
Coming Up Next
Training diffusion models is an iterative process. These tips will help you avoid common issues and get to good results faster. When in doubt, start simple and add complexity only when needed.