Chapter 13
15 min read
Section 65 of 104

Mixed Precision Training

Training Enhancements

Learning Objectives

By the end of this section, you will:

  1. Understand mixed precision training and its benefits
  2. Compare FP16 and FP32 number representations
  3. Explain gradient scaling to prevent underflow
  4. Implement mixed precision with PyTorch AMP
  5. Know when to use and when to avoid mixed precision
Why This Matters: Mixed precision training uses 16-bit floating point (FP16) for most operations, which can speed up training by 2-3Γ— and reduce memory usage by 50%. Modern GPUs have dedicated Tensor Cores that accelerate FP16 computation, making this a "free" performance boost when done correctly.

What is Mixed Precision?

Mixed precision training uses both 16-bit and 32-bit floating point numbers strategically.

The Core Idea

Different parts of training have different precision requirements:

OperationPrecisionReason
Forward passFP16Speed, memory savings
Backward passFP16Speed, memory savings
Loss computationFP32Numerical accuracy
Weight updatesFP32Prevent drift
Master weightsFP32Accumulation accuracy

Why "Mixed"?

We use FP16 where possible for speed, but maintain FP32 master copies of weights. This gives us the best of both worlds:

πŸ“text
1Mixed Precision Training Flow:
2
3β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4β”‚                    FP32 Master Weights                   β”‚
5β”‚                         (stored)                         β”‚
6β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
7                         β”‚ Cast to FP16
8                         β–Ό
9β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
10β”‚              FP16 Working Weights                        β”‚
11β”‚                (for forward/backward)                    β”‚
12β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
13                         β”‚
14         β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
15         β–Ό                               β–Ό
16    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”                    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
17    β”‚ Forward β”‚                    β”‚Backward β”‚
18    β”‚  Pass   │───→ FP32 Loss ───→│  Pass   β”‚
19    β”‚ (FP16)  β”‚                    β”‚ (FP16)  β”‚
20    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                    β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜
21                                        β”‚ FP16 gradients
22                                        β”‚ (scaled)
23                                        β–Ό
24                              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
25                              β”‚  Unscale to FP32 β”‚
26                              β”‚  Update master   β”‚
27                              β”‚    weights       β”‚
28                              β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Hardware Requirement

Mixed precision provides the most benefit on GPUs with Tensor Cores (NVIDIA Volta, Turing, Ampere, Ada). On older GPUs or CPUs, the benefits are minimal or nonexistent.


FP16 vs FP32 Arithmetic

Understanding the difference between precision formats is crucial.

Number Representation

PropertyFP32FP16
Total bits3216
Sign bit11
Exponent bits85
Mantissa bits2310
Max value~3.4 Γ— 10³⁸~65,504
Min positive~1.2 Γ— 10⁻³⁸~6 Γ— 10⁻⁸
Precision~7 decimal digits~3 decimal digits

Precision Limitations

Why FP16 is Faster

AspectFP32FP16
Memory bandwidth1Γ—2Γ— (half the bytes)
Tensor Core ops1Γ—8-16Γ— faster
Model memory100%50%
Batch size limit1Γ—2Γ— (larger batches possible)

Gradient Scaling

Gradient scaling prevents FP16 gradients from underflowing to zero.

The Underflow Problem

Small gradients (e.g., 10⁻⁸) underflow to zero in FP16. Once zero, they contribute nothing to weight updatesβ€”the model stops learning.

Solution: Scale Up Before Backward

Multiply the loss by a large scale factor before backpropagation. This scales all gradients proportionally, moving them into FP16's representable range:

scaled_loss=lossΓ—S\text{scaled\_loss} = \text{loss} \times S

Then unscale before applying updates:

true_gradient=scaled_gradientS\text{true\_gradient} = \frac{\text{scaled\_gradient}}{S}

Dynamic Loss Scaling

PyTorch's GradScaler automatically adjusts the scale factor:

SituationActionEffect
Gradients underflowIncrease scalePrevent future underflow
Gradients overflow (inf/NaN)Decrease scale, skip stepRecover from overflow
Gradients normalGradually increase scalePush towards optimal

Default Starting Scale

PyTorch starts with scale = 2¹⁢ = 65536 and adjusts dynamically. You rarely need to tune thisβ€”the automatic scaling handles most cases.


Implementation

PyTorch Automatic Mixed Precision (AMP) implementation.

Basic Mixed Precision Training

🐍python
1import torch
2from torch.cuda.amp import autocast, GradScaler
3
4def train_mixed_precision(
5    model: nn.Module,
6    train_loader: DataLoader,
7    optimizer: torch.optim.Optimizer,
8    criterion: nn.Module,
9    epochs: int,
10    device: torch.device = torch.device('cuda')
11) -> dict:
12    """
13    Training with automatic mixed precision.
14
15    Uses FP16 for forward/backward, FP32 for weight updates.
16    """
17    # Initialize gradient scaler
18    scaler = GradScaler()
19
20    history = {'train_loss': []}
21
22    for epoch in range(epochs):
23        model.train()
24        epoch_loss = 0.0
25
26        for batch in train_loader:
27            x, y = batch
28            x, y = x.to(device), y.to(device)
29
30            optimizer.zero_grad()
31
32            # Forward pass with autocast (FP16)
33            with autocast():
34                pred = model(x)
35                loss = criterion(pred, y)
36
37            # Backward pass with scaled gradients
38            scaler.scale(loss).backward()
39
40            # Unscale gradients and update weights
41            scaler.step(optimizer)
42
43            # Update scale factor for next iteration
44            scaler.update()
45
46            epoch_loss += loss.item()
47
48        history['train_loss'].append(epoch_loss / len(train_loader))
49
50    return history

