Chapter 6
25 min read
Section 30 of 76

Training Loop Implementation

Building the Diffusion Model

Learning Objectives

By the end of this section, you will:

  1. Implement a production-ready training loop for diffusion models
  2. Understand and implement EMA (Exponential Moving Average) for better samples
  3. Use mixed precision training for faster training and lower memory usage
  4. Handle checkpointing for resumable training
  5. Debug common training issues like NaN losses and poor convergence

Training is Where Theory Meets Practice

Having a correct DDPM implementation is necessary but not sufficient. The training loop determines whether your model actually learns. This section covers the engineering details that make the difference between a model that works and one that produces noise.

Training Overview

The diffusion training loop has a simple structure:

🐍python
1# Simplified training loop structure
2for each batch of images:
3    1. Sample random timesteps t ~ Uniform(0, T)
4    2. Add noise: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1-alpha_bar_t) * epsilon
5    3. Predict noise: epsilon_pred = model(x_t, t)
6    4. Compute loss: L = MSE(epsilon_pred, epsilon)
7    5. Update model: optimizer.step()

However, production training requires additional components:

ComponentPurposeImpact
EMASmooth model weights for samplingMajor quality improvement
Gradient clippingPrevent exploding gradientsTraining stability
Mixed precisionFaster training, less memory2x speedup on modern GPUs
Learning rate warmupStable early trainingPrevents early divergence
CheckpointingResume interrupted trainingEssential for long runs
LoggingMonitor training progressEarly detection of issues

Dataset Preparation

Proper data preprocessing is crucial for diffusion models:

🐍python
1import torch
2from torch.utils.data import DataLoader, Dataset
3from torchvision import transforms
4from PIL import Image
5import os
6
7class ImageDataset(Dataset):
8    """Dataset for loading images for diffusion training."""
9
10    def __init__(
11        self,
12        image_dir: str,
13        image_size: int = 64,
14        augment: bool = True,
15    ):
16        self.image_paths = [
17            os.path.join(image_dir, f)
18            for f in os.listdir(image_dir)
19            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
20        ]
21
22        # Build transform pipeline
23        transform_list = [
24            transforms.Resize(image_size),
25            transforms.CenterCrop(image_size),
26        ]
27
28        if augment:
29            transform_list.extend([
30                transforms.RandomHorizontalFlip(),
31            ])
32
33        transform_list.extend([
34            transforms.ToTensor(),
35            # Scale to [-1, 1] - CRITICAL for diffusion!
36            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
37        ])
38
39        self.transform = transforms.Compose(transform_list)
40
41    def __len__(self):
42        return len(self.image_paths)
43
44    def __getitem__(self, idx):
45        image = Image.open(self.image_paths[idx]).convert('RGB')
46        return self.transform(image)
47
48
49def create_dataloader(
50    image_dir: str,
51    batch_size: int = 64,
52    image_size: int = 64,
53    num_workers: int = 4,
54) -> DataLoader:
55    """Create a DataLoader for diffusion training."""
56
57    dataset = ImageDataset(image_dir, image_size)
58
59    return DataLoader(
60        dataset,
61        batch_size=batch_size,
62        shuffle=True,
63        num_workers=num_workers,
64        pin_memory=True,
65        drop_last=True,  # Important for consistent batch sizes
66    )
67
68
69# Example usage
70# dataloader = create_dataloader("./data/images", batch_size=64, image_size=64)

Normalization is Critical

Diffusion models expect inputs in [-1, 1]. The normalizationNormalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) transforms [0, 1] tensors to [-1, 1]. If your images are already in [-1, 1], skip the normalization. Incorrect scaling is a common source of training failures.

Exponential Moving Average

Exponential Moving Average (EMA) maintains a smoothed version of the model weights. This is critical for diffusion models because:

  • Training weights have noise from stochastic optimization
  • EMA weights average over many training steps, reducing this noise
  • Samples from EMA models are significantly better than from training models

