Chapter 10
22 min read
Section 54 of 75

Checkpointing and Model Selection

Training Pipeline

Introduction

Training transformers can take hours or days. Checkpointing allows you to save progress, recover from failures, and select the best model. This section covers comprehensive checkpoint management strategies.


What to Save in a Checkpoint

Essential Components

πŸ“text
1A complete checkpoint contains:
2
3β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4β”‚                     CHECKPOINT FILE                          β”‚
5β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
6β”‚  model_state_dict      β”‚ All model weights                  β”‚
7β”‚  optimizer_state_dict  β”‚ Optimizer momentum, etc.           β”‚
8β”‚  scheduler_state_dict  β”‚ LR scheduler state                 β”‚
9β”‚  epoch                 β”‚ Current epoch number               β”‚
10β”‚  global_step           β”‚ Total training steps               β”‚
11β”‚  best_metric           β”‚ Best validation score              β”‚
12β”‚  config                β”‚ Model/training configuration       β”‚
13β”‚  rng_states            β”‚ Random number generator states     β”‚
14β”‚  scaler_state_dict     β”‚ Mixed precision scaler (optional)  β”‚
15β”‚  metrics_history       β”‚ Training/validation metrics        β”‚
16β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Why Each Component Matters

πŸ“text
1MODEL STATE DICT:
2─────────────────
3Contains all learnable parameters (weights, biases).
4Required to restore model for inference or continued training.
5
6Example keys:
7  'encoder.layers.0.self_attn.W_q.weight'
8  'decoder.embedding.weight'
9  'output_projection.bias'
10
11
12OPTIMIZER STATE DICT:
13─────────────────────
14Contains momentum buffers (Adam: m, v for each parameter).
15Without this, training restarts from "cold" optimizer.
16
17For Adam, this can be 2x the model size!
18
19Example:
20  state[param_id] = {'step': 1000, 'm': tensor, 'v': tensor}
21
22
23SCHEDULER STATE DICT:
24─────────────────────
25Contains current step count for learning rate.
26Without this, LR schedule restarts from beginning.
27
28Example:
29  {'_step_count': 10000, 'last_epoch': 50}
30
31
32RNG STATES:
33──────────
34Random states for Python, NumPy, PyTorch (CPU & GPU).
35Ensures exact reproducibility when resuming.
36
37Components:
38  - random.getstate()
39  - np.random.get_state()
40  - torch.get_rng_state()
41  - torch.cuda.get_rng_state_all() (if GPU)
42
43
44METRICS HISTORY:
45────────────────
46Training and validation metrics over time.
47Useful for plotting learning curves.
48
49Example:
50  {'train_loss': [2.5, 2.1, 1.8, ...],
51   'val_loss': [2.4, 2.0, 1.7, ...],
52   'val_bleu': [5.0, 12.0, 18.0, ...]}

Checkpoint Manager Implementation

Complete CheckpointManager Class

