Chapter 14
12 min read
Section 71 of 104

Checkpointing Strategy

Complete Training Script

Learning Objectives

By the end of this section, you will:

  1. Know what to save in a complete checkpoint
  2. Understand when to checkpointโ€”best model vs. periodic
  3. Design checkpoint structure for full recoverability
  4. Implement save and load functions in PyTorch
  5. Resume training from interrupted runs
Why This Matters: Training deep learning models can take hours or days. Without proper checkpointing, a crash, power outage, or preemption means starting from scratch. Good checkpointing saves the best model for deployment and enables seamless recovery from any interruption.

What to Checkpoint

A complete checkpoint must contain everything needed to resume training or deploy the model.

Essential Components

ComponentPurposeRequired For
Model state_dictLearned weightsBoth inference and training
Optimizer state_dictMomentum, adaptive LR statesResume training
Epoch numberTraining progressResume training
Best metricsSelection criterionModel comparison
EMA shadow weightsSmoothed weightsStable inference

Model State Dictionary

The model's state_dict() contains all learnable parameters:

๐Ÿpython
1# Model state_dict structure
2model_state = model.state_dict()
3
4# Example keys for DualTaskEnhancedModel:
5# 'feature_extractor.conv1.weight'
6# 'feature_extractor.conv1.bias'
7# 'feature_extractor.lstm.weight_ih_l0'
8# 'rul_head.0.weight'
9# 'health_head.0.weight'
10# ... all learnable parameters

Optimizer State Dictionary

AdamW maintains per-parameter state that must be preserved:

๐Ÿpython
1# Optimizer state_dict structure
2optimizer_state = optimizer.state_dict()
3
4# Contains:
5# - 'state': Per-parameter momentum and variance estimates
6#   - 'exp_avg': First moment (momentum)
7#   - 'exp_avg_sq': Second moment (for adaptive LR)
8#   - 'step': Number of updates
9# - 'param_groups': Learning rate and hyperparameters
10
11# Example:
12# optimizer_state['state'][0] = {
13#     'step': 1000,
14#     'exp_avg': tensor(...),
15#     'exp_avg_sq': tensor(...)
16# }

Why Optimizer State Matters

Without optimizer state, AdamW restarts from scratch. This means losing all accumulated momentum and variance estimates built over potentially hundreds of epochs. The result: training instability and degraded final performance.

EMA Shadow Weights

For EMA-based inference, save the shadow parameters:

๐Ÿpython
1# EMA shadow weights
2if ema is not None:
3    ema_shadow = copy.deepcopy(ema.shadow)
4    # shadow is a dict: {param_name: smoothed_tensor}

When to Checkpoint

Different checkpointing strategies serve different purposes.

Best Model Checkpoint

Save when validation metric improvesโ€”this is the checkpoint you deploy:

๐Ÿpython
1# Save best model during training
2if rmse_last < best_rmse_last_cycle:
3    best_rmse_last_cycle = rmse_last
4    best_epoch = epoch
5
6    # Create complete checkpoint
7    model.eval()
8    best_model_state = {
9        'model_state_dict': copy.deepcopy(model.state_dict()),
10        'optimizer_state_dict': copy.deepcopy(optimizer.state_dict()),
11        'epoch': epoch,
12        'rmse': rmse_last,
13        'metrics': copy.deepcopy(eval_results)
14    }
15
16    # Include EMA if using
17    if ema:
18        best_model_state['ema_shadow'] = copy.deepcopy(ema.shadow)
19
20    logger.info(f"New best model! RMSE: {rmse_last:.2f}")

Periodic Checkpoints

Save at regular intervals for recovery from interruption:

๐Ÿpython
1# Periodic checkpoint (every N epochs)
2CHECKPOINT_INTERVAL = 50
3
4if (epoch + 1) % CHECKPOINT_INTERVAL == 0:
5    checkpoint = {
6        'epoch': epoch,
7        'model_state_dict': model.state_dict(),
8        'optimizer_state_dict': optimizer.state_dict(),
9        'scheduler_state_dict': scheduler.state_dict(),
10        'best_rmse': best_rmse_last_cycle,
11        'history': dict(history),
12    }
13    if ema:
14        checkpoint['ema_shadow'] = ema.shadow
15
16    torch.save(checkpoint, f'checkpoints/epoch_{epoch+1}.pt')
17    logger.info(f"Periodic checkpoint saved at epoch {epoch+1}")