With Gradient Clipping

When using gradient clipping with mixed precision, unscale before clipping:

🐍python
1def train_mixed_precision_with_clipping(
2    model: nn.Module,
3    train_loader: DataLoader,
4    optimizer: torch.optim.Optimizer,
5    criterion: nn.Module,
6    max_grad_norm: float = 1.0,
7    device: torch.device = torch.device('cuda')
8):
9    """
10    Mixed precision training with gradient clipping.
11
12    The key is to unscale gradients BEFORE clipping.
13    """
14    scaler = GradScaler()
15
16    for batch in train_loader:
17        x, y = batch
18        x, y = x.to(device), y.to(device)
19
20        optimizer.zero_grad()
21
22        # Forward with autocast
23        with autocast():
24            pred = model(x)
25            loss = criterion(pred, y)
26
27        # Backward with scaled gradients
28        scaler.scale(loss).backward()
29
30        # Unscale gradients BEFORE clipping
31        scaler.unscale_(optimizer)
32
33        # Now clip in FP32 space
34        grad_norm = torch.nn.utils.clip_grad_norm_(
35            model.parameters(),
36            max_norm=max_grad_norm
37        )
38
39        # Step (skips if gradients contain inf/NaN)
40        scaler.step(optimizer)
41
42        # Update scale
43        scaler.update()

Complete Training Loop with All Enhancements

🐍python
1def train_epoch_complete(
2    model: nn.Module,
3    train_loader: DataLoader,
4    optimizer: torch.optim.Optimizer,
5    criterion: nn.Module,
6    scaler: GradScaler,
7    ema: ExponentialMovingAverage,
8    max_grad_norm: float = 1.0,
9    accumulation_steps: int = 2,
10    device: torch.device = torch.device('cuda')
11) -> dict:
12    """
13    Complete training epoch with all enhancements:
14    - Mixed precision (autocast + GradScaler)
15    - Gradient clipping
16    - Gradient accumulation
17    - EMA weight tracking
18    """
19    model.train()
20    total_loss = 0.0
21    num_steps = 0
22
23    for batch_idx, (x, y) in enumerate(train_loader):
24        x, y = x.to(device), y.to(device)
25
26        # Zero gradients at start of accumulation cycle
27        if batch_idx % accumulation_steps == 0:
28            optimizer.zero_grad()
29
30        # Forward pass with mixed precision
31        with autocast():
32            pred = model(x)
33            loss = criterion(pred, y)
34            # Scale loss for accumulation
35            loss = loss / accumulation_steps
36
37        # Backward with scaled gradients
38        scaler.scale(loss).backward()
39
40        # Update weights after accumulation cycle
41        if (batch_idx + 1) % accumulation_steps == 0:
42            # Unscale gradients before clipping
43            scaler.unscale_(optimizer)
44
45            # Clip gradients in FP32 space
46            grad_norm = torch.nn.utils.clip_grad_norm_(
47                model.parameters(),
48                max_norm=max_grad_norm
49            )
50
51            # Optimizer step (skipped if inf/NaN gradients)
52            scaler.step(optimizer)
53            scaler.update()
54
55            # Update EMA weights
56            ema.update(model)
57
58            num_steps += 1
59
60        total_loss += loss.item() * accumulation_steps
61
62    return {
63        'loss': total_loss / len(train_loader),
64        'steps': num_steps
65    }

MPS (Apple Silicon) Limitation

Mixed precision with GradScaler is designed for NVIDIA CUDA GPUs. On Apple Silicon (MPS), use autocast only without GradScaler, as MPS handles precision differently.


Summary

In this section, we covered mixed precision training:

  1. Mixed precision: Use FP16 for speed, FP32 for accuracy
  2. FP16 limitations: Range ~10⁻⁸ to 65504, 3 decimal digits
  3. Gradient scaling: Multiply loss before backward, divide before update
  4. autocast: Automatic FP16 for forward/backward
  5. GradScaler: Dynamic scaling to prevent underflow/overflow
ComponentPurpose
autocast()Run forward/backward in FP16
GradScalerDynamic loss scaling
scaler.scale(loss)Apply scale before backward
scaler.unscale_()Unscale before clipping
scaler.step()Update weights (skip if inf/NaN)
scaler.update()Adjust scale for next iteration
Looking Ahead: Mixed precision lets us use larger batch sizes. But sometimes we need even larger effective batches than GPU memory allows. The next section covers gradient accumulationβ€”a technique for simulating large batches without extra memory.

With mixed precision understood, we explore gradient accumulation.