Chapter 13
12 min read
Section 66 of 104

Gradient Accumulation

Training Enhancements

Learning Objectives

By the end of this section, you will:

  1. Understand gradient accumulation and why it's useful
  2. Derive the equivalence between accumulated and large batches
  3. Know when accumulation is beneficial vs. detrimental
  4. Implement gradient accumulation correctly in PyTorch
  5. Integrate with mixed precision and other techniques
Why This Matters: Larger batch sizes often improve training stability and final model quality, but GPU memory limits how large a batch can be. Gradient accumulation lets you simulate any batch size by accumulating gradients across multiple mini-batches before updating weights.

What is Gradient Accumulation?

Gradient accumulation simulates large batches by summing gradients across multiple forward-backward passes.

The Core Idea

Instead of one large batch of 512 samples, process 4 mini-batches of 128 samples each, accumulating gradients before the optimizer step:

πŸ“text
1Standard Training (batch_size=512):
2β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
3β”‚                 512 samples                          β”‚
4β”‚                     β”‚                                β”‚
5β”‚            Forward Pass                              β”‚
6β”‚                     β”‚                                β”‚
7β”‚            Backward Pass                             β”‚
8β”‚                     β”‚                                β”‚
9β”‚            Optimizer Step                            β”‚
10β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
11
12Gradient Accumulation (batch_size=128, steps=4):
13β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
14β”‚ 128 batch β”‚ 128 batch β”‚ 128 batch β”‚ 128 batch β”‚
15β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚
16β”‚  Forward  β”‚  Forward  β”‚  Forward  β”‚  Forward  β”‚
17β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚
18β”‚ Backward  β”‚ Backward  β”‚ Backward  β”‚ Backward  β”‚
19β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚     β”‚
20β”‚  (accum)  β”‚  (accum)  β”‚  (accum)  β”‚  (step)   β”‚
21β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
22                                          β”‚
23                              Optimizer Step (once)

Memory Advantage

The key benefit: each mini-batch fits in GPU memory, but the effective batch size is the sum of all mini-batches:

Mini-batchAccumulation StepsEffective Batch
1282256
1284512
2562512
25641024
Beffective=BminiΓ—KB_{\text{effective}} = B_{\text{mini}} \times K

Where K is the number of accumulation steps.


The Mathematics

Gradient accumulation is mathematically equivalent to using a larger batch.

Gradient Averaging

The gradient for a batch is the average of per-sample gradients:

gB=1∣Bβˆ£βˆ‘i∈Bβˆ‡ΞΈL(xi,yi)g_B = \frac{1}{|B|} \sum_{i \in B} \nabla_\theta \mathcal{L}(x_i, y_i)

For K accumulated mini-batches of size b:

gaccum=1Kβˆ‘k=1KgBk=1Kβˆ‘k=1K1bβˆ‘i∈Bkβˆ‡ΞΈL(xi)g_{\text{accum}} = \frac{1}{K} \sum_{k=1}^{K} g_{B_k} = \frac{1}{K} \sum_{k=1}^{K} \frac{1}{b} \sum_{i \in B_k} \nabla_\theta \mathcal{L}(x_i)

Loss Scaling for Accumulation

Since PyTorch's loss.backward() adds to existing gradients, we must scale the loss:

scaled_loss=lossK\text{scaled\_loss} = \frac{\text{loss}}{K}

This ensures the accumulated gradient has the correct magnitude.

Divide Loss, Not Gradients

Scale the loss before backward(), not the gradients after. This is more efficient and numerically stable.


When to Use Gradient Accumulation

Gradient accumulation is not always beneficial.

Good Use Cases

ScenarioWhy Accumulation Helps
Large model, limited GPU memoryEnables training at all
Batch size affects convergenceSimulate larger batches
Stable gradient estimation neededReduce gradient variance
Multi-GPU alternativeCheaper than more GPUs

Poor Use Cases

ScenarioWhy Accumulation Hurts
Already large batch fitsJust slows down training
Need frequent updatesDelays adaptation to data
Small datasetMay overfit with large effective batch
Time-constrainedIncreases training time

AMNL Configuration

For our RUL prediction model:

ParameterValueEffective
Mini-batch size256β€”
Accumulation steps2β€”
Effective batch512256 Γ— 2

Why 2 Accumulation Steps?

With ~20K training samples, effective batch size 512 gives ~40 updates per epoch. This is enough for stable learning while maintaining frequent feedback. More accumulation (e.g., K=4 or K=8) would reduce updates too much.


Implementation

Correct gradient accumulation implementation in PyTorch.

Basic Gradient Accumulation

