Chapter 11
20 min read
Section 52 of 76

Complete Training Script

Training the Model

Learning Objectives

By the end of this section, you will be able to:

  1. Build a complete, production-ready training script with proper argument parsing and configuration
  2. Integrate logging frameworks (Weights & Biases, TensorBoard) for experiment tracking
  3. Implement robust checkpointing for training resumption and model saving
  4. Run training end-to-end on standard datasets like CIFAR-10

Script Overview

A complete training script needs to handle several concerns: argument parsing, model initialization, data loading, the training loop itself, logging, checkpointing, and graceful error handling. We'll build each component and then combine them into a single, runnable script.

Design Principle: Your training script should be reproducible. Given the same configuration and random seed, it should produce identical results. This requires careful handling of randomness and configuration.

Argument Parsing

🐍python
1import argparse
2
3def parse_args():
4    """Parse command-line arguments for training."""
5    parser = argparse.ArgumentParser(
6        description="Train a diffusion model",
7        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
8    )
9
10    # Data arguments
11    data_group = parser.add_argument_group("Data")
12    data_group.add_argument(
13        "--dataset", type=str, default="cifar10",
14        choices=["mnist", "cifar10", "celeba", "custom"],
15        help="Dataset to train on",
16    )
17    data_group.add_argument(
18        "--data_dir", type=str, default="./data",
19        help="Directory for dataset storage",
20    )
21    data_group.add_argument(
22        "--image_size", type=int, default=32,
23        help="Image resolution (square)",
24    )
25
26    # Model arguments
27    model_group = parser.add_argument_group("Model")
28    model_group.add_argument(
29        "--base_channels", type=int, default=64,
30        help="Base channel count for U-Net",
31    )
32    model_group.add_argument(
33        "--channel_mult", type=int, nargs="+", default=[1, 2, 4],
34        help="Channel multipliers for each resolution",
35    )
36    model_group.add_argument(
37        "--attention_resolutions", type=int, nargs="+", default=[16, 8],
38        help="Resolutions where attention is applied",
39    )
40    model_group.add_argument(
41        "--dropout", type=float, default=0.0,
42        help="Dropout probability",
43    )
44
45    # Diffusion arguments
46    diff_group = parser.add_argument_group("Diffusion")
47    diff_group.add_argument(
48        "--timesteps", type=int, default=1000,
49        help="Number of diffusion timesteps",
50    )
51    diff_group.add_argument(
52        "--schedule", type=str, default="cosine",
53        choices=["linear", "cosine", "sqrt"],
54        help="Noise schedule type",
55    )
56
57    # Training arguments
58    train_group = parser.add_argument_group("Training")
59    train_group.add_argument("--batch_size", type=int, default=64)
60    train_group.add_argument("--epochs", type=int, default=100)
61    train_group.add_argument("--lr", type=float, default=2e-4)
62    train_group.add_argument("--warmup_steps", type=int, default=1000)
63    train_group.add_argument("--grad_accumulation", type=int, default=1)
64    train_group.add_argument("--use_amp", action="store_true")
65    train_group.add_argument("--use_ema", action="store_true")
66    train_group.add_argument("--ema_decay", type=float, default=0.9999)
67
68    # Logging arguments
69    log_group = parser.add_argument_group("Logging")
70    log_group.add_argument("--log_dir", type=str, default="./logs")
71    log_group.add_argument("--exp_name", type=str, default="diffusion")
72    log_group.add_argument("--wandb", action="store_true")
73    log_group.add_argument("--sample_every", type=int, default=5)
74    log_group.add_argument("--save_every", type=int, default=10)
75
76    # Other
77    parser.add_argument("--seed", type=int, default=42)
78    parser.add_argument("--num_workers", type=int, default=4)
79    parser.add_argument("--resume", type=str, default=None)
80
81    return parser.parse_args()

Model and Diffusion Setup

