Learning Objectives
By the end of this section, you will:
- Understand mixed precision training and its benefits
- Compare FP16 and FP32 number representations
- Explain gradient scaling to prevent underflow
- Implement mixed precision with PyTorch AMP
- Know when to use and when to avoid mixed precision
Why This Matters: Mixed precision training uses 16-bit floating point (FP16) for most operations, which can speed up training by 2-3Γ and reduce memory usage by 50%. Modern GPUs have dedicated Tensor Cores that accelerate FP16 computation, making this a "free" performance boost when done correctly.
What is Mixed Precision?
Mixed precision training uses both 16-bit and 32-bit floating point numbers strategically.
The Core Idea
Different parts of training have different precision requirements:
| Operation | Precision | Reason |
|---|---|---|
| Forward pass | FP16 | Speed, memory savings |
| Backward pass | FP16 | Speed, memory savings |
| Loss computation | FP32 | Numerical accuracy |
| Weight updates | FP32 | Prevent drift |
| Master weights | FP32 | Accumulation accuracy |
Why "Mixed"?
We use FP16 where possible for speed, but maintain FP32 master copies of weights. This gives us the best of both worlds:
1Mixed Precision Training Flow:
2
3βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
4β FP32 Master Weights β
5β (stored) β
6ββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββ
7 β Cast to FP16
8 βΌ
9βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
10β FP16 Working Weights β
11β (for forward/backward) β
12ββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββ
13 β
14 βββββββββββββββββ΄ββββββββββββββββ
15 βΌ βΌ
16 βββββββββββ βββββββββββ
17 β Forward β βBackward β
18 β Pass βββββ FP32 Loss βββββ Pass β
19 β (FP16) β β (FP16) β
20 βββββββββββ ββββββ¬βββββ
21 β FP16 gradients
22 β (scaled)
23 βΌ
24 ββββββββββββββββββββ
25 β Unscale to FP32 β
26 β Update master β
27 β weights β
28 ββββββββββββββββββββHardware Requirement
Mixed precision provides the most benefit on GPUs with Tensor Cores (NVIDIA Volta, Turing, Ampere, Ada). On older GPUs or CPUs, the benefits are minimal or nonexistent.
FP16 vs FP32 Arithmetic
Understanding the difference between precision formats is crucial.
Number Representation
| Property | FP32 | FP16 |
|---|---|---|
| Total bits | 32 | 16 |
| Sign bit | 1 | 1 |
| Exponent bits | 8 | 5 |
| Mantissa bits | 23 | 10 |
| Max value | ~3.4 Γ 10Β³βΈ | ~65,504 |
| Min positive | ~1.2 Γ 10β»Β³βΈ | ~6 Γ 10β»βΈ |
| Precision | ~7 decimal digits | ~3 decimal digits |
Precision Limitations
Why FP16 is Faster
| Aspect | FP32 | FP16 |
|---|---|---|
| Memory bandwidth | 1Γ | 2Γ (half the bytes) |
| Tensor Core ops | 1Γ | 8-16Γ faster |
| Model memory | 100% | 50% |
| Batch size limit | 1Γ | 2Γ (larger batches possible) |
Gradient Scaling
Gradient scaling prevents FP16 gradients from underflowing to zero.
The Underflow Problem
Small gradients (e.g., 10β»βΈ) underflow to zero in FP16. Once zero, they contribute nothing to weight updatesβthe model stops learning.
Solution: Scale Up Before Backward
Multiply the loss by a large scale factor before backpropagation. This scales all gradients proportionally, moving them into FP16's representable range:
Then unscale before applying updates:
Dynamic Loss Scaling
PyTorch's GradScaler automatically adjusts the scale factor:
| Situation | Action | Effect |
|---|---|---|
| Gradients underflow | Increase scale | Prevent future underflow |
| Gradients overflow (inf/NaN) | Decrease scale, skip step | Recover from overflow |
| Gradients normal | Gradually increase scale | Push towards optimal |
Default Starting Scale
PyTorch starts with scale = 2ΒΉβΆ = 65536 and adjusts dynamically. You rarely need to tune thisβthe automatic scaling handles most cases.
Implementation
PyTorch Automatic Mixed Precision (AMP) implementation.
Basic Mixed Precision Training
1import torch
2from torch.cuda.amp import autocast, GradScaler
3
4def train_mixed_precision(
5 model: nn.Module,
6 train_loader: DataLoader,
7 optimizer: torch.optim.Optimizer,
8 criterion: nn.Module,
9 epochs: int,
10 device: torch.device = torch.device('cuda')
11) -> dict:
12 """
13 Training with automatic mixed precision.
14
15 Uses FP16 for forward/backward, FP32 for weight updates.
16 """
17 # Initialize gradient scaler
18 scaler = GradScaler()
19
20 history = {'train_loss': []}
21
22 for epoch in range(epochs):
23 model.train()
24 epoch_loss = 0.0
25
26 for batch in train_loader:
27 x, y = batch
28 x, y = x.to(device), y.to(device)
29
30 optimizer.zero_grad()
31
32 # Forward pass with autocast (FP16)
33 with autocast():
34 pred = model(x)
35 loss = criterion(pred, y)
36
37 # Backward pass with scaled gradients
38 scaler.scale(loss).backward()
39
40 # Unscale gradients and update weights
41 scaler.step(optimizer)
42
43 # Update scale factor for next iteration
44 scaler.update()
45
46 epoch_loss += loss.item()
47
48 history['train_loss'].append(epoch_loss / len(train_loader))
49
50 return historyWith Gradient Clipping
When using gradient clipping with mixed precision, unscale before clipping:
1def train_mixed_precision_with_clipping(
2 model: nn.Module,
3 train_loader: DataLoader,
4 optimizer: torch.optim.Optimizer,
5 criterion: nn.Module,
6 max_grad_norm: float = 1.0,
7 device: torch.device = torch.device('cuda')
8):
9 """
10 Mixed precision training with gradient clipping.
11
12 The key is to unscale gradients BEFORE clipping.
13 """
14 scaler = GradScaler()
15
16 for batch in train_loader:
17 x, y = batch
18 x, y = x.to(device), y.to(device)
19
20 optimizer.zero_grad()
21
22 # Forward with autocast
23 with autocast():
24 pred = model(x)
25 loss = criterion(pred, y)
26
27 # Backward with scaled gradients
28 scaler.scale(loss).backward()
29
30 # Unscale gradients BEFORE clipping
31 scaler.unscale_(optimizer)
32
33 # Now clip in FP32 space
34 grad_norm = torch.nn.utils.clip_grad_norm_(
35 model.parameters(),
36 max_norm=max_grad_norm
37 )
38
39 # Step (skips if gradients contain inf/NaN)
40 scaler.step(optimizer)
41
42 # Update scale
43 scaler.update()Complete Training Loop with All Enhancements
1def train_epoch_complete(
2 model: nn.Module,
3 train_loader: DataLoader,
4 optimizer: torch.optim.Optimizer,
5 criterion: nn.Module,
6 scaler: GradScaler,
7 ema: ExponentialMovingAverage,
8 max_grad_norm: float = 1.0,
9 accumulation_steps: int = 2,
10 device: torch.device = torch.device('cuda')
11) -> dict:
12 """
13 Complete training epoch with all enhancements:
14 - Mixed precision (autocast + GradScaler)
15 - Gradient clipping
16 - Gradient accumulation
17 - EMA weight tracking
18 """
19 model.train()
20 total_loss = 0.0
21 num_steps = 0
22
23 for batch_idx, (x, y) in enumerate(train_loader):
24 x, y = x.to(device), y.to(device)
25
26 # Zero gradients at start of accumulation cycle
27 if batch_idx % accumulation_steps == 0:
28 optimizer.zero_grad()
29
30 # Forward pass with mixed precision
31 with autocast():
32 pred = model(x)
33 loss = criterion(pred, y)
34 # Scale loss for accumulation
35 loss = loss / accumulation_steps
36
37 # Backward with scaled gradients
38 scaler.scale(loss).backward()
39
40 # Update weights after accumulation cycle
41 if (batch_idx + 1) % accumulation_steps == 0:
42 # Unscale gradients before clipping
43 scaler.unscale_(optimizer)
44
45 # Clip gradients in FP32 space
46 grad_norm = torch.nn.utils.clip_grad_norm_(
47 model.parameters(),
48 max_norm=max_grad_norm
49 )
50
51 # Optimizer step (skipped if inf/NaN gradients)
52 scaler.step(optimizer)
53 scaler.update()
54
55 # Update EMA weights
56 ema.update(model)
57
58 num_steps += 1
59
60 total_loss += loss.item() * accumulation_steps
61
62 return {
63 'loss': total_loss / len(train_loader),
64 'steps': num_steps
65 }MPS (Apple Silicon) Limitation
Mixed precision with GradScaler is designed for NVIDIA CUDA GPUs. On Apple Silicon (MPS), use autocast only without GradScaler, as MPS handles precision differently.
Summary
In this section, we covered mixed precision training:
- Mixed precision: Use FP16 for speed, FP32 for accuracy
- FP16 limitations: Range ~10β»βΈ to 65504, 3 decimal digits
- Gradient scaling: Multiply loss before backward, divide before update
- autocast: Automatic FP16 for forward/backward
- GradScaler: Dynamic scaling to prevent underflow/overflow
| Component | Purpose |
|---|---|
| autocast() | Run forward/backward in FP16 |
| GradScaler | Dynamic loss scaling |
| scaler.scale(loss) | Apply scale before backward |
| scaler.unscale_() | Unscale before clipping |
| scaler.step() | Update weights (skip if inf/NaN) |
| scaler.update() | Adjust scale for next iteration |
Looking Ahead: Mixed precision lets us use larger batch sizes. But sometimes we need even larger effective batches than GPU memory allows. The next section covers gradient accumulationβa technique for simulating large batches without extra memory.
With mixed precision understood, we explore gradient accumulation.