Learning Objectives
By the end of this section, you will be able to:
- Build a complete, production-ready training script with proper argument parsing and configuration
- Integrate logging frameworks (Weights & Biases, TensorBoard) for experiment tracking
- Implement robust checkpointing for training resumption and model saving
- Run training end-to-end on standard datasets like CIFAR-10
Script Overview
A complete training script needs to handle several concerns: argument parsing, model initialization, data loading, the training loop itself, logging, checkpointing, and graceful error handling. We'll build each component and then combine them into a single, runnable script.
Design Principle: Your training script should be reproducible. Given the same configuration and random seed, it should produce identical results. This requires careful handling of randomness and configuration.
Argument Parsing
🐍python
1import argparse
2
3def parse_args():
4 """Parse command-line arguments for training."""
5 parser = argparse.ArgumentParser(
6 description="Train a diffusion model",
7 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
8 )
9
10 # Data arguments
11 data_group = parser.add_argument_group("Data")
12 data_group.add_argument(
13 "--dataset", type=str, default="cifar10",
14 choices=["mnist", "cifar10", "celeba", "custom"],
15 help="Dataset to train on",
16 )
17 data_group.add_argument(
18 "--data_dir", type=str, default="./data",
19 help="Directory for dataset storage",
20 )
21 data_group.add_argument(
22 "--image_size", type=int, default=32,
23 help="Image resolution (square)",
24 )
25
26 # Model arguments
27 model_group = parser.add_argument_group("Model")
28 model_group.add_argument(
29 "--base_channels", type=int, default=64,
30 help="Base channel count for U-Net",
31 )
32 model_group.add_argument(
33 "--channel_mult", type=int, nargs="+", default=[1, 2, 4],
34 help="Channel multipliers for each resolution",
35 )
36 model_group.add_argument(
37 "--attention_resolutions", type=int, nargs="+", default=[16, 8],
38 help="Resolutions where attention is applied",
39 )
40 model_group.add_argument(
41 "--dropout", type=float, default=0.0,
42 help="Dropout probability",
43 )
44
45 # Diffusion arguments
46 diff_group = parser.add_argument_group("Diffusion")
47 diff_group.add_argument(
48 "--timesteps", type=int, default=1000,
49 help="Number of diffusion timesteps",
50 )
51 diff_group.add_argument(
52 "--schedule", type=str, default="cosine",
53 choices=["linear", "cosine", "sqrt"],
54 help="Noise schedule type",
55 )
56
57 # Training arguments
58 train_group = parser.add_argument_group("Training")
59 train_group.add_argument("--batch_size", type=int, default=64)
60 train_group.add_argument("--epochs", type=int, default=100)
61 train_group.add_argument("--lr", type=float, default=2e-4)
62 train_group.add_argument("--warmup_steps", type=int, default=1000)
63 train_group.add_argument("--grad_accumulation", type=int, default=1)
64 train_group.add_argument("--use_amp", action="store_true")
65 train_group.add_argument("--use_ema", action="store_true")
66 train_group.add_argument("--ema_decay", type=float, default=0.9999)
67
68 # Logging arguments
69 log_group = parser.add_argument_group("Logging")
70 log_group.add_argument("--log_dir", type=str, default="./logs")
71 log_group.add_argument("--exp_name", type=str, default="diffusion")
72 log_group.add_argument("--wandb", action="store_true")
73 log_group.add_argument("--sample_every", type=int, default=5)
74 log_group.add_argument("--save_every", type=int, default=10)
75
76 # Other
77 parser.add_argument("--seed", type=int, default=42)
78 parser.add_argument("--num_workers", type=int, default=4)
79 parser.add_argument("--resume", type=str, default=None)
80
81 return parser.parse_args()Model and Diffusion Setup
🐍python
1import torch
2import torch.nn as nn
3from torch.utils.data import DataLoader
4from torchvision import datasets, transforms
5import numpy as np
6
7def set_seed(seed: int):
8 """Set random seeds for reproducibility."""
9 torch.manual_seed(seed)
10 torch.cuda.manual_seed_all(seed)
11 np.random.seed(seed)
12
13
14def create_dataset(args):
15 """Create dataset based on arguments."""
16 if args.dataset == "cifar10":
17 transform = transforms.Compose([
18 transforms.RandomHorizontalFlip(),
19 transforms.ToTensor(),
20 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
21 ])
22 dataset = datasets.CIFAR10(
23 root=args.data_dir, train=True,
24 download=True, transform=transform,
25 )
26 in_channels = 3
27 elif args.dataset == "mnist":
28 transform = transforms.Compose([
29 transforms.ToTensor(),
30 transforms.Normalize((0.5,), (0.5,)),
31 ])
32 dataset = datasets.MNIST(
33 root=args.data_dir, train=True,
34 download=True, transform=transform,
35 )
36 in_channels = 1
37 else:
38 raise ValueError(f"Unknown dataset: {args.dataset}")
39 return dataset, in_channels
40
41
42class EMA:
43 """Exponential Moving Average of model parameters."""
44
45 def __init__(self, model: nn.Module, decay: float = 0.9999):
46 self.model = model
47 self.decay = decay
48 self.shadow = {}
49 self.backup = {}
50 for name, param in model.named_parameters():
51 if param.requires_grad:
52 self.shadow[name] = param.data.clone()
53
54 @torch.no_grad()
55 def update(self):
56 for name, param in self.model.named_parameters():
57 if param.requires_grad:
58 self.shadow[name] = (
59 self.decay * self.shadow[name] +
60 (1 - self.decay) * param.data
61 )
62
63 def apply_shadow(self):
64 for name, param in self.model.named_parameters():
65 if param.requires_grad:
66 self.backup[name] = param.data
67 param.data = self.shadow[name]
68
69 def restore(self):
70 for name, param in self.model.named_parameters():
71 if param.requires_grad:
72 param.data = self.backup[name]
73 self.backup = {}
74
75
76class GaussianDiffusion:
77 """Gaussian diffusion process."""
78
79 def __init__(self, timesteps=1000, schedule="cosine", device="cuda"):
80 self.timesteps = timesteps
81 self.device = device
82
83 if schedule == "linear":
84 betas = torch.linspace(1e-4, 0.02, timesteps)
85 elif schedule == "cosine":
86 s = 0.008
87 steps = torch.linspace(0, timesteps, timesteps + 1)
88 alphas_cumprod = torch.cos((steps / timesteps + s) / (1 + s) * np.pi / 2) ** 2
89 alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
90 betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
91 betas = torch.clamp(betas, 0.0001, 0.9999)
92 else:
93 betas = torch.linspace(1e-4**0.5, 0.02**0.5, timesteps) ** 2
94
95 self.betas = betas.to(device)
96 self.alphas = (1 - self.betas)
97 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
98 self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
99 self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
100
101 def add_noise(self, x0, noise, t):
102 sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
103 sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
104 return sqrt_alpha * x0 + sqrt_one_minus * noise
105
106 def loss(self, model, x0, t=None):
107 batch_size = x0.shape[0]
108 if t is None:
109 t = torch.randint(0, self.timesteps, (batch_size,), device=self.device)
110 noise = torch.randn_like(x0)
111 noisy_x = self.add_noise(x0, noise, t)
112 predicted = model(noisy_x, t)
113 return nn.functional.mse_loss(predicted, noise)Training Loop
🐍python
1from torch.cuda.amp import autocast, GradScaler
2from tqdm import tqdm
3
4def train_epoch(model, dataloader, optimizer, diffusion, scaler, args, ema=None):
5 """Train for one epoch."""
6 model.train()
7 total_loss = 0.0
8 num_batches = 0
9
10 pbar = tqdm(dataloader, desc="Training")
11 for batch_idx, batch in enumerate(pbar):
12 images = batch[0] if isinstance(batch, (list, tuple)) else batch
13 images = images.to(args.device)
14
15 with autocast(enabled=args.use_amp):
16 loss = diffusion.loss(model, images)
17 loss = loss / args.grad_accumulation
18
19 scaler.scale(loss).backward()
20
21 if (batch_idx + 1) % args.grad_accumulation == 0:
22 scaler.step(optimizer)
23 scaler.update()
24 optimizer.zero_grad()
25
26 if ema is not None:
27 ema.update()
28
29 total_loss += loss.item() * args.grad_accumulation
30 num_batches += 1
31 pbar.set_postfix({"loss": total_loss / num_batches})
32
33 return total_loss / num_batches
34
35
36@torch.no_grad()
37def generate_samples(model, diffusion, args, ema=None, num_samples=16):
38 """Generate sample images."""
39 model.eval()
40 if ema is not None:
41 ema.apply_shadow()
42
43 channels = 1 if args.dataset == "mnist" else 3
44 shape = (num_samples, channels, args.image_size, args.image_size)
45 x = torch.randn(shape, device=args.device)
46
47 for t in reversed(range(diffusion.timesteps)):
48 t_batch = torch.full((shape[0],), t, device=args.device, dtype=torch.long)
49
50 with autocast(enabled=args.use_amp):
51 pred_noise = model(x, t_batch)
52
53 alpha = diffusion.alphas[t]
54 alpha_bar = diffusion.alphas_cumprod[t]
55 beta = diffusion.betas[t]
56
57 noise = torch.randn_like(x) if t > 0 else 0
58 x = (1 / torch.sqrt(alpha)) * (
59 x - beta / torch.sqrt(1 - alpha_bar) * pred_noise
60 ) + torch.sqrt(beta) * noise
61
62 if ema is not None:
63 ema.restore()
64
65 model.train()
66 return x.clamp(-1, 1)Logging and Metrics
🐍python
1import wandb
2from torch.utils.tensorboard import SummaryWriter
3from pathlib import Path
4import torchvision.utils as vutils
5
6class Logger:
7 """Unified logger for W&B and TensorBoard."""
8
9 def __init__(self, args):
10 self.args = args
11 self.step = 0
12 self.log_dir = Path(args.log_dir) / args.exp_name
13 self.log_dir.mkdir(parents=True, exist_ok=True)
14
15 if args.wandb:
16 wandb.init(
17 project="diffusion",
18 name=args.exp_name,
19 config=vars(args),
20 )
21 self.use_wandb = True
22 else:
23 self.use_wandb = False
24
25 self.writer = SummaryWriter(log_dir=str(self.log_dir / "tensorboard"))
26
27 def log(self, metrics: dict):
28 self.step += 1
29 for key, value in metrics.items():
30 self.writer.add_scalar(key, value, self.step)
31 if self.use_wandb:
32 wandb.log(metrics, step=self.step)
33
34 def log_images(self, name: str, images):
35 grid = vutils.make_grid(images, nrow=4, normalize=True, value_range=(-1, 1))
36 self.writer.add_image(name, grid, self.step)
37 if self.use_wandb:
38 wandb.log({name: wandb.Image(grid)}, step=self.step)
39
40 save_path = self.log_dir / "samples" / f"{name}_{self.step:06d}.png"
41 save_path.parent.mkdir(parents=True, exist_ok=True)
42 vutils.save_image(grid, str(save_path))
43
44 def finish(self):
45 self.writer.close()
46 if self.use_wandb:
47 wandb.finish()Checkpointing
🐍python
1from pathlib import Path
2
3def save_checkpoint(epoch, model, optimizer, scheduler, ema, scaler, args):
4 """Save training checkpoint."""
5 ckpt_dir = Path(args.log_dir) / args.exp_name / "checkpoints"
6 ckpt_dir.mkdir(parents=True, exist_ok=True)
7
8 checkpoint = {
9 "epoch": epoch,
10 "model_state_dict": model.state_dict(),
11 "optimizer_state_dict": optimizer.state_dict(),
12 "args": vars(args),
13 }
14
15 if scheduler is not None:
16 checkpoint["scheduler_state_dict"] = scheduler.state_dict()
17 if scaler is not None:
18 checkpoint["scaler_state_dict"] = scaler.state_dict()
19 if ema is not None:
20 checkpoint["ema_shadow"] = ema.shadow
21
22 path = ckpt_dir / f"checkpoint_{epoch:04d}.pt"
23 torch.save(checkpoint, path)
24 torch.save(checkpoint, ckpt_dir / "latest.pt")
25 print(f"Saved checkpoint to {path}")
26
27
28def load_checkpoint(path, model, optimizer, scheduler, ema, scaler):
29 """Load checkpoint. Returns start epoch."""
30 ckpt = torch.load(path, map_location="cpu")
31
32 model.load_state_dict(ckpt["model_state_dict"])
33 optimizer.load_state_dict(ckpt["optimizer_state_dict"])
34
35 if scheduler and "scheduler_state_dict" in ckpt:
36 scheduler.load_state_dict(ckpt["scheduler_state_dict"])
37 if scaler and "scaler_state_dict" in ckpt:
38 scaler.load_state_dict(ckpt["scaler_state_dict"])
39 if ema and "ema_shadow" in ckpt:
40 ema.shadow = ckpt["ema_shadow"]
41
42 print(f"Resumed from epoch {ckpt['epoch']}")
43 return ckpt["epoch"] + 1Complete Training Script
Here's how all components come together in a complete script:
🐍python
1#!/usr/bin/env python3
2"""
3Complete diffusion model training script.
4
5Usage:
6 python train.py --dataset cifar10 --epochs 100 --use_amp --use_ema --wandb
7"""
8
9import torch
10import numpy as np
11from torch.cuda.amp import GradScaler
12from torch.utils.data import DataLoader
13
14# Import all the components defined above
15# from model import UNet
16# from utils import parse_args, create_dataset, GaussianDiffusion, EMA, Logger
17
18
19def main():
20 args = parse_args()
21
22 # Setup
23 args.device = "cuda" if torch.cuda.is_available() else "cpu"
24 set_seed(args.seed)
25 print(f"Training on {args.device}")
26
27 # Data
28 dataset, in_channels = create_dataset(args)
29 dataloader = DataLoader(
30 dataset,
31 batch_size=args.batch_size,
32 shuffle=True,
33 num_workers=args.num_workers,
34 pin_memory=True,
35 drop_last=True,
36 )
37
38 # Model (replace with your UNet)
39 model = UNet(
40 in_channels=in_channels,
41 base_channels=args.base_channels,
42 channel_multipliers=tuple(args.channel_mult),
43 attention_resolutions=tuple(args.attention_resolutions),
44 ).to(args.device)
45
46 print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
47
48 # Diffusion
49 diffusion = GaussianDiffusion(
50 timesteps=args.timesteps,
51 schedule=args.schedule,
52 device=args.device,
53 )
54
55 # Optimizer
56 optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
57
58 # EMA
59 ema = EMA(model, args.ema_decay) if args.use_ema else None
60
61 # AMP scaler
62 scaler = GradScaler(enabled=args.use_amp)
63
64 # Logger
65 logger = Logger(args)
66
67 # Resume if specified
68 start_epoch = 0
69 if args.resume:
70 start_epoch = load_checkpoint(args.resume, model, optimizer, None, ema, scaler)
71
72 # Training loop
73 print(f"Starting training for {args.epochs} epochs...")
74
75 for epoch in range(start_epoch, args.epochs):
76 loss = train_epoch(model, dataloader, optimizer, diffusion, scaler, args, ema)
77
78 logger.log({"epoch": epoch, "loss": loss})
79 print(f"Epoch {epoch + 1}/{args.epochs}: Loss = {loss:.4f}")
80
81 if (epoch + 1) % args.sample_every == 0:
82 samples = generate_samples(model, diffusion, args, ema)
83 logger.log_images("samples", samples)
84
85 if (epoch + 1) % args.save_every == 0:
86 save_checkpoint(epoch, model, optimizer, None, ema, scaler, args)
87
88 # Final save
89 save_checkpoint(args.epochs - 1, model, optimizer, None, ema, scaler, args)
90 logger.finish()
91 print("Training complete!")
92
93
94if __name__ == "__main__":
95 main()Running the Script
Save all components and run with:
⚡bash
1python train.py --dataset cifar10 --epochs 100 --use_amp --use_ema --wandbKey Takeaways
- Use argparse: Make your script configurable from the command line with sensible defaults.
- Implement EMA: Exponential moving average typically produces better generation quality than final weights.
- Log everything: Use W&B or TensorBoard to track loss curves and generated samples.
- Checkpoint regularly: Save model, optimizer, and EMA weights so you can resume after interruptions.
- Set seeds: For reproducibility, set random seeds for PyTorch, NumPy, and Python.
Looking Ahead: In the next section, we'll dive deep into training monitoring - tracking FID during training, visualizing samples, and building monitoring dashboards.