🐍python
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from dataclasses import dataclass, asdict
5from typing import Dict, Any, Optional, List
6from pathlib import Path
7import json
8import shutil
9import random
10import numpy as np
11from datetime import datetime
12
13
14@dataclass
15class CheckpointConfig:
16    """Configuration for checkpoint management."""
17    checkpoint_dir: str = "checkpoints"
18    save_total_limit: int = 5           # Keep N most recent
19    save_best_only: bool = False        # Only save if best
20    metric_name: str = "val_loss"       # Metric to track
21    metric_mode: str = "min"            # "min" or "max"
22    save_optimizer: bool = True         # Include optimizer state
23    save_rng_state: bool = True         # Include RNG states
24
25
26class CheckpointManager:
27    """
28    Comprehensive checkpoint management for transformer training.
29
30    Features:
31    - Automatic best model tracking
32    - Checkpoint rotation (keep N most recent)
33    - Safe atomic saves (prevents corruption)
34    - Easy resume from latest or best
35
36    Args:
37        config: CheckpointConfig instance
38        model: The model to checkpoint
39        optimizer: Optimizer instance
40        scheduler: Learning rate scheduler
41        scaler: GradScaler for mixed precision (optional)
42
43    Example:
44        >>> manager = CheckpointManager(config, model, optimizer, scheduler)
45        >>> manager.save(epoch=10, step=50000, metrics={'val_loss': 1.5})
46        >>> manager.load_best()
47    """
48
49    def __init__(
50        self,
51        config: CheckpointConfig,
52        model: nn.Module,
53        optimizer: optim.Optimizer,
54        scheduler: Any,
55        scaler: Optional[Any] = None
56    ):
57        self.config = config
58        self.model = model
59        self.optimizer = optimizer
60        self.scheduler = scheduler
61        self.scaler = scaler
62
63        # Create checkpoint directory
64        self.checkpoint_dir = Path(config.checkpoint_dir)
65        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
66
67        # Tracking
68        self.best_metric = float('inf') if config.metric_mode == 'min' else float('-inf')
69        self.checkpoints: List[Path] = []
70        self.metrics_history: Dict[str, List[float]] = {}
71
72        # Load existing checkpoints info
73        self._scan_existing_checkpoints()
74
75    def _scan_existing_checkpoints(self):
76        """Scan directory for existing checkpoints."""
77        pattern = "checkpoint_*.pt"
78        existing = sorted(
79            self.checkpoint_dir.glob(pattern),
80            key=lambda p: p.stat().st_mtime
81        )
82        self.checkpoints = list(existing)
83
84    def _is_better(self, metric: float) -> bool:
85        """Check if metric is better than best."""
86        if self.config.metric_mode == 'min':
87            return metric < self.best_metric
88        return metric > self.best_metric
89
90    def _get_rng_states(self) -> Dict[str, Any]:
91        """Capture all random states."""
92        states = {
93            'python': random.getstate(),
94            'numpy': np.random.get_state(),
95            'torch': torch.get_rng_state(),
96        }
97        if torch.cuda.is_available():
98            states['cuda'] = torch.cuda.get_rng_state_all()
99        return states
100
101    def _set_rng_states(self, states: Dict[str, Any]):
102        """Restore all random states."""
103        random.setstate(states['python'])
104        np.random.set_state(states['numpy'])
105        torch.set_rng_state(states['torch'])
106        if 'cuda' in states and torch.cuda.is_available():
107            torch.cuda.set_rng_state_all(states['cuda'])
108
109    def save(
110        self,
111        epoch: int,
112        step: int,
113        metrics: Dict[str, float],
114        model_config: Optional[Dict[str, Any]] = None,
115        extra_state: Optional[Dict[str, Any]] = None
116    ) -> Optional[Path]:
117        """
118        Save a checkpoint.
119
120        Args:
121            epoch: Current epoch number
122            step: Global step count
123            metrics: Dictionary of metrics (must include metric_name)
124            model_config: Model configuration dictionary
125            extra_state: Any additional state to save
126
127        Returns:
128            Path to saved checkpoint, or None if not saved
129        """
130        current_metric = metrics.get(self.config.metric_name, 0)
131        is_best = self._is_better(current_metric)
132
133        # Check if we should save
134        if self.config.save_best_only and not is_best:
135            return None
136
137        # Update best metric
138        if is_best:
139            self.best_metric = current_metric
140
141        # Update metrics history
142        for key, value in metrics.items():
143            if key not in self.metrics_history:
144                self.metrics_history[key] = []
145            self.metrics_history[key].append(value)
146
147        # Build checkpoint
148        checkpoint = {
149            'epoch': epoch,
150            'global_step': step,
151            'model_state_dict': self.model.state_dict(),
152            'best_metric': self.best_metric,
153            'metrics': metrics,
154            'metrics_history': self.metrics_history,
155            'timestamp': datetime.now().isoformat(),
156        }
157
158        # Optional components
159        if self.config.save_optimizer:
160            checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()
161            if self.scheduler is not None:
162                checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
163            if self.scaler is not None:
164                checkpoint['scaler_state_dict'] = self.scaler.state_dict()
165
166        if self.config.save_rng_state:
167            checkpoint['rng_states'] = self._get_rng_states()
168
169        if model_config is not None:
170            checkpoint['model_config'] = model_config
171
172        if extra_state is not None:
173            checkpoint['extra_state'] = extra_state
174
175        # Generate filename
176        filename = f"checkpoint_epoch{epoch:04d}_step{step:08d}.pt"
177        filepath = self.checkpoint_dir / filename
178
179        # Safe atomic save (write to temp, then rename)
180        temp_path = filepath.with_suffix('.pt.tmp')
181        torch.save(checkpoint, temp_path)
182        temp_path.rename(filepath)
183
184        self.checkpoints.append(filepath)
185
186        # Save best model separately
187        if is_best:
188            best_path = self.checkpoint_dir / "best_model.pt"
189            shutil.copy(filepath, best_path)
190            print(f"  New best model! {self.config.metric_name}: {current_metric:.4f}")
191
192        # Rotate checkpoints (keep N most recent)
193        self._rotate_checkpoints()
194
195        return filepath
196
197    def _rotate_checkpoints(self):
198        """Remove old checkpoints beyond save_total_limit."""
199        while len(self.checkpoints) > self.config.save_total_limit:
200            oldest = self.checkpoints.pop(0)
201            if oldest.exists():
202                oldest.unlink()
203
204    def load(
205        self,
206        checkpoint_path: Optional[Path] = None,
207        load_optimizer: bool = True,
208        load_rng_state: bool = True
209    ) -> Dict[str, Any]:
210        """
211        Load a checkpoint.
212
213        Args:
214            checkpoint_path: Specific checkpoint to load (None = latest)
215            load_optimizer: Whether to load optimizer state
216            load_rng_state: Whether to load RNG states
217
218        Returns:
219            Checkpoint dictionary with metadata
220        """
221        if checkpoint_path is None:
222            checkpoint_path = self.get_latest_checkpoint()
223
224        if checkpoint_path is None or not checkpoint_path.exists():
225            raise FileNotFoundError(f"No checkpoint found at {checkpoint_path}")
226
227        print(f"Loading checkpoint: {checkpoint_path}")
228        checkpoint = torch.load(checkpoint_path, map_location='cpu')
229
230        # Load model
231        self.model.load_state_dict(checkpoint['model_state_dict'])
232
233        # Load optimizer state
234        if load_optimizer and 'optimizer_state_dict' in checkpoint:
235            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
236
237            if self.scheduler is not None and 'scheduler_state_dict' in checkpoint:
238                self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
239
240            if self.scaler is not None and 'scaler_state_dict' in checkpoint:
241                self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
242
243        # Load RNG states
244        if load_rng_state and 'rng_states' in checkpoint:
245            self._set_rng_states(checkpoint['rng_states'])
246
247        # Restore tracking
248        self.best_metric = checkpoint.get('best_metric', self.best_metric)
249        self.metrics_history = checkpoint.get('metrics_history', {})
250
251        return checkpoint
252
253    def load_best(self, **kwargs) -> Dict[str, Any]:
254        """Load the best checkpoint."""
255        best_path = self.checkpoint_dir / "best_model.pt"
256        return self.load(best_path, **kwargs)
257
258    def get_latest_checkpoint(self) -> Optional[Path]:
259        """Get path to most recent checkpoint."""
260        if not self.checkpoints:
261            self._scan_existing_checkpoints()
262        return self.checkpoints[-1] if self.checkpoints else None
263
264    def get_best_checkpoint(self) -> Optional[Path]:
265        """Get path to best checkpoint."""
266        best_path = self.checkpoint_dir / "best_model.pt"
267        return best_path if best_path.exists() else None