🐍python
1def train_with_accumulation(
2    model: nn.Module,
3    train_loader: DataLoader,
4    optimizer: torch.optim.Optimizer,
5    criterion: nn.Module,
6    accumulation_steps: int = 2,
7    device: torch.device = torch.device('cuda')
8) -> float:
9    """
10    Training with gradient accumulation.
11
12    Args:
13        accumulation_steps: Number of mini-batches to accumulate
14
15    Returns:
16        Average training loss
17    """
18    model.train()
19    total_loss = 0.0
20
21    for batch_idx, (x, y) in enumerate(train_loader):
22        x, y = x.to(device), y.to(device)
23
24        # Forward pass
25        pred = model(x)
26        loss = criterion(pred, y)
27
28        # Scale loss for accumulation
29        loss = loss / accumulation_steps
30
31        # Backward pass (gradients accumulate)
32        loss.backward()
33
34        # Update weights every K steps
35        if (batch_idx + 1) % accumulation_steps == 0:
36            optimizer.step()
37            optimizer.zero_grad()
38
39        total_loss += loss.item() * accumulation_steps
40
41    # Handle remaining batches (if dataset size not divisible by K)
42    if (batch_idx + 1) % accumulation_steps != 0:
43        optimizer.step()
44        optimizer.zero_grad()
45
46    return total_loss / len(train_loader)

With Gradient Clipping

🐍python
1def train_with_accumulation_and_clipping(
2    model: nn.Module,
3    train_loader: DataLoader,
4    optimizer: torch.optim.Optimizer,
5    criterion: nn.Module,
6    accumulation_steps: int = 2,
7    max_grad_norm: float = 1.0,
8    device: torch.device = torch.device('cuda')
9) -> dict:
10    """
11    Gradient accumulation with gradient clipping.
12
13    Clip gradients AFTER accumulation is complete,
14    BEFORE optimizer step.
15    """
16    model.train()
17    total_loss = 0.0
18    total_grad_norm = 0.0
19    num_updates = 0
20
21    for batch_idx, (x, y) in enumerate(train_loader):
22        x, y = x.to(device), y.to(device)
23
24        # Forward pass
25        pred = model(x)
26        loss = criterion(pred, y) / accumulation_steps
27
28        # Backward pass (accumulates gradients)
29        loss.backward()
30
31        # Update weights every K steps
32        if (batch_idx + 1) % accumulation_steps == 0:
33            # Clip accumulated gradients
34            grad_norm = torch.nn.utils.clip_grad_norm_(
35                model.parameters(),
36                max_norm=max_grad_norm
37            )
38
39            optimizer.step()
40            optimizer.zero_grad()
41
42            total_grad_norm += grad_norm.item()
43            num_updates += 1
44
45        total_loss += loss.item() * accumulation_steps
46
47    return {
48        'loss': total_loss / len(train_loader),
49        'avg_grad_norm': total_grad_norm / max(num_updates, 1)
50    }

With Mixed Precision

🐍python
1from torch.cuda.amp import autocast, GradScaler
2
3def train_with_accumulation_mixed_precision(
4    model: nn.Module,
5    train_loader: DataLoader,
6    optimizer: torch.optim.Optimizer,
7    criterion: nn.Module,
8    scaler: GradScaler,
9    accumulation_steps: int = 2,
10    max_grad_norm: float = 1.0,
11    device: torch.device = torch.device('cuda')
12) -> dict:
13    """
14    Gradient accumulation with mixed precision and clipping.
15
16    Order of operations:
17    1. Forward (FP16 with autocast)
18    2. Scale loss and backward
19    3. Repeat for K steps
20    4. Unscale gradients
21    5. Clip gradients
22    6. Optimizer step (skips if inf/NaN)
23    7. Update scaler
24    """
25    model.train()
26    total_loss = 0.0
27
28    for batch_idx, (x, y) in enumerate(train_loader):
29        x, y = x.to(device), y.to(device)
30
31        # Forward with autocast
32        with autocast():
33            pred = model(x)
34            loss = criterion(pred, y) / accumulation_steps
35
36        # Backward with scaled gradients
37        scaler.scale(loss).backward()
38
39        # Update weights every K steps
40        if (batch_idx + 1) % accumulation_steps == 0:
41            # Unscale before clipping
42            scaler.unscale_(optimizer)
43
44            # Clip in FP32 space
45            grad_norm = torch.nn.utils.clip_grad_norm_(
46                model.parameters(),
47                max_norm=max_grad_norm
48            )
49
50            # Step (skips if inf/NaN gradients)
51            scaler.step(optimizer)
52            scaler.update()
53            optimizer.zero_grad()
54
55        total_loss += loss.item() * accumulation_steps
56
57    return {'loss': total_loss / len(train_loader)}

Common Mistakes

MistakeProblemFix
Forget to scale lossGradients KΓ— too largeloss = loss / K
zero_grad() every batchClears accumulated gradientsOnly zero_grad() after step()
Clip before accumulation completeClips partial gradientClip only when (idx+1) % K == 0
Ignore remainder batchesSome data never updates weightsHandle (idx+1) % K != 0 at end

Summary

In this section, we covered gradient accumulation:

  1. Purpose: Simulate large batches without extra memory
  2. Mathematics: Exactly equivalent to large-batch training
  3. Loss scaling: Divide by K before backward()
  4. Update timing: Only step() after K accumulations
  5. AMNL uses: K = 2 for effective batch 512
ParameterValue
Mini-batch size256
Accumulation steps (K)2
Effective batch size512
Updates per epoch~40
Looking Ahead: We have covered the main training enhancements. The final section addresses reproducibilityβ€”ensuring experiments can be exactly replicated by controlling all sources of randomness.

With accumulation understood, we address reproducibility.