The EMA update formula is:

θEMA(t+1)=βθEMA(t)+(1β)θ(t+1)\theta_{\text{EMA}}^{(t+1)} = \beta \cdot \theta_{\text{EMA}}^{(t)} + (1 - \beta) \cdot \theta^{(t+1)}

where β\beta is the decay rate (typically 0.9999).

EMA Implementation
🐍ema.py
1EMA Class

Exponential Moving Average maintains a smoothed copy of model weights. This reduces noise in samples and is essential for high-quality generation.

8Decay Rate

The decay rate (typically 0.9999) controls how quickly the EMA updates. Higher values = slower updates = smoother but slower to adapt.

12Shadow Parameters

We maintain a separate copy of all parameters. These 'shadow' parameters are updated as exponential moving averages of the training parameters.

18Update Formula

EMA update: shadow = decay * shadow + (1 - decay) * current. This blends old and new values, giving more weight to the history.

25Context Manager

The context manager temporarily swaps EMA weights into the model for sampling, then restores original weights for continued training.

86 lines without explanation
1import copy
2import torch
3import torch.nn as nn
4from contextlib import contextmanager
5
6class EMA:
7    """
8    Exponential Moving Average of model parameters.
9
10    Maintains shadow copies of all parameters and updates them
11    as exponential moving averages of the training parameters.
12    """
13
14    def __init__(
15        self,
16        model: nn.Module,
17        decay: float = 0.9999,
18        warmup_steps: int = 2000,
19    ):
20        """
21        Args:
22            model: The model to track
23            decay: EMA decay rate (higher = slower updates)
24            warmup_steps: Steps before starting EMA (use training weights)
25        """
26        self.model = model
27        self.decay = decay
28        self.warmup_steps = warmup_steps
29        self.step = 0
30
31        # Create shadow copies of all parameters
32        self.shadow = {}
33        for name, param in model.named_parameters():
34            if param.requires_grad:
35                self.shadow[name] = param.data.clone()
36
37    @torch.no_grad()
38    def update(self):
39        """Update EMA parameters after an optimizer step."""
40        self.step += 1
41
42        # During warmup, just copy parameters directly
43        if self.step < self.warmup_steps:
44            for name, param in self.model.named_parameters():
45                if param.requires_grad:
46                    self.shadow[name].copy_(param.data)
47            return
48
49        # After warmup, use EMA update
50        for name, param in self.model.named_parameters():
51            if param.requires_grad:
52                # shadow = decay * shadow + (1 - decay) * current
53                self.shadow[name].mul_(self.decay).add_(
54                    param.data, alpha=1 - self.decay
55                )
56
57    @contextmanager
58    def average_parameters(self):
59        """
60        Context manager to temporarily use EMA parameters.
61
62        Usage:
63            with ema.average_parameters():
64                samples = model.sample(...)
65        """
66        # Store original parameters
67        original = {}
68        for name, param in self.model.named_parameters():
69            if param.requires_grad:
70                original[name] = param.data.clone()
71                param.data.copy_(self.shadow[name])
72
73        try:
74            yield
75        finally:
76            # Restore original parameters
77            for name, param in self.model.named_parameters():
78                if param.requires_grad:
79                    param.data.copy_(original[name])
80
81    def state_dict(self):
82        """Get EMA state for checkpointing."""
83        return {
84            'shadow': self.shadow,
85            'step': self.step,
86        }
87
88    def load_state_dict(self, state_dict):
89        """Load EMA state from checkpoint."""
90        self.shadow = state_dict['shadow']
91        self.step = state_dict['step']

EMA Decay Rate

The standard decay rate is 0.9999. This means each EMA parameter is ~63% determined by the last 5000 updates. For smaller datasets or shorter training, you might use 0.999 (faster adaptation). For very large runs, 0.99999 can work better.

Complete Training Loop

Here's a production-ready training loop that combines all the pieces:

Production Training Loop
🐍trainer.py
1Trainer Class