Model Selection Strategies

Choosing the Best Checkpoint

πŸ“text
1STRATEGY COMPARISON:
2────────────────────
3
41. BEST VALIDATION LOSS:
5   Pros: Simple, reliable
6   Cons: May overfit to validation set
7   Use: Default choice for most tasks
8
92. BEST BLEU SCORE:
10   Pros: Direct optimization for task metric
11   Cons: BLEU is noisy, may select overfitted model
12   Use: When BLEU is your primary metric
13
143. CHECKPOINT AVERAGING:
15   Pros: Reduces variance, often improves by 0.5-1 BLEU
16   Cons: More complex, requires storing multiple checkpoints
17   Use: For final model submission/deployment
18
194. EXPONENTIAL MOVING AVERAGE (EMA):
20   Pros: Smooth weights during training
21   Cons: Requires maintaining EMA copy
22   Use: When training is noisy
23
245. ENSEMBLE:
25   Pros: Best quality, reduces variance
26   Cons: Slower inference (multiple forward passes)
27   Use: When quality >> speed
28
29
30RECOMMENDED WORKFLOW:
31────────────────────
32
331. During training:
34   - Save checkpoints every N steps
35   - Track validation loss AND BLEU
36
372. After training:
38   - Find best single checkpoint by BLEU
39   - Average top-5 checkpoints by validation loss
40   - Compare both on test set
41
423. For deployment:
43   - Usually: averaged checkpoint
44   - If latency critical: single best checkpoint