🐍python
1import torch
2import torch.nn as nn
3from torch.utils.data import DataLoader
4from torchvision import datasets, transforms
5import numpy as np
6
7def set_seed(seed: int):
8    """Set random seeds for reproducibility."""
9    torch.manual_seed(seed)
10    torch.cuda.manual_seed_all(seed)
11    np.random.seed(seed)
12
13
14def create_dataset(args):
15    """Create dataset based on arguments."""
16    if args.dataset == "cifar10":
17        transform = transforms.Compose([
18            transforms.RandomHorizontalFlip(),
19            transforms.ToTensor(),
20            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
21        ])
22        dataset = datasets.CIFAR10(
23            root=args.data_dir, train=True,
24            download=True, transform=transform,
25        )
26        in_channels = 3
27    elif args.dataset == "mnist":
28        transform = transforms.Compose([
29            transforms.ToTensor(),
30            transforms.Normalize((0.5,), (0.5,)),
31        ])
32        dataset = datasets.MNIST(
33            root=args.data_dir, train=True,
34            download=True, transform=transform,
35        )
36        in_channels = 1
37    else:
38        raise ValueError(f"Unknown dataset: {args.dataset}")
39    return dataset, in_channels
40
41
42class EMA:
43    """Exponential Moving Average of model parameters."""
44
45    def __init__(self, model: nn.Module, decay: float = 0.9999):
46        self.model = model
47        self.decay = decay
48        self.shadow = {}
49        self.backup = {}
50        for name, param in model.named_parameters():
51            if param.requires_grad:
52                self.shadow[name] = param.data.clone()
53
54    @torch.no_grad()
55    def update(self):
56        for name, param in self.model.named_parameters():
57            if param.requires_grad:
58                self.shadow[name] = (
59                    self.decay * self.shadow[name] +
60                    (1 - self.decay) * param.data
61                )
62
63    def apply_shadow(self):
64        for name, param in self.model.named_parameters():
65            if param.requires_grad:
66                self.backup[name] = param.data
67                param.data = self.shadow[name]
68
69    def restore(self):
70        for name, param in self.model.named_parameters():
71            if param.requires_grad:
72                param.data = self.backup[name]
73        self.backup = {}
74
75
76class GaussianDiffusion:
77    """Gaussian diffusion process."""
78
79    def __init__(self, timesteps=1000, schedule="cosine", device="cuda"):
80        self.timesteps = timesteps
81        self.device = device
82
83        if schedule == "linear":
84            betas = torch.linspace(1e-4, 0.02, timesteps)
85        elif schedule == "cosine":
86            s = 0.008
87            steps = torch.linspace(0, timesteps, timesteps + 1)
88            alphas_cumprod = torch.cos((steps / timesteps + s) / (1 + s) * np.pi / 2) ** 2
89            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
90            betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
91            betas = torch.clamp(betas, 0.0001, 0.9999)
92        else:
93            betas = torch.linspace(1e-4**0.5, 0.02**0.5, timesteps) ** 2
94
95        self.betas = betas.to(device)
96        self.alphas = (1 - self.betas)
97        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
98        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
99        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
100
101    def add_noise(self, x0, noise, t):
102        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
103        sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
104        return sqrt_alpha * x0 + sqrt_one_minus * noise
105
106    def loss(self, model, x0, t=None):
107        batch_size = x0.shape[0]
108        if t is None:
109            t = torch.randint(0, self.timesteps, (batch_size,), device=self.device)
110        noise = torch.randn_like(x0)
111        noisy_x = self.add_noise(x0, noise, t)
112        predicted = model(noisy_x, t)
113        return nn.functional.mse_loss(predicted, noise)

Training Loop

