Chapter 10
15 min read
Section 50 of 76

Building the Training Pipeline

Dataset Preparation

Learning Objectives

By the end of this section, you will be able to:

  1. Configure DataLoaders for optimal GPU utilization with proper batch sizes, workers, and prefetching
  2. Optimize memory usage through gradient accumulation, mixed precision, and efficient data loading
  3. Set up distributed data loading for multi-GPU training with proper synchronization
  4. Build a complete, production-ready training pipeline that scales from single GPU to multi-node clusters

The Big Picture

A well-designed training pipeline is the backbone of successful diffusion model training. The difference between a naive pipeline and an optimized one can mean the difference between weeks of training and days - or between a model that fits in memory and one that crashes.

The Goal: Keep the GPU at 100% utilization. Every moment the GPU waits for data is wasted compute. Your pipeline should load and preprocess the next batch while the GPU processes the current one.

In this section, we'll build a training pipeline that maximizes efficiency across all hardware configurations, from a single consumer GPU to a cluster of A100s.


DataLoader Configuration

Essential DataLoader Parameters

The PyTorch DataLoader has many parameters that significantly affect training performance:

🐍python
1from torch.utils.data import DataLoader
2from typing import Optional
3
4def create_training_dataloader(
5    dataset,
6    batch_size: int = 32,
7    num_workers: int = 4,
8    pin_memory: bool = True,
9    persistent_workers: bool = True,
10    prefetch_factor: int = 2,
11) -> DataLoader:
12    """
13    Create an optimized DataLoader for diffusion model training.
14
15    Args:
16        dataset: PyTorch Dataset object
17        batch_size: Number of samples per batch
18        num_workers: Number of parallel data loading processes
19        pin_memory: Pin memory for faster CPU->GPU transfer
20        persistent_workers: Keep workers alive between epochs
21        prefetch_factor: Number of batches to prefetch per worker
22
23    Returns:
24        Configured DataLoader
25    """
26    return DataLoader(
27        dataset,
28        batch_size=batch_size,
29        shuffle=True,              # Shuffle for training
30        num_workers=num_workers,
31        pin_memory=pin_memory,     # Speed up GPU transfer
32        drop_last=True,            # Drop incomplete final batch
33        persistent_workers=persistent_workers and num_workers > 0,
34        prefetch_factor=prefetch_factor if num_workers > 0 else None,
35    )

Choosing the Right Batch Size

Batch size affects both training dynamics and memory usage:

ResolutionGPU MemoryRecommended Batch Size
32x32 (CIFAR)8 GB128-256
32x32 (CIFAR)16 GB256-512
64x648 GB32-64
64x6416 GB64-128
128x12816 GB16-32
256x25624 GB8-16
256x25640 GB16-32
512x51280 GB4-8
🐍python
1import torch
2
3def estimate_batch_size(
4    model,
5    image_size: int,
6    channels: int = 3,
7    safety_factor: float = 0.8,
8    device: str = "cuda",
9) -> int:
10    """
11    Estimate maximum batch size that fits in GPU memory.
12
13    This is a rough estimate - actual usage depends on model architecture,
14    optimizer state, and activation checkpointing.
15
16    Args:
17        model: The diffusion model
18        image_size: Target image size (square)
19        channels: Number of image channels (3 for RGB)
20        safety_factor: Fraction of GPU memory to use (0.8 = 80%)
21        device: Device to run on
22
23    Returns:
24        Estimated batch size
25    """
26    # Get available GPU memory
27    if device == "cuda":
28        gpu_memory = torch.cuda.get_device_properties(0).total_memory
29    else:
30        return 32  # Default for CPU
31
32    # Count model parameters
33    param_memory = sum(p.numel() * 4 for p in model.parameters())  # float32
34
35    # Optimizer memory (Adam uses 2x param memory for moments)
36    optimizer_memory = param_memory * 2
37
38    # Gradient memory
39    gradient_memory = param_memory
40
41    # Fixed overhead
42    fixed_memory = param_memory + optimizer_memory + gradient_memory
43
44    # Available for activations and data
45    available = (gpu_memory * safety_factor) - fixed_memory
46
47    # Estimate per-sample memory (very rough)
48    # Images: channels * height * width * 4 bytes
49    image_memory = channels * image_size * image_size * 4
50
51    # Activations are typically 10-50x the input size for U-Net
52    activation_factor = 30  # Conservative estimate
53    per_sample_memory = image_memory * activation_factor
54
55    # Estimate batch size
56    batch_size = int(available / per_sample_memory)
57
58    # Clamp to reasonable range
59    batch_size = max(1, min(batch_size, 512))
60
61    # Round down to power of 2 for efficiency
62    batch_size = 2 ** int(torch.log2(torch.tensor(batch_size, dtype=torch.float32)))
63
64    return batch_size
65
66
67# Usage
68# batch_size = estimate_batch_size(model, image_size=64)
69# print(f"Estimated batch size: {batch_size}")