The Trainer class encapsulates all training logic: data loading, optimization, EMA updates, checkpointing, and logging.

15Gradient Accumulation

Gradient accumulation allows training with effectively larger batch sizes than GPU memory permits. We accumulate gradients over multiple mini-batches.

22Learning Rate Warmup

Warmup gradually increases the learning rate from 0 to the target over the first few thousand steps. This prevents early training instabilities.

30Main Training Loop

The training loop iterates over epochs and batches. Each step: forward pass, loss computation, backward pass, optimizer step, EMA update.

42Gradient Clipping

Clip gradients to prevent exploding gradients. A max norm of 1.0 is typical. This is essential for stable training.

50EMA Update

After each optimizer step, we update the EMA weights. This must happen after optimizer.step() to use the updated model weights.

156 lines without explanation
1import torch
2import torch.nn as nn
3from torch.optim import AdamW
4from torch.cuda.amp import GradScaler, autocast
5from tqdm import tqdm
6import wandb  # For logging (optional)
7
8class DiffusionTrainer:
9    """
10    Complete training pipeline for diffusion models.
11
12    Features:
13    - Gradient accumulation for larger effective batch sizes
14    - Mixed precision training (AMP)
15    - EMA weight tracking
16    - Learning rate warmup and scheduling
17    - Gradient clipping
18    - Checkpointing
19    - Progress logging
20    """
21
22    def __init__(
23        self,
24        model: nn.Module,
25        ddpm: 'DDPM',
26        dataloader: torch.utils.data.DataLoader,
27        lr: float = 2e-4,
28        warmup_steps: int = 5000,
29        grad_clip: float = 1.0,
30        grad_accumulation_steps: int = 1,
31        ema_decay: float = 0.9999,
32        device: str = "cuda",
33        use_amp: bool = True,
34    ):
35        self.model = model.to(device)
36        self.ddpm = ddpm.to(device)
37        self.dataloader = dataloader
38        self.device = device
39
40        # Optimizer
41        self.optimizer = AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999))
42
43        # Learning rate scheduler with warmup
44        self.warmup_steps = warmup_steps
45        self.lr = lr
46
47        # Gradient settings
48        self.grad_clip = grad_clip
49        self.grad_accumulation_steps = grad_accumulation_steps
50
51        # EMA
52        self.ema = EMA(model, decay=ema_decay)
53
54        # Mixed precision
55        self.use_amp = use_amp
56        self.scaler = GradScaler() if use_amp else None
57
58        # Training state
59        self.global_step = 0
60
61    def get_lr(self) -> float:
62        """Get current learning rate with warmup."""
63        if self.global_step < self.warmup_steps:
64            return self.lr * self.global_step / self.warmup_steps
65        return self.lr
66
67    def train_epoch(self) -> dict:
68        """Train for one epoch."""
69        self.model.train()
70
71        total_loss = 0
72        num_batches = 0
73
74        progress_bar = tqdm(self.dataloader, desc="Training")
75
76        for batch_idx, images in enumerate(progress_bar):
77            images = images.to(self.device)
78
79            # Forward pass with optional mixed precision
80            if self.use_amp:
81                with autocast():
82                    loss = self.ddpm.training_loss(images)
83                    loss = loss / self.grad_accumulation_steps
84            else:
85                loss = self.ddpm.training_loss(images)
86                loss = loss / self.grad_accumulation_steps
87
88            # Backward pass
89            if self.use_amp:
90                self.scaler.scale(loss).backward()
91            else:
92                loss.backward()
93
94            # Gradient accumulation
95            if (batch_idx + 1) % self.grad_accumulation_steps == 0:
96                # Update learning rate
97                for param_group in self.optimizer.param_groups:
98                    param_group['lr'] = self.get_lr()
99
100                # Clip gradients
101                if self.use_amp:
102                    self.scaler.unscale_(self.optimizer)
103                torch.nn.utils.clip_grad_norm_(
104                    self.model.parameters(), self.grad_clip
105                )
106
107                # Optimizer step
108                if self.use_amp:
109                    self.scaler.step(self.optimizer)
110                    self.scaler.update()
111                else:
112                    self.optimizer.step()
113
114                self.optimizer.zero_grad()
115
116                # EMA update
117                self.ema.update()
118
119                self.global_step += 1
120
121            # Logging
122            total_loss += loss.item() * self.grad_accumulation_steps
123            num_batches += 1
124
125            progress_bar.set_postfix({
126                'loss': total_loss / num_batches,
127                'lr': self.get_lr(),
128                'step': self.global_step,
129            })
130
131        return {
132            'loss': total_loss / num_batches,
133            'steps': self.global_step,
134        }
135
136    @torch.no_grad()
137    def sample(
138        self,
139        num_samples: int = 16,
140        image_size: int = 64,
141        use_ema: bool = True,
142    ) -> torch.Tensor:
143        """Generate samples, optionally using EMA weights."""
144        self.model.eval()
145
146        if use_ema:
147            with self.ema.average_parameters():
148                samples = self.ddpm.sample(
149                    batch_size=num_samples,
150                    image_size=image_size,
151                    device=self.device,
152                )
153        else:
154            samples = self.ddpm.sample(
155                batch_size=num_samples,
156                image_size=image_size,
157                device=self.device,
158            )
159
160        # Scale to [0, 1] for visualization
161        samples = (samples + 1) / 2
162        return samples.clamp(0, 1)