🐍python
1from torch.cuda.amp import autocast, GradScaler
2from tqdm import tqdm
3
4def train_epoch(model, dataloader, optimizer, diffusion, scaler, args, ema=None):
5    """Train for one epoch."""
6    model.train()
7    total_loss = 0.0
8    num_batches = 0
9
10    pbar = tqdm(dataloader, desc="Training")
11    for batch_idx, batch in enumerate(pbar):
12        images = batch[0] if isinstance(batch, (list, tuple)) else batch
13        images = images.to(args.device)
14
15        with autocast(enabled=args.use_amp):
16            loss = diffusion.loss(model, images)
17            loss = loss / args.grad_accumulation
18
19        scaler.scale(loss).backward()
20
21        if (batch_idx + 1) % args.grad_accumulation == 0:
22            scaler.step(optimizer)
23            scaler.update()
24            optimizer.zero_grad()
25
26            if ema is not None:
27                ema.update()
28
29        total_loss += loss.item() * args.grad_accumulation
30        num_batches += 1
31        pbar.set_postfix({"loss": total_loss / num_batches})
32
33    return total_loss / num_batches
34
35
36@torch.no_grad()
37def generate_samples(model, diffusion, args, ema=None, num_samples=16):
38    """Generate sample images."""
39    model.eval()
40    if ema is not None:
41        ema.apply_shadow()
42
43    channels = 1 if args.dataset == "mnist" else 3
44    shape = (num_samples, channels, args.image_size, args.image_size)
45    x = torch.randn(shape, device=args.device)
46
47    for t in reversed(range(diffusion.timesteps)):
48        t_batch = torch.full((shape[0],), t, device=args.device, dtype=torch.long)
49
50        with autocast(enabled=args.use_amp):
51            pred_noise = model(x, t_batch)
52
53        alpha = diffusion.alphas[t]
54        alpha_bar = diffusion.alphas_cumprod[t]
55        beta = diffusion.betas[t]
56
57        noise = torch.randn_like(x) if t > 0 else 0
58        x = (1 / torch.sqrt(alpha)) * (
59            x - beta / torch.sqrt(1 - alpha_bar) * pred_noise
60        ) + torch.sqrt(beta) * noise
61
62    if ema is not None:
63        ema.restore()
64
65    model.train()
66    return x.clamp(-1, 1)

Logging and Metrics

🐍python
1import wandb
2from torch.utils.tensorboard import SummaryWriter
3from pathlib import Path
4import torchvision.utils as vutils
5
6class Logger:
7    """Unified logger for W&B and TensorBoard."""
8
9    def __init__(self, args):
10        self.args = args
11        self.step = 0
12        self.log_dir = Path(args.log_dir) / args.exp_name
13        self.log_dir.mkdir(parents=True, exist_ok=True)
14
15        if args.wandb:
16            wandb.init(
17                project="diffusion",
18                name=args.exp_name,
19                config=vars(args),
20            )
21            self.use_wandb = True
22        else:
23            self.use_wandb = False
24
25        self.writer = SummaryWriter(log_dir=str(self.log_dir / "tensorboard"))
26
27    def log(self, metrics: dict):
28        self.step += 1
29        for key, value in metrics.items():
30            self.writer.add_scalar(key, value, self.step)
31        if self.use_wandb:
32            wandb.log(metrics, step=self.step)
33
34    def log_images(self, name: str, images):
35        grid = vutils.make_grid(images, nrow=4, normalize=True, value_range=(-1, 1))
36        self.writer.add_image(name, grid, self.step)
37        if self.use_wandb:
38            wandb.log({name: wandb.Image(grid)}, step=self.step)
39
40        save_path = self.log_dir / "samples" / f"{name}_{self.step:06d}.png"
41        save_path.parent.mkdir(parents=True, exist_ok=True)
42        vutils.save_image(grid, str(save_path))
43
44    def finish(self):
45        self.writer.close()
46        if self.use_wandb:
47            wandb.finish()

Checkpointing

🐍python
1from pathlib import Path
2
3def save_checkpoint(epoch, model, optimizer, scheduler, ema, scaler, args):
4    """Save training checkpoint."""
5    ckpt_dir = Path(args.log_dir) / args.exp_name / "checkpoints"
6    ckpt_dir.mkdir(parents=True, exist_ok=True)
7
8    checkpoint = {
9        "epoch": epoch,
10        "model_state_dict": model.state_dict(),
11        "optimizer_state_dict": optimizer.state_dict(),
12        "args": vars(args),
13    }
14
15    if scheduler is not None:
16        checkpoint["scheduler_state_dict"] = scheduler.state_dict()
17    if scaler is not None:
18        checkpoint["scaler_state_dict"] = scaler.state_dict()
19    if ema is not None:
20        checkpoint["ema_shadow"] = ema.shadow
21
22    path = ckpt_dir / f"checkpoint_{epoch:04d}.pt"
23    torch.save(checkpoint, path)
24    torch.save(checkpoint, ckpt_dir / "latest.pt")
25    print(f"Saved checkpoint to {path}")
26
27
28def load_checkpoint(path, model, optimizer, scheduler, ema, scaler):
29    """Load checkpoint. Returns start epoch."""
30    ckpt = torch.load(path, map_location="cpu")
31
32    model.load_state_dict(ckpt["model_state_dict"])
33    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
34
35    if scheduler and "scheduler_state_dict" in ckpt:
36        scheduler.load_state_dict(ckpt["scheduler_state_dict"])
37    if scaler and "scaler_state_dict" in ckpt:
38        scaler.load_state_dict(ckpt["scaler_state_dict"])
39    if ema and "ema_shadow" in ckpt:
40        ema.shadow = ckpt["ema_shadow"]
41
42    print(f"Resumed from epoch {ckpt['epoch']}")
43    return ckpt["epoch"] + 1

