Learning Objectives
By the end of this section, you will:
- Know what to save in a complete checkpoint
- Understand when to checkpointโbest model vs. periodic
- Design checkpoint structure for full recoverability
- Implement save and load functions in PyTorch
- Resume training from interrupted runs
Why This Matters: Training deep learning models can take hours or days. Without proper checkpointing, a crash, power outage, or preemption means starting from scratch. Good checkpointing saves the best model for deployment and enables seamless recovery from any interruption.
What to Checkpoint
A complete checkpoint must contain everything needed to resume training or deploy the model.
Essential Components
| Component | Purpose | Required For |
|---|---|---|
| Model state_dict | Learned weights | Both inference and training |
| Optimizer state_dict | Momentum, adaptive LR states | Resume training |
| Epoch number | Training progress | Resume training |
| Best metrics | Selection criterion | Model comparison |
| EMA shadow weights | Smoothed weights | Stable inference |
Model State Dictionary
The model's state_dict() contains all learnable parameters:
1# Model state_dict structure
2model_state = model.state_dict()
3
4# Example keys for DualTaskEnhancedModel:
5# 'feature_extractor.conv1.weight'
6# 'feature_extractor.conv1.bias'
7# 'feature_extractor.lstm.weight_ih_l0'
8# 'rul_head.0.weight'
9# 'health_head.0.weight'
10# ... all learnable parametersOptimizer State Dictionary
AdamW maintains per-parameter state that must be preserved:
1# Optimizer state_dict structure
2optimizer_state = optimizer.state_dict()
3
4# Contains:
5# - 'state': Per-parameter momentum and variance estimates
6# - 'exp_avg': First moment (momentum)
7# - 'exp_avg_sq': Second moment (for adaptive LR)
8# - 'step': Number of updates
9# - 'param_groups': Learning rate and hyperparameters
10
11# Example:
12# optimizer_state['state'][0] = {
13# 'step': 1000,
14# 'exp_avg': tensor(...),
15# 'exp_avg_sq': tensor(...)
16# }Why Optimizer State Matters
Without optimizer state, AdamW restarts from scratch. This means losing all accumulated momentum and variance estimates built over potentially hundreds of epochs. The result: training instability and degraded final performance.
EMA Shadow Weights
For EMA-based inference, save the shadow parameters:
1# EMA shadow weights
2if ema is not None:
3 ema_shadow = copy.deepcopy(ema.shadow)
4 # shadow is a dict: {param_name: smoothed_tensor}When to Checkpoint
Different checkpointing strategies serve different purposes.
Best Model Checkpoint
Save when validation metric improvesโthis is the checkpoint you deploy:
1# Save best model during training
2if rmse_last < best_rmse_last_cycle:
3 best_rmse_last_cycle = rmse_last
4 best_epoch = epoch
5
6 # Create complete checkpoint
7 model.eval()
8 best_model_state = {
9 'model_state_dict': copy.deepcopy(model.state_dict()),
10 'optimizer_state_dict': copy.deepcopy(optimizer.state_dict()),
11 'epoch': epoch,
12 'rmse': rmse_last,
13 'metrics': copy.deepcopy(eval_results)
14 }
15
16 # Include EMA if using
17 if ema:
18 best_model_state['ema_shadow'] = copy.deepcopy(ema.shadow)
19
20 logger.info(f"New best model! RMSE: {rmse_last:.2f}")Periodic Checkpoints
Save at regular intervals for recovery from interruption:
1# Periodic checkpoint (every N epochs)
2CHECKPOINT_INTERVAL = 50
3
4if (epoch + 1) % CHECKPOINT_INTERVAL == 0:
5 checkpoint = {
6 'epoch': epoch,
7 'model_state_dict': model.state_dict(),
8 'optimizer_state_dict': optimizer.state_dict(),
9 'scheduler_state_dict': scheduler.state_dict(),
10 'best_rmse': best_rmse_last_cycle,
11 'history': dict(history),
12 }
13 if ema:
14 checkpoint['ema_shadow'] = ema.shadow
15
16 torch.save(checkpoint, f'checkpoints/epoch_{epoch+1}.pt')
17 logger.info(f"Periodic checkpoint saved at epoch {epoch+1}")Interrupt Recovery
Handle keyboard interrupts gracefully:
1try:
2 for epoch in range(start_epoch, epochs):
3 # Training loop
4 train_epoch(...)
5
6except KeyboardInterrupt:
7 logger.warning("Training interrupted by user")
8
9 # Save emergency checkpoint
10 emergency_checkpoint = {
11 'epoch': epoch,
12 'model_state_dict': model.state_dict(),
13 'optimizer_state_dict': optimizer.state_dict(),
14 'scheduler_state_dict': scheduler.state_dict(),
15 'best_model_state': best_model_state,
16 'history': dict(history),
17 }
18 torch.save(emergency_checkpoint, 'checkpoints/interrupted.pt')
19 logger.info("Emergency checkpoint saved")
20
21finally:
22 # Always save best model at end
23 if best_model_state:
24 torch.save(best_model_state, 'models/best_model.pt')| Strategy | Frequency | Purpose |
|---|---|---|
| Best model | On improvement | Deployment |
| Periodic | Every N epochs | Crash recovery |
| Interrupt | On Ctrl+C | Emergency recovery |
| Final | End of training | Complete archive |
Checkpoint Structure
A well-designed checkpoint includes everything needed for both training and deployment.
Complete Checkpoint Schema
1# Complete checkpoint structure for AMNL
2checkpoint = {
3 # Model weights (required for inference)
4 'model_state_dict': model.state_dict(),
5
6 # Optimizer state (required for training resume)
7 'optimizer_state_dict': optimizer.state_dict(),
8
9 # Scheduler state (required for LR schedule continuation)
10 'scheduler_state_dict': scheduler.state_dict(),
11
12 # Dataset information
13 'dataset': dataset_name,
14
15 # Training state
16 'epoch': epoch,
17 'best_epoch': best_epoch,
18 'best_rmse': best_rmse_last_cycle,
19
20 # Metrics at checkpoint
21 'final_metrics': convert_numpy_types(final_results),
22
23 # Complete training history
24 'training_history': convert_numpy_types(dict(history)),
25
26 # Model configuration
27 'model_config': {
28 'version': 'V7',
29 'input_size': input_size,
30 'sequence_length': 30,
31 'hidden_size': 256,
32 'epochs_trained': epoch + 1,
33 'batch_size': batch_size,
34 'learning_rate': learning_rate,
35 'random_seed': random_seed,
36 'amnl_config': {
37 'loss': 'weighted MSE (linear)',
38 'task_weights': '0.5/0.5',
39 'scheduler': 'ReduceLROnPlateau',
40 'warmup': '10 epochs',
41 'early_stopping': '80 epochs'
42 }
43 },
44
45 # Data preprocessing parameters (critical for inference!)
46 'scaler_params': train_dataset.get_scaler_params(),
47
48 # EMA weights (optional but recommended)
49 'ema_shadow': ema.shadow if ema else None,
50}Scaler Parameters
Never forget to save data normalization parameters! Without the same scaler used during training, test-time predictions will be meaningless. The scaler_params should include means, standard deviations, and any per-condition normalization statistics.
Checkpoint File Organization
1models/
2โโโ nasa_cmapss_v7/
3โ โโโ FD001_v7_seed42.pth # Final model (deployment)
4โ โโโ FD001_v7_seed42_history.json # Training history
5โ โโโ checkpoints/
6โ โโโ epoch_50.pt # Periodic checkpoint
7โ โโโ epoch_100.pt # Periodic checkpoint
8โ โโโ best_epoch_87.pt # Best model checkpoint
9
10logs/
11โโโ FD001_v7_20240115_143022.log # Training logImplementation
Complete save and load functions for the AMNL training pipeline.
Save Checkpoint Function
1import torch
2import copy
3import os
4from pathlib import Path
5from typing import Dict, Optional
6
7def save_checkpoint(
8 model: nn.Module,
9 optimizer: torch.optim.Optimizer,
10 scheduler,
11 epoch: int,
12 metrics: Dict,
13 history: Dict,
14 filepath: str,
15 ema: Optional[ExponentialMovingAverage] = None,
16 scaler_params: Optional[Dict] = None,
17 model_config: Optional[Dict] = None,
18 is_best: bool = False
19):
20 """
21 Save a complete training checkpoint.
22
23 Args:
24 model: Trained model
25 optimizer: Optimizer with accumulated state
26 scheduler: Learning rate scheduler
27 epoch: Current epoch number
28 metrics: Evaluation metrics at this epoch
29 history: Complete training history
30 filepath: Path to save checkpoint
31 ema: Optional EMA tracker
32 scaler_params: Data normalization parameters
33 model_config: Model architecture configuration
34 is_best: Whether this is the best model so far
35 """
36 # Create directory if needed
37 Path(filepath).parent.mkdir(parents=True, exist_ok=True)
38
39 # Build checkpoint dictionary
40 checkpoint = {
41 'epoch': epoch,
42 'model_state_dict': copy.deepcopy(model.state_dict()),
43 'optimizer_state_dict': copy.deepcopy(optimizer.state_dict()),
44 'scheduler_state_dict': copy.deepcopy(scheduler.state_dict()),
45 'metrics': convert_numpy_types(metrics),
46 'history': convert_numpy_types(dict(history)),
47 'is_best': is_best,
48 }
49
50 # Add optional components
51 if ema is not None:
52 checkpoint['ema_shadow'] = copy.deepcopy(ema.shadow)
53
54 if scaler_params is not None:
55 checkpoint['scaler_params'] = scaler_params
56
57 if model_config is not None:
58 checkpoint['model_config'] = model_config
59
60 # Save checkpoint
61 torch.save(checkpoint, filepath)
62
63 return filepathLoad Checkpoint Function
1def load_checkpoint(
2 filepath: str,
3 model: nn.Module,
4 optimizer: Optional[torch.optim.Optimizer] = None,
5 scheduler = None,
6 ema: Optional[ExponentialMovingAverage] = None,
7 device: torch.device = torch.device('cpu')
8) -> Dict:
9 """
10 Load a training checkpoint.
11
12 Args:
13 filepath: Path to checkpoint file
14 model: Model to load weights into
15 optimizer: Optional optimizer to restore state
16 scheduler: Optional scheduler to restore state
17 ema: Optional EMA to restore shadow weights
18 device: Device to map tensors to
19
20 Returns:
21 Dictionary with checkpoint metadata
22 """
23 # Load checkpoint
24 checkpoint = torch.load(filepath, map_location=device)
25
26 # Load model weights
27 model.load_state_dict(checkpoint['model_state_dict'])
28
29 # Load optimizer state if provided
30 if optimizer is not None and 'optimizer_state_dict' in checkpoint:
31 optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
32
33 # Load scheduler state if provided
34 if scheduler is not None and 'scheduler_state_dict' in checkpoint:
35 scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
36
37 # Load EMA shadow if provided
38 if ema is not None and 'ema_shadow' in checkpoint:
39 ema.shadow = checkpoint['ema_shadow']
40
41 # Return metadata for training resume
42 return {
43 'epoch': checkpoint.get('epoch', 0),
44 'metrics': checkpoint.get('metrics', {}),
45 'history': checkpoint.get('history', {}),
46 'model_config': checkpoint.get('model_config', {}),
47 'scaler_params': checkpoint.get('scaler_params', None),
48 'is_best': checkpoint.get('is_best', False),
49 }Resume Training
1def resume_training(
2 checkpoint_path: str,
3 model: nn.Module,
4 optimizer: torch.optim.Optimizer,
5 scheduler,
6 device: torch.device
7) -> Tuple[int, Dict, Dict]:
8 """
9 Resume training from a checkpoint.
10
11 Returns:
12 Tuple of (start_epoch, history, metrics)
13 """
14 print(f"Resuming from checkpoint: {checkpoint_path}")
15
16 # Load checkpoint
17 metadata = load_checkpoint(
18 checkpoint_path,
19 model,
20 optimizer,
21 scheduler,
22 device=device
23 )
24
25 start_epoch = metadata['epoch'] + 1
26 history = defaultdict(list, metadata.get('history', {}))
27
28 print(f"Resuming from epoch {start_epoch}")
29 print(f"Previous best RMSE: {metadata['metrics'].get('RMSE_last_cycle', 'N/A')}")
30
31 return start_epoch, history, metadata['metrics']
32
33
34# Usage example
35if args.resume:
36 start_epoch, history, prev_metrics = resume_training(
37 args.resume_path,
38 model,
39 optimizer,
40 scheduler,
41 device
42 )
43else:
44 start_epoch = 0
45 history = defaultdict(list)
46
47# Continue training from start_epoch
48for epoch in range(start_epoch, total_epochs):
49 train_epoch(...)Final Model Save
1# Save final model for deployment
2os.makedirs(output_dir, exist_ok=True)
3model_path = f'{output_dir}/{dataset_name}_v7_seed{random_seed}.pth'
4
5torch.save({
6 'model_state_dict': model.state_dict(),
7 'optimizer_state_dict': optimizer.state_dict(),
8 'dataset': dataset_name,
9 'final_metrics': convert_numpy_types(final_results),
10 'training_history': convert_numpy_types(dict(history)),
11 'best_epoch': best_epoch,
12 'model_config': {
13 'version': 'V7',
14 'input_size': input_size,
15 'sequence_length': 30,
16 'hidden_size': 256,
17 'epochs_trained': epoch + 1,
18 'batch_size': batch_size,
19 'learning_rate': learning_rate,
20 'random_seed': random_seed,
21 'amnl_config': {
22 'loss': 'weighted MSE (linear)',
23 'task_weights': '0.5/0.5',
24 'scheduler': 'ReduceLROnPlateau',
25 'warmup': '10 epochs',
26 'early_stopping': '80 epochs'
27 }
28 },
29 'scaler_params': train_dataset.get_scaler_params()
30}, model_path)
31
32logger.info(f"Model saved to {model_path}")Best Model Loading at Training End
Always restore the best model before final evaluation and saving. The model at the last epoch is rarely the best due to overfitting or noise.
1# At end of training, restore best model
2if best_model_state:
3 logger.info(f"Loading best model from epoch {best_epoch+1}")
4 model.load_state_dict(best_model_state['model_state_dict'])
5
6 # Also restore EMA if used
7 if ema and 'ema_shadow' in best_model_state:
8 ema.shadow = best_model_state['ema_shadow']
9 ema.apply_shadow(model) # Apply EMA for final evaluationSummary
In this section, we covered checkpointing strategy:
- What to save: Model, optimizer, scheduler, EMA, scaler params
- When to save: Best model, periodic, on interrupt
- Checkpoint structure: Complete metadata for recovery
- Save/load functions: Reusable checkpoint utilities
- Resume training: Continue from any checkpoint
| Component | Why Essential |
|---|---|
| model_state_dict | Learned weights for inference |
| optimizer_state_dict | Momentum/variance for training |
| scheduler_state_dict | LR schedule continuation |
| ema_shadow | Smoothed weights for stable inference |
| scaler_params | Data normalization for deployment |
| training_history | Analysis and visualization |
Looking Ahead: We have all the pieces: training loop, validation, monitoring, and checkpointing. The final section brings everything together in the complete training script walkthroughโthe production-ready code that achieves state-of-the-art on NASA C-MAPSS.
With checkpointing complete, we assemble the full training script.