Checkpoint Averaging Implementation

🐍python
1class ModelSelector:
2    """
3    Strategies for selecting the best model from training.
4    """
5
6    @staticmethod
7    def average_checkpoints(
8        checkpoints: List[Path],
9        model: nn.Module
10    ) -> Dict[str, torch.Tensor]:
11        """
12        Average weights from multiple checkpoints.
13
14        This often improves generalization.
15
16        Args:
17            checkpoints: List of checkpoint paths
18            model: Model for architecture reference
19
20        Returns:
21            Averaged state dict
22        """
23        print(f"Averaging {len(checkpoints)} checkpoints...")
24
25        # Load all state dicts
26        state_dicts = []
27        for cp_path in checkpoints:
28            checkpoint = torch.load(cp_path, map_location='cpu')
29            state_dicts.append(checkpoint['model_state_dict'])
30
31        # Average weights
32        averaged_state = {}
33        for key in state_dicts[0].keys():
34            tensors = [sd[key].float() for sd in state_dicts]
35            averaged_state[key] = torch.stack(tensors).mean(dim=0)
36
37        return averaged_state
38
39    @staticmethod
40    def select_top_k_by_metric(
41        checkpoints: List[Path],
42        metric_name: str = "val_loss",
43        mode: str = "min",
44        k: int = 5
45    ) -> List[Path]:
46        """
47        Select top K checkpoints by metric.
48
49        Args:
50            checkpoints: List of checkpoint paths
51            metric_name: Metric to use
52            mode: "min" or "max"
53            k: Number to select
54
55        Returns:
56            List of top K checkpoint paths
57        """
58        checkpoint_metrics = []
59
60        for cp_path in checkpoints:
61            checkpoint = torch.load(cp_path, map_location='cpu')
62            metrics = checkpoint.get('metrics', {})
63            value = metrics.get(metric_name, float('inf') if mode == 'min' else float('-inf'))
64            checkpoint_metrics.append((cp_path, value))
65
66        # Sort by metric
67        reverse = (mode == 'max')
68        sorted_cps = sorted(checkpoint_metrics, key=lambda x: x[1], reverse=reverse)
69
70        return [cp[0] for cp in sorted_cps[:k]]

Exponential Moving Average (EMA)

EMA for Smoother Models

