Chapter 9
18 min read
Section 54 of 178

The Training Loop

Training Neural Networks

Learning Objectives

By the end of this section, you will be able to:

  1. Understand the training loop: The fundamental cycle of forward pass → loss computation → backward pass → weight update that drives all neural network learning
  2. Distinguish epochs, batches, and iterations: Master the terminology and understand why we process data in batches rather than all at once
  3. Implement training in PyTorch: Write clean, idiomatic PyTorch training code that handles all the essential components
  4. Monitor training progress: Use loss curves and metrics to understand whether training is progressing well or encountering problems
  5. Apply best practices: Implement proper data splitting, gradient handling, and model evaluation
Why This Matters: The training loop is the heartbeat of deep learning. Every neural network—from simple classifiers to GPT-4—learns through the same fundamental cycle. Mastering this loop gives you the foundation to train any model.

The Big Picture

What Does "Training" Mean?

Training a neural network means finding the values of its parameters (weights and biases) that minimize some measure of error on the training data. This is an optimization problem: we have millions of parameters and we want to find the combination that makes our model perform best.

But unlike traditional optimization problems where we might have a closed-form solution, neural networks have highly non-convex loss landscapes with millions of dimensions. We can't simply "solve" for the optimal parameters—instead, we must iteratively improve them through gradient descent.

The Fundamental Insight

The training loop is based on a simple but powerful idea: if we can compute how much each parameter contributes to the error (the gradient), we can adjust each parameter to reduce that error. Repeat this millions of times, and the network learns.

θt+1=θtηθL(θt)\theta_{t+1} = \theta_t - \eta \cdot \nabla_\theta \mathcal{L}(\theta_t)

Where:

  • θ\theta represents all model parameters (weights and biases)
  • η\eta is the learning rate (how big of a step to take)
  • L\mathcal{L} is the loss function (measures how wrong we are)
  • θL\nabla_\theta \mathcal{L} is the gradient (direction of steepest increase in loss)

The Sign

We subtract the gradient because the gradient points in the direction of steepest increase. To minimize loss, we go in the opposite direction.

The Four Steps of Training

Every training iteration consists of exactly four steps, repeated over and over until the model converges or we run out of patience:

Step 1: Forward Pass

Pass input data through the network to compute predictions. Each layer transforms its input:

xLayer 1h1Layer 2h2y^\mathbf{x} \xrightarrow{\text{Layer 1}} \mathbf{h}_1 \xrightarrow{\text{Layer 2}} \mathbf{h}_2 \xrightarrow{\cdots} \hat{\mathbf{y}}

During the forward pass, PyTorch builds a computational graph that tracks every operation. This graph is essential for the backward pass.

Step 2: Compute Loss

Compare the model's predictions y^\hat{\mathbf{y}} to the true labels y\mathbf{y} using a loss function. Common choices include:

Loss FunctionFormulaUse Case
MSE (L2)½(ŷ - y)²Regression
Cross-Entropy-y·log(ŷ) - (1-y)·log(1-ŷ)Classification
Negative Log-Likelihood-log(p(y|x))Probabilistic models

Step 3: Backward Pass (Backpropagation)

Compute gradients of the loss with respect to every parameter. PyTorch does this automatically using reverse-mode automatic differentiation:

🐍backward.py
1# loss is a scalar tensor
2loss.backward()  # Computes all gradients
3
4# After this, every parameter has a .grad attribute
5print(model.layer1.weight.grad.shape)  # Same shape as weight

The key insight: gradients flow backward through the computational graph, applying the chain rule at each step. This was covered in detail in Chapter 8.

Step 4: Update Weights

Adjust each parameter in the direction that reduces the loss:

🐍update.py
1# Simple gradient descent update
2for param in model.parameters():
3    param.data -= learning_rate * param.grad
4
5# Or more elegantly with an optimizer
6optimizer.step()  # Applies the update rule

Don't Forget to Zero Gradients!

PyTorch accumulates gradients by default. Before each backward pass, you must call optimizer.zero_grad() or gradients will add up across iterations, causing incorrect updates.

