Learning Objectives
By the end of this section, you will:
- Understand exploding gradients in deep networks
- Compare clipping methods (value vs. norm clipping)
- Choose appropriate clipping thresholds
- Monitor gradient statistics during training
- Implement gradient clipping in PyTorch
Why This Matters: Deep recurrent networks like our BiLSTM are susceptible to exploding gradients, where gradient magnitudes grow exponentially through layers. Gradient clipping prevents this by bounding gradient norms, ensuring stable training even with aggressive learning rates.
The Exploding Gradient Problem
Gradients can explode in deep networks due to repeated multiplication.
Why Gradients Explode
Consider backpropagation through L layers:
If each Jacobian has eigenvalues > 1, the product grows exponentially:
For λ = 1.1 and L = 50: 1.1^49 ≈ 97. Gradients are 97× larger than expected.
Symptoms of Exploding Gradients
| Symptom | Indicator | Severity |
|---|---|---|
| Loss spikes | Sudden large loss increases | Moderate |
| NaN loss | Loss becomes undefined | Critical |
| Weights explode | Very large weight values | Severe |
| Oscillating loss | Loss jumps up and down | Moderate |
| Training fails | No convergence | Critical |
When Clipping is Most Important
| Architecture | Clipping Need | Reason |
|---|---|---|
| Deep LSTMs (>2 layers) | High | Long gradient paths |
| Transformers | High | Attention can amplify |
| CNNs (deep) | Moderate | Residuals help |
| Shallow networks | Low | Short paths |
AMNL Uses Clipping
Our 2-layer BiLSTM with attention benefits from gradient clipping. We use max_norm = 1.0 to ensure stable training across all datasets.
Gradient Clipping Methods
There are two main approaches to gradient clipping.
Method 1: Gradient Norm Clipping (Recommended)
Scale all gradients if their combined norm exceeds a threshold:
This preserves the direction of the gradient while limiting its magnitude.
Method 2: Value Clipping
Clip each gradient element independently:
Problems: Changes gradient direction, treats all parameters the same. Not recommended for most applications.
Comparison
| Aspect | Norm Clipping | Value Clipping |
|---|---|---|
| Direction | Preserved | Changed |
| Relative magnitudes | Preserved | Distorted |
| Threshold meaning | Global norm bound | Per-element bound |
| Recommended | Yes | Rarely |
Threshold Selection
Choosing the right clipping threshold is important.
Guidelines for max_norm
| max_norm | Effect | Use Case |
|---|---|---|
| 0.1 | Very aggressive | Extremely unstable training |
| 0.5 | Aggressive | Unstable training |
| 1.0 | Standard | Most cases (recommended) |
| 5.0 | Permissive | Stable training |
| 10.0 | Very permissive | Already stable |
Empirical Selection
Monitor gradient norms during training to choose the threshold:
1# Monitor gradient norms during training
2def compute_gradient_norm(model: nn.Module) -> float:
3 """Compute total gradient norm across all parameters."""
4 total_norm = 0.0
5 for p in model.parameters():
6 if p.grad is not None:
7 param_norm = p.grad.data.norm(2)
8 total_norm += param_norm.item() ** 2
9 return total_norm ** 0.5
10
11# In training loop
12for epoch in range(num_epochs):
13 grad_norms = []
14 for batch in dataloader:
15 loss.backward()
16
17 # Monitor before clipping
18 grad_norm = compute_gradient_norm(model)
19 grad_norms.append(grad_norm)
20
21 # Clip and step
22 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
23 optimizer.step()
24
25 # Analyze
26 print(f"Epoch {epoch}: "
27 f"Mean grad norm: {np.mean(grad_norms):.2f}, "
28 f"Max: {np.max(grad_norms):.2f}, "
29 f"Clipped: {sum(g > 1.0 for g in grad_norms) / len(grad_norms):.1%}")Threshold Selection
Set max_norm so that 10-30% of batches are clipped. If almost every batch is clipped, the threshold is too low. If clipping never happens, it is too high (or not needed).
Implementation
PyTorch provides built-in gradient clipping utilities.
Basic Gradient Clipping
1import torch.nn.utils as utils
2
3# Training loop with gradient clipping
4for batch in dataloader:
5 optimizer.zero_grad()
6
7 # Forward pass
8 outputs = model(batch.x)
9 loss = criterion(outputs, batch.y)
10
11 # Backward pass
12 loss.backward()
13
14 # Gradient clipping (BEFORE optimizer.step())
15 utils.clip_grad_norm_(
16 model.parameters(),
17 max_norm=1.0,
18 norm_type=2 # L2 norm (default)
19 )
20
21 # Update parameters
22 optimizer.step()With Gradient Norm Logging
1def train_step(
2 model: nn.Module,
3 batch: tuple,
4 optimizer: torch.optim.Optimizer,
5 criterion: nn.Module,
6 max_grad_norm: float = 1.0
7) -> dict:
8 """
9 Single training step with gradient clipping and monitoring.
10
11 Returns:
12 Dictionary with loss and gradient statistics
13 """
14 optimizer.zero_grad()
15
16 # Forward
17 x, target = batch
18 pred = model(x)
19 loss = criterion(pred, target)
20
21 # Backward
22 loss.backward()
23
24 # Compute gradient norm before clipping
25 grad_norm_before = torch.nn.utils.clip_grad_norm_(
26 model.parameters(),
27 max_norm=float('inf') # Just compute, don't clip
28 )
29
30 # Actually clip
31 grad_norm_after = torch.nn.utils.clip_grad_norm_(
32 model.parameters(),
33 max_norm=max_grad_norm
34 )
35
36 # Update
37 optimizer.step()
38
39 return {
40 'loss': loss.item(),
41 'grad_norm_before': grad_norm_before.item(),
42 'grad_norm_after': grad_norm_after.item(),
43 'was_clipped': grad_norm_before > max_grad_norm
44 }Complete Training Loop with All Optimizations
1def train_epoch(
2 model: nn.Module,
3 dataloader: DataLoader,
4 optimizer: torch.optim.Optimizer,
5 scheduler,
6 criterion: nn.Module,
7 max_grad_norm: float = 1.0,
8 device: torch.device = torch.device('cuda')
9) -> dict:
10 """
11 Complete training epoch with all optimization techniques.
12 """
13 model.train()
14
15 total_loss = 0.0
16 total_grad_norm = 0.0
17 num_clipped = 0
18 num_steps = 0
19
20 for batch in dataloader:
21 x, y = batch
22 x, y = x.to(device), y.to(device)
23
24 # Zero gradients
25 optimizer.zero_grad()
26
27 # Forward pass
28 pred = model(x)
29 loss = criterion(pred, y)
30
31 # Backward pass
32 loss.backward()
33
34 # Gradient clipping
35 grad_norm = torch.nn.utils.clip_grad_norm_(
36 model.parameters(),
37 max_norm=max_grad_norm
38 )
39
40 # Track statistics
41 total_loss += loss.item()
42 total_grad_norm += grad_norm.item()
43 if grad_norm > max_grad_norm:
44 num_clipped += 1
45 num_steps += 1
46
47 # Optimizer step
48 optimizer.step()
49
50 # Scheduler step (if per-batch)
51 if scheduler is not None:
52 scheduler.step()
53
54 return {
55 'loss': total_loss / num_steps,
56 'avg_grad_norm': total_grad_norm / num_steps,
57 'clip_ratio': num_clipped / num_steps,
58 'lr': optimizer.param_groups[0]['lr']
59 }Summary
In this section, we covered gradient clipping:
- Exploding gradients: Caused by repeated multiplication in deep networks
- Norm clipping: Scale gradients to bound their L2 norm
- Threshold: max_norm = 1.0 is a good default
- Monitoring: Track gradient norms to diagnose issues
- Timing: Clip after backward(), before optimizer.step()
| Parameter | Value |
|---|---|
| Clipping method | Gradient norm (L2) |
| max_norm | 1.0 |
| When to apply | After loss.backward() |
| Target clip ratio | 10-30% of batches |
Chapter Complete: You now have a complete optimization toolkit: AdamW optimizer, learning rate warmup, cosine annealing, adaptive weight decay, and gradient clipping. The next chapter covers training enhancements—techniques like EMA model averaging, early stopping, and mixed precision training.
With the optimizer fully configured, we explore advanced training techniques.