Interrupt Recovery

Handle keyboard interrupts gracefully:

๐Ÿpython
1try:
2    for epoch in range(start_epoch, epochs):
3        # Training loop
4        train_epoch(...)
5
6except KeyboardInterrupt:
7    logger.warning("Training interrupted by user")
8
9    # Save emergency checkpoint
10    emergency_checkpoint = {
11        'epoch': epoch,
12        'model_state_dict': model.state_dict(),
13        'optimizer_state_dict': optimizer.state_dict(),
14        'scheduler_state_dict': scheduler.state_dict(),
15        'best_model_state': best_model_state,
16        'history': dict(history),
17    }
18    torch.save(emergency_checkpoint, 'checkpoints/interrupted.pt')
19    logger.info("Emergency checkpoint saved")
20
21finally:
22    # Always save best model at end
23    if best_model_state:
24        torch.save(best_model_state, 'models/best_model.pt')
StrategyFrequencyPurpose
Best modelOn improvementDeployment
PeriodicEvery N epochsCrash recovery
InterruptOn Ctrl+CEmergency recovery
FinalEnd of trainingComplete archive

Checkpoint Structure

A well-designed checkpoint includes everything needed for both training and deployment.

Complete Checkpoint Schema

๐Ÿpython
1# Complete checkpoint structure for AMNL
2checkpoint = {
3    # Model weights (required for inference)
4    'model_state_dict': model.state_dict(),
5
6    # Optimizer state (required for training resume)
7    'optimizer_state_dict': optimizer.state_dict(),
8
9    # Scheduler state (required for LR schedule continuation)
10    'scheduler_state_dict': scheduler.state_dict(),
11
12    # Dataset information
13    'dataset': dataset_name,
14
15    # Training state
16    'epoch': epoch,
17    'best_epoch': best_epoch,
18    'best_rmse': best_rmse_last_cycle,
19
20    # Metrics at checkpoint
21    'final_metrics': convert_numpy_types(final_results),
22
23    # Complete training history
24    'training_history': convert_numpy_types(dict(history)),
25
26    # Model configuration
27    'model_config': {
28        'version': 'V7',
29        'input_size': input_size,
30        'sequence_length': 30,
31        'hidden_size': 256,
32        'epochs_trained': epoch + 1,
33        'batch_size': batch_size,
34        'learning_rate': learning_rate,
35        'random_seed': random_seed,
36        'amnl_config': {
37            'loss': 'weighted MSE (linear)',
38            'task_weights': '0.5/0.5',
39            'scheduler': 'ReduceLROnPlateau',
40            'warmup': '10 epochs',
41            'early_stopping': '80 epochs'
42        }
43    },
44
45    # Data preprocessing parameters (critical for inference!)
46    'scaler_params': train_dataset.get_scaler_params(),
47
48    # EMA weights (optional but recommended)
49    'ema_shadow': ema.shadow if ema else None,
50}

Scaler Parameters

Never forget to save data normalization parameters! Without the same scaler used during training, test-time predictions will be meaningless. The scaler_params should include means, standard deviations, and any per-condition normalization statistics.

Checkpoint File Organization

๐Ÿ“text
1models/
2โ”œโ”€โ”€ nasa_cmapss_v7/
3โ”‚   โ”œโ”€โ”€ FD001_v7_seed42.pth          # Final model (deployment)
4โ”‚   โ”œโ”€โ”€ FD001_v7_seed42_history.json # Training history
5โ”‚   โ””โ”€โ”€ checkpoints/
6โ”‚       โ”œโ”€โ”€ epoch_50.pt               # Periodic checkpoint
7โ”‚       โ”œโ”€โ”€ epoch_100.pt              # Periodic checkpoint
8โ”‚       โ””โ”€โ”€ best_epoch_87.pt          # Best model checkpoint
9
10logs/
11โ”œโ”€โ”€ FD001_v7_20240115_143022.log     # Training log