Interactive: The Training Loop

Watch the training loop in action. Click through each step or let it play automatically to see how data flows through the network and back:

The Training Loop

Iteration 1

1Forward Pass2Compute Loss3Backward Pass4Update Weights
1

Forward Pass

Input data flows through the network, layer by layer, producing predictions

x → h₁ → h₂ → ... → ŷ

Quick Check

In what order do the four steps of training occur?


Epochs, Batches, and Iterations

Training data is typically too large to process all at once. We divide it into smaller pieces and introduce three key concepts:

Definitions

TermDefinitionExample
BatchA subset of training samples processed together32 images at once
IterationOne forward-backward-update cycle (one batch)Processing one batch
EpochOne complete pass through the entire training setSeeing all 50,000 images once

Why Use Batches?

  1. Memory constraints: Processing 50,000 images at once would require hundreds of GB of GPU memory. Batches make training feasible.
  2. Better gradient estimates: The gradient computed over a batch is a noisy estimate of the true gradient. Some noise is actually helpful—it helps escape local minima (this is called "stochastic gradient descent").
  3. Faster convergence: Updating weights after every batch (rather than after the whole dataset) means more frequent updates and faster learning.
  4. Hardware efficiency: GPUs are optimized for parallel operations on batches of data. Processing 32 samples is barely slower than processing 1.

The Math

Given NN training samples and batch size BB:

Iterations per epoch=NB\text{Iterations per epoch} = \left\lceil \frac{N}{B} \right\rceil
Total iterations=Epochs×NB\text{Total iterations} = \text{Epochs} \times \left\lceil \frac{N}{B} \right\rceil

For example, with 50,000 training samples and batch size 32:

  • Iterations per epoch: ⌈50,000 / 32⌉ = 1,563
  • 10 epochs = 15,630 total iterations (weight updates)

Batches & Epochs Explorer

Visualizing how data is processed during training

1
Current Epoch
1/5
Batch
0
Total Iterations
0%
Epoch Progress

Dataset Samples

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
Current batch
Seen this epoch
Not yet seen

Current Batch (Iteration 1)

Sample 0
Sample 1
Sample 2
Sample 3

4 samples in this batch

Key Relationships

Batches per Epoch:20 / 4⌉ = 5
Iterations per Epoch:5
Samples per Iteration:4
1 Epoch =20 samples seen once

Quick Check

You have 10,000 training samples and use batch size 64. How many iterations are in 5 epochs?


Train, Validation, and Test Splits

Before training, we must carefully divide our data into three disjoint sets. Each serves a distinct purpose in the model development process:

Train / Validation / Test Split

Understanding data partitioning for training

70%
Train
15%
Val
15%
Test
Training
700
samples
Validation
150
samples
Test
150
samples

Training Set

Used to update model weights via backpropagation. The model sees this data multiple times (once per epoch). Larger training sets generally lead to better models.

Validation Set

Used to tune hyperparameters and detect overfitting. Evaluated after each epoch but never used for gradient updates. Helps decide when to stop training.

Test Set

Used only once at the very end to evaluate the final model. Provides an unbiased estimate of real-world performance. Never peek at test data during development!

Common Split Ratios

Why Three Sets?

The fundamental problem we're solving is: how will our model perform on data it has never seen? We can't use training data to answer this because the model has seen it.

SplitPurposeWhen UsedAffects Model?
TrainingUpdate weights via backpropEvery iterationYes (directly)
ValidationTune hyperparameters, detect overfittingAfter each epochYes (indirectly)
TestFinal unbiased evaluationOnce, at the very endNo

Never Peek at Test Data

The test set must remain completely untouched until final evaluation. If you use test performance to make any decisions (choosing architectures, tuning hyperparameters, deciding when to stop), you've contaminated it and your reported performance is overly optimistic.

The Validation Set Dilemma

The validation set helps us detect overfitting: when training loss keeps decreasing but validation loss starts increasing, the model is memorizing training data instead of learning general patterns.

But there's a subtle issue: by choosing the model/hyperparameters that work best on validation data, we're indirectly fitting to the validation set. This is why we need a separate test set for final evaluation.