Optimal Number of Workers

The number of data loading workers depends on your CPU cores and I/O speed:

🐍python
1import os
2import multiprocessing
3
4def get_optimal_workers(
5    max_workers: int = 8,
6    workers_per_gpu: int = 4,
7) -> int:
8    """
9    Determine optimal number of DataLoader workers.
10
11    Rules of thumb:
12    - Start with 4 workers per GPU
13    - Don't exceed available CPU cores
14    - For NVMe SSDs, more workers help; for HDDs, fewer is better
15
16    Args:
17        max_workers: Maximum workers to use
18        workers_per_gpu: Target workers per GPU
19
20    Returns:
21        Recommended number of workers
22    """
23    # Get CPU count
24    cpu_count = multiprocessing.cpu_count()
25
26    # Get GPU count (if available)
27    if torch.cuda.is_available():
28        gpu_count = torch.cuda.device_count()
29    else:
30        gpu_count = 1
31
32    # Calculate target
33    target_workers = workers_per_gpu * gpu_count
34
35    # Clamp to available CPUs and max
36    workers = min(target_workers, cpu_count - 1, max_workers)
37
38    return max(0, workers)
39
40
41# Example
42num_workers = get_optimal_workers()
43print(f"Using {num_workers} data loading workers")

Memory Optimization

Gradient Accumulation

When your target batch size doesn't fit in memory, use gradient accumulation to simulate larger batches:

🐍python
1import torch
2from typing import Iterator
3from torch.utils.data import DataLoader
4
5def train_with_gradient_accumulation(
6    model,
7    dataloader: DataLoader,
8    optimizer,
9    criterion,
10    accumulation_steps: int = 4,
11    device: str = "cuda",
12) -> float:
13    """
14    Training loop with gradient accumulation.
15
16    Effective batch size = batch_size * accumulation_steps
17
18    Args:
19        model: The diffusion model
20        dataloader: Training data loader
21        optimizer: Optimizer
22        criterion: Loss function
23        accumulation_steps: Number of steps to accumulate
24        device: Device to train on
25
26    Returns:
27        Average loss for the epoch
28    """
29    model.train()
30    total_loss = 0.0
31    num_batches = 0
32
33    for batch_idx, batch in enumerate(dataloader):
34        # Move data to device
35        images = batch.to(device)
36
37        # Sample timesteps and noise
38        batch_size = images.shape[0]
39        t = torch.randint(0, 1000, (batch_size,), device=device)
40        noise = torch.randn_like(images)
41
42        # Forward pass
43        # (This is simplified - actual diffusion training is more complex)
44        predicted_noise = model(images, t)
45        loss = criterion(predicted_noise, noise)
46
47        # Normalize loss by accumulation steps
48        loss = loss / accumulation_steps
49        loss.backward()
50
51        # Only step optimizer every accumulation_steps
52        if (batch_idx + 1) % accumulation_steps == 0:
53            optimizer.step()
54            optimizer.zero_grad()
55
56        total_loss += loss.item() * accumulation_steps
57        num_batches += 1
58
59    # Handle remaining gradients at end of epoch
60    if (batch_idx + 1) % accumulation_steps != 0:
61        optimizer.step()
62        optimizer.zero_grad()
63
64    return total_loss / num_batches
65
66
67# Example: Train with effective batch size of 128 using 4x32 accumulation
68# train_with_gradient_accumulation(
69#     model, dataloader, optimizer, criterion,
70#     accumulation_steps=4,  # 32 * 4 = 128 effective batch size
71# )

