Introduction
Training transformers can take hours or days. Checkpointing allows you to save progress, recover from failures, and select the best model. This section covers comprehensive checkpoint management strategies.
What to Save in a Checkpoint
Essential Components
πtext
1A complete checkpoint contains:
2
3βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
4β CHECKPOINT FILE β
5βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
6β model_state_dict β All model weights β
7β optimizer_state_dict β Optimizer momentum, etc. β
8β scheduler_state_dict β LR scheduler state β
9β epoch β Current epoch number β
10β global_step β Total training steps β
11β best_metric β Best validation score β
12β config β Model/training configuration β
13β rng_states β Random number generator states β
14β scaler_state_dict β Mixed precision scaler (optional) β
15β metrics_history β Training/validation metrics β
16βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββWhy Each Component Matters
πtext
1MODEL STATE DICT:
2βββββββββββββββββ
3Contains all learnable parameters (weights, biases).
4Required to restore model for inference or continued training.
5
6Example keys:
7 'encoder.layers.0.self_attn.W_q.weight'
8 'decoder.embedding.weight'
9 'output_projection.bias'
10
11
12OPTIMIZER STATE DICT:
13βββββββββββββββββββββ
14Contains momentum buffers (Adam: m, v for each parameter).
15Without this, training restarts from "cold" optimizer.
16
17For Adam, this can be 2x the model size!
18
19Example:
20 state[param_id] = {'step': 1000, 'm': tensor, 'v': tensor}
21
22
23SCHEDULER STATE DICT:
24βββββββββββββββββββββ
25Contains current step count for learning rate.
26Without this, LR schedule restarts from beginning.
27
28Example:
29 {'_step_count': 10000, 'last_epoch': 50}
30
31
32RNG STATES:
33ββββββββββ
34Random states for Python, NumPy, PyTorch (CPU & GPU).
35Ensures exact reproducibility when resuming.
36
37Components:
38 - random.getstate()
39 - np.random.get_state()
40 - torch.get_rng_state()
41 - torch.cuda.get_rng_state_all() (if GPU)
42
43
44METRICS HISTORY:
45ββββββββββββββββ
46Training and validation metrics over time.
47Useful for plotting learning curves.
48
49Example:
50 {'train_loss': [2.5, 2.1, 1.8, ...],
51 'val_loss': [2.4, 2.0, 1.7, ...],
52 'val_bleu': [5.0, 12.0, 18.0, ...]}Checkpoint Manager Implementation
Complete CheckpointManager Class
πpython
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from dataclasses import dataclass, asdict
5from typing import Dict, Any, Optional, List
6from pathlib import Path
7import json
8import shutil
9import random
10import numpy as np
11from datetime import datetime
12
13
14@dataclass
15class CheckpointConfig:
16 """Configuration for checkpoint management."""
17 checkpoint_dir: str = "checkpoints"
18 save_total_limit: int = 5 # Keep N most recent
19 save_best_only: bool = False # Only save if best
20 metric_name: str = "val_loss" # Metric to track
21 metric_mode: str = "min" # "min" or "max"
22 save_optimizer: bool = True # Include optimizer state
23 save_rng_state: bool = True # Include RNG states
24
25
26class CheckpointManager:
27 """
28 Comprehensive checkpoint management for transformer training.
29
30 Features:
31 - Automatic best model tracking
32 - Checkpoint rotation (keep N most recent)
33 - Safe atomic saves (prevents corruption)
34 - Easy resume from latest or best
35
36 Args:
37 config: CheckpointConfig instance
38 model: The model to checkpoint
39 optimizer: Optimizer instance
40 scheduler: Learning rate scheduler
41 scaler: GradScaler for mixed precision (optional)
42
43 Example:
44 >>> manager = CheckpointManager(config, model, optimizer, scheduler)
45 >>> manager.save(epoch=10, step=50000, metrics={'val_loss': 1.5})
46 >>> manager.load_best()
47 """
48
49 def __init__(
50 self,
51 config: CheckpointConfig,
52 model: nn.Module,
53 optimizer: optim.Optimizer,
54 scheduler: Any,
55 scaler: Optional[Any] = None
56 ):
57 self.config = config
58 self.model = model
59 self.optimizer = optimizer
60 self.scheduler = scheduler
61 self.scaler = scaler
62
63 # Create checkpoint directory
64 self.checkpoint_dir = Path(config.checkpoint_dir)
65 self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
66
67 # Tracking
68 self.best_metric = float('inf') if config.metric_mode == 'min' else float('-inf')
69 self.checkpoints: List[Path] = []
70 self.metrics_history: Dict[str, List[float]] = {}
71
72 # Load existing checkpoints info
73 self._scan_existing_checkpoints()
74
75 def _scan_existing_checkpoints(self):
76 """Scan directory for existing checkpoints."""
77 pattern = "checkpoint_*.pt"
78 existing = sorted(
79 self.checkpoint_dir.glob(pattern),
80 key=lambda p: p.stat().st_mtime
81 )
82 self.checkpoints = list(existing)
83
84 def _is_better(self, metric: float) -> bool:
85 """Check if metric is better than best."""
86 if self.config.metric_mode == 'min':
87 return metric < self.best_metric
88 return metric > self.best_metric
89
90 def _get_rng_states(self) -> Dict[str, Any]:
91 """Capture all random states."""
92 states = {
93 'python': random.getstate(),
94 'numpy': np.random.get_state(),
95 'torch': torch.get_rng_state(),
96 }
97 if torch.cuda.is_available():
98 states['cuda'] = torch.cuda.get_rng_state_all()
99 return states
100
101 def _set_rng_states(self, states: Dict[str, Any]):
102 """Restore all random states."""
103 random.setstate(states['python'])
104 np.random.set_state(states['numpy'])
105 torch.set_rng_state(states['torch'])
106 if 'cuda' in states and torch.cuda.is_available():
107 torch.cuda.set_rng_state_all(states['cuda'])
108
109 def save(
110 self,
111 epoch: int,
112 step: int,
113 metrics: Dict[str, float],
114 model_config: Optional[Dict[str, Any]] = None,
115 extra_state: Optional[Dict[str, Any]] = None
116 ) -> Optional[Path]:
117 """
118 Save a checkpoint.
119
120 Args:
121 epoch: Current epoch number
122 step: Global step count
123 metrics: Dictionary of metrics (must include metric_name)
124 model_config: Model configuration dictionary
125 extra_state: Any additional state to save
126
127 Returns:
128 Path to saved checkpoint, or None if not saved
129 """
130 current_metric = metrics.get(self.config.metric_name, 0)
131 is_best = self._is_better(current_metric)
132
133 # Check if we should save
134 if self.config.save_best_only and not is_best:
135 return None
136
137 # Update best metric
138 if is_best:
139 self.best_metric = current_metric
140
141 # Update metrics history
142 for key, value in metrics.items():
143 if key not in self.metrics_history:
144 self.metrics_history[key] = []
145 self.metrics_history[key].append(value)
146
147 # Build checkpoint
148 checkpoint = {
149 'epoch': epoch,
150 'global_step': step,
151 'model_state_dict': self.model.state_dict(),
152 'best_metric': self.best_metric,
153 'metrics': metrics,
154 'metrics_history': self.metrics_history,
155 'timestamp': datetime.now().isoformat(),
156 }
157
158 # Optional components
159 if self.config.save_optimizer:
160 checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()
161 if self.scheduler is not None:
162 checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
163 if self.scaler is not None:
164 checkpoint['scaler_state_dict'] = self.scaler.state_dict()
165
166 if self.config.save_rng_state:
167 checkpoint['rng_states'] = self._get_rng_states()
168
169 if model_config is not None:
170 checkpoint['model_config'] = model_config
171
172 if extra_state is not None:
173 checkpoint['extra_state'] = extra_state
174
175 # Generate filename
176 filename = f"checkpoint_epoch{epoch:04d}_step{step:08d}.pt"
177 filepath = self.checkpoint_dir / filename
178
179 # Safe atomic save (write to temp, then rename)
180 temp_path = filepath.with_suffix('.pt.tmp')
181 torch.save(checkpoint, temp_path)
182 temp_path.rename(filepath)
183
184 self.checkpoints.append(filepath)
185
186 # Save best model separately
187 if is_best:
188 best_path = self.checkpoint_dir / "best_model.pt"
189 shutil.copy(filepath, best_path)
190 print(f" New best model! {self.config.metric_name}: {current_metric:.4f}")
191
192 # Rotate checkpoints (keep N most recent)
193 self._rotate_checkpoints()
194
195 return filepath
196
197 def _rotate_checkpoints(self):
198 """Remove old checkpoints beyond save_total_limit."""
199 while len(self.checkpoints) > self.config.save_total_limit:
200 oldest = self.checkpoints.pop(0)
201 if oldest.exists():
202 oldest.unlink()
203
204 def load(
205 self,
206 checkpoint_path: Optional[Path] = None,
207 load_optimizer: bool = True,
208 load_rng_state: bool = True
209 ) -> Dict[str, Any]:
210 """
211 Load a checkpoint.
212
213 Args:
214 checkpoint_path: Specific checkpoint to load (None = latest)
215 load_optimizer: Whether to load optimizer state
216 load_rng_state: Whether to load RNG states
217
218 Returns:
219 Checkpoint dictionary with metadata
220 """
221 if checkpoint_path is None:
222 checkpoint_path = self.get_latest_checkpoint()
223
224 if checkpoint_path is None or not checkpoint_path.exists():
225 raise FileNotFoundError(f"No checkpoint found at {checkpoint_path}")
226
227 print(f"Loading checkpoint: {checkpoint_path}")
228 checkpoint = torch.load(checkpoint_path, map_location='cpu')
229
230 # Load model
231 self.model.load_state_dict(checkpoint['model_state_dict'])
232
233 # Load optimizer state
234 if load_optimizer and 'optimizer_state_dict' in checkpoint:
235 self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
236
237 if self.scheduler is not None and 'scheduler_state_dict' in checkpoint:
238 self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
239
240 if self.scaler is not None and 'scaler_state_dict' in checkpoint:
241 self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
242
243 # Load RNG states
244 if load_rng_state and 'rng_states' in checkpoint:
245 self._set_rng_states(checkpoint['rng_states'])
246
247 # Restore tracking
248 self.best_metric = checkpoint.get('best_metric', self.best_metric)
249 self.metrics_history = checkpoint.get('metrics_history', {})
250
251 return checkpoint
252
253 def load_best(self, **kwargs) -> Dict[str, Any]:
254 """Load the best checkpoint."""
255 best_path = self.checkpoint_dir / "best_model.pt"
256 return self.load(best_path, **kwargs)
257
258 def get_latest_checkpoint(self) -> Optional[Path]:
259 """Get path to most recent checkpoint."""
260 if not self.checkpoints:
261 self._scan_existing_checkpoints()
262 return self.checkpoints[-1] if self.checkpoints else None
263
264 def get_best_checkpoint(self) -> Optional[Path]:
265 """Get path to best checkpoint."""
266 best_path = self.checkpoint_dir / "best_model.pt"
267 return best_path if best_path.exists() else NoneModel Selection Strategies
Choosing the Best Checkpoint
πtext
1STRATEGY COMPARISON:
2ββββββββββββββββββββ
3
41. BEST VALIDATION LOSS:
5 Pros: Simple, reliable
6 Cons: May overfit to validation set
7 Use: Default choice for most tasks
8
92. BEST BLEU SCORE:
10 Pros: Direct optimization for task metric
11 Cons: BLEU is noisy, may select overfitted model
12 Use: When BLEU is your primary metric
13
143. CHECKPOINT AVERAGING:
15 Pros: Reduces variance, often improves by 0.5-1 BLEU
16 Cons: More complex, requires storing multiple checkpoints
17 Use: For final model submission/deployment
18
194. EXPONENTIAL MOVING AVERAGE (EMA):
20 Pros: Smooth weights during training
21 Cons: Requires maintaining EMA copy
22 Use: When training is noisy
23
245. ENSEMBLE:
25 Pros: Best quality, reduces variance
26 Cons: Slower inference (multiple forward passes)
27 Use: When quality >> speed
28
29
30RECOMMENDED WORKFLOW:
31ββββββββββββββββββββ
32
331. During training:
34 - Save checkpoints every N steps
35 - Track validation loss AND BLEU
36
372. After training:
38 - Find best single checkpoint by BLEU
39 - Average top-5 checkpoints by validation loss
40 - Compare both on test set
41
423. For deployment:
43 - Usually: averaged checkpoint
44 - If latency critical: single best checkpointCheckpoint Averaging Implementation
πpython
1class ModelSelector:
2 """
3 Strategies for selecting the best model from training.
4 """
5
6 @staticmethod
7 def average_checkpoints(
8 checkpoints: List[Path],
9 model: nn.Module
10 ) -> Dict[str, torch.Tensor]:
11 """
12 Average weights from multiple checkpoints.
13
14 This often improves generalization.
15
16 Args:
17 checkpoints: List of checkpoint paths
18 model: Model for architecture reference
19
20 Returns:
21 Averaged state dict
22 """
23 print(f"Averaging {len(checkpoints)} checkpoints...")
24
25 # Load all state dicts
26 state_dicts = []
27 for cp_path in checkpoints:
28 checkpoint = torch.load(cp_path, map_location='cpu')
29 state_dicts.append(checkpoint['model_state_dict'])
30
31 # Average weights
32 averaged_state = {}
33 for key in state_dicts[0].keys():
34 tensors = [sd[key].float() for sd in state_dicts]
35 averaged_state[key] = torch.stack(tensors).mean(dim=0)
36
37 return averaged_state
38
39 @staticmethod
40 def select_top_k_by_metric(
41 checkpoints: List[Path],
42 metric_name: str = "val_loss",
43 mode: str = "min",
44 k: int = 5
45 ) -> List[Path]:
46 """
47 Select top K checkpoints by metric.
48
49 Args:
50 checkpoints: List of checkpoint paths
51 metric_name: Metric to use
52 mode: "min" or "max"
53 k: Number to select
54
55 Returns:
56 List of top K checkpoint paths
57 """
58 checkpoint_metrics = []
59
60 for cp_path in checkpoints:
61 checkpoint = torch.load(cp_path, map_location='cpu')
62 metrics = checkpoint.get('metrics', {})
63 value = metrics.get(metric_name, float('inf') if mode == 'min' else float('-inf'))
64 checkpoint_metrics.append((cp_path, value))
65
66 # Sort by metric
67 reverse = (mode == 'max')
68 sorted_cps = sorted(checkpoint_metrics, key=lambda x: x[1], reverse=reverse)
69
70 return [cp[0] for cp in sorted_cps[:k]]Exponential Moving Average (EMA)
EMA for Smoother Models
πpython
1class EMAModel:
2 """
3 Exponential Moving Average of model weights.
4
5 Maintains a shadow copy of weights that is a moving average
6 of training weights. Often produces better final models.
7
8 EMA update: shadow = decay * shadow + (1 - decay) * current
9
10 Args:
11 model: Model to track
12 decay: EMA decay rate (0.999 - 0.9999 typical)
13 device: Device for EMA weights
14
15 Example:
16 >>> ema = EMAModel(model, decay=0.9999)
17 >>> for batch in dataloader:
18 ... loss = train_step(model, batch)
19 ... ema.update()
20 >>> ema.apply() # Use EMA weights for inference
21 """
22
23 def __init__(
24 self,
25 model: nn.Module,
26 decay: float = 0.9999,
27 device: Optional[torch.device] = None
28 ):
29 self.model = model
30 self.decay = decay
31 self.device = device
32
33 # Create shadow parameters
34 self.shadow = {}
35 self.backup = {}
36
37 for name, param in model.named_parameters():
38 if param.requires_grad:
39 self.shadow[name] = param.data.clone()
40 if device is not None:
41 self.shadow[name] = self.shadow[name].to(device)
42
43 def update(self):
44 """Update shadow weights with current model weights."""
45 for name, param in self.model.named_parameters():
46 if param.requires_grad and name in self.shadow:
47 # EMA update
48 self.shadow[name] = (
49 self.decay * self.shadow[name] +
50 (1 - self.decay) * param.data
51 )
52
53 def apply(self):
54 """Apply EMA weights to model (backup current weights)."""
55 for name, param in self.model.named_parameters():
56 if param.requires_grad and name in self.shadow:
57 self.backup[name] = param.data.clone()
58 param.data.copy_(self.shadow[name])
59
60 def restore(self):
61 """Restore original weights (undo apply)."""
62 for name, param in self.model.named_parameters():
63 if param.requires_grad and name in self.backup:
64 param.data.copy_(self.backup[name])
65 self.backup = {}
66
67 def state_dict(self) -> Dict[str, torch.Tensor]:
68 """Get EMA state dict for checkpointing."""
69 return {
70 'shadow': self.shadow.copy(),
71 'decay': self.decay
72 }
73
74 def load_state_dict(self, state_dict: Dict[str, Any]):
75 """Load EMA state from checkpoint."""
76 self.shadow = state_dict['shadow']
77 self.decay = state_dict.get('decay', self.decay)Resume Training Protocol
Complete Resume Implementation
πtext
1RESUME CHECKLIST:
2βββββββββββββββββ
3
4β Load model state dict
5β Load optimizer state dict (critical for Adam momentum!)
6β Load scheduler state dict (for correct LR)
7β Load RNG states (for reproducibility)
8β Load metrics history (for plotting)
9β Verify epoch and step counts
10β Confirm batch ordering (same shuffle seed)πpython
1def resume_training(
2 checkpoint_path: str,
3 model: nn.Module,
4 optimizer: optim.Optimizer,
5 scheduler: Any,
6 scaler: Optional[Any] = None
7) -> Dict[str, Any]:
8 """
9 Resume training from checkpoint.
10
11 Returns:
12 Dictionary with resume state (epoch, step, etc.)
13 """
14 print(f"Resuming from {checkpoint_path}")
15
16 # Load checkpoint
17 checkpoint = torch.load(checkpoint_path, map_location='cpu')
18
19 # Restore model
20 model.load_state_dict(checkpoint['model_state_dict'])
21
22 # Restore optimizer (CRITICAL for Adam!)
23 if 'optimizer_state_dict' in checkpoint:
24 optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
25
26 # Move optimizer state to correct device
27 device = next(model.parameters()).device
28 for state in optimizer.state.values():
29 for k, v in state.items():
30 if isinstance(v, torch.Tensor):
31 state[k] = v.to(device)
32
33 # Restore scheduler
34 if scheduler is not None and 'scheduler_state_dict' in checkpoint:
35 scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
36
37 # Restore mixed precision scaler
38 if scaler is not None and 'scaler_state_dict' in checkpoint:
39 scaler.load_state_dict(checkpoint['scaler_state_dict'])
40
41 # Restore RNG states for exact reproducibility
42 if 'rng_states' in checkpoint:
43 rng_states = checkpoint['rng_states']
44 random.setstate(rng_states['python'])
45 np.random.set_state(rng_states['numpy'])
46 torch.set_rng_state(rng_states['torch'])
47 if 'cuda' in rng_states and torch.cuda.is_available():
48 torch.cuda.set_rng_state_all(rng_states['cuda'])
49
50 resume_state = {
51 'epoch': checkpoint['epoch'],
52 'global_step': checkpoint['global_step'],
53 'best_metric': checkpoint.get('best_metric'),
54 'metrics_history': checkpoint.get('metrics_history', {})
55 }
56
57 print(f" Resumed at epoch {resume_state['epoch']}, "
58 f"step {resume_state['global_step']}")
59
60 return resume_stateExport for Deployment
Preparing Models for Production
πtext
1DEPLOYMENT FORMATS:
2βββββββββββββββββββ
3
41. PYTORCH CHECKPOINT (inference-only):
5 Pros: Native PyTorch, easy to load
6 Cons: Requires PyTorch runtime
7 Use: Python deployment, research
8
9 torch.save({'model_state_dict': model.state_dict()}, 'model.pt')
10
112. TORCHSCRIPT:
12 Pros: No Python dependency, optimized
13 Cons: Some ops not supported
14 Use: Production, mobile, C++ deployment
15
16 traced = torch.jit.trace(model, inputs)
17 traced.save('model.pt')
18
193. ONNX:
20 Pros: Framework agnostic, wide support
21 Cons: May lose some functionality
22 Use: Cross-platform, TensorRT, ONNX Runtime
23
24 torch.onnx.export(model, inputs, 'model.onnx')
25
26
27DEPLOYMENT CHECKLIST:
28βββββββββββββββββββββ
29
30β Remove dropout (model.eval())
31β Remove unused layers
32β Benchmark latency
33β Verify output quality
34β Test edge cases
35β Document input/output format
36β Version the model
37β Set up monitoringExport Inference Checkpoint
πpython
1class ModelExporter:
2 """
3 Export trained models for deployment.
4 """
5
6 @staticmethod
7 def export_inference_checkpoint(
8 checkpoint_path: str,
9 output_path: str,
10 include_config: bool = True
11 ):
12 """
13 Export minimal checkpoint for inference only.
14
15 Removes optimizer state, RNG states, and other training data.
16 Much smaller file size.
17
18 Args:
19 checkpoint_path: Full training checkpoint
20 output_path: Output path for inference checkpoint
21 include_config: Whether to include model config
22 """
23 checkpoint = torch.load(checkpoint_path, map_location='cpu')
24
25 inference_checkpoint = {
26 'model_state_dict': checkpoint['model_state_dict'],
27 }
28
29 if include_config and 'model_config' in checkpoint:
30 inference_checkpoint['model_config'] = checkpoint['model_config']
31
32 torch.save(inference_checkpoint, output_path)
33
34 # Report size reduction
35 import os
36 original_size = os.path.getsize(checkpoint_path) / 1024 / 1024
37 new_size = os.path.getsize(output_path) / 1024 / 1024
38
39 print(f"Exported inference checkpoint:")
40 print(f" Original: {original_size:.1f} MB")
41 print(f" Inference: {new_size:.1f} MB")
42 print(f" Reduction: {(1 - new_size/original_size)*100:.1f}%")Summary
Checkpointing Key Points
| Component | Purpose | Required for Resume |
|---|---|---|
| model_state_dict | Model weights | Yes |
| optimizer_state_dict | Momentum buffers | Yes (for training) |
| scheduler_state_dict | LR schedule position | Yes (for training) |
| rng_states | Reproducibility | Optional |
| metrics_history | Learning curves | Optional |
Model Selection Strategies
| Strategy | Best For | Quality Gain |
|---|---|---|
| Best validation loss | Default choice | Baseline |
| Best BLEU | Task-specific | Variable |
| Checkpoint averaging | Production | +0.5-1 BLEU |
| EMA | Noisy training | Smoother model |
Best Practices
- Always save optimizer state for training continuation
- Use atomic saves to prevent corruption
- Rotate checkpoints to manage disk space
- Average checkpoints for final model
- Export inference-only checkpoint for deployment
Exercises
Implementation
- Implement cyclical checkpoint saving (save every K steps, but also save at specific loss thresholds).
- Add support for distributed training checkpoints (gather from all ranks).
- Implement checkpoint comparison tool (diff two checkpoints).
Analysis
- Compare single best vs averaged checkpoint on test set.
- Experiment with different EMA decay values.
In the next chapter, we'll cover Evaluation Metricsβspecifically BLEU score computation and other metrics for measuring translation quality.