Learning Objectives
By the end of this section, you will:
- Implement a production-ready training loop for diffusion models
- Understand and implement EMA (Exponential Moving Average) for better samples
- Use mixed precision training for faster training and lower memory usage
- Handle checkpointing for resumable training
- Debug common training issues like NaN losses and poor convergence
Training is Where Theory Meets Practice
Training Overview
The diffusion training loop has a simple structure:
1# Simplified training loop structure
2for each batch of images:
3 1. Sample random timesteps t ~ Uniform(0, T)
4 2. Add noise: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1-alpha_bar_t) * epsilon
5 3. Predict noise: epsilon_pred = model(x_t, t)
6 4. Compute loss: L = MSE(epsilon_pred, epsilon)
7 5. Update model: optimizer.step()However, production training requires additional components:
| Component | Purpose | Impact |
|---|---|---|
| EMA | Smooth model weights for sampling | Major quality improvement |
| Gradient clipping | Prevent exploding gradients | Training stability |
| Mixed precision | Faster training, less memory | 2x speedup on modern GPUs |
| Learning rate warmup | Stable early training | Prevents early divergence |
| Checkpointing | Resume interrupted training | Essential for long runs |
| Logging | Monitor training progress | Early detection of issues |
Dataset Preparation
Proper data preprocessing is crucial for diffusion models:
1import torch
2from torch.utils.data import DataLoader, Dataset
3from torchvision import transforms
4from PIL import Image
5import os
6
7class ImageDataset(Dataset):
8 """Dataset for loading images for diffusion training."""
9
10 def __init__(
11 self,
12 image_dir: str,
13 image_size: int = 64,
14 augment: bool = True,
15 ):
16 self.image_paths = [
17 os.path.join(image_dir, f)
18 for f in os.listdir(image_dir)
19 if f.lower().endswith(('.png', '.jpg', '.jpeg'))
20 ]
21
22 # Build transform pipeline
23 transform_list = [
24 transforms.Resize(image_size),
25 transforms.CenterCrop(image_size),
26 ]
27
28 if augment:
29 transform_list.extend([
30 transforms.RandomHorizontalFlip(),
31 ])
32
33 transform_list.extend([
34 transforms.ToTensor(),
35 # Scale to [-1, 1] - CRITICAL for diffusion!
36 transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
37 ])
38
39 self.transform = transforms.Compose(transform_list)
40
41 def __len__(self):
42 return len(self.image_paths)
43
44 def __getitem__(self, idx):
45 image = Image.open(self.image_paths[idx]).convert('RGB')
46 return self.transform(image)
47
48
49def create_dataloader(
50 image_dir: str,
51 batch_size: int = 64,
52 image_size: int = 64,
53 num_workers: int = 4,
54) -> DataLoader:
55 """Create a DataLoader for diffusion training."""
56
57 dataset = ImageDataset(image_dir, image_size)
58
59 return DataLoader(
60 dataset,
61 batch_size=batch_size,
62 shuffle=True,
63 num_workers=num_workers,
64 pin_memory=True,
65 drop_last=True, # Important for consistent batch sizes
66 )
67
68
69# Example usage
70# dataloader = create_dataloader("./data/images", batch_size=64, image_size=64)Normalization is Critical
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) transforms [0, 1] tensors to [-1, 1]. If your images are already in [-1, 1], skip the normalization. Incorrect scaling is a common source of training failures.Exponential Moving Average
Exponential Moving Average (EMA) maintains a smoothed version of the model weights. This is critical for diffusion models because:
- Training weights have noise from stochastic optimization
- EMA weights average over many training steps, reducing this noise
- Samples from EMA models are significantly better than from training models
The EMA update formula is:
where is the decay rate (typically 0.9999).
EMA Decay Rate
Complete Training Loop
Here's a production-ready training loop that combines all the pieces:
Mixed Precision Training
Mixed precision training uses 16-bit floats (FP16) for most computations while keeping critical operations in 32-bit (FP32). Benefits:
- ~2x faster training on modern GPUs (Tensor Cores)
- ~50% less GPU memory for activations
- Allows larger batch sizes
1from torch.cuda.amp import GradScaler, autocast
2
3# Initialize scaler
4scaler = GradScaler()
5
6# Training step with mixed precision
7def train_step_amp(model, ddpm, images, optimizer):
8 optimizer.zero_grad()
9
10 # Forward pass in FP16
11 with autocast():
12 loss = ddpm.training_loss(images)
13
14 # Backward pass with gradient scaling
15 scaler.scale(loss).backward()
16
17 # Unscale gradients for clipping
18 scaler.unscale_(optimizer)
19 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
20
21 # Optimizer step with automatic scaling
22 scaler.step(optimizer)
23 scaler.update()
24
25 return loss.item()
26
27
28# Sampling should use FP32 for quality
29@torch.no_grad()
30def sample_fp32(ddpm, batch_size, image_size, device):
31 # Don't use autocast for sampling
32 return ddpm.sample(batch_size, image_size, device=device)Sampling in FP32
Checkpointing and Resumption
Diffusion training is slow (often days or weeks). Proper checkpointing is essential:
1import os
2import torch
3
4def save_checkpoint(
5 trainer: 'DiffusionTrainer',
6 epoch: int,
7 checkpoint_dir: str,
8):
9 """Save a complete training checkpoint."""
10 os.makedirs(checkpoint_dir, exist_ok=True)
11
12 checkpoint = {
13 # Model weights
14 'model_state_dict': trainer.model.state_dict(),
15
16 # Optimizer state (for momentum, Adam states)
17 'optimizer_state_dict': trainer.optimizer.state_dict(),
18
19 # EMA state
20 'ema_state_dict': trainer.ema.state_dict(),
21
22 # AMP scaler state
23 'scaler_state_dict': trainer.scaler.state_dict() if trainer.scaler else None,
24
25 # Training progress
26 'epoch': epoch,
27 'global_step': trainer.global_step,
28 }
29
30 path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt')
31 torch.save(checkpoint, path)
32
33 # Also save latest
34 latest_path = os.path.join(checkpoint_dir, 'checkpoint_latest.pt')
35 torch.save(checkpoint, latest_path)
36
37 print(f"Saved checkpoint to {path}")
38
39
40def load_checkpoint(
41 trainer: 'DiffusionTrainer',
42 checkpoint_path: str,
43) -> int:
44 """Load a checkpoint and return the epoch to resume from."""
45 checkpoint = torch.load(checkpoint_path, map_location=trainer.device)
46
47 trainer.model.load_state_dict(checkpoint['model_state_dict'])
48 trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
49 trainer.ema.load_state_dict(checkpoint['ema_state_dict'])
50
51 if trainer.scaler and checkpoint['scaler_state_dict']:
52 trainer.scaler.load_state_dict(checkpoint['scaler_state_dict'])
53
54 trainer.global_step = checkpoint['global_step']
55
56 print(f"Resumed from epoch {checkpoint['epoch']}, step {checkpoint['global_step']}")
57
58 return checkpoint['epoch']
59
60
61# Usage in training loop
62def train(trainer, num_epochs, checkpoint_dir, resume_from=None):
63 start_epoch = 0
64
65 if resume_from:
66 start_epoch = load_checkpoint(trainer, resume_from)
67
68 for epoch in range(start_epoch, num_epochs):
69 metrics = trainer.train_epoch()
70 print(f"Epoch {epoch}: loss={metrics['loss']:.4f}")
71
72 # Save checkpoint every epoch
73 save_checkpoint(trainer, epoch, checkpoint_dir)
74
75 # Generate samples periodically
76 if epoch % 10 == 0:
77 samples = trainer.sample(16)
78 save_images(samples, f"{checkpoint_dir}/samples_epoch{epoch}.png")Checkpoint Strategy
Training Monitoring
Monitor these metrics during training:
| Metric | What to Watch | Warning Signs |
|---|---|---|
| Loss | Should decrease steadily | Sudden spikes, NaN, stuck at high value |
| Gradient norm | Should be stable (< 1 after clipping) | Very large values before clipping |
| Learning rate | Warmup then constant/decay | N/A |
| EMA update | Should match decay schedule | EMA diverging from training weights |
| Sample quality | Visual inspection every N steps | Blurry, artifacts, mode collapse |
1import wandb
2
3def train_with_logging(trainer, num_epochs, project_name="diffusion"):
4 # Initialize logging
5 wandb.init(project=project_name, config={
6 'lr': trainer.lr,
7 'batch_size': trainer.dataloader.batch_size,
8 'warmup_steps': trainer.warmup_steps,
9 'ema_decay': trainer.ema.decay,
10 })
11
12 for epoch in range(num_epochs):
13 trainer.model.train()
14
15 for batch_idx, images in enumerate(trainer.dataloader):
16 images = images.to(trainer.device)
17
18 # Training step
19 with autocast():
20 loss = trainer.ddpm.training_loss(images)
21
22 trainer.scaler.scale(loss).backward()
23 trainer.scaler.unscale_(trainer.optimizer)
24
25 # Log gradient norm
26 grad_norm = torch.nn.utils.clip_grad_norm_(
27 trainer.model.parameters(), trainer.grad_clip
28 )
29
30 trainer.scaler.step(trainer.optimizer)
31 trainer.scaler.update()
32 trainer.optimizer.zero_grad()
33 trainer.ema.update()
34 trainer.global_step += 1
35
36 # Log metrics
37 wandb.log({
38 'train/loss': loss.item(),
39 'train/grad_norm': grad_norm.item(),
40 'train/lr': trainer.get_lr(),
41 'train/step': trainer.global_step,
42 })
43
44 # Log samples periodically
45 if epoch % 10 == 0:
46 samples = trainer.sample(16, use_ema=True)
47 wandb.log({
48 'samples': [wandb.Image(s) for s in samples],
49 'epoch': epoch,
50 })Debugging Common Issues
Here are common training issues and how to fix them:
NaN Loss
1# Problem: Loss becomes NaN
2# Causes:
3# 1. Learning rate too high
4# 2. Gradient explosion
5# 3. Division by zero in schedule
6# 4. Bad initialization
7
8# Solution: Add debugging
9def safe_training_step(model, ddpm, images, optimizer):
10 loss = ddpm.training_loss(images)
11
12 # Check for NaN
13 if torch.isnan(loss):
14 print("NaN loss detected!")
15 print(f" Image stats: min={images.min()}, max={images.max()}")
16
17 # Check model outputs
18 with torch.no_grad():
19 t = torch.randint(0, 1000, (1,), device=images.device)
20 x_t, _ = ddpm.q_sample(images[:1], t)
21 pred = model(x_t, t)
22 print(f" Pred stats: min={pred.min()}, max={pred.max()}")
23
24 raise ValueError("NaN loss - check inputs and model")
25
26 return lossLoss Not Decreasing
- Check that images are scaled to [-1, 1]
- Verify the noise schedule is correct
- Ensure timesteps are sampled uniformly
- Check that the model output shape matches input shape
- Try a lower learning rate
Poor Sample Quality
1# Debugging poor samples
2
3# 1. Check if using EMA (major impact!)
4samples_no_ema = trainer.sample(16, use_ema=False)
5samples_ema = trainer.sample(16, use_ema=True)
6# EMA samples should be significantly better
7
8# 2. Check intermediate timesteps
9def debug_sampling(ddpm, device):
10 """Visualize denoising at different stages."""
11 x = torch.randn(1, 3, 64, 64, device=device)
12
13 checkpoints = [999, 750, 500, 250, 100, 50, 0]
14 images = []
15
16 for t in range(999, -1, -1):
17 t_batch = torch.tensor([t], device=device)
18 x = ddpm.p_sample(x, t_batch)
19
20 if t in checkpoints:
21 images.append(x.clone())
22
23 return images # Should show gradual denoising
24
25# 3. Check if model is predicting noise correctly
26def check_noise_prediction(ddpm, images, device):
27 """Verify noise prediction at various timesteps."""
28 images = images.to(device)
29
30 for t_val in [0, 100, 500, 900]:
31 t = torch.tensor([t_val], device=device)
32 x_t, true_noise = ddpm.q_sample(images[:1], t)
33
34 pred_noise = ddpm.model(x_t, t)
35 error = F.mse_loss(pred_noise, true_noise)
36
37 print(f"t={t_val}: MSE={error.item():.4f}")
38 # Error should be lower for t close to training distributionAlways Use EMA for Sampling
Summary
In this section, we built a production-ready training pipeline:
- Dataset preparation: Proper loading and scaling to [-1, 1]
- EMA: Essential for high-quality samples, with decay rate ~0.9999
- Complete training loop: Gradient accumulation, clipping, warmup, logging
- Mixed precision: 2x speedup with AMP, but sample in FP32
- Checkpointing: Save all state for seamless resumption
- Debugging: Strategies for NaN loss, poor convergence, bad samples
Coming Up Next
With this training infrastructure, you're ready to train diffusion models on your own datasets. The key is patience - diffusion models need many steps to converge, but the results are worth the wait.