Mixed Precision Training

Mixed precision training uses 16-bit floats (FP16) for most computations while keeping critical operations in 32-bit (FP32). Benefits:

  • ~2x faster training on modern GPUs (Tensor Cores)
  • ~50% less GPU memory for activations
  • Allows larger batch sizes
🐍python
1from torch.cuda.amp import GradScaler, autocast
2
3# Initialize scaler
4scaler = GradScaler()
5
6# Training step with mixed precision
7def train_step_amp(model, ddpm, images, optimizer):
8    optimizer.zero_grad()
9
10    # Forward pass in FP16
11    with autocast():
12        loss = ddpm.training_loss(images)
13
14    # Backward pass with gradient scaling
15    scaler.scale(loss).backward()
16
17    # Unscale gradients for clipping
18    scaler.unscale_(optimizer)
19    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
20
21    # Optimizer step with automatic scaling
22    scaler.step(optimizer)
23    scaler.update()
24
25    return loss.item()
26
27
28# Sampling should use FP32 for quality
29@torch.no_grad()
30def sample_fp32(ddpm, batch_size, image_size, device):
31    # Don't use autocast for sampling
32    return ddpm.sample(batch_size, image_size, device=device)

Sampling in FP32

While training benefits from FP16, sampling should use FP32. The iterative nature of sampling can accumulate small numerical errors in FP16, leading to degraded image quality. Don't wrap sampling code in autocast.

Checkpointing and Resumption

Diffusion training is slow (often days or weeks). Proper checkpointing is essential:

🐍python
1import os
2import torch
3
4def save_checkpoint(
5    trainer: 'DiffusionTrainer',
6    epoch: int,
7    checkpoint_dir: str,
8):
9    """Save a complete training checkpoint."""
10    os.makedirs(checkpoint_dir, exist_ok=True)
11
12    checkpoint = {
13        # Model weights
14        'model_state_dict': trainer.model.state_dict(),
15
16        # Optimizer state (for momentum, Adam states)
17        'optimizer_state_dict': trainer.optimizer.state_dict(),
18
19        # EMA state
20        'ema_state_dict': trainer.ema.state_dict(),
21
22        # AMP scaler state
23        'scaler_state_dict': trainer.scaler.state_dict() if trainer.scaler else None,
24
25        # Training progress
26        'epoch': epoch,
27        'global_step': trainer.global_step,
28    }
29
30    path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt')
31    torch.save(checkpoint, path)
32
33    # Also save latest
34    latest_path = os.path.join(checkpoint_dir, 'checkpoint_latest.pt')
35    torch.save(checkpoint, latest_path)
36
37    print(f"Saved checkpoint to {path}")
38
39
40def load_checkpoint(
41    trainer: 'DiffusionTrainer',
42    checkpoint_path: str,
43) -> int:
44    """Load a checkpoint and return the epoch to resume from."""
45    checkpoint = torch.load(checkpoint_path, map_location=trainer.device)
46
47    trainer.model.load_state_dict(checkpoint['model_state_dict'])
48    trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
49    trainer.ema.load_state_dict(checkpoint['ema_state_dict'])
50
51    if trainer.scaler and checkpoint['scaler_state_dict']:
52        trainer.scaler.load_state_dict(checkpoint['scaler_state_dict'])
53
54    trainer.global_step = checkpoint['global_step']
55
56    print(f"Resumed from epoch {checkpoint['epoch']}, step {checkpoint['global_step']}")
57
58    return checkpoint['epoch']
59
60
61# Usage in training loop
62def train(trainer, num_epochs, checkpoint_dir, resume_from=None):
63    start_epoch = 0
64
65    if resume_from:
66        start_epoch = load_checkpoint(trainer, resume_from)
67
68    for epoch in range(start_epoch, num_epochs):
69        metrics = trainer.train_epoch()
70        print(f"Epoch {epoch}: loss={metrics['loss']:.4f}")
71
72        # Save checkpoint every epoch
73        save_checkpoint(trainer, epoch, checkpoint_dir)
74
75        # Generate samples periodically
76        if epoch % 10 == 0:
77            samples = trainer.sample(16)
78            save_images(samples, f"{checkpoint_dir}/samples_epoch{epoch}.png")

Checkpoint Strategy

Save checkpoints at least every epoch. For long epochs, consider saving every N steps as well. Always keep the last few checkpoints in case a recent one is corrupted. Many teams keep checkpoints at regular intervals (e.g., every 10K steps) for evaluation.

Training Monitoring

Monitor these metrics during training:

MetricWhat to WatchWarning Signs
LossShould decrease steadilySudden spikes, NaN, stuck at high value
Gradient normShould be stable (< 1 after clipping)Very large values before clipping
Learning rateWarmup then constant/decayN/A
EMA updateShould match decay scheduleEMA diverging from training weights
Sample qualityVisual inspection every N stepsBlurry, artifacts, mode collapse
🐍python
1import wandb
2
3def train_with_logging(trainer, num_epochs, project_name="diffusion"):
4    # Initialize logging
5    wandb.init(project=project_name, config={
6        'lr': trainer.lr,
7        'batch_size': trainer.dataloader.batch_size,
8        'warmup_steps': trainer.warmup_steps,
9        'ema_decay': trainer.ema.decay,
10    })
11
12    for epoch in range(num_epochs):
13        trainer.model.train()
14
15        for batch_idx, images in enumerate(trainer.dataloader):
16            images = images.to(trainer.device)
17
18            # Training step
19            with autocast():
20                loss = trainer.ddpm.training_loss(images)
21
22            trainer.scaler.scale(loss).backward()
23            trainer.scaler.unscale_(trainer.optimizer)
24
25            # Log gradient norm
26            grad_norm = torch.nn.utils.clip_grad_norm_(
27                trainer.model.parameters(), trainer.grad_clip
28            )
29
30            trainer.scaler.step(trainer.optimizer)
31            trainer.scaler.update()
32            trainer.optimizer.zero_grad()
33            trainer.ema.update()
34            trainer.global_step += 1
35
36            # Log metrics
37            wandb.log({
38                'train/loss': loss.item(),
39                'train/grad_norm': grad_norm.item(),
40                'train/lr': trainer.get_lr(),
41                'train/step': trainer.global_step,
42            })
43
44        # Log samples periodically
45        if epoch % 10 == 0:
46            samples = trainer.sample(16, use_ema=True)
47            wandb.log({
48                'samples': [wandb.Image(s) for s in samples],
49                'epoch': epoch,
50            })