Mixed Precision Training

Mixed precision (FP16/BF16) nearly doubles training speed and halves memory usage on modern GPUs:

🐍python
1import torch
2from torch.cuda.amp import autocast, GradScaler
3
4def train_epoch_mixed_precision(
5    model,
6    dataloader,
7    optimizer,
8    diffusion,
9    device: str = "cuda",
10    scaler: GradScaler = None,
11) -> float:
12    """
13    Train one epoch with mixed precision.
14
15    Mixed precision:
16    - Forward pass uses FP16/BF16 for speed
17    - Gradients are scaled to prevent underflow
18    - Optimizer step uses FP32 for stability
19
20    Args:
21        model: The diffusion model
22        dataloader: Training data loader
23        optimizer: Optimizer
24        diffusion: Diffusion scheduler with loss computation
25        device: Device to train on
26        scaler: GradScaler for mixed precision
27
28    Returns:
29        Average loss for the epoch
30    """
31    if scaler is None:
32        scaler = GradScaler()
33
34    model.train()
35    total_loss = 0.0
36
37    for batch in dataloader:
38        images = batch.to(device)
39        batch_size = images.shape[0]
40
41        # Sample timesteps
42        t = torch.randint(0, 1000, (batch_size,), device=device)
43
44        optimizer.zero_grad()
45
46        # Forward pass with autocast
47        with autocast():
48            # Sample noise and create noisy images
49            noise = torch.randn_like(images)
50            noisy_images = diffusion.add_noise(images, noise, t)
51
52            # Predict noise
53            predicted_noise = model(noisy_images, t)
54
55            # Compute loss
56            loss = torch.nn.functional.mse_loss(predicted_noise, noise)
57
58        # Backward pass with scaled gradients
59        scaler.scale(loss).backward()
60
61        # Unscale and step
62        scaler.step(optimizer)
63        scaler.update()
64
65        total_loss += loss.item()
66
67    return total_loss / len(dataloader)
68
69
70# Initialize training with mixed precision
71scaler = GradScaler()
72
73# Training loop
74for epoch in range(num_epochs):
75    loss = train_epoch_mixed_precision(
76        model, train_loader, optimizer, diffusion,
77        scaler=scaler,
78    )
79    print(f"Epoch {epoch}: Loss = {loss:.4f}")

BFloat16 vs Float16

On Ampere GPUs (RTX 30xx, A100) and newer, use torch.bfloat16\texttt{torch.bfloat16}instead of float16. BF16 has a larger dynamic range and rarely requires gradient scaling.

Distributed Training Setup

DistributedDataParallel (DDP)

For multi-GPU training, PyTorch's DistributedDataParallel is the standard approach:

🐍python
1import torch
2import torch.distributed as dist
3from torch.nn.parallel import DistributedDataParallel as DDP
4from torch.utils.data.distributed import DistributedSampler
5import os
6
7def setup_distributed(rank: int, world_size: int):
8    """
9    Initialize distributed training.
10
11    Args:
12        rank: This process's rank (0 to world_size-1)
13        world_size: Total number of processes
14    """
15    os.environ['MASTER_ADDR'] = 'localhost'
16    os.environ['MASTER_PORT'] = '12355'
17
18    # Initialize process group
19    dist.init_process_group(
20        backend='nccl',  # Use NCCL for GPUs
21        rank=rank,
22        world_size=world_size,
23    )
24
25    # Set device for this process
26    torch.cuda.set_device(rank)
27
28
29def cleanup_distributed():
30    """Clean up distributed training."""
31    dist.destroy_process_group()
32
33
34def create_distributed_dataloader(
35    dataset,
36    batch_size: int,
37    num_workers: int = 4,
38) -> DataLoader:
39    """
40    Create a DataLoader for distributed training.
41
42    The DistributedSampler ensures each GPU sees different data.
43    """
44    sampler = DistributedSampler(
45        dataset,
46        shuffle=True,
47        drop_last=True,
48    )
49
50    return DataLoader(
51        dataset,
52        batch_size=batch_size,
53        sampler=sampler,  # Use sampler instead of shuffle
54        num_workers=num_workers,
55        pin_memory=True,
56        drop_last=True,
57    )
58
59
60def train_distributed(rank: int, world_size: int, args):
61    """
62    Main training function for each process.
63
64    Args:
65        rank: This process's rank
66        world_size: Total number of processes
67        args: Training arguments
68    """
69    setup_distributed(rank, world_size)
70
71    # Create model and move to GPU
72    model = create_model(args).to(rank)
73
74    # Wrap model with DDP
75    model = DDP(model, device_ids=[rank])
76
77    # Create distributed dataloader
78    dataset = create_dataset(args)
79    dataloader = create_distributed_dataloader(
80        dataset,
81        batch_size=args.batch_size,  # Per-GPU batch size
82    )
83
84    # Create optimizer
85    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
86
87    # Training loop
88    for epoch in range(args.epochs):
89        # IMPORTANT: Set epoch for sampler to shuffle differently each epoch
90        dataloader.sampler.set_epoch(epoch)
91
92        for batch in dataloader:
93            # Training step...
94            pass
95
96        # Only save checkpoint from rank 0
97        if rank == 0:
98            save_checkpoint(model.module, optimizer, epoch)
99
100    cleanup_distributed()
101
102
103# Launch distributed training
104if __name__ == "__main__":
105    import torch.multiprocessing as mp
106
107    world_size = torch.cuda.device_count()
108    mp.spawn(
109        train_distributed,
110        args=(world_size, args),
111        nprocs=world_size,
112        join=True,
113    )

Using torchrun for Distributed Training

The modern way to launch distributed training is with torchrun\texttt{torchrun}:

bash
1# Single node, 4 GPUs
2torchrun --nproc_per_node=4 train.py --batch_size=32
3
4# Multi-node (2 nodes, 8 GPUs each)
5# On node 0:
6torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 \
7    --master_addr="192.168.1.1" --master_port=12355 train.py
8
9# On node 1:
10torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 \
11    --master_addr="192.168.1.1" --master_port=12355 train.py
🐍python
1# In train.py, use environment variables set by torchrun
2import os
3import torch.distributed as dist
4
5def setup_from_env():
6    """
7    Setup distributed training from torchrun environment variables.
8    """
9    dist.init_process_group(backend='nccl')
10
11    local_rank = int(os.environ['LOCAL_RANK'])
12    torch.cuda.set_device(local_rank)
13
14    return local_rank
15
16
17# Training script compatible with torchrun
18def main():
19    local_rank = setup_from_env()
20    world_size = dist.get_world_size()
21
22    print(f"Running on rank {local_rank} of {world_size}")
23
24    # ... rest of training code ...
25
26if __name__ == "__main__":
27    main()

Efficient Data Loading

Prefetching and Pipelining

The DataLoader automatically prefetches data, but you can optimize further:

🐍python
1class PrefetchLoader:
2    """
3    Wrapper that prefetches data to GPU asynchronously.
4
5    This overlaps CPU->GPU transfer with GPU computation.
6    """
7
8    def __init__(self, loader, device: str = "cuda"):
9        self.loader = loader
10        self.device = device
11        self.stream = torch.cuda.Stream()
12
13    def __iter__(self):
14        first = True
15        batch = None
16
17        for next_batch in self.loader:
18            with torch.cuda.stream(self.stream):
19                if isinstance(next_batch, torch.Tensor):
20                    next_batch = next_batch.to(
21                        self.device, non_blocking=True
22                    )
23                elif isinstance(next_batch, (list, tuple)):
24                    next_batch = [
25                        b.to(self.device, non_blocking=True)
26                        if isinstance(b, torch.Tensor) else b
27                        for b in next_batch
28                    ]
29
30            if not first:
31                yield batch
32            else:
33                first = False
34
35            torch.cuda.current_stream().wait_stream(self.stream)
36            batch = next_batch
37
38        yield batch
39
40    def __len__(self):
41        return len(self.loader)
42
43
44# Usage
45train_loader = DataLoader(dataset, batch_size=32, num_workers=4)
46prefetch_loader = PrefetchLoader(train_loader, device="cuda")
47
48for batch in prefetch_loader:
49    # batch is already on GPU
50    # ... training step ...
51    pass