🐍python
1class EMAModel:
2    """
3    Exponential Moving Average of model weights.
4
5    Maintains a shadow copy of weights that is a moving average
6    of training weights. Often produces better final models.
7
8    EMA update: shadow = decay * shadow + (1 - decay) * current
9
10    Args:
11        model: Model to track
12        decay: EMA decay rate (0.999 - 0.9999 typical)
13        device: Device for EMA weights
14
15    Example:
16        >>> ema = EMAModel(model, decay=0.9999)
17        >>> for batch in dataloader:
18        ...     loss = train_step(model, batch)
19        ...     ema.update()
20        >>> ema.apply()  # Use EMA weights for inference
21    """
22
23    def __init__(
24        self,
25        model: nn.Module,
26        decay: float = 0.9999,
27        device: Optional[torch.device] = None
28    ):
29        self.model = model
30        self.decay = decay
31        self.device = device
32
33        # Create shadow parameters
34        self.shadow = {}
35        self.backup = {}
36
37        for name, param in model.named_parameters():
38            if param.requires_grad:
39                self.shadow[name] = param.data.clone()
40                if device is not None:
41                    self.shadow[name] = self.shadow[name].to(device)
42
43    def update(self):
44        """Update shadow weights with current model weights."""
45        for name, param in self.model.named_parameters():
46            if param.requires_grad and name in self.shadow:
47                # EMA update
48                self.shadow[name] = (
49                    self.decay * self.shadow[name] +
50                    (1 - self.decay) * param.data
51                )
52
53    def apply(self):
54        """Apply EMA weights to model (backup current weights)."""
55        for name, param in self.model.named_parameters():
56            if param.requires_grad and name in self.shadow:
57                self.backup[name] = param.data.clone()
58                param.data.copy_(self.shadow[name])
59
60    def restore(self):
61        """Restore original weights (undo apply)."""
62        for name, param in self.model.named_parameters():
63            if param.requires_grad and name in self.backup:
64                param.data.copy_(self.backup[name])
65        self.backup = {}
66
67    def state_dict(self) -> Dict[str, torch.Tensor]:
68        """Get EMA state dict for checkpointing."""
69        return {
70            'shadow': self.shadow.copy(),
71            'decay': self.decay
72        }
73
74    def load_state_dict(self, state_dict: Dict[str, Any]):
75        """Load EMA state from checkpoint."""
76        self.shadow = state_dict['shadow']
77        self.decay = state_dict.get('decay', self.decay)

Resume Training Protocol

Complete Resume Implementation

πŸ“text
1RESUME CHECKLIST:
2─────────────────
3
4βœ“ Load model state dict
5βœ“ Load optimizer state dict (critical for Adam momentum!)
6βœ“ Load scheduler state dict (for correct LR)
7βœ“ Load RNG states (for reproducibility)
8βœ“ Load metrics history (for plotting)
9βœ“ Verify epoch and step counts
10βœ“ Confirm batch ordering (same shuffle seed)
🐍python
1def resume_training(
2    checkpoint_path: str,
3    model: nn.Module,
4    optimizer: optim.Optimizer,
5    scheduler: Any,
6    scaler: Optional[Any] = None
7) -> Dict[str, Any]:
8    """
9    Resume training from checkpoint.
10
11    Returns:
12        Dictionary with resume state (epoch, step, etc.)
13    """
14    print(f"Resuming from {checkpoint_path}")
15
16    # Load checkpoint
17    checkpoint = torch.load(checkpoint_path, map_location='cpu')
18
19    # Restore model
20    model.load_state_dict(checkpoint['model_state_dict'])
21
22    # Restore optimizer (CRITICAL for Adam!)
23    if 'optimizer_state_dict' in checkpoint:
24        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
25
26        # Move optimizer state to correct device
27        device = next(model.parameters()).device
28        for state in optimizer.state.values():
29            for k, v in state.items():
30                if isinstance(v, torch.Tensor):
31                    state[k] = v.to(device)
32
33    # Restore scheduler
34    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
35        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
36
37    # Restore mixed precision scaler
38    if scaler is not None and 'scaler_state_dict' in checkpoint:
39        scaler.load_state_dict(checkpoint['scaler_state_dict'])
40
41    # Restore RNG states for exact reproducibility
42    if 'rng_states' in checkpoint:
43        rng_states = checkpoint['rng_states']
44        random.setstate(rng_states['python'])
45        np.random.set_state(rng_states['numpy'])
46        torch.set_rng_state(rng_states['torch'])
47        if 'cuda' in rng_states and torch.cuda.is_available():
48            torch.cuda.set_rng_state_all(rng_states['cuda'])
49
50    resume_state = {
51        'epoch': checkpoint['epoch'],
52        'global_step': checkpoint['global_step'],
53        'best_metric': checkpoint.get('best_metric'),
54        'metrics_history': checkpoint.get('metrics_history', {})
55    }
56
57    print(f"  Resumed at epoch {resume_state['epoch']}, "
58          f"step {resume_state['global_step']}")
59
60    return resume_state