Debugging Common Issues

Here are common training issues and how to fix them:

NaN Loss

🐍python
1# Problem: Loss becomes NaN
2# Causes:
3# 1. Learning rate too high
4# 2. Gradient explosion
5# 3. Division by zero in schedule
6# 4. Bad initialization
7
8# Solution: Add debugging
9def safe_training_step(model, ddpm, images, optimizer):
10    loss = ddpm.training_loss(images)
11
12    # Check for NaN
13    if torch.isnan(loss):
14        print("NaN loss detected!")
15        print(f"  Image stats: min={images.min()}, max={images.max()}")
16
17        # Check model outputs
18        with torch.no_grad():
19            t = torch.randint(0, 1000, (1,), device=images.device)
20            x_t, _ = ddpm.q_sample(images[:1], t)
21            pred = model(x_t, t)
22            print(f"  Pred stats: min={pred.min()}, max={pred.max()}")
23
24        raise ValueError("NaN loss - check inputs and model")
25
26    return loss

Loss Not Decreasing

  • Check that images are scaled to [-1, 1]
  • Verify the noise schedule is correct
  • Ensure timesteps are sampled uniformly
  • Check that the model output shape matches input shape
  • Try a lower learning rate

Poor Sample Quality

🐍python
1# Debugging poor samples
2
3# 1. Check if using EMA (major impact!)
4samples_no_ema = trainer.sample(16, use_ema=False)
5samples_ema = trainer.sample(16, use_ema=True)
6# EMA samples should be significantly better
7
8# 2. Check intermediate timesteps
9def debug_sampling(ddpm, device):
10    """Visualize denoising at different stages."""
11    x = torch.randn(1, 3, 64, 64, device=device)
12
13    checkpoints = [999, 750, 500, 250, 100, 50, 0]
14    images = []
15
16    for t in range(999, -1, -1):
17        t_batch = torch.tensor([t], device=device)
18        x = ddpm.p_sample(x, t_batch)
19
20        if t in checkpoints:
21            images.append(x.clone())
22
23    return images  # Should show gradual denoising
24
25# 3. Check if model is predicting noise correctly
26def check_noise_prediction(ddpm, images, device):
27    """Verify noise prediction at various timesteps."""
28    images = images.to(device)
29
30    for t_val in [0, 100, 500, 900]:
31        t = torch.tensor([t_val], device=device)
32        x_t, true_noise = ddpm.q_sample(images[:1], t)
33
34        pred_noise = ddpm.model(x_t, t)
35        error = F.mse_loss(pred_noise, true_noise)
36
37        print(f"t={t_val}: MSE={error.item():.4f}")
38    # Error should be lower for t close to training distribution

Always Use EMA for Sampling

One of the most common mistakes is forgetting to use EMA weights for sampling. The difference is dramatic - EMA samples are sharper and more coherent. If your samples look bad, first check if you're using EMA.

Summary

In this section, we built a production-ready training pipeline:

  1. Dataset preparation: Proper loading and scaling to [-1, 1]
  2. EMA: Essential for high-quality samples, with decay rate ~0.9999
  3. Complete training loop: Gradient accumulation, clipping, warmup, logging
  4. Mixed precision: 2x speedup with AMP, but sample in FP32
  5. Checkpointing: Save all state for seamless resumption
  6. Debugging: Strategies for NaN loss, poor convergence, bad samples

Coming Up Next

In the next section, we'll explore the loss function in depth: the mathematical derivation of the simplified loss, alternative parameterizations (v-prediction, x0-prediction), and loss weighting strategies that can improve sample quality.

With this training infrastructure, you're ready to train diffusion models on your own datasets. The key is patience - diffusion models need many steps to converge, but the results are worth the wait.