Cross-Validation

When data is scarce, use k-fold cross-validation: split training data into k parts, train on k-1 parts, validate on the remaining part, and rotate. This gives more reliable validation estimates without needing a larger dataset.

PyTorch Implementation

Let's implement a complete, production-quality training loop in PyTorch. We'll build it step by step, explaining each component:

The Complete Training Loop

Training One Epoch
🐍train.py
12Training Mode

model.train() enables training-specific behaviors: dropout is active, batch norm uses batch statistics. Always call this before training!

EXAMPLE
model.train() vs model.eval()
17Iterate Over Batches

DataLoader automatically batches data, shuffles it, and handles multi-worker loading. Each iteration yields (inputs, targets) tuples.

EXAMPLE
inputs.shape = (batch_size, channels, H, W)
19Move to Device

Tensors must be on the same device as the model. .to(device) moves data to GPU if available, enabling parallel computation.

EXAMPLE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23Forward Pass

Pass inputs through the model to get predictions. PyTorch automatically builds a computational graph for backpropagation.

24Compute Loss

The loss function measures how far predictions are from targets. Common choices: CrossEntropyLoss for classification, MSELoss for regression.

EXAMPLE
loss_fn = nn.CrossEntropyLoss()
27Zero Gradients

CRITICAL: PyTorch accumulates gradients by default. Without this, gradients from previous batches would add up, causing incorrect updates.

28Backward Pass

Compute gradients of loss with respect to all parameters via backpropagation. After this, every param.grad is populated.

31Update Weights

The optimizer applies its update rule (SGD, Adam, etc.) to all parameters using their gradients. This is where learning happens!

34Track Loss

.item() extracts the Python number from a 0-dimensional tensor. We accumulate losses to report average loss per epoch.

27 lines without explanation
1import torch
2import torch.nn as nn
3from torch.utils.data import DataLoader
4
5def train_one_epoch(
6    model: nn.Module,
7    dataloader: DataLoader,
8    loss_fn: nn.Module,
9    optimizer: torch.optim.Optimizer,
10    device: torch.device,
11) -> float:
12    """Train model for one epoch, return average loss."""
13    model.train()  # Set to training mode
14    total_loss = 0.0
15    num_batches = len(dataloader)
16
17    for batch_idx, (inputs, targets) in enumerate(dataloader):
18        # Move data to device (CPU or GPU)
19        inputs = inputs.to(device)
20        targets = targets.to(device)
21
22        # Forward pass
23        outputs = model(inputs)
24        loss = loss_fn(outputs, targets)
25
26        # Backward pass
27        optimizer.zero_grad()  # Clear old gradients
28        loss.backward()        # Compute new gradients
29
30        # Update weights
31        optimizer.step()
32
33        # Accumulate loss for logging
34        total_loss += loss.item()
35
36    return total_loss / num_batches

The Validation Loop

Validation Loop
🐍validate.py
1Disable Gradients

@torch.no_grad() decorator tells PyTorch not to track operations for backprop. This saves memory and speeds up inference.

EXAMPLE
Can also use 'with torch.no_grad():' context manager
9Evaluation Mode

model.eval() disables dropout and uses running statistics for batch normalization instead of batch statistics. Critical for consistent evaluation!

23Get Predictions

For classification, the predicted class is the one with highest output (logit). argmax(dim=1) finds the max along the class dimension.

EXAMPLE
outputs.shape = (batch_size, num_classes)
24Count Correct

Compare predictions to targets element-wise, sum the True values. .item() extracts the Python int for accumulation.