Complete Training Script

Here's how all components come together in a complete script:

🐍python
1#!/usr/bin/env python3
2"""
3Complete diffusion model training script.
4
5Usage:
6    python train.py --dataset cifar10 --epochs 100 --use_amp --use_ema --wandb
7"""
8
9import torch
10import numpy as np
11from torch.cuda.amp import GradScaler
12from torch.utils.data import DataLoader
13
14# Import all the components defined above
15# from model import UNet
16# from utils import parse_args, create_dataset, GaussianDiffusion, EMA, Logger
17
18
19def main():
20    args = parse_args()
21
22    # Setup
23    args.device = "cuda" if torch.cuda.is_available() else "cpu"
24    set_seed(args.seed)
25    print(f"Training on {args.device}")
26
27    # Data
28    dataset, in_channels = create_dataset(args)
29    dataloader = DataLoader(
30        dataset,
31        batch_size=args.batch_size,
32        shuffle=True,
33        num_workers=args.num_workers,
34        pin_memory=True,
35        drop_last=True,
36    )
37
38    # Model (replace with your UNet)
39    model = UNet(
40        in_channels=in_channels,
41        base_channels=args.base_channels,
42        channel_multipliers=tuple(args.channel_mult),
43        attention_resolutions=tuple(args.attention_resolutions),
44    ).to(args.device)
45
46    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
47
48    # Diffusion
49    diffusion = GaussianDiffusion(
50        timesteps=args.timesteps,
51        schedule=args.schedule,
52        device=args.device,
53    )
54
55    # Optimizer
56    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
57
58    # EMA
59    ema = EMA(model, args.ema_decay) if args.use_ema else None
60
61    # AMP scaler
62    scaler = GradScaler(enabled=args.use_amp)
63
64    # Logger
65    logger = Logger(args)
66
67    # Resume if specified
68    start_epoch = 0
69    if args.resume:
70        start_epoch = load_checkpoint(args.resume, model, optimizer, None, ema, scaler)
71
72    # Training loop
73    print(f"Starting training for {args.epochs} epochs...")
74
75    for epoch in range(start_epoch, args.epochs):
76        loss = train_epoch(model, dataloader, optimizer, diffusion, scaler, args, ema)
77
78        logger.log({"epoch": epoch, "loss": loss})
79        print(f"Epoch {epoch + 1}/{args.epochs}: Loss = {loss:.4f}")
80
81        if (epoch + 1) % args.sample_every == 0:
82            samples = generate_samples(model, diffusion, args, ema)
83            logger.log_images("samples", samples)
84
85        if (epoch + 1) % args.save_every == 0:
86            save_checkpoint(epoch, model, optimizer, None, ema, scaler, args)
87
88    # Final save
89    save_checkpoint(args.epochs - 1, model, optimizer, None, ema, scaler, args)
90    logger.finish()
91    print("Training complete!")
92
93
94if __name__ == "__main__":
95    main()

Running the Script

Save all components and run with:
bash
1python train.py --dataset cifar10 --epochs 100 --use_amp --use_ema --wandb

Key Takeaways

  1. Use argparse: Make your script configurable from the command line with sensible defaults.
  2. Implement EMA: Exponential moving average typically produces better generation quality than final weights.
  3. Log everything: Use W&B or TensorBoard to track loss curves and generated samples.
  4. Checkpoint regularly: Save model, optimizer, and EMA weights so you can resume after interruptions.
  5. Set seeds: For reproducibility, set random seeds for PyTorch, NumPy, and Python.
Looking Ahead: In the next section, we'll dive deep into training monitoring - tracking FID during training, visualizing samples, and building monitoring dashboards.