Learning Objectives
By the end of this section, you will be able to:
- Configure DataLoaders for optimal GPU utilization with proper batch sizes, workers, and prefetching
- Optimize memory usage through gradient accumulation, mixed precision, and efficient data loading
- Set up distributed data loading for multi-GPU training with proper synchronization
- Build a complete, production-ready training pipeline that scales from single GPU to multi-node clusters
The Big Picture
A well-designed training pipeline is the backbone of successful diffusion model training. The difference between a naive pipeline and an optimized one can mean the difference between weeks of training and days - or between a model that fits in memory and one that crashes.
The Goal: Keep the GPU at 100% utilization. Every moment the GPU waits for data is wasted compute. Your pipeline should load and preprocess the next batch while the GPU processes the current one.
In this section, we'll build a training pipeline that maximizes efficiency across all hardware configurations, from a single consumer GPU to a cluster of A100s.
DataLoader Configuration
Essential DataLoader Parameters
The PyTorch DataLoader has many parameters that significantly affect training performance:
1from torch.utils.data import DataLoader
2from typing import Optional
3
4def create_training_dataloader(
5 dataset,
6 batch_size: int = 32,
7 num_workers: int = 4,
8 pin_memory: bool = True,
9 persistent_workers: bool = True,
10 prefetch_factor: int = 2,
11) -> DataLoader:
12 """
13 Create an optimized DataLoader for diffusion model training.
14
15 Args:
16 dataset: PyTorch Dataset object
17 batch_size: Number of samples per batch
18 num_workers: Number of parallel data loading processes
19 pin_memory: Pin memory for faster CPU->GPU transfer
20 persistent_workers: Keep workers alive between epochs
21 prefetch_factor: Number of batches to prefetch per worker
22
23 Returns:
24 Configured DataLoader
25 """
26 return DataLoader(
27 dataset,
28 batch_size=batch_size,
29 shuffle=True, # Shuffle for training
30 num_workers=num_workers,
31 pin_memory=pin_memory, # Speed up GPU transfer
32 drop_last=True, # Drop incomplete final batch
33 persistent_workers=persistent_workers and num_workers > 0,
34 prefetch_factor=prefetch_factor if num_workers > 0 else None,
35 )Choosing the Right Batch Size
Batch size affects both training dynamics and memory usage:
| Resolution | GPU Memory | Recommended Batch Size |
|---|---|---|
| 32x32 (CIFAR) | 8 GB | 128-256 |
| 32x32 (CIFAR) | 16 GB | 256-512 |
| 64x64 | 8 GB | 32-64 |
| 64x64 | 16 GB | 64-128 |
| 128x128 | 16 GB | 16-32 |
| 256x256 | 24 GB | 8-16 |
| 256x256 | 40 GB | 16-32 |
| 512x512 | 80 GB | 4-8 |
1import torch
2
3def estimate_batch_size(
4 model,
5 image_size: int,
6 channels: int = 3,
7 safety_factor: float = 0.8,
8 device: str = "cuda",
9) -> int:
10 """
11 Estimate maximum batch size that fits in GPU memory.
12
13 This is a rough estimate - actual usage depends on model architecture,
14 optimizer state, and activation checkpointing.
15
16 Args:
17 model: The diffusion model
18 image_size: Target image size (square)
19 channels: Number of image channels (3 for RGB)
20 safety_factor: Fraction of GPU memory to use (0.8 = 80%)
21 device: Device to run on
22
23 Returns:
24 Estimated batch size
25 """
26 # Get available GPU memory
27 if device == "cuda":
28 gpu_memory = torch.cuda.get_device_properties(0).total_memory
29 else:
30 return 32 # Default for CPU
31
32 # Count model parameters
33 param_memory = sum(p.numel() * 4 for p in model.parameters()) # float32
34
35 # Optimizer memory (Adam uses 2x param memory for moments)
36 optimizer_memory = param_memory * 2
37
38 # Gradient memory
39 gradient_memory = param_memory
40
41 # Fixed overhead
42 fixed_memory = param_memory + optimizer_memory + gradient_memory
43
44 # Available for activations and data
45 available = (gpu_memory * safety_factor) - fixed_memory
46
47 # Estimate per-sample memory (very rough)
48 # Images: channels * height * width * 4 bytes
49 image_memory = channels * image_size * image_size * 4
50
51 # Activations are typically 10-50x the input size for U-Net
52 activation_factor = 30 # Conservative estimate
53 per_sample_memory = image_memory * activation_factor
54
55 # Estimate batch size
56 batch_size = int(available / per_sample_memory)
57
58 # Clamp to reasonable range
59 batch_size = max(1, min(batch_size, 512))
60
61 # Round down to power of 2 for efficiency
62 batch_size = 2 ** int(torch.log2(torch.tensor(batch_size, dtype=torch.float32)))
63
64 return batch_size
65
66
67# Usage
68# batch_size = estimate_batch_size(model, image_size=64)
69# print(f"Estimated batch size: {batch_size}")Optimal Number of Workers
The number of data loading workers depends on your CPU cores and I/O speed:
1import os
2import multiprocessing
3
4def get_optimal_workers(
5 max_workers: int = 8,
6 workers_per_gpu: int = 4,
7) -> int:
8 """
9 Determine optimal number of DataLoader workers.
10
11 Rules of thumb:
12 - Start with 4 workers per GPU
13 - Don't exceed available CPU cores
14 - For NVMe SSDs, more workers help; for HDDs, fewer is better
15
16 Args:
17 max_workers: Maximum workers to use
18 workers_per_gpu: Target workers per GPU
19
20 Returns:
21 Recommended number of workers
22 """
23 # Get CPU count
24 cpu_count = multiprocessing.cpu_count()
25
26 # Get GPU count (if available)
27 if torch.cuda.is_available():
28 gpu_count = torch.cuda.device_count()
29 else:
30 gpu_count = 1
31
32 # Calculate target
33 target_workers = workers_per_gpu * gpu_count
34
35 # Clamp to available CPUs and max
36 workers = min(target_workers, cpu_count - 1, max_workers)
37
38 return max(0, workers)
39
40
41# Example
42num_workers = get_optimal_workers()
43print(f"Using {num_workers} data loading workers")Memory Optimization
Gradient Accumulation
When your target batch size doesn't fit in memory, use gradient accumulation to simulate larger batches:
1import torch
2from typing import Iterator
3from torch.utils.data import DataLoader
4
5def train_with_gradient_accumulation(
6 model,
7 dataloader: DataLoader,
8 optimizer,
9 criterion,
10 accumulation_steps: int = 4,
11 device: str = "cuda",
12) -> float:
13 """
14 Training loop with gradient accumulation.
15
16 Effective batch size = batch_size * accumulation_steps
17
18 Args:
19 model: The diffusion model
20 dataloader: Training data loader
21 optimizer: Optimizer
22 criterion: Loss function
23 accumulation_steps: Number of steps to accumulate
24 device: Device to train on
25
26 Returns:
27 Average loss for the epoch
28 """
29 model.train()
30 total_loss = 0.0
31 num_batches = 0
32
33 for batch_idx, batch in enumerate(dataloader):
34 # Move data to device
35 images = batch.to(device)
36
37 # Sample timesteps and noise
38 batch_size = images.shape[0]
39 t = torch.randint(0, 1000, (batch_size,), device=device)
40 noise = torch.randn_like(images)
41
42 # Forward pass
43 # (This is simplified - actual diffusion training is more complex)
44 predicted_noise = model(images, t)
45 loss = criterion(predicted_noise, noise)
46
47 # Normalize loss by accumulation steps
48 loss = loss / accumulation_steps
49 loss.backward()
50
51 # Only step optimizer every accumulation_steps
52 if (batch_idx + 1) % accumulation_steps == 0:
53 optimizer.step()
54 optimizer.zero_grad()
55
56 total_loss += loss.item() * accumulation_steps
57 num_batches += 1
58
59 # Handle remaining gradients at end of epoch
60 if (batch_idx + 1) % accumulation_steps != 0:
61 optimizer.step()
62 optimizer.zero_grad()
63
64 return total_loss / num_batches
65
66
67# Example: Train with effective batch size of 128 using 4x32 accumulation
68# train_with_gradient_accumulation(
69# model, dataloader, optimizer, criterion,
70# accumulation_steps=4, # 32 * 4 = 128 effective batch size
71# )Mixed Precision Training
Mixed precision (FP16/BF16) nearly doubles training speed and halves memory usage on modern GPUs:
1import torch
2from torch.cuda.amp import autocast, GradScaler
3
4def train_epoch_mixed_precision(
5 model,
6 dataloader,
7 optimizer,
8 diffusion,
9 device: str = "cuda",
10 scaler: GradScaler = None,
11) -> float:
12 """
13 Train one epoch with mixed precision.
14
15 Mixed precision:
16 - Forward pass uses FP16/BF16 for speed
17 - Gradients are scaled to prevent underflow
18 - Optimizer step uses FP32 for stability
19
20 Args:
21 model: The diffusion model
22 dataloader: Training data loader
23 optimizer: Optimizer
24 diffusion: Diffusion scheduler with loss computation
25 device: Device to train on
26 scaler: GradScaler for mixed precision
27
28 Returns:
29 Average loss for the epoch
30 """
31 if scaler is None:
32 scaler = GradScaler()
33
34 model.train()
35 total_loss = 0.0
36
37 for batch in dataloader:
38 images = batch.to(device)
39 batch_size = images.shape[0]
40
41 # Sample timesteps
42 t = torch.randint(0, 1000, (batch_size,), device=device)
43
44 optimizer.zero_grad()
45
46 # Forward pass with autocast
47 with autocast():
48 # Sample noise and create noisy images
49 noise = torch.randn_like(images)
50 noisy_images = diffusion.add_noise(images, noise, t)
51
52 # Predict noise
53 predicted_noise = model(noisy_images, t)
54
55 # Compute loss
56 loss = torch.nn.functional.mse_loss(predicted_noise, noise)
57
58 # Backward pass with scaled gradients
59 scaler.scale(loss).backward()
60
61 # Unscale and step
62 scaler.step(optimizer)
63 scaler.update()
64
65 total_loss += loss.item()
66
67 return total_loss / len(dataloader)
68
69
70# Initialize training with mixed precision
71scaler = GradScaler()
72
73# Training loop
74for epoch in range(num_epochs):
75 loss = train_epoch_mixed_precision(
76 model, train_loader, optimizer, diffusion,
77 scaler=scaler,
78 )
79 print(f"Epoch {epoch}: Loss = {loss:.4f}")BFloat16 vs Float16
Distributed Training Setup
DistributedDataParallel (DDP)
For multi-GPU training, PyTorch's DistributedDataParallel is the standard approach:
1import torch
2import torch.distributed as dist
3from torch.nn.parallel import DistributedDataParallel as DDP
4from torch.utils.data.distributed import DistributedSampler
5import os
6
7def setup_distributed(rank: int, world_size: int):
8 """
9 Initialize distributed training.
10
11 Args:
12 rank: This process's rank (0 to world_size-1)
13 world_size: Total number of processes
14 """
15 os.environ['MASTER_ADDR'] = 'localhost'
16 os.environ['MASTER_PORT'] = '12355'
17
18 # Initialize process group
19 dist.init_process_group(
20 backend='nccl', # Use NCCL for GPUs
21 rank=rank,
22 world_size=world_size,
23 )
24
25 # Set device for this process
26 torch.cuda.set_device(rank)
27
28
29def cleanup_distributed():
30 """Clean up distributed training."""
31 dist.destroy_process_group()
32
33
34def create_distributed_dataloader(
35 dataset,
36 batch_size: int,
37 num_workers: int = 4,
38) -> DataLoader:
39 """
40 Create a DataLoader for distributed training.
41
42 The DistributedSampler ensures each GPU sees different data.
43 """
44 sampler = DistributedSampler(
45 dataset,
46 shuffle=True,
47 drop_last=True,
48 )
49
50 return DataLoader(
51 dataset,
52 batch_size=batch_size,
53 sampler=sampler, # Use sampler instead of shuffle
54 num_workers=num_workers,
55 pin_memory=True,
56 drop_last=True,
57 )
58
59
60def train_distributed(rank: int, world_size: int, args):
61 """
62 Main training function for each process.
63
64 Args:
65 rank: This process's rank
66 world_size: Total number of processes
67 args: Training arguments
68 """
69 setup_distributed(rank, world_size)
70
71 # Create model and move to GPU
72 model = create_model(args).to(rank)
73
74 # Wrap model with DDP
75 model = DDP(model, device_ids=[rank])
76
77 # Create distributed dataloader
78 dataset = create_dataset(args)
79 dataloader = create_distributed_dataloader(
80 dataset,
81 batch_size=args.batch_size, # Per-GPU batch size
82 )
83
84 # Create optimizer
85 optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
86
87 # Training loop
88 for epoch in range(args.epochs):
89 # IMPORTANT: Set epoch for sampler to shuffle differently each epoch
90 dataloader.sampler.set_epoch(epoch)
91
92 for batch in dataloader:
93 # Training step...
94 pass
95
96 # Only save checkpoint from rank 0
97 if rank == 0:
98 save_checkpoint(model.module, optimizer, epoch)
99
100 cleanup_distributed()
101
102
103# Launch distributed training
104if __name__ == "__main__":
105 import torch.multiprocessing as mp
106
107 world_size = torch.cuda.device_count()
108 mp.spawn(
109 train_distributed,
110 args=(world_size, args),
111 nprocs=world_size,
112 join=True,
113 )Using torchrun for Distributed Training
The modern way to launch distributed training is with :
1# Single node, 4 GPUs
2torchrun --nproc_per_node=4 train.py --batch_size=32
3
4# Multi-node (2 nodes, 8 GPUs each)
5# On node 0:
6torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 \
7 --master_addr="192.168.1.1" --master_port=12355 train.py
8
9# On node 1:
10torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 \
11 --master_addr="192.168.1.1" --master_port=12355 train.py1# In train.py, use environment variables set by torchrun
2import os
3import torch.distributed as dist
4
5def setup_from_env():
6 """
7 Setup distributed training from torchrun environment variables.
8 """
9 dist.init_process_group(backend='nccl')
10
11 local_rank = int(os.environ['LOCAL_RANK'])
12 torch.cuda.set_device(local_rank)
13
14 return local_rank
15
16
17# Training script compatible with torchrun
18def main():
19 local_rank = setup_from_env()
20 world_size = dist.get_world_size()
21
22 print(f"Running on rank {local_rank} of {world_size}")
23
24 # ... rest of training code ...
25
26if __name__ == "__main__":
27 main()Efficient Data Loading
Prefetching and Pipelining
The DataLoader automatically prefetches data, but you can optimize further:
1class PrefetchLoader:
2 """
3 Wrapper that prefetches data to GPU asynchronously.
4
5 This overlaps CPU->GPU transfer with GPU computation.
6 """
7
8 def __init__(self, loader, device: str = "cuda"):
9 self.loader = loader
10 self.device = device
11 self.stream = torch.cuda.Stream()
12
13 def __iter__(self):
14 first = True
15 batch = None
16
17 for next_batch in self.loader:
18 with torch.cuda.stream(self.stream):
19 if isinstance(next_batch, torch.Tensor):
20 next_batch = next_batch.to(
21 self.device, non_blocking=True
22 )
23 elif isinstance(next_batch, (list, tuple)):
24 next_batch = [
25 b.to(self.device, non_blocking=True)
26 if isinstance(b, torch.Tensor) else b
27 for b in next_batch
28 ]
29
30 if not first:
31 yield batch
32 else:
33 first = False
34
35 torch.cuda.current_stream().wait_stream(self.stream)
36 batch = next_batch
37
38 yield batch
39
40 def __len__(self):
41 return len(self.loader)
42
43
44# Usage
45train_loader = DataLoader(dataset, batch_size=32, num_workers=4)
46prefetch_loader = PrefetchLoader(train_loader, device="cuda")
47
48for batch in prefetch_loader:
49 # batch is already on GPU
50 # ... training step ...
51 passMemory-Mapped Datasets
For very large datasets, memory-mapping can reduce memory usage:
1import numpy as np
2from torch.utils.data import Dataset
3import torch
4
5class MemoryMappedDataset(Dataset):
6 """
7 Dataset that uses memory-mapped numpy arrays.
8
9 Memory mapping lets you work with datasets larger than RAM
10 by loading only the accessed portions into memory.
11 """
12
13 def __init__(
14 self,
15 data_path: str,
16 transform=None,
17 ):
18 """
19 Args:
20 data_path: Path to .npy file with shape [N, C, H, W]
21 transform: Optional transform to apply
22 """
23 # Memory-map the array (not loaded into RAM)
24 self.data = np.load(data_path, mmap_mode='r')
25 self.transform = transform
26
27 def __len__(self):
28 return len(self.data)
29
30 def __getitem__(self, idx):
31 # Only this sample is loaded into memory
32 image = self.data[idx].copy() # copy() to avoid modifying mmap
33
34 # Convert to tensor
35 image = torch.from_numpy(image).float()
36
37 # Normalize to [-1, 1] if not already
38 if image.max() > 1:
39 image = image / 127.5 - 1.0
40
41 if self.transform:
42 image = self.transform(image)
43
44 return image
45
46
47# Create the memory-mapped file (do this once)
48def create_memmap_dataset(image_folder: str, output_path: str, image_size: int):
49 """Preprocess images and save as memory-mapped numpy array."""
50 from PIL import Image
51 from pathlib import Path
52 from torchvision import transforms
53
54 transform = transforms.Compose([
55 transforms.Resize(image_size),
56 transforms.CenterCrop(image_size),
57 transforms.ToTensor(),
58 ])
59
60 # Find all images
61 image_paths = list(Path(image_folder).rglob("*.jpg"))
62 n_images = len(image_paths)
63
64 # Create memory-mapped output
65 shape = (n_images, 3, image_size, image_size)
66 mmap = np.lib.format.open_memmap(
67 output_path, mode='w+', dtype=np.float32, shape=shape
68 )
69
70 # Process and save images
71 for i, path in enumerate(image_paths):
72 image = Image.open(path).convert('RGB')
73 tensor = transform(image)
74 mmap[i] = tensor.numpy()
75
76 if i % 1000 == 0:
77 print(f"Processed {i}/{n_images} images")
78
79 mmap.flush()
80 print(f"Saved to {output_path}")Complete Training Pipeline
Here's a complete, production-ready training pipeline that incorporates all the techniques we've discussed:
1import torch
2import torch.nn as nn
3import torch.distributed as dist
4from torch.nn.parallel import DistributedDataParallel as DDP
5from torch.utils.data import DataLoader
6from torch.utils.data.distributed import DistributedSampler
7from torch.cuda.amp import autocast, GradScaler
8from dataclasses import dataclass
9from typing import Optional
10import os
11
12
13@dataclass
14class TrainingConfig:
15 """Training configuration."""
16 # Data
17 data_dir: str = "./data"
18 image_size: int = 64
19 channels: int = 3
20
21 # Training
22 batch_size: int = 32 # Per-GPU batch size
23 accumulation_steps: int = 1 # Gradient accumulation
24 num_epochs: int = 100
25 learning_rate: float = 1e-4
26 warmup_steps: int = 1000
27
28 # Model
29 timesteps: int = 1000
30
31 # Optimization
32 use_amp: bool = True # Mixed precision
33 num_workers: int = 4
34
35 # Distributed
36 distributed: bool = False
37
38 # Checkpointing
39 checkpoint_dir: str = "./checkpoints"
40 save_every: int = 10 # Save every N epochs
41
42
43class TrainingPipeline:
44 """
45 Complete training pipeline for diffusion models.
46
47 Features:
48 - Distributed training support
49 - Mixed precision training
50 - Gradient accumulation
51 - Automatic checkpointing
52 - Learning rate warmup
53 """
54
55 def __init__(self, config: TrainingConfig):
56 self.config = config
57 self.device = self._setup_device()
58 self.scaler = GradScaler() if config.use_amp else None
59
60 # Will be set in setup()
61 self.model = None
62 self.optimizer = None
63 self.scheduler = None
64 self.train_loader = None
65 self.diffusion = None
66
67 def _setup_device(self) -> torch.device:
68 """Setup device and distributed training if needed."""
69 if self.config.distributed:
70 dist.init_process_group(backend='nccl')
71 local_rank = int(os.environ.get('LOCAL_RANK', 0))
72 torch.cuda.set_device(local_rank)
73 return torch.device(f'cuda:{local_rank}')
74 elif torch.cuda.is_available():
75 return torch.device('cuda')
76 else:
77 return torch.device('cpu')
78
79 @property
80 def is_main_process(self) -> bool:
81 """Check if this is the main process (for logging/saving)."""
82 if self.config.distributed:
83 return dist.get_rank() == 0
84 return True
85
86 @property
87 def world_size(self) -> int:
88 """Get number of processes."""
89 if self.config.distributed:
90 return dist.get_world_size()
91 return 1
92
93 def setup(self, model, dataset, diffusion):
94 """
95 Setup training components.
96
97 Args:
98 model: The diffusion model (nn.Module)
99 dataset: Training dataset
100 diffusion: Diffusion scheduler
101 """
102 self.diffusion = diffusion
103
104 # Move model to device
105 self.model = model.to(self.device)
106
107 # Wrap in DDP if distributed
108 if self.config.distributed:
109 self.model = DDP(
110 self.model,
111 device_ids=[self.device.index],
112 output_device=self.device.index,
113 )
114
115 # Create dataloader
116 if self.config.distributed:
117 sampler = DistributedSampler(dataset, shuffle=True)
118 self.train_loader = DataLoader(
119 dataset,
120 batch_size=self.config.batch_size,
121 sampler=sampler,
122 num_workers=self.config.num_workers,
123 pin_memory=True,
124 drop_last=True,
125 )
126 else:
127 self.train_loader = DataLoader(
128 dataset,
129 batch_size=self.config.batch_size,
130 shuffle=True,
131 num_workers=self.config.num_workers,
132 pin_memory=True,
133 drop_last=True,
134 persistent_workers=self.config.num_workers > 0,
135 )
136
137 # Create optimizer
138 self.optimizer = torch.optim.AdamW(
139 self.model.parameters(),
140 lr=self.config.learning_rate,
141 betas=(0.9, 0.999),
142 weight_decay=0.01,
143 )
144
145 # Create scheduler with warmup
146 total_steps = (
147 len(self.train_loader) *
148 self.config.num_epochs //
149 self.config.accumulation_steps
150 )
151 self.scheduler = self._create_scheduler(total_steps)
152
153 def _create_scheduler(self, total_steps: int):
154 """Create learning rate scheduler with warmup."""
155 from torch.optim.lr_scheduler import LambdaLR
156
157 warmup_steps = self.config.warmup_steps
158
159 def lr_lambda(step):
160 if step < warmup_steps:
161 return step / warmup_steps
162 return 1.0 # Constant after warmup
163
164 return LambdaLR(self.optimizer, lr_lambda)
165
166 def train_epoch(self, epoch: int) -> float:
167 """Train for one epoch."""
168 self.model.train()
169
170 if self.config.distributed:
171 self.train_loader.sampler.set_epoch(epoch)
172
173 total_loss = 0.0
174 num_steps = 0
175
176 for batch_idx, batch in enumerate(self.train_loader):
177 loss = self._train_step(batch, batch_idx)
178 total_loss += loss
179 num_steps += 1
180
181 return total_loss / num_steps
182
183 def _train_step(self, batch, batch_idx: int) -> float:
184 """Execute one training step."""
185 images = batch.to(self.device)
186 batch_size = images.shape[0]
187
188 # Sample timesteps
189 t = torch.randint(
190 0, self.config.timesteps, (batch_size,),
191 device=self.device
192 )
193
194 # Sample noise
195 noise = torch.randn_like(images)
196
197 # Add noise to images
198 noisy_images = self.diffusion.add_noise(images, noise, t)
199
200 # Forward pass with optional autocast
201 with autocast(enabled=self.config.use_amp):
202 predicted_noise = self.model(noisy_images, t)
203 loss = nn.functional.mse_loss(predicted_noise, noise)
204 loss = loss / self.config.accumulation_steps
205
206 # Backward pass
207 if self.scaler is not None:
208 self.scaler.scale(loss).backward()
209 else:
210 loss.backward()
211
212 # Step optimizer (with accumulation)
213 if (batch_idx + 1) % self.config.accumulation_steps == 0:
214 if self.scaler is not None:
215 self.scaler.step(self.optimizer)
216 self.scaler.update()
217 else:
218 self.optimizer.step()
219
220 self.optimizer.zero_grad()
221 self.scheduler.step()
222
223 return loss.item() * self.config.accumulation_steps
224
225 def save_checkpoint(self, epoch: int, loss: float):
226 """Save training checkpoint."""
227 if not self.is_main_process:
228 return
229
230 os.makedirs(self.config.checkpoint_dir, exist_ok=True)
231
232 # Get model state (unwrap DDP if needed)
233 model_state = (
234 self.model.module.state_dict()
235 if self.config.distributed
236 else self.model.state_dict()
237 )
238
239 checkpoint = {
240 'epoch': epoch,
241 'model_state_dict': model_state,
242 'optimizer_state_dict': self.optimizer.state_dict(),
243 'scheduler_state_dict': self.scheduler.state_dict(),
244 'loss': loss,
245 'config': self.config,
246 }
247
248 if self.scaler is not None:
249 checkpoint['scaler_state_dict'] = self.scaler.state_dict()
250
251 path = os.path.join(
252 self.config.checkpoint_dir,
253 f'checkpoint_epoch_{epoch:04d}.pt'
254 )
255 torch.save(checkpoint, path)
256
257 if self.is_main_process:
258 print(f"Saved checkpoint to {path}")
259
260 def train(self):
261 """Main training loop."""
262 for epoch in range(self.config.num_epochs):
263 loss = self.train_epoch(epoch)
264
265 if self.is_main_process:
266 print(f"Epoch {epoch+1}/{self.config.num_epochs}, Loss: {loss:.4f}")
267
268 if (epoch + 1) % self.config.save_every == 0:
269 self.save_checkpoint(epoch + 1, loss)
270
271 # Final checkpoint
272 self.save_checkpoint(self.config.num_epochs, loss)
273
274 if self.config.distributed:
275 dist.destroy_process_group()
276
277
278# Usage example
279if __name__ == "__main__":
280 config = TrainingConfig(
281 data_dir="./data/cifar10",
282 image_size=32,
283 batch_size=64,
284 num_epochs=100,
285 )
286
287 pipeline = TrainingPipeline(config)
288
289 # Create your model, dataset, and diffusion scheduler
290 # model = create_model(...)
291 # dataset = create_dataset(...)
292 # diffusion = create_diffusion(...)
293
294 # pipeline.setup(model, dataset, diffusion)
295 # pipeline.train()Key Takeaways
- Optimize DataLoader: Use , multiple workers, and persistent workers for maximum GPU utilization.
- Use mixed precision: FP16/BF16 nearly doubles throughput and halves memory usage with minimal quality impact.
- Enable gradient accumulation: Simulate larger batch sizes when GPU memory is limited by accumulating gradients.
- Scale with DDP: Use DistributedDataParallel and torchrun for multi-GPU training with near-linear scaling.
- Prefetch data: Overlap data loading with GPU computation to eliminate I/O bottlenecks.
Looking Ahead: In the next chapter, we'll configure the model architecture, including channel multipliers, attention resolutions, and other hyperparameters that determine model capacity and training cost.