Chapter 12
10 min read
Section 62 of 104

Gradient Clipping

Optimization Strategy

Learning Objectives

By the end of this section, you will:

  1. Understand exploding gradients in deep networks
  2. Compare clipping methods (value vs. norm clipping)
  3. Choose appropriate clipping thresholds
  4. Monitor gradient statistics during training
  5. Implement gradient clipping in PyTorch
Why This Matters: Deep recurrent networks like our BiLSTM are susceptible to exploding gradients, where gradient magnitudes grow exponentially through layers. Gradient clipping prevents this by bounding gradient norms, ensuring stable training even with aggressive learning rates.

The Exploding Gradient Problem

Gradients can explode in deep networks due to repeated multiplication.

Why Gradients Explode

Consider backpropagation through L layers:

Lθ1=LhLl=1L1hl+1hlh1θ1\frac{\partial \mathcal{L}}{\partial \theta_1} = \frac{\partial \mathcal{L}}{\partial h_L} \cdot \prod_{l=1}^{L-1} \frac{\partial h_{l+1}}{\partial h_l} \cdot \frac{\partial h_1}{\partial \theta_1}

If each Jacobian hl+1/hl\partial h_{l+1}/\partial h_l has eigenvalues > 1, the product grows exponentially:

l=1L1hl+1hlλL1\left\|\prod_{l=1}^{L-1} \frac{\partial h_{l+1}}{\partial h_l}\right\| \approx \lambda^{L-1}

For λ = 1.1 and L = 50: 1.1^49 ≈ 97. Gradients are 97× larger than expected.

Symptoms of Exploding Gradients

SymptomIndicatorSeverity
Loss spikesSudden large loss increasesModerate
NaN lossLoss becomes undefinedCritical
Weights explodeVery large weight valuesSevere
Oscillating lossLoss jumps up and downModerate
Training failsNo convergenceCritical

When Clipping is Most Important

ArchitectureClipping NeedReason
Deep LSTMs (>2 layers)HighLong gradient paths
TransformersHighAttention can amplify
CNNs (deep)ModerateResiduals help
Shallow networksLowShort paths

AMNL Uses Clipping

Our 2-layer BiLSTM with attention benefits from gradient clipping. We use max_norm = 1.0 to ensure stable training across all datasets.


Gradient Clipping Methods

There are two main approaches to gradient clipping.

Scale all gradients if their combined norm exceeds a threshold:

g{gif gmax_normgmax_normgif g>max_normg \leftarrow \begin{cases} g & \text{if } \|g\| \leq \text{max\_norm} \\ g \cdot \frac{\text{max\_norm}}{\|g\|} & \text{if } \|g\| > \text{max\_norm} \end{cases}

This preserves the direction of the gradient while limiting its magnitude.

Method 2: Value Clipping

Clip each gradient element independently:

giclip(gi,clip_value,+clip_value)g_i \leftarrow \text{clip}(g_i, -\text{clip\_value}, +\text{clip\_value})

Problems: Changes gradient direction, treats all parameters the same. Not recommended for most applications.

Comparison

AspectNorm ClippingValue Clipping
DirectionPreservedChanged
Relative magnitudesPreservedDistorted
Threshold meaningGlobal norm boundPer-element bound
RecommendedYesRarely

Threshold Selection

Choosing the right clipping threshold is important.

Guidelines for max_norm

max_normEffectUse Case
0.1Very aggressiveExtremely unstable training
0.5AggressiveUnstable training
1.0StandardMost cases (recommended)
5.0PermissiveStable training
10.0Very permissiveAlready stable

Empirical Selection

Monitor gradient norms during training to choose the threshold:

🐍python
1# Monitor gradient norms during training
2def compute_gradient_norm(model: nn.Module) -> float:
3    """Compute total gradient norm across all parameters."""
4    total_norm = 0.0
5    for p in model.parameters():
6        if p.grad is not None:
7            param_norm = p.grad.data.norm(2)
8            total_norm += param_norm.item() ** 2
9    return total_norm ** 0.5
10
11# In training loop
12for epoch in range(num_epochs):
13    grad_norms = []
14    for batch in dataloader:
15        loss.backward()
16
17        # Monitor before clipping
18        grad_norm = compute_gradient_norm(model)
19        grad_norms.append(grad_norm)
20
21        # Clip and step
22        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
23        optimizer.step()
24
25    # Analyze
26    print(f"Epoch {epoch}: "
27          f"Mean grad norm: {np.mean(grad_norms):.2f}, "
28          f"Max: {np.max(grad_norms):.2f}, "
29          f"Clipped: {sum(g > 1.0 for g in grad_norms) / len(grad_norms):.1%}")

Threshold Selection

Set max_norm so that 10-30% of batches are clipped. If almost every batch is clipped, the threshold is too low. If clipping never happens, it is too high (or not needed).


Implementation

PyTorch provides built-in gradient clipping utilities.

Basic Gradient Clipping

🐍python
1import torch.nn.utils as utils
2
3# Training loop with gradient clipping
4for batch in dataloader:
5    optimizer.zero_grad()
6
7    # Forward pass
8    outputs = model(batch.x)
9    loss = criterion(outputs, batch.y)
10
11    # Backward pass
12    loss.backward()
13
14    # Gradient clipping (BEFORE optimizer.step())
15    utils.clip_grad_norm_(
16        model.parameters(),
17        max_norm=1.0,
18        norm_type=2  # L2 norm (default)
19    )
20
21    # Update parameters
22    optimizer.step()

With Gradient Norm Logging

🐍python
1def train_step(
2    model: nn.Module,
3    batch: tuple,
4    optimizer: torch.optim.Optimizer,
5    criterion: nn.Module,
6    max_grad_norm: float = 1.0
7) -> dict:
8    """
9    Single training step with gradient clipping and monitoring.
10
11    Returns:
12        Dictionary with loss and gradient statistics
13    """
14    optimizer.zero_grad()
15
16    # Forward
17    x, target = batch
18    pred = model(x)
19    loss = criterion(pred, target)
20
21    # Backward
22    loss.backward()
23
24    # Compute gradient norm before clipping
25    grad_norm_before = torch.nn.utils.clip_grad_norm_(
26        model.parameters(),
27        max_norm=float('inf')  # Just compute, don't clip
28    )
29
30    # Actually clip
31    grad_norm_after = torch.nn.utils.clip_grad_norm_(
32        model.parameters(),
33        max_norm=max_grad_norm
34    )
35
36    # Update
37    optimizer.step()
38
39    return {
40        'loss': loss.item(),
41        'grad_norm_before': grad_norm_before.item(),
42        'grad_norm_after': grad_norm_after.item(),
43        'was_clipped': grad_norm_before > max_grad_norm
44    }

Complete Training Loop with All Optimizations

🐍python
1def train_epoch(
2    model: nn.Module,
3    dataloader: DataLoader,
4    optimizer: torch.optim.Optimizer,
5    scheduler,
6    criterion: nn.Module,
7    max_grad_norm: float = 1.0,
8    device: torch.device = torch.device('cuda')
9) -> dict:
10    """
11    Complete training epoch with all optimization techniques.
12    """
13    model.train()
14
15    total_loss = 0.0
16    total_grad_norm = 0.0
17    num_clipped = 0
18    num_steps = 0
19
20    for batch in dataloader:
21        x, y = batch
22        x, y = x.to(device), y.to(device)
23
24        # Zero gradients
25        optimizer.zero_grad()
26
27        # Forward pass
28        pred = model(x)
29        loss = criterion(pred, y)
30
31        # Backward pass
32        loss.backward()
33
34        # Gradient clipping
35        grad_norm = torch.nn.utils.clip_grad_norm_(
36            model.parameters(),
37            max_norm=max_grad_norm
38        )
39
40        # Track statistics
41        total_loss += loss.item()
42        total_grad_norm += grad_norm.item()
43        if grad_norm > max_grad_norm:
44            num_clipped += 1
45        num_steps += 1
46
47        # Optimizer step
48        optimizer.step()
49
50        # Scheduler step (if per-batch)
51        if scheduler is not None:
52            scheduler.step()
53
54    return {
55        'loss': total_loss / num_steps,
56        'avg_grad_norm': total_grad_norm / num_steps,
57        'clip_ratio': num_clipped / num_steps,
58        'lr': optimizer.param_groups[0]['lr']
59    }

Summary

In this section, we covered gradient clipping:

  1. Exploding gradients: Caused by repeated multiplication in deep networks
  2. Norm clipping: Scale gradients to bound their L2 norm
  3. Threshold: max_norm = 1.0 is a good default
  4. Monitoring: Track gradient norms to diagnose issues
  5. Timing: Clip after backward(), before optimizer.step()
ParameterValue
Clipping methodGradient norm (L2)
max_norm1.0
When to applyAfter loss.backward()
Target clip ratio10-30% of batches
Chapter Complete: You now have a complete optimization toolkit: AdamW optimizer, learning rate warmup, cosine annealing, adaptive weight decay, and gradient clipping. The next chapter covers training enhancements—techniques like EMA model averaging, early stopping, and mixed precision training.

With the optimizer fully configured, we explore advanced training techniques.