Implementation

Complete save and load functions for the AMNL training pipeline.

Save Checkpoint Function

๐Ÿpython
1import torch
2import copy
3import os
4from pathlib import Path
5from typing import Dict, Optional
6
7def save_checkpoint(
8    model: nn.Module,
9    optimizer: torch.optim.Optimizer,
10    scheduler,
11    epoch: int,
12    metrics: Dict,
13    history: Dict,
14    filepath: str,
15    ema: Optional[ExponentialMovingAverage] = None,
16    scaler_params: Optional[Dict] = None,
17    model_config: Optional[Dict] = None,
18    is_best: bool = False
19):
20    """
21    Save a complete training checkpoint.
22
23    Args:
24        model: Trained model
25        optimizer: Optimizer with accumulated state
26        scheduler: Learning rate scheduler
27        epoch: Current epoch number
28        metrics: Evaluation metrics at this epoch
29        history: Complete training history
30        filepath: Path to save checkpoint
31        ema: Optional EMA tracker
32        scaler_params: Data normalization parameters
33        model_config: Model architecture configuration
34        is_best: Whether this is the best model so far
35    """
36    # Create directory if needed
37    Path(filepath).parent.mkdir(parents=True, exist_ok=True)
38
39    # Build checkpoint dictionary
40    checkpoint = {
41        'epoch': epoch,
42        'model_state_dict': copy.deepcopy(model.state_dict()),
43        'optimizer_state_dict': copy.deepcopy(optimizer.state_dict()),
44        'scheduler_state_dict': copy.deepcopy(scheduler.state_dict()),
45        'metrics': convert_numpy_types(metrics),
46        'history': convert_numpy_types(dict(history)),
47        'is_best': is_best,
48    }
49
50    # Add optional components
51    if ema is not None:
52        checkpoint['ema_shadow'] = copy.deepcopy(ema.shadow)
53
54    if scaler_params is not None:
55        checkpoint['scaler_params'] = scaler_params
56
57    if model_config is not None:
58        checkpoint['model_config'] = model_config
59
60    # Save checkpoint
61    torch.save(checkpoint, filepath)
62
63    return filepath

Load Checkpoint Function

๐Ÿpython
1def load_checkpoint(
2    filepath: str,
3    model: nn.Module,
4    optimizer: Optional[torch.optim.Optimizer] = None,
5    scheduler = None,
6    ema: Optional[ExponentialMovingAverage] = None,
7    device: torch.device = torch.device('cpu')
8) -> Dict:
9    """
10    Load a training checkpoint.
11
12    Args:
13        filepath: Path to checkpoint file
14        model: Model to load weights into
15        optimizer: Optional optimizer to restore state
16        scheduler: Optional scheduler to restore state
17        ema: Optional EMA to restore shadow weights
18        device: Device to map tensors to
19
20    Returns:
21        Dictionary with checkpoint metadata
22    """
23    # Load checkpoint
24    checkpoint = torch.load(filepath, map_location=device)
25
26    # Load model weights
27    model.load_state_dict(checkpoint['model_state_dict'])
28
29    # Load optimizer state if provided
30    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
31        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
32
33    # Load scheduler state if provided
34    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
35        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
36
37    # Load EMA shadow if provided
38    if ema is not None and 'ema_shadow' in checkpoint:
39        ema.shadow = checkpoint['ema_shadow']
40
41    # Return metadata for training resume
42    return {
43        'epoch': checkpoint.get('epoch', 0),
44        'metrics': checkpoint.get('metrics', {}),
45        'history': checkpoint.get('history', {}),
46        'model_config': checkpoint.get('model_config', {}),
47        'scaler_params': checkpoint.get('scaler_params', None),
48        'is_best': checkpoint.get('is_best', False),
49    }

Resume Training