26 lines without explanation
1@torch.no_grad()  # Disable gradient computation
2def validate(
3    model: nn.Module,
4    dataloader: DataLoader,
5    loss_fn: nn.Module,
6    device: torch.device,
7) -> tuple[float, float]:
8    """Evaluate model on validation set, return (loss, accuracy)."""
9    model.eval()  # Set to evaluation mode
10    total_loss = 0.0
11    correct = 0
12    total = 0
13
14    for inputs, targets in dataloader:
15        inputs = inputs.to(device)
16        targets = targets.to(device)
17
18        # Forward pass only (no backward!)
19        outputs = model(inputs)
20        loss = loss_fn(outputs, targets)
21        total_loss += loss.item()
22
23        # Compute accuracy
24        predictions = outputs.argmax(dim=1)
25        correct += (predictions == targets).sum().item()
26        total += targets.size(0)
27
28    avg_loss = total_loss / len(dataloader)
29    accuracy = correct / total
30    return avg_loss, accuracy

Putting It All Together

Complete Training Pipeline
🐍train_full.py
9Device Selection

Automatically use GPU if available. This single line handles both CPU and GPU training without code changes.

10Move Model to Device

The model must be on the same device as the data. .to(device) moves all parameters to the specified device.

13Choose Optimizer

Adam is a good default: it adapts learning rates per-parameter and handles momentum. We'll explore optimizers in detail in the next section.

16Track Best Model

Keep track of the best validation loss seen so far. We save the model when it improves, implementing a form of early stopping.

35Save Best Model

state_dict() contains all learnable parameters. Saving the best model (not the final model) often gives better generalization.

39 lines without explanation
1def train_model(
2    model: nn.Module,
3    train_loader: DataLoader,
4    val_loader: DataLoader,
5    num_epochs: int,
6    learning_rate: float = 0.001,
7) -> dict:
8    """Complete training pipeline."""
9    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10    model = model.to(device)
11
12    loss_fn = nn.CrossEntropyLoss()
13    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
14
15    history = {"train_loss": [], "val_loss": [], "val_accuracy": []}
16    best_val_loss = float("inf")
17
18    for epoch in range(num_epochs):
19        # Training phase
20        train_loss = train_one_epoch(
21            model, train_loader, loss_fn, optimizer, device
22        )
23
24        # Validation phase
25        val_loss, val_accuracy = validate(
26            model, val_loader, loss_fn, device
27        )
28
29        # Record history
30        history["train_loss"].append(train_loss)
31        history["val_loss"].append(val_loss)
32        history["val_accuracy"].append(val_accuracy)
33
34        # Save best model (based on validation loss)
35        if val_loss < best_val_loss:
36            best_val_loss = val_loss
37            torch.save(model.state_dict(), "best_model.pt")
38
39        # Print progress
40        print(f"Epoch {epoch+1}/{num_epochs}")
41        print(f"  Train Loss: {train_loss:.4f}")
42        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2%}")
43
44    return history

Monitoring Training Progress

Watching loss curves is essential for understanding what's happening during training. Different curve patterns indicate different problems:

Healthy Training

  • Both losses decrease: Training is progressing well
  • Gap remains small: Model generalizes well
  • Curves plateau together: Model has converged

Warning Signs

PatternDiagnosisSolution
Val loss increases while train loss decreasesOverfittingMore data, regularization, early stopping
Both losses plateau highUnderfittingBigger model, more training, check data
Loss oscillates wildlyLearning rate too highReduce learning rate
Loss doesn't decrease at allLearning rate too low OR bugIncrease LR, check gradients
Loss becomes NaNNumerical instabilityGradient clipping, smaller LR, check inputs

Key Metrics to Track

  1. Training loss: Should decrease (eventually plateaus)
  2. Validation loss: Should decrease, then may increase (overfitting)
  3. Generalization gap: val_loss - train_loss (should stay small)
  4. Task-specific metrics: Accuracy, F1, BLEU, etc. (what we actually care about)

Interactive: Loss Curves

Experiment with the training simulation below. Adjust the learning rate and observe how it affects both training and validation loss curves:

Loss Curve Visualization

Watch training and validation loss evolve

Higher learning rate = faster initial learning but may cause overfitting

0
Epoch
2.500
Train Loss
2.500
Val Loss
0.000
Gap
Train Loss
Val Loss

Ready to Train

Click "Train" to start the training simulation and observe how loss curves evolve over epochs.

Quick Check

You observe that training loss is 0.05 but validation loss is 0.85. What is this called?


