Learning Objectives
By the end of this section, you will be able to:
- Understand the training loop: The fundamental cycle of forward pass → loss computation → backward pass → weight update that drives all neural network learning
- Distinguish epochs, batches, and iterations: Master the terminology and understand why we process data in batches rather than all at once
- Implement training in PyTorch: Write clean, idiomatic PyTorch training code that handles all the essential components
- Monitor training progress: Use loss curves and metrics to understand whether training is progressing well or encountering problems
- Apply best practices: Implement proper data splitting, gradient handling, and model evaluation
Why This Matters: The training loop is the heartbeat of deep learning. Every neural network—from simple classifiers to GPT-4—learns through the same fundamental cycle. Mastering this loop gives you the foundation to train any model.
The Big Picture
What Does "Training" Mean?
Training a neural network means finding the values of its parameters (weights and biases) that minimize some measure of error on the training data. This is an optimization problem: we have millions of parameters and we want to find the combination that makes our model perform best.
But unlike traditional optimization problems where we might have a closed-form solution, neural networks have highly non-convex loss landscapes with millions of dimensions. We can't simply "solve" for the optimal parameters—instead, we must iteratively improve them through gradient descent.
The Fundamental Insight
The training loop is based on a simple but powerful idea: if we can compute how much each parameter contributes to the error (the gradient), we can adjust each parameter to reduce that error. Repeat this millions of times, and the network learns.
Where:
- represents all model parameters (weights and biases)
- is the learning rate (how big of a step to take)
- is the loss function (measures how wrong we are)
- is the gradient (direction of steepest increase in loss)
The Sign
The Four Steps of Training
Every training iteration consists of exactly four steps, repeated over and over until the model converges or we run out of patience:
Step 1: Forward Pass
Pass input data through the network to compute predictions. Each layer transforms its input:
During the forward pass, PyTorch builds a computational graph that tracks every operation. This graph is essential for the backward pass.
Step 2: Compute Loss
Compare the model's predictions to the true labels using a loss function. Common choices include:
| Loss Function | Formula | Use Case |
|---|---|---|
| MSE (L2) | ½(ŷ - y)² | Regression |
| Cross-Entropy | -y·log(ŷ) - (1-y)·log(1-ŷ) | Classification |
| Negative Log-Likelihood | -log(p(y|x)) | Probabilistic models |
Step 3: Backward Pass (Backpropagation)
Compute gradients of the loss with respect to every parameter. PyTorch does this automatically using reverse-mode automatic differentiation:
1# loss is a scalar tensor
2loss.backward() # Computes all gradients
3
4# After this, every parameter has a .grad attribute
5print(model.layer1.weight.grad.shape) # Same shape as weightThe key insight: gradients flow backward through the computational graph, applying the chain rule at each step. This was covered in detail in Chapter 8.
Step 4: Update Weights
Adjust each parameter in the direction that reduces the loss:
1# Simple gradient descent update
2for param in model.parameters():
3 param.data -= learning_rate * param.grad
4
5# Or more elegantly with an optimizer
6optimizer.step() # Applies the update ruleDon't Forget to Zero Gradients!
optimizer.zero_grad() or gradients will add up across iterations, causing incorrect updates.Interactive: The Training Loop
Watch the training loop in action. Click through each step or let it play automatically to see how data flows through the network and back:
The Training Loop
Iteration 1
Forward Pass
Input data flows through the network, layer by layer, producing predictions
x → h₁ → h₂ → ... → ŷQuick Check
In what order do the four steps of training occur?
Epochs, Batches, and Iterations
Training data is typically too large to process all at once. We divide it into smaller pieces and introduce three key concepts:
Definitions
| Term | Definition | Example |
|---|---|---|
| Batch | A subset of training samples processed together | 32 images at once |
| Iteration | One forward-backward-update cycle (one batch) | Processing one batch |
| Epoch | One complete pass through the entire training set | Seeing all 50,000 images once |
Why Use Batches?
- Memory constraints: Processing 50,000 images at once would require hundreds of GB of GPU memory. Batches make training feasible.
- Better gradient estimates: The gradient computed over a batch is a noisy estimate of the true gradient. Some noise is actually helpful—it helps escape local minima (this is called "stochastic gradient descent").
- Faster convergence: Updating weights after every batch (rather than after the whole dataset) means more frequent updates and faster learning.
- Hardware efficiency: GPUs are optimized for parallel operations on batches of data. Processing 32 samples is barely slower than processing 1.
The Math
Given training samples and batch size :
For example, with 50,000 training samples and batch size 32:
- Iterations per epoch: ⌈50,000 / 32⌉ = 1,563
- 10 epochs = 15,630 total iterations (weight updates)
Batches & Epochs Explorer
Visualizing how data is processed during training
Dataset Samples
Current Batch (Iteration 1)
4 samples in this batch
Key Relationships
Quick Check
You have 10,000 training samples and use batch size 64. How many iterations are in 5 epochs?
Train, Validation, and Test Splits
Before training, we must carefully divide our data into three disjoint sets. Each serves a distinct purpose in the model development process:
Train / Validation / Test Split
Understanding data partitioning for training
Training Set
Used to update model weights via backpropagation. The model sees this data multiple times (once per epoch). Larger training sets generally lead to better models.
Validation Set
Used to tune hyperparameters and detect overfitting. Evaluated after each epoch but never used for gradient updates. Helps decide when to stop training.
Test Set
Used only once at the very end to evaluate the final model. Provides an unbiased estimate of real-world performance. Never peek at test data during development!
Common Split Ratios
Why Three Sets?
The fundamental problem we're solving is: how will our model perform on data it has never seen? We can't use training data to answer this because the model has seen it.
| Split | Purpose | When Used | Affects Model? |
|---|---|---|---|
| Training | Update weights via backprop | Every iteration | Yes (directly) |
| Validation | Tune hyperparameters, detect overfitting | After each epoch | Yes (indirectly) |
| Test | Final unbiased evaluation | Once, at the very end | No |
Never Peek at Test Data
The Validation Set Dilemma
The validation set helps us detect overfitting: when training loss keeps decreasing but validation loss starts increasing, the model is memorizing training data instead of learning general patterns.
But there's a subtle issue: by choosing the model/hyperparameters that work best on validation data, we're indirectly fitting to the validation set. This is why we need a separate test set for final evaluation.
Cross-Validation
PyTorch Implementation
Let's implement a complete, production-quality training loop in PyTorch. We'll build it step by step, explaining each component:
The Complete Training Loop
The Validation Loop
Putting It All Together
Monitoring Training Progress
Watching loss curves is essential for understanding what's happening during training. Different curve patterns indicate different problems:
Healthy Training
- Both losses decrease: Training is progressing well
- Gap remains small: Model generalizes well
- Curves plateau together: Model has converged
Warning Signs
| Pattern | Diagnosis | Solution |
|---|---|---|
| Val loss increases while train loss decreases | Overfitting | More data, regularization, early stopping |
| Both losses plateau high | Underfitting | Bigger model, more training, check data |
| Loss oscillates wildly | Learning rate too high | Reduce learning rate |
| Loss doesn't decrease at all | Learning rate too low OR bug | Increase LR, check gradients |
| Loss becomes NaN | Numerical instability | Gradient clipping, smaller LR, check inputs |
Key Metrics to Track
- Training loss: Should decrease (eventually plateaus)
- Validation loss: Should decrease, then may increase (overfitting)
- Generalization gap: val_loss - train_loss (should stay small)
- Task-specific metrics: Accuracy, F1, BLEU, etc. (what we actually care about)
Interactive: Loss Curves
Experiment with the training simulation below. Adjust the learning rate and observe how it affects both training and validation loss curves:
Loss Curve Visualization
Watch training and validation loss evolve
Higher learning rate = faster initial learning but may cause overfitting
Ready to Train
Click "Train" to start the training simulation and observe how loss curves evolve over epochs.
Quick Check
You observe that training loss is 0.05 but validation loss is 0.85. What is this called?
Common Training Patterns
Pattern 1: Train/Eval Mode Toggle
Always switch between model.train() and model.eval():
1# Training phase
2model.train()
3for batch in train_loader:
4 # ... training code ...
5
6# Evaluation phase
7model.eval()
8with torch.no_grad():
9 for batch in val_loader:
10 # ... evaluation code ...Pattern 2: Gradient Clipping
Prevent exploding gradients by clipping their norm:
1# Clip gradients to max norm of 1.0
2loss.backward()
3torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
4optimizer.step()Pattern 3: Learning Rate Scheduling
Adjust learning rate during training (covered in detail in Section 9.3):
1scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
2
3for epoch in range(num_epochs):
4 train_one_epoch(...)
5 scheduler.step() # Reduce LR every 30 epochsPattern 4: Checkpointing
Save complete training state to resume later:
1# Save checkpoint
2torch.save({
3 'epoch': epoch,
4 'model_state_dict': model.state_dict(),
5 'optimizer_state_dict': optimizer.state_dict(),
6 'loss': loss,
7}, 'checkpoint.pt')
8
9# Load checkpoint
10checkpoint = torch.load('checkpoint.pt')
11model.load_state_dict(checkpoint['model_state_dict'])
12optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
13start_epoch = checkpoint['epoch']Logging Frameworks
See Chapter 8 Section 6
GPU Training
Training on GPUs is essential for any non-trivial deep learning project. GPUs can perform thousands of matrix operations in parallel, making training 10-100x faster than CPUs.
Checking GPU Availability
1import torch
2
3# Check if CUDA (NVIDIA GPU) is available
4print(f"CUDA available: {torch.cuda.is_available()}")
5
6if torch.cuda.is_available():
7 print(f"GPU count: {torch.cuda.device_count()}")
8 print(f"GPU name: {torch.cuda.get_device_name(0)}")
9 print(f"Current device: {torch.cuda.current_device()}")
10
11# Check for Apple Silicon (MPS)
12print(f"MPS available: {torch.backends.mps.is_available()}")Device-Agnostic Code
The best practice is to write code that works on any device (CPU, CUDA, or MPS) without modification:
1import torch
2
3def get_device() -> torch.device:
4 """Get the best available device."""
5 if torch.cuda.is_available():
6 return torch.device("cuda")
7 elif torch.backends.mps.is_available():
8 return torch.device("mps")
9 else:
10 return torch.device("cpu")
11
12device = get_device()
13print(f"Using device: {device}")
14
15# Move model to device
16model = MyModel()
17model = model.to(device)
18
19# Move data to device in training loop
20for inputs, targets in dataloader:
21 inputs = inputs.to(device)
22 targets = targets.to(device)
23
24 outputs = model(inputs)
25 # ... rest of training loopData and Model Must Be on Same Device
Expected all tensors to be on the same device. Always ensure both model and data are on the same device.Memory Management
GPU memory is limited. Here's how to monitor and manage it:
1import torch
2
3# Check memory usage (CUDA only)
4if torch.cuda.is_available():
5 print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
6 print(f"Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
7
8# Free unused cached memory
9torch.cuda.empty_cache()
10
11# Monitor peak memory usage
12torch.cuda.reset_peak_memory_stats()
13# ... run training code ...
14print(f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")Common GPU Issues and Solutions
| Issue | Symptom | Solution |
|---|---|---|
| Out of memory | CUDA out of memory error | Reduce batch size, use gradient checkpointing, use mixed precision |
| Device mismatch | Tensors on different devices | Ensure all tensors are on same device with .to(device) |
| Slow data loading | GPU utilization low | Increase num_workers in DataLoader, use pin_memory=True |
| Memory leak | Memory grows each epoch | Detach tensors before storing: loss.item() instead of loss |
Optimizing Data Loading
The DataLoader has options specifically for GPU training:
1from torch.utils.data import DataLoader
2
3train_loader = DataLoader(
4 train_dataset,
5 batch_size=64,
6 shuffle=True,
7 num_workers=4, # Parallel data loading
8 pin_memory=True, # Faster CPU→GPU transfer
9 prefetch_factor=2, # Prefetch 2 batches per worker
10)
11
12# pin_memory=True allocates data in page-locked memory,
13# enabling faster and asynchronous transfer to GPUSee Chapter 7 Section 1
Saving and Loading Models
Saving models is essential for resuming training, deploying to production, and sharing with others. PyTorch provides several options depending on your needs.
Saving Model Weights Only
For inference or transfer learning, save just the learned parameters:
1import torch
2
3# Save model weights
4torch.save(model.state_dict(), "model_weights.pt")
5
6# Load model weights
7model = MyModel() # Create model architecture first
8model.load_state_dict(torch.load("model_weights.pt"))
9model.eval() # Set to evaluation modeRecommended Approach
state_dict() is the recommended approach because it's more portable. The file only contains the weights, not the code, so you can load it into a different (but compatible) model architecture.Saving Complete Checkpoints
To resume training, you need to save not just the model, but also the optimizer state, epoch number, and any other training state:
1import torch
2
3def save_checkpoint(model, optimizer, scheduler, epoch, loss, path):
4 """Save complete training state."""
5 torch.save({
6 "epoch": epoch,
7 "model_state_dict": model.state_dict(),
8 "optimizer_state_dict": optimizer.state_dict(),
9 "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
10 "loss": loss,
11 "best_val_loss": best_val_loss,
12 }, path)
13 print(f"Checkpoint saved: {path}")
14
15def load_checkpoint(path, model, optimizer, scheduler=None):
16 """Load complete training state."""
17 checkpoint = torch.load(path)
18
19 model.load_state_dict(checkpoint["model_state_dict"])
20 optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
21
22 if scheduler and checkpoint["scheduler_state_dict"]:
23 scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
24
25 epoch = checkpoint["epoch"]
26 loss = checkpoint["loss"]
27
28 print(f"Resumed from epoch {epoch}")
29 return epoch, lossBest Model Saving Strategy
A common pattern is to save the model whenever validation performance improves:
1import torch
2
3best_val_loss = float("inf")
4
5for epoch in range(num_epochs):
6 train_loss = train_one_epoch(model, train_loader, optimizer)
7 val_loss, val_acc = validate(model, val_loader)
8
9 # Save if validation loss improved
10 if val_loss < best_val_loss:
11 best_val_loss = val_loss
12 torch.save({
13 "epoch": epoch,
14 "model_state_dict": model.state_dict(),
15 "val_loss": val_loss,
16 "val_acc": val_acc,
17 }, "best_model.pt")
18 print(f"New best model saved (val_loss: {val_loss:.4f})")
19
20 # Also save periodic checkpoints
21 if (epoch + 1) % 10 == 0:
22 torch.save({
23 "epoch": epoch,
24 "model_state_dict": model.state_dict(),
25 "optimizer_state_dict": optimizer.state_dict(),
26 }, f"checkpoint_epoch_{epoch+1}.pt")Loading Models for Inference
When deploying a model, use map_location to handle device differences:
1import torch
2
3# Load on CPU (works everywhere)
4model = MyModel()
5model.load_state_dict(
6 torch.load("model_weights.pt", map_location="cpu")
7)
8model.eval()
9
10# Or load to specific device
11device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12model.load_state_dict(
13 torch.load("model_weights.pt", map_location=device)
14)
15model.to(device)
16model.eval()
17
18# Inference
19with torch.no_grad():
20 predictions = model(inputs.to(device))Security Note
torch.load() uses pickle internally, which can execute arbitrary code. Only load models from trusted sources! For untrusted models, use torch.load(..., weights_only=True) (PyTorch 2.0+).Common Saving Patterns
| Use Case | What to Save | File Extension |
|---|---|---|
| Inference only | model.state_dict() | .pt or .pth |
| Resume training | Full checkpoint (model + optimizer + epoch) | .pt or .ckpt |
| Transfer learning | model.state_dict() | .pt |
| Model comparison | model.state_dict() + metrics | .pt |
| Production deployment | TorchScript or ONNX export | .pt or .onnx |
Summary
The training loop is the core algorithm that enables all neural network learning. Let's review what we've covered:
| Concept | Key Point |
|---|---|
| Training Loop | Forward → Loss → Backward → Update, repeated for each batch |
| Epoch | One complete pass through the training dataset |
| Batch | Subset of data processed together (trades memory for parallelism) |
| Iteration | One training loop cycle, processing one batch |
| Training Set | Used to compute gradients and update weights |
| Validation Set | Used to tune hyperparameters and detect overfitting |
| Test Set | Used once at the end for unbiased evaluation |
| Loss Curves | Primary tool for diagnosing training progress |
| GPU Training | Use .to(device) for model and data, pin_memory for faster transfer |
| Checkpointing | Save state_dict() for inference, full checkpoint for resuming training |
Essential PyTorch Functions
1model.train() # Enable training mode
2model.eval() # Enable evaluation mode
3optimizer.zero_grad() # Clear accumulated gradients
4loss.backward() # Compute gradients via backprop
5optimizer.step() # Update parameters
6torch.no_grad() # Disable gradient tracking
7model.to(device) # Move model to CPU/GPUExercises
Conceptual Questions
- Why do we call
optimizer.zero_grad()beforeloss.backward()? What would happen if we forgot this call for several batches? - Explain the difference between
model.train()andmodel.eval(). Which layers behave differently in each mode? - You have 45,000 training samples and use batch size 128. How many iterations occur in 3 epochs? How many weight updates?
- Why is it problematic to use test set performance to decide when to stop training?
Solution Hints
- Q1: Gradients accumulate. After N batches, each gradient would be N times too large, causing unstable training.
- Q2: Dropout is active in train mode, disabled in eval. BatchNorm uses batch statistics in train, running statistics in eval.
- Q3: Iterations per epoch = ⌈45,000/128⌉ = 352. Total = 3 × 352 = 1,056 iterations = 1,056 weight updates.
- Q4: This "leaks" test information into model selection, making reported performance overly optimistic.
Coding Exercises
- Implement early stopping: Modify the training loop to stop training if validation loss doesn't improve for N epochs (the "patience"). Return the best model, not the final model.
- Add a progress bar: Use
tqdmto add a progress bar that shows batch progress within each epoch and displays current loss. - Implement gradient accumulation: For large effective batch sizes on limited GPU memory, accumulate gradients over K mini-batches before calling
optimizer.step(). - Build a training dashboard: Use matplotlib to create a live-updating plot showing training and validation loss after each epoch.
Coding Exercise Hints
- Exercise 1: Track epochs_without_improvement. Reset when val_loss improves, increment otherwise. Stop when it exceeds patience.
- Exercise 2: Wrap dataloader with
tqdm(dataloader, desc=f"Epoch {epoch}"). Usepbar.set_postfix(loss=loss.item()). - Exercise 3: Only call zero_grad() every K batches. Only call step() every K batches. Divide loss by K for proper averaging.
- Exercise 4: Use
plt.ion()for interactive mode. Callplt.pause(0.01)after each update.
In the next section, we'll explore optimizers—the algorithms that determine exactly how weights are updated based on gradients. You'll learn why Adam often works better than vanilla SGD and when to use each optimizer.