๐Ÿpython
1def resume_training(
2    checkpoint_path: str,
3    model: nn.Module,
4    optimizer: torch.optim.Optimizer,
5    scheduler,
6    device: torch.device
7) -> Tuple[int, Dict, Dict]:
8    """
9    Resume training from a checkpoint.
10
11    Returns:
12        Tuple of (start_epoch, history, metrics)
13    """
14    print(f"Resuming from checkpoint: {checkpoint_path}")
15
16    # Load checkpoint
17    metadata = load_checkpoint(
18        checkpoint_path,
19        model,
20        optimizer,
21        scheduler,
22        device=device
23    )
24
25    start_epoch = metadata['epoch'] + 1
26    history = defaultdict(list, metadata.get('history', {}))
27
28    print(f"Resuming from epoch {start_epoch}")
29    print(f"Previous best RMSE: {metadata['metrics'].get('RMSE_last_cycle', 'N/A')}")
30
31    return start_epoch, history, metadata['metrics']
32
33
34# Usage example
35if args.resume:
36    start_epoch, history, prev_metrics = resume_training(
37        args.resume_path,
38        model,
39        optimizer,
40        scheduler,
41        device
42    )
43else:
44    start_epoch = 0
45    history = defaultdict(list)
46
47# Continue training from start_epoch
48for epoch in range(start_epoch, total_epochs):
49    train_epoch(...)

Final Model Save

๐Ÿpython
1# Save final model for deployment
2os.makedirs(output_dir, exist_ok=True)
3model_path = f'{output_dir}/{dataset_name}_v7_seed{random_seed}.pth'
4
5torch.save({
6    'model_state_dict': model.state_dict(),
7    'optimizer_state_dict': optimizer.state_dict(),
8    'dataset': dataset_name,
9    'final_metrics': convert_numpy_types(final_results),
10    'training_history': convert_numpy_types(dict(history)),
11    'best_epoch': best_epoch,
12    'model_config': {
13        'version': 'V7',
14        'input_size': input_size,
15        'sequence_length': 30,
16        'hidden_size': 256,
17        'epochs_trained': epoch + 1,
18        'batch_size': batch_size,
19        'learning_rate': learning_rate,
20        'random_seed': random_seed,
21        'amnl_config': {
22            'loss': 'weighted MSE (linear)',
23            'task_weights': '0.5/0.5',
24            'scheduler': 'ReduceLROnPlateau',
25            'warmup': '10 epochs',
26            'early_stopping': '80 epochs'
27        }
28    },
29    'scaler_params': train_dataset.get_scaler_params()
30}, model_path)
31
32logger.info(f"Model saved to {model_path}")

Best Model Loading at Training End

Always restore the best model before final evaluation and saving. The model at the last epoch is rarely the best due to overfitting or noise.

๐Ÿpython
1# At end of training, restore best model
2if best_model_state:
3    logger.info(f"Loading best model from epoch {best_epoch+1}")
4    model.load_state_dict(best_model_state['model_state_dict'])
5
6    # Also restore EMA if used
7    if ema and 'ema_shadow' in best_model_state:
8        ema.shadow = best_model_state['ema_shadow']
9        ema.apply_shadow(model)  # Apply EMA for final evaluation

Summary

In this section, we covered checkpointing strategy:

  1. What to save: Model, optimizer, scheduler, EMA, scaler params
  2. When to save: Best model, periodic, on interrupt
  3. Checkpoint structure: Complete metadata for recovery
  4. Save/load functions: Reusable checkpoint utilities
  5. Resume training: Continue from any checkpoint
ComponentWhy Essential
model_state_dictLearned weights for inference
optimizer_state_dictMomentum/variance for training
scheduler_state_dictLR schedule continuation
ema_shadowSmoothed weights for stable inference
scaler_paramsData normalization for deployment
training_historyAnalysis and visualization
Looking Ahead: We have all the pieces: training loop, validation, monitoring, and checkpointing. The final section brings everything together in the complete training script walkthroughโ€”the production-ready code that achieves state-of-the-art on NASA C-MAPSS.

With checkpointing complete, we assemble the full training script.