Common Training Patterns

Pattern 1: Train/Eval Mode Toggle

Always switch between model.train() and model.eval():

🐍mode_toggle.py
1# Training phase
2model.train()
3for batch in train_loader:
4    # ... training code ...
5
6# Evaluation phase
7model.eval()
8with torch.no_grad():
9    for batch in val_loader:
10        # ... evaluation code ...

Pattern 2: Gradient Clipping

Prevent exploding gradients by clipping their norm:

🐍grad_clip.py
1# Clip gradients to max norm of 1.0
2loss.backward()
3torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
4optimizer.step()

Pattern 3: Learning Rate Scheduling

Adjust learning rate during training (covered in detail in Section 9.3):

🐍lr_schedule.py
1scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
2
3for epoch in range(num_epochs):
4    train_one_epoch(...)
5    scheduler.step()  # Reduce LR every 30 epochs

Pattern 4: Checkpointing

Save complete training state to resume later:

🐍checkpoint.py
1# Save checkpoint
2torch.save({
3    'epoch': epoch,
4    'model_state_dict': model.state_dict(),
5    'optimizer_state_dict': optimizer.state_dict(),
6    'loss': loss,
7}, 'checkpoint.pt')
8
9# Load checkpoint
10checkpoint = torch.load('checkpoint.pt')
11model.load_state_dict(checkpoint['model_state_dict'])
12optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
13start_epoch = checkpoint['epoch']

Logging Frameworks

For real projects, use logging frameworks like TensorBoard, Weights & Biases, or MLflow instead of print statements. They provide real-time visualization, hyperparameter tracking, and experiment comparison.

See Chapter 8 Section 6

For a deep dive into gradient clipping theory and when it's needed (vanishing/exploding gradients), see Chapter 8 Section 6: Gradient Flow Analysis.

GPU Training

Training on GPUs is essential for any non-trivial deep learning project. GPUs can perform thousands of matrix operations in parallel, making training 10-100x faster than CPUs.

Checking GPU Availability

🐍gpu_check.py
1import torch
2
3# Check if CUDA (NVIDIA GPU) is available
4print(f"CUDA available: {torch.cuda.is_available()}")
5
6if torch.cuda.is_available():
7    print(f"GPU count: {torch.cuda.device_count()}")
8    print(f"GPU name: {torch.cuda.get_device_name(0)}")
9    print(f"Current device: {torch.cuda.current_device()}")
10
11# Check for Apple Silicon (MPS)
12print(f"MPS available: {torch.backends.mps.is_available()}")

Device-Agnostic Code

The best practice is to write code that works on any device (CPU, CUDA, or MPS) without modification:

🐍device_agnostic.py
1import torch
2
3def get_device() -> torch.device:
4    """Get the best available device."""
5    if torch.cuda.is_available():
6        return torch.device("cuda")
7    elif torch.backends.mps.is_available():
8        return torch.device("mps")
9    else:
10        return torch.device("cpu")
11
12device = get_device()
13print(f"Using device: {device}")
14
15# Move model to device
16model = MyModel()
17model = model.to(device)
18
19# Move data to device in training loop
20for inputs, targets in dataloader:
21    inputs = inputs.to(device)
22    targets = targets.to(device)
23
24    outputs = model(inputs)
25    # ... rest of training loop

Data and Model Must Be on Same Device

A common error is forgetting to move data to the GPU. You'll get an error like Expected all tensors to be on the same device. Always ensure both model and data are on the same device.

Memory Management

GPU memory is limited. Here's how to monitor and manage it:

🐍memory.py
1import torch
2
3# Check memory usage (CUDA only)
4if torch.cuda.is_available():
5    print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
6    print(f"Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
7
8# Free unused cached memory
9torch.cuda.empty_cache()
10
11# Monitor peak memory usage
12torch.cuda.reset_peak_memory_stats()
13# ... run training code ...
14print(f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

Common GPU Issues and Solutions

IssueSymptomSolution
Out of memoryCUDA out of memory errorReduce batch size, use gradient checkpointing, use mixed precision
Device mismatchTensors on different devicesEnsure all tensors are on same device with .to(device)
Slow data loadingGPU utilization lowIncrease num_workers in DataLoader, use pin_memory=True
Memory leakMemory grows each epochDetach tensors before storing: loss.item() instead of loss

Optimizing Data Loading

The DataLoader has options specifically for GPU training:

🐍dataloader_gpu.py
1from torch.utils.data import DataLoader
2
3train_loader = DataLoader(
4    train_dataset,
5    batch_size=64,
6    shuffle=True,
7    num_workers=4,      # Parallel data loading
8    pin_memory=True,    # Faster CPU→GPU transfer
9    prefetch_factor=2,  # Prefetch 2 batches per worker
10)
11
12# pin_memory=True allocates data in page-locked memory,
13# enabling faster and asynchronous transfer to GPU

See Chapter 7 Section 1

For comprehensive coverage of DataLoader parameters and optimization, see Chapter 7 Section 1: DataLoader Fundamentals.

Saving and Loading Models

Saving models is essential for resuming training, deploying to production, and sharing with others. PyTorch provides several options depending on your needs.

Saving Model Weights Only

For inference or transfer learning, save just the learned parameters:

🐍save_weights.py
1import torch
2
3# Save model weights
4torch.save(model.state_dict(), "model_weights.pt")
5
6# Load model weights
7model = MyModel()  # Create model architecture first
8model.load_state_dict(torch.load("model_weights.pt"))
9model.eval()  # Set to evaluation mode

Recommended Approach

Saving state_dict() is the recommended approach because it's more portable. The file only contains the weights, not the code, so you can load it into a different (but compatible) model architecture.

Saving Complete Checkpoints

To resume training, you need to save not just the model, but also the optimizer state, epoch number, and any other training state:

🐍checkpoint.py
1import torch
2
3def save_checkpoint(model, optimizer, scheduler, epoch, loss, path):
4    """Save complete training state."""
5    torch.save({
6        "epoch": epoch,
7        "model_state_dict": model.state_dict(),
8        "optimizer_state_dict": optimizer.state_dict(),
9        "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
10        "loss": loss,
11        "best_val_loss": best_val_loss,
12    }, path)
13    print(f"Checkpoint saved: {path}")
14
15def load_checkpoint(path, model, optimizer, scheduler=None):
16    """Load complete training state."""
17    checkpoint = torch.load(path)
18
19    model.load_state_dict(checkpoint["model_state_dict"])
20    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
21
22    if scheduler and checkpoint["scheduler_state_dict"]:
23        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
24
25    epoch = checkpoint["epoch"]
26    loss = checkpoint["loss"]
27
28    print(f"Resumed from epoch {epoch}")
29    return epoch, loss

Best Model Saving Strategy

A common pattern is to save the model whenever validation performance improves:

🐍best_model.py
1import torch
2
3best_val_loss = float("inf")
4
5for epoch in range(num_epochs):
6    train_loss = train_one_epoch(model, train_loader, optimizer)
7    val_loss, val_acc = validate(model, val_loader)
8
9    # Save if validation loss improved
10    if val_loss < best_val_loss:
11        best_val_loss = val_loss
12        torch.save({
13            "epoch": epoch,
14            "model_state_dict": model.state_dict(),
15            "val_loss": val_loss,
16            "val_acc": val_acc,
17        }, "best_model.pt")
18        print(f"New best model saved (val_loss: {val_loss:.4f})")
19
20    # Also save periodic checkpoints
21    if (epoch + 1) % 10 == 0:
22        torch.save({
23            "epoch": epoch,
24            "model_state_dict": model.state_dict(),
25            "optimizer_state_dict": optimizer.state_dict(),
26        }, f"checkpoint_epoch_{epoch+1}.pt")

Loading Models for Inference

When deploying a model, use map_location to handle device differences:

🐍inference.py
1import torch
2
3# Load on CPU (works everywhere)
4model = MyModel()
5model.load_state_dict(
6    torch.load("model_weights.pt", map_location="cpu")
7)
8model.eval()
9
10# Or load to specific device
11device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12model.load_state_dict(
13    torch.load("model_weights.pt", map_location=device)
14)
15model.to(device)
16model.eval()
17
18# Inference
19with torch.no_grad():
20    predictions = model(inputs.to(device))

Security Note

torch.load() uses pickle internally, which can execute arbitrary code. Only load models from trusted sources! For untrusted models, use torch.load(..., weights_only=True) (PyTorch 2.0+).

Common Saving Patterns

Use CaseWhat to SaveFile Extension
Inference onlymodel.state_dict().pt or .pth
Resume trainingFull checkpoint (model + optimizer + epoch).pt or .ckpt
Transfer learningmodel.state_dict().pt
Model comparisonmodel.state_dict() + metrics.pt
Production deploymentTorchScript or ONNX export.pt or .onnx

Summary

The training loop is the core algorithm that enables all neural network learning. Let's review what we've covered:

ConceptKey Point
Training LoopForward → Loss → Backward → Update, repeated for each batch
EpochOne complete pass through the training dataset
BatchSubset of data processed together (trades memory for parallelism)
IterationOne training loop cycle, processing one batch
Training SetUsed to compute gradients and update weights
Validation SetUsed to tune hyperparameters and detect overfitting
Test SetUsed once at the end for unbiased evaluation
Loss CurvesPrimary tool for diagnosing training progress
GPU TrainingUse .to(device) for model and data, pin_memory for faster transfer
CheckpointingSave state_dict() for inference, full checkpoint for resuming training

Essential PyTorch Functions

🐍summary.py
1model.train()           # Enable training mode
2model.eval()            # Enable evaluation mode
3optimizer.zero_grad()   # Clear accumulated gradients
4loss.backward()         # Compute gradients via backprop
5optimizer.step()        # Update parameters
6torch.no_grad()         # Disable gradient tracking
7model.to(device)        # Move model to CPU/GPU

Exercises

Conceptual Questions

  1. Why do we call optimizer.zero_grad() before loss.backward()? What would happen if we forgot this call for several batches?
  2. Explain the difference between model.train() and model.eval(). Which layers behave differently in each mode?
  3. You have 45,000 training samples and use batch size 128. How many iterations occur in 3 epochs? How many weight updates?
  4. Why is it problematic to use test set performance to decide when to stop training?

Solution Hints

  1. Q1: Gradients accumulate. After N batches, each gradient would be N times too large, causing unstable training.
  2. Q2: Dropout is active in train mode, disabled in eval. BatchNorm uses batch statistics in train, running statistics in eval.
  3. Q3: Iterations per epoch = ⌈45,000/128⌉ = 352. Total = 3 × 352 = 1,056 iterations = 1,056 weight updates.
  4. Q4: This "leaks" test information into model selection, making reported performance overly optimistic.

Coding Exercises

  1. Implement early stopping: Modify the training loop to stop training if validation loss doesn't improve for N epochs (the "patience"). Return the best model, not the final model.
  2. Add a progress bar: Use tqdm to add a progress bar that shows batch progress within each epoch and displays current loss.
  3. Implement gradient accumulation: For large effective batch sizes on limited GPU memory, accumulate gradients over K mini-batches before calling optimizer.step().
  4. Build a training dashboard: Use matplotlib to create a live-updating plot showing training and validation loss after each epoch.

Coding Exercise Hints

  • Exercise 1: Track epochs_without_improvement. Reset when val_loss improves, increment otherwise. Stop when it exceeds patience.
  • Exercise 2: Wrap dataloader with tqdm(dataloader, desc=f"Epoch {epoch}"). Use pbar.set_postfix(loss=loss.item()).
  • Exercise 3: Only call zero_grad() every K batches. Only call step() every K batches. Divide loss by K for proper averaging.
  • Exercise 4: Use plt.ion() for interactive mode. Call plt.pause(0.01) after each update.

In the next section, we'll explore optimizers—the algorithms that determine exactly how weights are updated based on gradients. You'll learn why Adam often works better than vanilla SGD and when to use each optimizer.