Memory-Mapped Datasets

For very large datasets, memory-mapping can reduce memory usage:

🐍python
1import numpy as np
2from torch.utils.data import Dataset
3import torch
4
5class MemoryMappedDataset(Dataset):
6    """
7    Dataset that uses memory-mapped numpy arrays.
8
9    Memory mapping lets you work with datasets larger than RAM
10    by loading only the accessed portions into memory.
11    """
12
13    def __init__(
14        self,
15        data_path: str,
16        transform=None,
17    ):
18        """
19        Args:
20            data_path: Path to .npy file with shape [N, C, H, W]
21            transform: Optional transform to apply
22        """
23        # Memory-map the array (not loaded into RAM)
24        self.data = np.load(data_path, mmap_mode='r')
25        self.transform = transform
26
27    def __len__(self):
28        return len(self.data)
29
30    def __getitem__(self, idx):
31        # Only this sample is loaded into memory
32        image = self.data[idx].copy()  # copy() to avoid modifying mmap
33
34        # Convert to tensor
35        image = torch.from_numpy(image).float()
36
37        # Normalize to [-1, 1] if not already
38        if image.max() > 1:
39            image = image / 127.5 - 1.0
40
41        if self.transform:
42            image = self.transform(image)
43
44        return image
45
46
47# Create the memory-mapped file (do this once)
48def create_memmap_dataset(image_folder: str, output_path: str, image_size: int):
49    """Preprocess images and save as memory-mapped numpy array."""
50    from PIL import Image
51    from pathlib import Path
52    from torchvision import transforms
53
54    transform = transforms.Compose([
55        transforms.Resize(image_size),
56        transforms.CenterCrop(image_size),
57        transforms.ToTensor(),
58    ])
59
60    # Find all images
61    image_paths = list(Path(image_folder).rglob("*.jpg"))
62    n_images = len(image_paths)
63
64    # Create memory-mapped output
65    shape = (n_images, 3, image_size, image_size)
66    mmap = np.lib.format.open_memmap(
67        output_path, mode='w+', dtype=np.float32, shape=shape
68    )
69
70    # Process and save images
71    for i, path in enumerate(image_paths):
72        image = Image.open(path).convert('RGB')
73        tensor = transform(image)
74        mmap[i] = tensor.numpy()
75
76        if i % 1000 == 0:
77            print(f"Processed {i}/{n_images} images")
78
79    mmap.flush()
80    print(f"Saved to {output_path}")

Complete Training Pipeline

Here's a complete, production-ready training pipeline that incorporates all the techniques we've discussed:

🐍python
1import torch
2import torch.nn as nn
3import torch.distributed as dist
4from torch.nn.parallel import DistributedDataParallel as DDP
5from torch.utils.data import DataLoader
6from torch.utils.data.distributed import DistributedSampler
7from torch.cuda.amp import autocast, GradScaler
8from dataclasses import dataclass
9from typing import Optional
10import os
11
12
13@dataclass
14class TrainingConfig:
15    """Training configuration."""
16    # Data
17    data_dir: str = "./data"
18    image_size: int = 64
19    channels: int = 3
20
21    # Training
22    batch_size: int = 32         # Per-GPU batch size
23    accumulation_steps: int = 1  # Gradient accumulation
24    num_epochs: int = 100
25    learning_rate: float = 1e-4
26    warmup_steps: int = 1000
27
28    # Model
29    timesteps: int = 1000
30
31    # Optimization
32    use_amp: bool = True         # Mixed precision
33    num_workers: int = 4
34
35    # Distributed
36    distributed: bool = False
37
38    # Checkpointing
39    checkpoint_dir: str = "./checkpoints"
40    save_every: int = 10         # Save every N epochs
41
42
43class TrainingPipeline:
44    """
45    Complete training pipeline for diffusion models.
46
47    Features:
48    - Distributed training support
49    - Mixed precision training
50    - Gradient accumulation
51    - Automatic checkpointing
52    - Learning rate warmup
53    """
54
55    def __init__(self, config: TrainingConfig):
56        self.config = config
57        self.device = self._setup_device()
58        self.scaler = GradScaler() if config.use_amp else None
59
60        # Will be set in setup()
61        self.model = None
62        self.optimizer = None
63        self.scheduler = None
64        self.train_loader = None
65        self.diffusion = None
66
67    def _setup_device(self) -> torch.device:
68        """Setup device and distributed training if needed."""
69        if self.config.distributed:
70            dist.init_process_group(backend='nccl')
71            local_rank = int(os.environ.get('LOCAL_RANK', 0))
72            torch.cuda.set_device(local_rank)
73            return torch.device(f'cuda:{local_rank}')
74        elif torch.cuda.is_available():
75            return torch.device('cuda')
76        else:
77            return torch.device('cpu')
78
79    @property
80    def is_main_process(self) -> bool:
81        """Check if this is the main process (for logging/saving)."""
82        if self.config.distributed:
83            return dist.get_rank() == 0
84        return True
85
86    @property
87    def world_size(self) -> int:
88        """Get number of processes."""
89        if self.config.distributed:
90            return dist.get_world_size()
91        return 1
92
93    def setup(self, model, dataset, diffusion):
94        """
95        Setup training components.
96
97        Args:
98            model: The diffusion model (nn.Module)
99            dataset: Training dataset
100            diffusion: Diffusion scheduler
101        """
102        self.diffusion = diffusion
103
104        # Move model to device
105        self.model = model.to(self.device)
106
107        # Wrap in DDP if distributed
108        if self.config.distributed:
109            self.model = DDP(
110                self.model,
111                device_ids=[self.device.index],
112                output_device=self.device.index,
113            )
114
115        # Create dataloader
116        if self.config.distributed:
117            sampler = DistributedSampler(dataset, shuffle=True)
118            self.train_loader = DataLoader(
119                dataset,
120                batch_size=self.config.batch_size,
121                sampler=sampler,
122                num_workers=self.config.num_workers,
123                pin_memory=True,
124                drop_last=True,
125            )
126        else:
127            self.train_loader = DataLoader(
128                dataset,
129                batch_size=self.config.batch_size,
130                shuffle=True,
131                num_workers=self.config.num_workers,
132                pin_memory=True,
133                drop_last=True,
134                persistent_workers=self.config.num_workers > 0,
135            )
136
137        # Create optimizer
138        self.optimizer = torch.optim.AdamW(
139            self.model.parameters(),
140            lr=self.config.learning_rate,
141            betas=(0.9, 0.999),
142            weight_decay=0.01,
143        )
144
145        # Create scheduler with warmup
146        total_steps = (
147            len(self.train_loader) *
148            self.config.num_epochs //
149            self.config.accumulation_steps
150        )
151        self.scheduler = self._create_scheduler(total_steps)
152
153    def _create_scheduler(self, total_steps: int):
154        """Create learning rate scheduler with warmup."""
155        from torch.optim.lr_scheduler import LambdaLR
156
157        warmup_steps = self.config.warmup_steps
158
159        def lr_lambda(step):
160            if step < warmup_steps:
161                return step / warmup_steps
162            return 1.0  # Constant after warmup
163
164        return LambdaLR(self.optimizer, lr_lambda)
165
166    def train_epoch(self, epoch: int) -> float:
167        """Train for one epoch."""
168        self.model.train()
169
170        if self.config.distributed:
171            self.train_loader.sampler.set_epoch(epoch)
172
173        total_loss = 0.0
174        num_steps = 0
175
176        for batch_idx, batch in enumerate(self.train_loader):
177            loss = self._train_step(batch, batch_idx)
178            total_loss += loss
179            num_steps += 1
180
181        return total_loss / num_steps
182
183    def _train_step(self, batch, batch_idx: int) -> float:
184        """Execute one training step."""
185        images = batch.to(self.device)
186        batch_size = images.shape[0]
187
188        # Sample timesteps
189        t = torch.randint(
190            0, self.config.timesteps, (batch_size,),
191            device=self.device
192        )
193
194        # Sample noise
195        noise = torch.randn_like(images)
196
197        # Add noise to images
198        noisy_images = self.diffusion.add_noise(images, noise, t)
199
200        # Forward pass with optional autocast
201        with autocast(enabled=self.config.use_amp):
202            predicted_noise = self.model(noisy_images, t)
203            loss = nn.functional.mse_loss(predicted_noise, noise)
204            loss = loss / self.config.accumulation_steps
205
206        # Backward pass
207        if self.scaler is not None:
208            self.scaler.scale(loss).backward()
209        else:
210            loss.backward()
211
212        # Step optimizer (with accumulation)
213        if (batch_idx + 1) % self.config.accumulation_steps == 0:
214            if self.scaler is not None:
215                self.scaler.step(self.optimizer)
216                self.scaler.update()
217            else:
218                self.optimizer.step()
219
220            self.optimizer.zero_grad()
221            self.scheduler.step()
222
223        return loss.item() * self.config.accumulation_steps
224
225    def save_checkpoint(self, epoch: int, loss: float):
226        """Save training checkpoint."""
227        if not self.is_main_process:
228            return
229
230        os.makedirs(self.config.checkpoint_dir, exist_ok=True)
231
232        # Get model state (unwrap DDP if needed)
233        model_state = (
234            self.model.module.state_dict()
235            if self.config.distributed
236            else self.model.state_dict()
237        )
238
239        checkpoint = {
240            'epoch': epoch,
241            'model_state_dict': model_state,
242            'optimizer_state_dict': self.optimizer.state_dict(),
243            'scheduler_state_dict': self.scheduler.state_dict(),
244            'loss': loss,
245            'config': self.config,
246        }
247
248        if self.scaler is not None:
249            checkpoint['scaler_state_dict'] = self.scaler.state_dict()
250
251        path = os.path.join(
252            self.config.checkpoint_dir,
253            f'checkpoint_epoch_{epoch:04d}.pt'
254        )
255        torch.save(checkpoint, path)
256
257        if self.is_main_process:
258            print(f"Saved checkpoint to {path}")
259
260    def train(self):
261        """Main training loop."""
262        for epoch in range(self.config.num_epochs):
263            loss = self.train_epoch(epoch)
264
265            if self.is_main_process:
266                print(f"Epoch {epoch+1}/{self.config.num_epochs}, Loss: {loss:.4f}")
267
268            if (epoch + 1) % self.config.save_every == 0:
269                self.save_checkpoint(epoch + 1, loss)
270
271        # Final checkpoint
272        self.save_checkpoint(self.config.num_epochs, loss)
273
274        if self.config.distributed:
275            dist.destroy_process_group()
276
277
278# Usage example
279if __name__ == "__main__":
280    config = TrainingConfig(
281        data_dir="./data/cifar10",
282        image_size=32,
283        batch_size=64,
284        num_epochs=100,
285    )
286
287    pipeline = TrainingPipeline(config)
288
289    # Create your model, dataset, and diffusion scheduler
290    # model = create_model(...)
291    # dataset = create_dataset(...)
292    # diffusion = create_diffusion(...)
293
294    # pipeline.setup(model, dataset, diffusion)
295    # pipeline.train()

Key Takeaways

  1. Optimize DataLoader: Use pin_memory=True\texttt{pin\_memory=True}, multiple workers, and persistent workers for maximum GPU utilization.
  2. Use mixed precision: FP16/BF16 nearly doubles throughput and halves memory usage with minimal quality impact.
  3. Enable gradient accumulation: Simulate larger batch sizes when GPU memory is limited by accumulating gradients.
  4. Scale with DDP: Use DistributedDataParallel and torchrun for multi-GPU training with near-linear scaling.
  5. Prefetch data: Overlap data loading with GPU computation to eliminate I/O bottlenecks.
Looking Ahead: In the next chapter, we'll configure the model architecture, including channel multipliers, attention resolutions, and other hyperparameters that determine model capacity and training cost.