Learning Objectives
By the end of this section, you will:
- Understand gradient accumulation and why it's useful
- Derive the equivalence between accumulated and large batches
- Know when accumulation is beneficial vs. detrimental
- Implement gradient accumulation correctly in PyTorch
- Integrate with mixed precision and other techniques
Why This Matters: Larger batch sizes often improve training stability and final model quality, but GPU memory limits how large a batch can be. Gradient accumulation lets you simulate any batch size by accumulating gradients across multiple mini-batches before updating weights.
What is Gradient Accumulation?
Gradient accumulation simulates large batches by summing gradients across multiple forward-backward passes.
The Core Idea
Instead of one large batch of 512 samples, process 4 mini-batches of 128 samples each, accumulating gradients before the optimizer step:
1Standard Training (batch_size=512):
2βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
3β 512 samples β
4β β β
5β Forward Pass β
6β β β
7β Backward Pass β
8β β β
9β Optimizer Step β
10βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
11
12Gradient Accumulation (batch_size=128, steps=4):
13βββββββββββββ¬ββββββββββββ¬ββββββββββββ¬ββββββββββββ
14β 128 batch β 128 batch β 128 batch β 128 batch β
15β β β β β β β β β
16β Forward β Forward β Forward β Forward β
17β β β β β β β β β
18β Backward β Backward β Backward β Backward β
19β β β β β β β β β
20β (accum) β (accum) β (accum) β (step) β
21βββββββββββββ΄ββββββββββββ΄ββββββββββββ΄ββββββββββββ
22 β
23 Optimizer Step (once)Memory Advantage
The key benefit: each mini-batch fits in GPU memory, but the effective batch size is the sum of all mini-batches:
| Mini-batch | Accumulation Steps | Effective Batch |
|---|---|---|
| 128 | 2 | 256 |
| 128 | 4 | 512 |
| 256 | 2 | 512 |
| 256 | 4 | 1024 |
Where K is the number of accumulation steps.
The Mathematics
Gradient accumulation is mathematically equivalent to using a larger batch.
Gradient Averaging
The gradient for a batch is the average of per-sample gradients:
For K accumulated mini-batches of size b:
Loss Scaling for Accumulation
Since PyTorch's loss.backward() adds to existing gradients, we must scale the loss:
This ensures the accumulated gradient has the correct magnitude.
Divide Loss, Not Gradients
Scale the loss before backward(), not the gradients after. This is more efficient and numerically stable.
When to Use Gradient Accumulation
Gradient accumulation is not always beneficial.
Good Use Cases
| Scenario | Why Accumulation Helps |
|---|---|
| Large model, limited GPU memory | Enables training at all |
| Batch size affects convergence | Simulate larger batches |
| Stable gradient estimation needed | Reduce gradient variance |
| Multi-GPU alternative | Cheaper than more GPUs |
Poor Use Cases
| Scenario | Why Accumulation Hurts |
|---|---|
| Already large batch fits | Just slows down training |
| Need frequent updates | Delays adaptation to data |
| Small dataset | May overfit with large effective batch |
| Time-constrained | Increases training time |
AMNL Configuration
For our RUL prediction model:
| Parameter | Value | Effective |
|---|---|---|
| Mini-batch size | 256 | β |
| Accumulation steps | 2 | β |
| Effective batch | 512 | 256 Γ 2 |
Why 2 Accumulation Steps?
With ~20K training samples, effective batch size 512 gives ~40 updates per epoch. This is enough for stable learning while maintaining frequent feedback. More accumulation (e.g., K=4 or K=8) would reduce updates too much.
Implementation
Correct gradient accumulation implementation in PyTorch.
Basic Gradient Accumulation
1def train_with_accumulation(
2 model: nn.Module,
3 train_loader: DataLoader,
4 optimizer: torch.optim.Optimizer,
5 criterion: nn.Module,
6 accumulation_steps: int = 2,
7 device: torch.device = torch.device('cuda')
8) -> float:
9 """
10 Training with gradient accumulation.
11
12 Args:
13 accumulation_steps: Number of mini-batches to accumulate
14
15 Returns:
16 Average training loss
17 """
18 model.train()
19 total_loss = 0.0
20
21 for batch_idx, (x, y) in enumerate(train_loader):
22 x, y = x.to(device), y.to(device)
23
24 # Forward pass
25 pred = model(x)
26 loss = criterion(pred, y)
27
28 # Scale loss for accumulation
29 loss = loss / accumulation_steps
30
31 # Backward pass (gradients accumulate)
32 loss.backward()
33
34 # Update weights every K steps
35 if (batch_idx + 1) % accumulation_steps == 0:
36 optimizer.step()
37 optimizer.zero_grad()
38
39 total_loss += loss.item() * accumulation_steps
40
41 # Handle remaining batches (if dataset size not divisible by K)
42 if (batch_idx + 1) % accumulation_steps != 0:
43 optimizer.step()
44 optimizer.zero_grad()
45
46 return total_loss / len(train_loader)With Gradient Clipping
1def train_with_accumulation_and_clipping(
2 model: nn.Module,
3 train_loader: DataLoader,
4 optimizer: torch.optim.Optimizer,
5 criterion: nn.Module,
6 accumulation_steps: int = 2,
7 max_grad_norm: float = 1.0,
8 device: torch.device = torch.device('cuda')
9) -> dict:
10 """
11 Gradient accumulation with gradient clipping.
12
13 Clip gradients AFTER accumulation is complete,
14 BEFORE optimizer step.
15 """
16 model.train()
17 total_loss = 0.0
18 total_grad_norm = 0.0
19 num_updates = 0
20
21 for batch_idx, (x, y) in enumerate(train_loader):
22 x, y = x.to(device), y.to(device)
23
24 # Forward pass
25 pred = model(x)
26 loss = criterion(pred, y) / accumulation_steps
27
28 # Backward pass (accumulates gradients)
29 loss.backward()
30
31 # Update weights every K steps
32 if (batch_idx + 1) % accumulation_steps == 0:
33 # Clip accumulated gradients
34 grad_norm = torch.nn.utils.clip_grad_norm_(
35 model.parameters(),
36 max_norm=max_grad_norm
37 )
38
39 optimizer.step()
40 optimizer.zero_grad()
41
42 total_grad_norm += grad_norm.item()
43 num_updates += 1
44
45 total_loss += loss.item() * accumulation_steps
46
47 return {
48 'loss': total_loss / len(train_loader),
49 'avg_grad_norm': total_grad_norm / max(num_updates, 1)
50 }With Mixed Precision
1from torch.cuda.amp import autocast, GradScaler
2
3def train_with_accumulation_mixed_precision(
4 model: nn.Module,
5 train_loader: DataLoader,
6 optimizer: torch.optim.Optimizer,
7 criterion: nn.Module,
8 scaler: GradScaler,
9 accumulation_steps: int = 2,
10 max_grad_norm: float = 1.0,
11 device: torch.device = torch.device('cuda')
12) -> dict:
13 """
14 Gradient accumulation with mixed precision and clipping.
15
16 Order of operations:
17 1. Forward (FP16 with autocast)
18 2. Scale loss and backward
19 3. Repeat for K steps
20 4. Unscale gradients
21 5. Clip gradients
22 6. Optimizer step (skips if inf/NaN)
23 7. Update scaler
24 """
25 model.train()
26 total_loss = 0.0
27
28 for batch_idx, (x, y) in enumerate(train_loader):
29 x, y = x.to(device), y.to(device)
30
31 # Forward with autocast
32 with autocast():
33 pred = model(x)
34 loss = criterion(pred, y) / accumulation_steps
35
36 # Backward with scaled gradients
37 scaler.scale(loss).backward()
38
39 # Update weights every K steps
40 if (batch_idx + 1) % accumulation_steps == 0:
41 # Unscale before clipping
42 scaler.unscale_(optimizer)
43
44 # Clip in FP32 space
45 grad_norm = torch.nn.utils.clip_grad_norm_(
46 model.parameters(),
47 max_norm=max_grad_norm
48 )
49
50 # Step (skips if inf/NaN gradients)
51 scaler.step(optimizer)
52 scaler.update()
53 optimizer.zero_grad()
54
55 total_loss += loss.item() * accumulation_steps
56
57 return {'loss': total_loss / len(train_loader)}Common Mistakes
| Mistake | Problem | Fix |
|---|---|---|
| Forget to scale loss | Gradients KΓ too large | loss = loss / K |
| zero_grad() every batch | Clears accumulated gradients | Only zero_grad() after step() |
| Clip before accumulation complete | Clips partial gradient | Clip only when (idx+1) % K == 0 |
| Ignore remainder batches | Some data never updates weights | Handle (idx+1) % K != 0 at end |
Summary
In this section, we covered gradient accumulation:
- Purpose: Simulate large batches without extra memory
- Mathematics: Exactly equivalent to large-batch training
- Loss scaling: Divide by K before backward()
- Update timing: Only step() after K accumulations
- AMNL uses: K = 2 for effective batch 512
| Parameter | Value |
|---|---|
| Mini-batch size | 256 |
| Accumulation steps (K) | 2 |
| Effective batch size | 512 |
| Updates per epoch | ~40 |
Looking Ahead: We have covered the main training enhancements. The final section addresses reproducibilityβensuring experiments can be exactly replicated by controlling all sources of randomness.
With accumulation understood, we address reproducibility.