Export for Deployment

Preparing Models for Production

πŸ“text
1DEPLOYMENT FORMATS:
2───────────────────
3
41. PYTORCH CHECKPOINT (inference-only):
5   Pros: Native PyTorch, easy to load
6   Cons: Requires PyTorch runtime
7   Use: Python deployment, research
8
9   torch.save({'model_state_dict': model.state_dict()}, 'model.pt')
10
112. TORCHSCRIPT:
12   Pros: No Python dependency, optimized
13   Cons: Some ops not supported
14   Use: Production, mobile, C++ deployment
15
16   traced = torch.jit.trace(model, inputs)
17   traced.save('model.pt')
18
193. ONNX:
20   Pros: Framework agnostic, wide support
21   Cons: May lose some functionality
22   Use: Cross-platform, TensorRT, ONNX Runtime
23
24   torch.onnx.export(model, inputs, 'model.onnx')
25
26
27DEPLOYMENT CHECKLIST:
28─────────────────────
29
30βœ“ Remove dropout (model.eval())
31βœ“ Remove unused layers
32βœ“ Benchmark latency
33βœ“ Verify output quality
34βœ“ Test edge cases
35βœ“ Document input/output format
36βœ“ Version the model
37βœ“ Set up monitoring

Export Inference Checkpoint

🐍python
1class ModelExporter:
2    """
3    Export trained models for deployment.
4    """
5
6    @staticmethod
7    def export_inference_checkpoint(
8        checkpoint_path: str,
9        output_path: str,
10        include_config: bool = True
11    ):
12        """
13        Export minimal checkpoint for inference only.
14
15        Removes optimizer state, RNG states, and other training data.
16        Much smaller file size.
17
18        Args:
19            checkpoint_path: Full training checkpoint
20            output_path: Output path for inference checkpoint
21            include_config: Whether to include model config
22        """
23        checkpoint = torch.load(checkpoint_path, map_location='cpu')
24
25        inference_checkpoint = {
26            'model_state_dict': checkpoint['model_state_dict'],
27        }
28
29        if include_config and 'model_config' in checkpoint:
30            inference_checkpoint['model_config'] = checkpoint['model_config']
31
32        torch.save(inference_checkpoint, output_path)
33
34        # Report size reduction
35        import os
36        original_size = os.path.getsize(checkpoint_path) / 1024 / 1024
37        new_size = os.path.getsize(output_path) / 1024 / 1024
38
39        print(f"Exported inference checkpoint:")
40        print(f"  Original: {original_size:.1f} MB")
41        print(f"  Inference: {new_size:.1f} MB")
42        print(f"  Reduction: {(1 - new_size/original_size)*100:.1f}%")

Summary

Checkpointing Key Points

ComponentPurposeRequired for Resume
model_state_dictModel weightsYes
optimizer_state_dictMomentum buffersYes (for training)
scheduler_state_dictLR schedule positionYes (for training)
rng_statesReproducibilityOptional
metrics_historyLearning curvesOptional

Model Selection Strategies

StrategyBest ForQuality Gain
Best validation lossDefault choiceBaseline
Best BLEUTask-specificVariable
Checkpoint averagingProduction+0.5-1 BLEU
EMANoisy trainingSmoother model

Best Practices

  • Always save optimizer state for training continuation
  • Use atomic saves to prevent corruption
  • Rotate checkpoints to manage disk space
  • Average checkpoints for final model
  • Export inference-only checkpoint for deployment

Exercises

Implementation

  • Implement cyclical checkpoint saving (save every K steps, but also save at specific loss thresholds).
  • Add support for distributed training checkpoints (gather from all ranks).
  • Implement checkpoint comparison tool (diff two checkpoints).

Analysis

  • Compare single best vs averaged checkpoint on test set.
  • Experiment with different EMA decay values.

In the next chapter, we'll cover Evaluation Metricsβ€”specifically BLEU score computation and other metrics for measuring translation quality.

Loading comments...