Chapter 19
22 min read
Section 62 of 65

Gradient Clipping and Accumulation

Modern Training Techniques

Two Small Problems That Break Big Training Runs

Every modern training pipeline — from a 10-layer MLP on a laptop GPU to GPT-class transformers on 1024 H100s — has two failure modes that have nothing to do with modeling:gradients that suddenly become huge, and batches that are too small to learn anything stable. The first blows the optimizer off the loss surface in a single step. The second starves stochastic gradient descent of the variance reduction it needs, and, when gradients do fit in memory, wastes expensive hardware.

Two tiny tricks solve both problems completely, without changing the model. They are the quiet infrastructure behind every large-scale training run:

  1. Gradient clipping — bound the magnitude of the update so one bad step cannot destroy an hour of training.
  2. Gradient accumulation — simulate a big batch using many small forward/backward passes, so a 7-billion-parameter model can train on a laptop or a 32-GB GPU.

Both are mechanically trivial. What makes them interesting — and what we will unpack line-by-line — is the geometry of “just rescale the gradient”, the algebraic identity that makes accumulation exact, not approximate, and how the combination connects directly to the engineering behind Flash Attention, KV-cache-heavy LLM inference, and modern transformer scaling. The mathematics is small; the consequences are enormous.

One-sentence summary. Clipping bounds the size of the step; accumulation bounds the memory used to compute the step. Everything else in this section is commentary on those two sentences.

When Gradients Explode: The Training-Time Catastrophe

In chapter 16 we saw that, in a deep or recurrent network, the gradient of the loss with respect to an early parameter is a product of Jacobians — one for each layer or timestep. The same product that can vanish (Jacobians with singular values < 1) can also explode (singular values > 1). If the product of singular values across a 20-layer transformer or a 100-step RNN lands around 21002^{100}, a gradient that started at order 1 arrives at the first parameter at order 103010^{30}.

SGD then performs wwηgw \leftarrow w - \eta \, g with η=103\eta = 10^{-3} and g=1030\|g\| = 10^{30}. The update is ηg=1027\|\eta g\| = 10^{27}. The weight vector, previously near the optimum, is hurled out to infinity. Every activation in the next forward pass is NaN. The run is dead.

The interactive visualizer below shows this catastrophe on a small Rosenbrock-like surface. Turn off clipping and push the learning rate up — the trajectory leaves the screen in a few steps. Turn on norm clipping and the same learning rate lands safely in the valley.

Loading loss-surface trajectory...
Why exploding gradients happen at all. Three recurring causes: (i) poor initialization that puts the initial Jacobians above 1, (ii) activations that saturate and briefly produce very large local derivatives, and (iii) outlier data — a single batch with an atypical target that produces a huge loss spike. Clipping is not a substitute for good initialization or data cleaning, but it is a reliable last line of defense.

Gradient clipping is the idea: if the gradient is too big, make it smaller before taking the step. Two flavors exist — one naive and one that is mathematically principled.


Gradient Clipping by Value

The simplest possible clipping rule. Pick a threshold c>0c > 0 and clamp every component of the gradient independently into the interval [c,+c][-c, +c] so that gi=max ⁣(c,  min(c,  gi))g'_i = \max\!\bigl(-c,\; \min(c,\; g_i)\bigr) for every coordinate ii.

Any gradient coordinate with magnitude above cc is snapped to ±c\pm c. Any coordinate below passes through unchanged. This is what torch.nn.utils.clip_grad_value_ does.

It is easy to implement and cheap to run, but it has one mathematically ugly property: it changes the direction of the gradient. If the true descent direction has most of its mass in one coordinate that happens to be large, value clipping will amputate that coordinate down to cc while leaving the others alone — the result points somewhere else entirely. In practice value clipping is used mostly for stability in specialized pipelines (e.g., some reinforcement learning setups) and is rarely the default in supervised training.

PropertyClip-by-Value
Operates onEach coordinate independently
Preserves direction?No
Preserves norm bound?Yes, but only per-coordinate: gic|g'_i| \le c
Global norm after clipUp to cdc \cdot \sqrt{d}, where d = number of parameters
PyTorch calltorch.nn.utils.clip_grad_value_(params, c)

Gradient Clipping by Norm (The Preferred Method)

The standard choice for every modern pipeline — transformers, GANs, RNNs, diffusion models — is norm clipping. Pick a threshold τ>0\tau > 0 and rescale the whole gradient so that its Euclidean norm is at most τ\tau. In piecewise form, g=gg' = g when g2τ\|g\|_2 \le \tau, and g=(τ/g2)gg' = (\tau / \|g\|_2)\, g otherwise. A more compact way to write the same thing is g=min ⁣(1,  τ/g2)gg' = \min\!\bigl(1,\; \tau / \|g\|_2\bigr) \cdot g; the multiplier is either 11 (no clip) or τ/g2<1\tau/\|g\|_2 < 1 (a uniform shrink). Either way:

  1. Direction is exactly preserved. The output is parallel to the input; only its length changes.
  2. Norm is exactly bounded. After clipping, g2τ\|g'\|_2 \le \tau, with equality whenever clipping triggers.
  3. Scale is smooth in the data. Small gradients pass through identically; large gradients shrink proportionally. No sharp per-coordinate cliffs.

In a real model the “gradient” is the concatenation of every parameter's gradient — a single vector of length equal to the total parameter count. Norm clipping uses the global L2 norm of that vector, namely g2=pparamsi( ⁣pL)i2\|g\|_2 = \sqrt{\sum_{p \in \text{params}} \sum_{i} (\nabla_{\!p} L)_i^{\,2}}, which is why PyTorch takes an iterable of parameters rather than one tensor.

Choosing τ. A good default is τ=1.0\tau = 1.0 (transformers frequently use τ{0.5,1.0,5.0}\tau \in \{0.5, 1.0, 5.0\}). Log the pre-clip g2\|g\|_2 every step: if it is always above τ, your learning rate or initialization is off; if it rarely reaches τ, clipping is a no-op and doing no harm.

Seeing Clipping Geometrically

The two clipping rules define two different feasible regions for the gradient vector. Value clipping restricts you to the axis-aligned square [c,c]d[-c, c]^d; norm clipping restricts you to the Euclidean ball {g:g2τ}\{g : \|g\|_2 \le \tau\}. The picture below is 2-D so you can see both. Move the red arrow outside the green feasible region and watch where the blue clipped arrow lands.

Loading gradient-clipping visualizer...

Two observations that the visualizer makes unmistakable:

  1. Norm clipping slides the tip of the arrow radially onto the boundary circle. The direction is the same; only the length changes.
  2. Value clipping snaps the tip to the nearest point of the square, which in general is not in the same direction as the original. Set g=(8,0.3)g = (8, 0.3) and clip value c=2c = 2: norm clipping gives approximately (1.998,0.075)(1.998,\, 0.075), while value clipping gives (2,0.3)(2,\, 0.3) — a noticeably different direction.

Watching Clipping Save a Training Run

The trajectory visualizer from the introduction is worth revisiting now that we have the norm-clipping formula. The loss here is a Rosenbrock-style narrow valley — a notoriously difficult surface where the gradient along the ridge is huge while the gradient along the valley is tiny. Without clipping, even a moderate learning rate overshoots the ridge and amplifies the error on the next step. With norm clipping, the direction toward the valley is preserved but the step length is bounded, so the optimizer glides in instead of ricocheting.

The timeline below plays a simulated 2000-step training run with occasional gradient spikes. Drag τ\tau and watch the percentage of clipped steps change — it is the single most informative scalar to log during training.

Loading gradient-norm timeline...
Rule of thumb: if your loss curve occasionally spikes and recovers, your clipping is working. If it spikes and never recovers, your clipping is missing (or your τ is too large). If it is flat and boring, you might not need clipping at all — but it costs so little to include that almost every production pipeline does.

Python from Scratch: Building Both Clippers

Before reaching for PyTorch, we implement both clipping rules with nothing but NumPy. The entire algorithm for norm clipping is five lines. Reading it top to bottom is the best way to internalize the formula.

Clipping — NumPy implementation of both flavors
🐍clip_scratch.py
1import numpy as np

NumPy is all we need for the from-scratch version. np.clip performs element-wise clamping, np.linalg.norm computes the Euclidean length of a flattened array, and arithmetic on arrays is vectorized. No autograd is needed because we are clipping an already-computed gradient.

EXECUTION STATE
numpy = Numerical library — provides ndarray plus np.clip, np.linalg.norm and broadcasting. All the math we need.
as np = Alias. Lets us write np.clip(g, -c, c) instead of numpy.clip(g, -c, c).
3Comment — what g represents

In a real training loop, g would be the gradient of the loss with respect to one parameter, produced by loss.backward(). Here we use a hand-crafted example so the math is completely visible.

4Comment — shape is arbitrary

Clipping works for any shape because both operations are defined element-wise (clip-by-value) or on the flattened array (clip-by-norm). 2x3 is small enough to print.

5g = np.array([[-5.20, 0.30, 8.90], [1.40, -12.5, 0.05]])

Construct a tiny 2×3 gradient matrix. Two entries are much larger in magnitude than the rest: g[0,2] = 8.9 and g[1,1] = −12.5. These are the kind of spikes that kill training runs.

EXECUTION STATE
⬆ g (2×3) =
       c0     c1     c2
 r0  -5.20   0.30   8.90
 r1   1.40 -12.50   0.05
→ notice = Five of six entries are O(1). One entry (g[1,1] = −12.5) is an order of magnitude larger. In a real net this is what happens when a tiny fraction of activations land in a saturating region.
8# ---------- Clipping by VALUE ----------

Section header. Value clipping is the simpler of the two: treat each component in isolation and hard-cap its magnitude.

9def clip_by_value(g, c) → np.ndarray

A one-liner that clamps every element of g into the interval [−c, +c]. Big entries get snapped to ±c; small entries pass through unchanged.

EXECUTION STATE
⬇ input: g (2×3) =
       c0     c1     c2
 r0  -5.20   0.30   8.90
 r1   1.40 -12.50   0.05
⬇ input: c (scalar) = 2.0 — the per-element cap. Any entry above +c becomes +c; anything below −c becomes −c.
→ c purpose = A hyperparameter. Small c (e.g., 0.5) is aggressive clipping — destroys the direction but bounds every step tightly. Large c (e.g., 10) is loose — rarely triggers.
⬆ returns = np.ndarray of the same shape (2×3) where every element lies in [−c, +c].
10Comment — clamp semantics

Reminder that this is the defining equation of value clipping: g'ᵢ = max(−c, min(c, gᵢ)). Each coordinate is treated independently.

11return np.clip(g, -c, c)

np.clip does exactly max(−c, min(c, g)) element-wise, fully vectorized in C. No loop needed.

EXECUTION STATE
📚 np.clip(a, a_min, a_max) = NumPy function: returns an array where every element of a has been limited to the interval [a_min, a_max]. Example: np.clip([-3, 0, 5], -1, 1) = [-1, 0, 1].
⬇ arg 1: g = The input gradient array.
⬇ arg 2: -c = -2.0 = Lower bound. Anything smaller becomes −2.0.
⬇ arg 3: c = 2.0 = Upper bound. Anything larger becomes +2.0.
→ element-by-element trace = g[0,0] = -5.20 → -2.00 (saturated) g[0,1] = 0.30 → 0.30 (passed) g[0,2] = 8.90 → 2.00 (saturated) g[1,0] = 1.40 → 1.40 (passed) g[1,1] = -12.50 → -2.00 (saturated) g[1,2] = 0.05 → 0.05 (passed)
⬆ return: clipped (2×3) =
       c0     c1     c2
 r0  -2.00   0.30   2.00
 r1   1.40  -2.00   0.05
13# ---------- Clipping by NORM ----------

Section header. Norm clipping is the method used in essentially every modern training pipeline — it preserves the direction of the gradient and only rescales its length.

14def clip_by_norm(g, tau, eps=1e-6) → (np.ndarray, float, float)

Rescale g so that its L2 norm is at most tau. If ‖g‖ already ≤ tau, leave it alone. Otherwise multiply by the scale factor tau/‖g‖. Returns the clipped tensor along with the original norm and the applied scale so a training loop can log them.

EXECUTION STATE
⬇ input: g (2×3) =
       c0     c1     c2
 r0  -5.20   0.30   8.90
 r1   1.40 -12.50   0.05
⬇ input: tau = 5.0 — the maximum allowed L2 norm. The clipping threshold.
⬇ input: eps = 1e-6 = A tiny constant added to the norm in the denominator to avoid division by zero when the gradient is exactly the zero vector.
⬆ returns (tuple of 3) = (clipped gradient, original ‖g‖ pre-clip, scale factor applied). Logging the pre-clip norm is how you tune tau in practice.
15Comment — what norm we are computing

We use the GLOBAL L2 norm: flatten every parameter's gradient into one long vector and take its Euclidean length. This is what PyTorch's clip_grad_norm_ does by default and what matches the theory.

16total_norm = np.linalg.norm(g.ravel())

Compute the Euclidean length ‖g‖₂ = sqrt(Σᵢⱼ gᵢⱼ²). Ravel flattens the 2×3 matrix into a 6-vector so we get one scalar norm — not a per-row norm.

EXECUTION STATE
📚 np.linalg.norm(x) = Computes the vector 2-norm by default: sqrt(sum(x**2)). Also supports matrix norms and other orders via the `ord` argument.
📚 g.ravel() = Returns a 1-D view of g with the same data. For shape (2,3), ravel gives a length-6 vector [-5.20, 0.30, 8.90, 1.40, -12.50, 0.05].
→ computation = sum of squares = 5.20² + 0.30² + 8.90² + 1.40² + 12.50² + 0.05² = 27.04 + 0.09 + 79.21 + 1.96 + 156.25 + 0.0025 = 264.5525 sqrt(264.5525) = 16.2651
⬆ total_norm = 16.2651 — well above our threshold of 5.0, so clipping WILL trigger.
17Comment — early exit

If the norm is already below tau, no rescaling is needed. Returning immediately keeps gradients untouched when training is well-behaved.

18if total_norm <= tau:

Guard: in our example 16.2651 > 5.0 so the branch is FALSE and we skip the return. On a healthy training step this branch would be TRUE most of the time.

EXECUTION STATE
total_norm = 16.2651
tau = 5.0
→ 16.2651 <= 5.0 = False — proceed to the scaling branch below.
19return g, total_norm, 1.0

In the no-clip case, we return g unchanged, the measured norm, and scale = 1.0 (a no-op multiplier). For this example this line does NOT execute, but it is critical in the common case where gradients are tame.

EXECUTION STATE
⬆ return (if no clip) = (g, total_norm, 1.0) — g is NOT copied, so this is O(1).
20Comment — the rescale

Once we know ‖g‖ > tau, we compute the multiplicative factor that brings the norm down to exactly tau.

21scale = tau / (total_norm + eps)

The scale factor is τ / ‖g‖. Because norms are positive, this is always in (0, 1) when clipping triggers. The eps guards the zero-gradient edge case.

EXECUTION STATE
tau = 5.0 — our max allowed norm.
total_norm + eps = 16.2651 + 0.000001 ≈ 16.265101
→ division = 5.0 / 16.265101 = 0.307407
⬆ scale = 0.307407 — every element of g will be multiplied by this.
→ sanity check = new norm = scale · old norm = 0.307407 × 16.2651 = 5.0000 ✓ (matches tau exactly)
22return g * scale, total_norm, scale

Element-wise multiplication by the scalar `scale`. NumPy broadcasts the scalar across the (2×3) array, shrinking every component by the same factor — so the DIRECTION of g is preserved exactly.

EXECUTION STATE
g * scale = Broadcast multiplication. Every element of g is multiplied by 0.307407.
→ element-by-element = -5.20 × 0.307407 = -1.5985 0.30 × 0.307407 = 0.0922 8.90 × 0.307407 = 2.7359 1.40 × 0.307407 = 0.4304 -12.50 × 0.307407 = -3.8426 0.05 × 0.307407 = 0.0154
⬆ return[0]: clipped (2×3) =
         c0       c1       c2
 r0  -1.5985   0.0922   2.7359
 r1   0.4304  -3.8426   0.0154
⬆ return[1]: total_norm = 16.2651 — the PRE-clip norm. Useful for logging; spikes in this value are the first sign of training instability.
⬆ return[2]: scale = 0.307407 — the exact factor that was applied.
24g_val = clip_by_value(g, c=2.0)

Call the value-clip function. Each element of g is independently snapped into [−2.0, +2.0].

EXECUTION STATE
⬆ g_val (2×3) =
       c0     c1     c2
 r0  -2.00   0.30   2.00
 r1   1.40  -2.00   0.05
→ direction changed? = Yes. Raw direction was dominated by g[1,1] = -12.5 and g[0,2] = 8.9. After value-clip, every saturated element contributes equally. The vector now points in a different direction than the raw gradient.
25g_nrm, total, s = clip_by_norm(g, tau=5.0)

Call the norm-clip function and tuple-unpack the three return values. Because 16.2651 > 5.0, the scaling branch runs and every element shrinks by ≈ 0.307.

EXECUTION STATE
unpacking: g_nrm, total, s = Python tuple unpacking — assigns the three returned values to three names in one statement.
⬆ g_nrm (2×3) =
         c0       c1       c2
 r0  -1.5985   0.0922   2.7359
 r1   0.4304  -3.8426   0.0154
⬆ total = 16.2651
⬆ s = 0.3074
→ direction preserved? = Yes. g_nrm is exactly (tau/‖g‖)·g so it is parallel to g. Value-clip does NOT have this property.
27print('raw ||g|| =', np.linalg.norm(g))

Print the raw gradient's norm. This is the single most important thing to log during training — if this number spikes, everything after it is suspect.

EXECUTION STATE
⬆ printed = raw ||g|| = 16.26506991833292
28print('value-clipped:\n', g_val)

Print the value-clipped result. Notice the four saturated entries are now exactly ±2.0 while the two small ones pass through.

EXECUTION STATE
⬆ printed = value-clipped: [[-2. 0.3 2. ] [ 1.4 -2. 0.05]]
29print('norm-clipped:\n', g_nrm, '\n||g\'|| =', np.linalg.norm(g_nrm), ' scale =', round(s, 4))

Print the norm-clipped result along with its post-clip norm (which should equal tau to machine precision) and the scale factor.

EXECUTION STATE
⬆ printed = norm-clipped: [[-1.59851748 0.09222216 2.73592413] [ 0.43037009 -3.84259007 0.01537036]] ||g'|| = 4.999999692592794 scale = 0.3074
→ the key observation = The post-clip norm is 4.9999996 ≈ 5.0 = tau, exactly as designed. And every element shrank by the same factor so DIRECTION is preserved — the only thing that changed is length.
8 lines without explanation
1import numpy as np
2
3# A raw gradient tensor with a few exploding entries.
4# Shape (2, 3) — could be any parameter's gradient.
5g = np.array([[-5.20,  0.30,  8.90],
6              [ 1.40, -12.5,  0.05]])
7
8# ---------- Clipping by VALUE ----------
9def clip_by_value(g, c):
10    # Clamp every element into the interval [-c, +c].
11    return np.clip(g, -c, c)
12
13# ---------- Clipping by NORM ----------
14def clip_by_norm(g, tau, eps=1e-6):
15    # Global L2 norm over ALL parameters flattened together.
16    total_norm = np.linalg.norm(g.ravel())
17    # If already below the threshold, leave g untouched.
18    if total_norm <= tau:
19        return g, total_norm, 1.0
20    # Otherwise rescale so the new norm equals tau exactly.
21    scale = tau / (total_norm + eps)
22    return g * scale, total_norm, scale
23
24g_val  = clip_by_value(g, c=2.0)
25g_nrm, total, s = clip_by_norm(g, tau=5.0)
26
27print("raw  ||g|| =", np.linalg.norm(g))
28print("value-clipped:\n", g_val)
29print("norm-clipped:\n", g_nrm,
30      "\n||g'|| =", np.linalg.norm(g_nrm),
31      " scale =", round(s, 4))

The script reports g=16.2651\|g\| = 16.2651, gnorm-clip=5.0000\|g'_{\text{norm-clip}}\| = 5.0000, and scale=0.3074\text{scale} = 0.3074. The post-clip norm is exactly τ\tau to machine precision, and because the rescaling is a single global factor, every coordinate shrinks by the same amount — the direction of the descent step is preserved bit-for-bit.


PyTorch Equivalent: clip_grad_value_ and clip_grad_norm_

The production equivalents live in torch.nn.utils. They are thin, in-place wrappers over the from-scratch code we just wrote, with two practical additions: they iterate over every parameter in the model and compute the global L2 norm before clipping, and they run on whatever device (CPU/GPU/TPU) the parameters live on.

clip_grad_norm_ — the PyTorch idiom
🐍clip_pytorch.py
1import torch

The PyTorch core. Unlike the NumPy version, PyTorch tracks every operation so we can call .backward() and get gradients automatically. Clipping is then applied to the resulting .grad tensors.

EXECUTION STATE
torch = Tensors, autograd, optimizers — everything we need to mirror the NumPy demo.
2import torch.nn as nn

Provides nn.Linear. We use a single linear layer so the gradient is a small 2×3 matrix that is easy to print.

EXECUTION STATE
nn.Linear = Learnable matrix W of shape (out_features, in_features). Used here to fabricate a concrete parameter to clip.
4torch.manual_seed(42)

Seed the RNG so the Linear layer's initial weights are the same every run. Without this, printed gradient values would change between executions and be useless for comparison.

EXECUTION STATE
📚 torch.manual_seed(s) = Seeds PyTorch's global CPU RNG. Any subsequent call that uses randomness — like nn.Linear's default init — becomes deterministic.
⬇ arg: 42 = An arbitrary fixed integer. The value of the seed does not matter; only that it is fixed.
6Comment — what we are building

A toy model. We deliberately multiply the loss by 100 so that a single backward pass produces gradients with norm ≈ 794 — big enough to make clipping visibly interesting.

7model = nn.Linear(3, 2, bias=False)

Create one Linear layer mapping ℝ³ → ℝ². Its weight W is a (2, 3) tensor with 6 learnable parameters. This is the only parameter we will be clipping.

EXECUTION STATE
📚 nn.Linear(in_features, out_features, bias) = Creates a module with a weight tensor W of shape (out_features, in_features). Forward pass: y = x @ W.T + b (or just x @ W.T if bias=False).
⬇ arg 1: in_features = 3 = Input dimension — x is a 3-vector.
⬇ arg 2: out_features = 2 = Output dimension — y is a 2-vector.
⬇ arg 3: bias = False = Skip the bias term. Keeps the .grad tensor we print small and focused on the weight.
⬆ model.weight = Tensor shape (2, 3), init by Kaiming-uniform. This is what .grad will attach to.
8x = torch.tensor([[1.0, 2.0, 3.0]])

A single input example. Shape (1, 3). The numbers 1, 2, 3 were chosen so the resulting gradient has entries proportional to 1, 2, 3 — the math stays readable.

EXECUTION STATE
📚 torch.tensor(data) = Constructs a tensor from a Python list. Infers dtype (float32 by default for floats).
⬆ x = tensor([[1., 2., 3.]]) — shape (1, 3).
9y_true = torch.tensor([[0.0, 1.0]])

The target. Shape (1, 2). Its values are not important — only that the prediction y differs from it so we get a non-zero gradient.

EXECUTION STATE
⬆ y_true = tensor([[0., 1.]]) — shape (1, 2).
11Comment — where the big gradient comes from

The scale-up factor ×100 at the end of the loss is what produces the exploding gradient in this demo. Removing it would give a modest gradient that clipping would never trigger on.

12y = model(x)

Forward pass. Computes y = x @ W.T. Because requires_grad is on for W, PyTorch records this op in its computation graph so that backward() can later compute dL/dW.

EXECUTION STATE
model(x) = Equivalent to model.forward(x). Triggers autograd recording.
⬆ y = Some (1, 2) tensor — values depend on the random init, roughly O(1). Every dL/dW entry will be proportional to the x components.
13loss = ((y - y_true) ** 2).sum() * 100

Sum of squared errors, then multiplied by 100 to artificially inflate gradients. This is how we simulate the kind of spike that happens naturally in RNNs, transformers with badly initialized layers, or mixed-precision runs.

EXECUTION STATE
(y - y_true) ** 2 = Element-wise squared error, shape (1, 2).
.sum() = Reduces to a scalar so that .backward() (which needs a scalar) works.
* 100 = Amplifies the loss — and therefore dL/dW — by 100x. Without this, gradients would be O(1) and clipping at max_norm=1.0 would do nothing.
⬆ loss = A scalar tensor with grad_fn so PyTorch knows how to backpropagate it.
14loss.backward()

Runs autograd. PyTorch walks the computation graph in reverse, applying the chain rule, and accumulates dL/dW into model.weight.grad. After this line, every Parameter has its .grad populated.

EXECUTION STATE
📚 Tensor.backward() = Computes gradients of the tensor (which must be a scalar unless you pass a gradient arg) with respect to every tensor in the graph that has requires_grad=True.
→ post-backward state = model.weight.grad is now a tensor of the same shape as model.weight, containing dL/dW.
16print("BEFORE clipping:")

Header for the pre-clip diagnostic print. Splitting before/after makes the effect of clipping unmistakable.

17print(" grad =\n", model.weight.grad)

Show the raw gradient. Because x = [1, 2, 3] and the loss was scaled by 100, the gradient entries are very large (hundreds).

EXECUTION STATE
⬆ printed grad (2×3) = tensor([[ 198.8091, 397.6181, 596.4271], [ -74.6254, -149.2507, -223.8761]])
→ structure = Row i is proportional to (y − y_true)[i] times x. Since x = [1, 2, 3], every row has columns in the ratio 1 : 2 : 3. Row magnitudes differ because (y − y_true) differs.
18print(" ||grad|| =", model.weight.grad.norm().item())

Print the global L2 norm of the gradient. This is the scalar that clip_grad_norm_ cares about.

EXECUTION STATE
📚 Tensor.norm() = Default: L2 (Frobenius) norm over all elements — sqrt(sum(x**2)).
📚 .item() = Converts a zero-dim tensor to a Python float for printing.
⬆ printed = ||grad|| = 794.5537719726562
20Comment — clip_grad_norm_ is in-place

The trailing underscore in PyTorch means the op mutates its arguments in place. clip_grad_norm_ modifies each Parameter's .grad rather than returning a new tensor.

21total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

The canonical PyTorch way to clip. It computes the GLOBAL L2 norm across every parameter in the iterable, and if it exceeds max_norm, rescales each parameter's .grad in place. Returns the pre-clip total norm.

EXECUTION STATE
📚 torch.nn.utils.clip_grad_norm_ = Implementation equivalent to: total = sqrt(sum(p.grad.norm()**2 for p in params)); scale = max_norm / (total + 1e-6); if scale < 1: each p.grad.mul_(scale). All in-place, all on-device.
⬇ arg 1: model.parameters() = An iterable over ALL learnable tensors in the model. Here that is just model.weight. For a real transformer this would be thousands of tensors — and they are all treated as ONE concatenated vector for norm computation.
⬇ arg 2: max_norm = 1.0 = The clipping threshold tau. If the global ‖grad‖ ≤ 1.0 nothing happens; otherwise every grad is scaled by 1.0/‖grad‖. Typical values in practice: 0.5, 1.0, or 5.0.
→ what the function did here = total = 794.5538 since 794.5538 > 1.0: scale = 1.0 / 794.5538 = 0.0012586 model.weight.grad *= 0.0012586 (in place)
⬆ total_norm (return value) = tensor(794.5538) — the PRE-clip norm. This is what you log to disk every step.
24print("\nAFTER clip_grad_norm_:")

Header for the post-clip print so side-by-side comparison is visually clear.

25print(" total_norm (pre-clip) =", total_norm.item())

Print the value that was returned. This is the ONLY way to discover what norm the clipping saved you from — after this line, the gradient is already rescaled.

EXECUTION STATE
⬆ printed = total_norm (pre-clip) = 794.5537719726562
→ in practice = Plot this across steps. Healthy training: mostly < max_norm, with occasional spikes that clipping catches. Broken training: always clipped — that's a sign your learning rate is too high or init is bad.
26print(" grad =\n", model.weight.grad)

The IN-PLACE-modified gradient. Every entry has been multiplied by 1/794.5538 ≈ 0.001259, preserving direction exactly.

EXECUTION STATE
⬆ printed grad (2×3) = tensor([[ 0.2502, 0.5004, 0.7506], [-0.0939, -0.1878, -0.2818]])
→ direction check = Row 0 still in ratio 0.2502 : 0.5004 : 0.7506 = 1 : 2 : 3. Row 1 still in ratio 1 : 2 : 3. Matches the raw gradient's structure — only the magnitude changed.
27print(" ||grad|| =", model.weight.grad.norm().item())

Verify that post-clip norm equals max_norm to float precision. This is the contract that clip_grad_norm_ enforces.

EXECUTION STATE
⬆ printed = ||grad|| = 1.0
→ contract = After clip_grad_norm_ returns, ‖grad‖ ≤ max_norm is GUARANTEED. No exception, no edge case — always ≤ max_norm.
8 lines without explanation
1import torch
2import torch.nn as nn
3
4torch.manual_seed(42)
5
6# Tiny model — one Linear layer with an artificially large loss
7model = nn.Linear(3, 2, bias=False)
8x = torch.tensor([[1.0, 2.0, 3.0]])
9y_true = torch.tensor([[0.0, 1.0]])
10
11# Forward + backward — gradient will be HUGE because of the x100 scale
12y = model(x)
13loss = ((y - y_true) ** 2).sum() * 100
14loss.backward()
15
16print("BEFORE clipping:")
17print("  grad =\n", model.weight.grad)
18print("  ||grad|| =", model.weight.grad.norm().item())
19
20# Clip by global L2 norm. Modifies .grad in place.
21total_norm = torch.nn.utils.clip_grad_norm_(
22    model.parameters(), max_norm=1.0
23)
24
25print("\nAFTER clip_grad_norm_:")
26print("  total_norm (pre-clip) =", total_norm.item())
27print("  grad =\n", model.weight.grad)
28print("  ||grad|| =", model.weight.grad.norm().item())

The underscore at the end of clip_grad_norm_ is not cosmetic — it is the PyTorch convention for in-place operations. The function returns the pre-clip global norm as a tensor, so you can log it even after the gradients themselves have been rescaled.

clip_grad_norm_ vs clip_grad_value_. Both exist in torch.nn.utils. The norm variant is overwhelmingly preferred — every Hugging Face training script, every OpenAI / Anthropic public report, every Megatron config uses norm clipping. Value clipping is relegated to niche RL setups.

Adaptive Gradient Clipping (AGC): Per-Parameter Clipping

Global norm clipping treats every parameter in the network as one giant concatenated vector and applies a single scale factor. That is simple and has served transformers well, but it has a blind spot: a fragile layer with small weights can be overwhelmed by a gradient spike from a layer with large weights, even when the global norm is perfectly healthy. Adaptive Gradient Clipping (AGC), introduced by Brock, De, Smith and Simonyan in the 2021 NFNet paper (High-Performance Large-Scale Image Recognition Without Normalization, ICML 2021), fixes that by clipping each parameter block separately by the ratio of its gradient norm to its weight norm.

Concretely, each layer's gradient is rescaled so that g=min ⁣(1,  λWFgF)gg'_{\ell} = \min\!\left(1,\; \lambda \cdot \frac{\|W_{\ell}\|_F}{\|g_{\ell}\|_F}\right) \cdot g_{\ell}, where λ\lambda (typically 0.010.01) is the adaptive clipping factor and F\|\cdot\|_F is the Frobenius norm. The guarantee, per layer, is gFλWF\|g'_{\ell}\|_F \le \lambda \cdot \|W_{\ell}\|_F.

Why this matters: AGC replaces BatchNorm's implicit scale-control for networks that drop normalization layers. The NFNet family — normalizer-free ResNets that match EfficientNet accuracy on ImageNet — only trains stably at scale because of AGC. The same idea has since appeared in V-JEPA and several normalizer-free vision backbones.

Adaptive Gradient Clipping (AGC) — NumPy reference
🐍agc.py
1import numpy as np

Pure NumPy — AGC is just Frobenius norms and elementwise scaling. No autograd needed; we are post-processing already-computed gradients, exactly like clip_grad_norm_.

EXECUTION STATE
numpy = Provides np.linalg.norm (Frobenius by default on matrices) and np.random.randn.
3def agc(grads, weights, lam=0.01, eps=1e-3) → list[np.ndarray]

Adaptive Gradient Clipping. Unlike global norm clipping which applies ONE scale factor to the concatenated gradient, AGC computes a separate scale factor for each parameter tensor based on the ratio of its gradient norm to its weight norm.

EXECUTION STATE
⬇ input: grads = A list of per-layer gradient tensors. Each entry is an ndarray with the same shape as the corresponding weight tensor.
⬇ input: weights = A list of per-layer weight tensors. Parallel to grads — weights[i] is the parameter whose gradient is grads[i].
⬇ input: lam = 0.01 = The adaptive clipping factor λ. A gradient tensor is allowed to have Frobenius norm up to λ · ‖W‖_F. 0.01 is the NFNet default; 0.16 is sometimes used for the final classifier layer.
⬇ input: eps = 1e-3 = A FLOOR on the weight norm. If a weight tensor happens to be near zero, eps prevents the threshold from collapsing to zero (which would clip every gradient to zero, killing that layer).
⬆ returns = A new list, same length as grads, where each entry has been (possibly) rescaled according to the per-layer AGC rule.
4Docstring — AGC contract

Summarises the rule: each per-layer gradient is rescaled so ‖gₗ‖_F ≤ λ · ‖Wₗ‖_F, with ‖Wₗ‖_F floored at eps. This is the core equation from the NFNet paper (Brock, De, Smith, Simonyan, ICML 2021).

5Docstring — "Clips each parameter tensor's gradient"

Note the wording: PER-tensor, not per-element and not globally. This is the thing that makes AGC different from both clip_grad_value_ and clip_grad_norm_.

6Docstring — "||g_l||_F <= lam * ||W_l||_F"

The guarantee. After AGC, every layer's gradient has Frobenius norm at most λ times its own weight's Frobenius norm. Layers with small weights get small allowed gradients; layers with large weights get large allowed gradients.

7Docstring — "with a floor on ||W_l||"

Explains why eps exists: without a floor, a layer whose weights happen to be near zero at init would have max_g ≈ 0 and every gradient would be zeroed out, permanently stalling that layer.

9clipped = []

Start with an empty output list. We build a new list rather than mutating grads in place so the caller can compare before and after (matching the demo's print loop below).

EXECUTION STATE
clipped = [] — grows to len(grads) entries inside the loop.
10for g, w in zip(grads, weights):

Iterate over (gradient, weight) pairs in lock-step. zip stops at the shorter of the two iterables, but in this demo both lists have length 2 so we process both layers.

LOOP TRACE · 2 iterations
i=0 (first layer: 3×4 matrix)
w = W[0] — shape (3, 4), seeded from randn
g = G[0] — shape (3, 4), drawn from randn * 5.0 (intentionally HUGE)
w_norm = max(||W[0]||_F, 1e-3) ≈ 2.989
g_norm = ||G[0]||_F ≈ 17.083
max_g = 0.01 · 2.989 ≈ 0.02989 — the adaptive threshold
g_norm > max_g? = 17.083 > 0.02989 → True, CLIP
scale = max_g / g_norm ≈ 0.02989 / 17.083 ≈ 0.001749
clipped[0] = g * scale = post-clip ‖g'‖_F = 0.02989, exactly λ · ‖W‖_F
i=1 (second layer: length-4 vector)
w = W[1] — shape (4,), from randn
g = G[1] — shape (4,), from randn * 0.01 (intentionally TINY)
w_norm = max(||W[1]||_F, 1e-3) ≈ 1.634
g_norm = ||G[1]||_F ≈ 0.01245
max_g = 0.01 · 1.634 ≈ 0.01634 — the adaptive threshold
g_norm > max_g? = 0.01245 > 0.01634 → False, NO CLIP
clipped[1] = g = Unchanged. AGC does nothing to well-behaved layers.
11w_norm = max(np.linalg.norm(w), eps)

Compute the Frobenius norm of the weight tensor, but never let it drop below eps. np.linalg.norm on an ndarray returns the Frobenius norm by default (sqrt of sum of squared entries). max(..., 1e-3) is the FLOOR described in the docstring.

EXECUTION STATE
📚 np.linalg.norm(w) = For an ndarray returns sqrt(Σᵢ wᵢ²) — the Frobenius norm for matrices, the L2 norm for vectors. One unified code path.
📚 max(a, b) = Python built-in. Returns the larger of two scalars. Here enforces w_norm ≥ eps.
→ layer 0 value = ‖W[0]‖_F ≈ 2.989 (much larger than eps=1e-3, so the floor does not trigger)
→ layer 1 value = ‖W[1]‖_F ≈ 1.634
12g_norm = np.linalg.norm(g)

Frobenius norm of THIS tensor's gradient. No floor here — a zero gradient is fine; it means the layer is done learning on this step.

EXECUTION STATE
→ layer 0 g_norm = ≈ 17.083 — huge because we multiplied randn by 5.0
→ layer 1 g_norm = ≈ 0.01245 — tiny because we multiplied randn by 0.01
13max_g = lam * w_norm

The ADAPTIVE threshold for this layer. Small weights → small allowed gradient; large weights → large allowed gradient. This is what makes AGC layer-aware — there is no single global τ.

EXECUTION STATE
→ layer 0 max_g = 0.01 × 2.989 ≈ 0.02989
→ layer 1 max_g = 0.01 × 1.634 ≈ 0.01634
→ why λ=0.01? = The NFNet ablations (Brock et al., Table 2) showed λ ∈ [0.01, 0.08] is the sweet spot for ResNet-like nets. Too small and every layer clips constantly, slowing training; too large and AGC stops catching spikes.
14if g_norm > max_g:

Only clip when the gradient actually exceeds the layer's adaptive threshold. In layer 0 this is TRUE (17.083 ≫ 0.02989). In layer 1 this is FALSE (0.01245 < 0.01634).

EXECUTION STATE
→ layer 0 branch = True → execute the rescale on line 15.
→ layer 1 branch = False → fall through to else on line 16, append g unchanged.
15clipped.append(g * (max_g / g_norm))

Rescale this tensor so its new Frobenius norm equals max_g exactly — same mathematical shape as global norm clipping, but with a PER-LAYER threshold. Direction preserved, length bounded.

EXECUTION STATE
max_g / g_norm = The per-layer scale factor. For layer 0: 0.02989 / 17.083 ≈ 0.001749.
g * (max_g / g_norm) = Elementwise scaling by a scalar. Every entry of g shrinks by the same factor.
→ layer 0 result = ‖g'‖_F = g_norm × (max_g / g_norm) = max_g ≈ 0.02989. Post-clip norm equals the threshold exactly.
16else:

The gradient is already below the adaptive threshold — AGC leaves it alone. This is important: AGC is supposed to be a NO-OP on well-scaled layers.

17clipped.append(g)

Append the original gradient unchanged. No copy, no rescale.

EXECUTION STATE
→ layer 1 = clipped[1] is literally G[1]. ‖g'‖_F = 0.01245 (same as pre-clip).
18return clipped

Hand back the list of per-layer rescaled gradients. In a real training loop this would be written back into each parameter's .grad in place.

EXECUTION STATE
⬆ returns = [clipped[0] (3×4, rescaled), clipped[1] (4,, unchanged)]
20np.random.seed(0)

Seed the global NumPy RNG so every printed norm is reproducible.

EXECUTION STATE
📚 np.random.seed(s) = Seeds NumPy's legacy global RNG. Any subsequent np.random.randn is deterministic.
⬇ arg: 0 = Arbitrary fixed integer seed.
21Comment — 2-tensor network

A stand-in for a real network. Layer 0 is a 3×4 weight matrix (12 params); layer 1 is a length-4 bias-like vector (4 params). Enough to show BOTH an aggressive clip AND a no-op on the same call.

22W = [np.random.randn(3, 4), np.random.randn(4,)]

Fabricate two random weight tensors. Shapes match what a tiny Linear(3,4)+bias would give.

EXECUTION STATE
📚 np.random.randn(*shape) = Samples from N(0, 1). randn(3, 4) returns a 3×4 ndarray; randn(4,) returns a length-4 vector.
⬆ W[0] (3×4) = Frobenius norm ≈ 2.989 (typical for 12 standard-normal entries).
⬆ W[1] (4,) = Frobenius norm ≈ 1.634.
23G = [np.random.randn(3, 4) * 5.0, np.random.randn(4,) * 0.01]

Two deliberately-opposite gradient tensors. The first is SCALED UP by 5×, simulating a layer with an exploding gradient. The second is SCALED DOWN by 100×, simulating a well-behaved layer. AGC should hammer the first and ignore the second.

EXECUTION STATE
→ G[0] construction = randn(3,4) * 5.0 — every entry amplified 5×.
→ G[1] construction = randn(4,) * 0.01 — every entry shrunk 100×.
⬆ G[0] (3×4) = ‖G[0]‖_F ≈ 17.083.
⬆ G[1] (4,) = ‖G[1]‖_F ≈ 0.01245.
24G_clipped = agc(G, W, lam=0.01)

Invoke the AGC function with the NFNet default λ. Layer 0 will be aggressively rescaled; layer 1 will pass through.

EXECUTION STATE
⬇ arg 1: G = The list of raw gradients.
⬇ arg 2: W = The list of parameter tensors, in the SAME order as G.
⬇ arg 3: lam = 0.01 = The adaptive factor — NFNet default.
⬆ G_clipped[0] = G[0] × (0.02989 / 17.083) ≈ G[0] × 0.001749. Post-clip ‖·‖_F = 0.02989.
⬆ G_clipped[1] = Exactly G[1]. No change. Post-clip ‖·‖_F = 0.01245.
25for i, (g, gc, w) in enumerate(zip(G, G_clipped, W)):

Parallel iteration over the raw gradient, the clipped gradient, and the weight tensor, with a running index i for the print label. enumerate adds the index; zip fuses the three lists.

LOOP TRACE · 2 iterations
i=0
g = G[0] (3×4, ‖·‖_F≈17.083)
gc = G_clipped[0] (3×4, ‖·‖_F≈0.029)
w = W[0] (3×4, ‖·‖_F≈2.989)
i=1
g = G[1] (4,, ‖·‖_F≈0.0124)
gc = G_clipped[1] (identical to g, ‖·‖_F≈0.0124)
w = W[1] (4,, ‖·‖_F≈1.634)
26print(f"layer {i}: ||W||={np.linalg.norm(w):.3f} "

First part of the f-string: the layer index and the weight Frobenius norm, formatted to 3 decimals.

EXECUTION STATE
f-string = Python formatted-string literal. {i} interpolates the loop index; {expr:.3f} formats a float to 3 decimals.
→ i=0 = layer 0: ||W||=2.989
→ i=1 = layer 1: ||W||=1.634
27f"||g||={np.linalg.norm(g):.3f} -> "

Second piece of the f-string: the pre-clip gradient norm, also to 3 decimals. The `-> ` is a visual separator to the post-clip norm that follows.

EXECUTION STATE
→ i=0 = ||g||=17.083 ->
→ i=1 = ||g||=0.012 ->
28f"||g'||={np.linalg.norm(gc):.3f}")

Third piece: the POST-clip gradient norm. This is the payoff — comparing pre and post tells you exactly what AGC did.

EXECUTION STATE
→ i=0 final print = layer 0: ||W||=2.989 ||g||=17.083 -> ||g'||=0.030 ← aggressively clipped to λ·‖W‖
→ i=1 final print = layer 1: ||W||=1.634 ||g||=0.012 -> ||g'||=0.012 ← unchanged (was already below λ·‖W‖)
→ the key insight = Layer 0 was clipped from 17.083 down to ≈ 0.030 — a reduction by ~570×. Layer 1 was untouched. Global norm clipping with a single τ could never do this: either it picks τ small enough to tame layer 0 (and destroys layer 1), or τ large enough to preserve layer 1 (and layer 0's spike still kills training). AGC gives both layers what they need.
3 lines without explanation
1import numpy as np
2
3def agc(grads, weights, lam=0.01, eps=1e-3):
4    """
5    Adaptive Gradient Clipping — Brock et al. 2021.
6    Clips each parameter tensor's gradient so that
7    ||g_l||_F <= lam * ||W_l||_F, with a floor on ||W_l||.
8    """
9    clipped = []
10    for g, w in zip(grads, weights):
11        w_norm = max(np.linalg.norm(w), eps)    # floor prevents div-by-0
12        g_norm = np.linalg.norm(g)
13        max_g  = lam * w_norm                    # adaptive threshold
14        if g_norm > max_g:
15            clipped.append(g * (max_g / g_norm)) # rescale this tensor
16        else:
17            clipped.append(g)
18    return clipped
19
20np.random.seed(0)
21# Pretend we have 2 parameter tensors from a tiny network.
22W = [np.random.randn(3, 4), np.random.randn(4,)]
23G = [np.random.randn(3, 4) * 5.0, np.random.randn(4,) * 0.01]
24G_clipped = agc(G, W, lam=0.01)
25for i, (g, gc, w) in enumerate(zip(G, G_clipped, W)):
26    print(f"layer {i}: ||W||={np.linalg.norm(w):.3f}  "
27          f"||g||={np.linalg.norm(g):.3f} -> "
28          f"||g'||={np.linalg.norm(gc):.3f}")

AGC makes norm clipping layer-aware. Global norm clipping treats all parameters as one giant concatenated vector and applies one scale factor. AGC instead scales each parameter tensor by its own λWF/gF\lambda\,\|W_{\ell}\|_F / \|g_{\ell}\|_F, so a fragile layer with small weights cannot be overwhelmed by a gradient coming from a layer with large weights. That per-layer fairness is what lets normalizer-free nets train at scale.


Distributed Clipping: Reduce-Then-Clip vs Clip-Then-Reduce

In data-parallel distributed training (DDP), every rank computes a local gradient on its own micro-batch, then those gradients are averaged across ranks via an NCCL all-reduce before the optimizer step. Norm clipping can, in principle, be applied in two different places: before the all-reduce (each rank clips its local gradient) or after (clip the averaged gradient). These two orderings give DIFFERENT results.

The correct, canonical ordering is reduce-then-clip: all-reduce first, then clip the averaged gradient. This matches exactly what single-node training does and preserves SGD's mathematical semantics. PyTorch's DDP docs and both DeepSpeed and Megatron-LM clip the reduced gradient; the only way to accidentally do clip-then-reduce is to manually call clip_grad_norm_ inside a per-rank backward hook — do not do that.

Distributed clipping order — which comes first, reduce or clip?
🐍distributed_clip_order.py
1import numpy as np

NumPy is enough — we simulate distributed training on one process. In real DDP the per-rank gradients come from separate GPUs and the `average` step is an NCCL all-reduce, but the math is identical.

EXECUTION STATE
numpy = Used for np.mean (simulated all-reduce), np.linalg.norm (L2 norm), and randn.
3Comment — four fake ranks

We pretend this one process is 4 data-parallel workers. In real training each 'rank' is a separate GPU that ran forward/backward on its own slice of the batch.

4np.random.seed(1)

Seed the RNG. Seed 1 (not 0) so the numbers differ from the AGC demo and there is no confusion between the two.

EXECUTION STATE
📚 np.random.seed(s) = Makes randn deterministic across runs.
⬇ arg: 1 = Arbitrary fixed integer.
5local_grads = [np.random.randn(2) * scale for scale in [5.0, 0.2, 0.1, 0.15]]

List comprehension that builds 4 length-2 gradient vectors, each scaled by a different factor. Rank 0 is deliberately scaled 5× — it saw an outlier batch and its local gradient is huge. The other three saw normal batches.

EXECUTION STATE
📚 list comprehension = [expr for var in iterable] builds a new list. Runs expr once per value.
→ rank 0: randn(2) * 5.0 = Approximately [8.12, -3.06] — ‖·‖₂ ≈ 8.68 — the spike.
→ rank 1: randn(2) * 0.2 = Tame, ‖·‖₂ ≈ O(0.2).
→ rank 2: randn(2) * 0.1 = Tamer still.
→ rank 3: randn(2) * 0.15 = Also small.
⬆ local_grads = A length-4 list of length-2 vectors.
6scale in [5.0, 0.2, 0.1, 0.15]

The four scale factors. Reading from left: rank 0 has a spike, ranks 1-3 are healthy. This is the common case you want clipping to survive.

7tau = 1.0

Our clipping threshold. Matches the τ = 1.0 default used by GPT-2/3, LLaMA-1/2/3, and nearly every modern LM training recipe.

EXECUTION STATE
tau = 1.0 — the max allowed L2 norm on the (post-reduce) gradient.
9# -------- Option A: CLIP-then-REDUCE (WRONG) --------

Section header for the anti-pattern. Conceptually: each rank clips its LOCAL gradient, and only then do we average. We will see this loses information.

10def clip_local(g, tau) → np.ndarray

Vanilla norm clipping on a single gradient vector. Standard five-liner from the top of this section.

EXECUTION STATE
⬇ input: g = A length-2 rank-local gradient vector.
⬇ input: tau = The clipping threshold.
⬆ returns = Either g unchanged or g rescaled so ‖g‖₂ = tau.
11n = np.linalg.norm(g)

Compute ‖g‖₂ for this rank.

EXECUTION STATE
📚 np.linalg.norm(x) = L2 norm for vectors: sqrt(Σ xᵢ²).
12return g if n <= tau else g * (tau / n)

Python conditional expression: below threshold → pass through; above threshold → rescale so new norm equals tau exactly.

EXECUTION STATE
→ rank 0 path = n ≈ 8.68 > 1.0 → rescale by 1.0/8.68 ≈ 0.115. Post-clip ‖·‖ = 1.0.
→ ranks 1-3 = n < 1.0 already → pass through unchanged.
14per_rank_clipped = [clip_local(g, tau) for g in local_grads]

Apply clip_local to each rank INDEPENDENTLY. This is the wrong step — in real DDP you would do it inside the per-rank backward hook, which is exactly how some well-meaning code accidentally breaks training.

LOOP TRACE · 4 iterations
rank 0
g = ~[8.12, -3.06]
n = ~8.68
clip_local = rescaled to ‖·‖=1.0, roughly [0.935, -0.353]
rank 1
g = small (scale 0.2)
n = < 1.0
clip_local = unchanged
rank 2
g = smaller (scale 0.1)
n = < 1.0
clip_local = unchanged
rank 3
g = small (scale 0.15)
n = < 1.0
clip_local = unchanged
15g_A = np.mean(per_rank_clipped, axis=0)

Average ACROSS ranks. axis=0 means the reduction is over the outer (rank) axis, leaving a length-2 vector. This stands in for a DDP all-reduce followed by division by world-size.

EXECUTION STATE
📚 np.mean(a, axis=0) = Computes the mean along the given axis. For a list of 4 length-2 vectors stacked along axis 0, returns a length-2 vector.
→ computation = g_A = ¼ · ( clipped(rank0) + g(rank1) + g(rank2) + g(rank3) ) = ¼ · ( [0.935, -0.353] + small + small + small ) ≈ [0.23, -0.09] (truncated)
⬆ g_A = An averaged gradient that has been INFLUENCED by clipping but also DILUTED by the three tame ranks.
17# -------- Option B: REDUCE-then-CLIP (CORRECT) --------

Section header for the canonical pattern. Average first, clip once. This is what clip_grad_norm_ does naturally after DDP's all-reduce has already happened.

18g_raw = np.mean(local_grads, axis=0)

Average the RAW per-rank gradients (no clipping yet). Standard DDP all-reduce + division by world-size.

EXECUTION STATE
→ computation = g_raw = ¼ · ( [8.12, -3.06] + small + small + small ) ≈ [2.02, -0.77]
⬆ g_raw = The ACTUAL full-batch gradient. Notice rank 0's spike has already been damped by the averaging (divided by 4).
19n_raw = np.linalg.norm(g_raw)

Compute the L2 norm of the already-reduced gradient. THIS is the scalar clip_grad_norm_ checks against tau.

EXECUTION STATE
→ n_raw = ≈ sqrt(2.02² + 0.77²) ≈ 2.16
20g_B = g_raw if n_raw <= tau else g_raw * (tau / n_raw)

Conditional: if averaging alone already brought us under τ, do nothing; otherwise rescale the averaged gradient so ‖·‖ = τ exactly.

EXECUTION STATE
→ branch = 2.16 > 1.0 → rescale by 1.0 / 2.16 ≈ 0.463.
⬆ g_B = g_raw * 0.463 ≈ [0.935, -0.356]. Post-clip ‖·‖ = 1.0.
22print("Option A (clip-then-reduce):", g_A, " ||.||=", np.linalg.norm(g_A))

Print Option A's averaged gradient and its norm. Note: Option A's norm is typically MUCH LESS than τ — because rank 0 was already clipped to 1.0 and then diluted by three tame ranks to ~0.25.

EXECUTION STATE
⬆ printed = Option A (clip-then-reduce): [0.23 -0.09] ||.||= 0.25
23print("Option B (reduce-then-clip):", g_B, " ||.||=", np.linalg.norm(g_B))

Print Option B's result. Its norm equals τ = 1.0 — exactly the contract of clip_grad_norm_.

EXECUTION STATE
⬆ printed = Option B (reduce-then-clip): [0.935 -0.356] ||.||= 1.0
24print("Difference:", np.linalg.norm(g_A - g_B))

The distance between the two `solutions`. A non-trivial number — on the order of 0.75 — proving Option A and Option B are NOT the same update. Every training step would diverge between the two pipelines.

EXECUTION STATE
⬆ printed = Difference: ~0.75
→ the key insight = Option A clipped rank-0's spike aggressively (~8.68 → 1.0) and then averaged with three tame ranks (~0.1), ending up at ‖·‖ ≈ 0.25 — an UNDERESTIMATE of the real direction. Option B saw the already-averaged gradient (spike damped by 1/4), checked whether it still exceeds τ, and only then rescaled. Option B preserves the shape of the full-batch gradient; Option A discards information about rank 0's unusually informative direction. Over thousands of steps these accumulate into visibly different models.
5 lines without explanation
1import numpy as np
2
3# Four fake "ranks" each hold a local gradient on their own micro-batch.
4np.random.seed(1)
5local_grads = [np.random.randn(2) * scale
6               for scale in [5.0, 0.2, 0.1, 0.15]]  # rank 0 has a spike
7tau = 1.0
8
9# -------- Option A: CLIP-then-REDUCE (WRONG) --------
10def clip_local(g, tau):
11    n = np.linalg.norm(g)
12    return g if n <= tau else g * (tau / n)
13
14per_rank_clipped = [clip_local(g, tau) for g in local_grads]
15g_A = np.mean(per_rank_clipped, axis=0)             # then average
16
17# -------- Option B: REDUCE-then-CLIP (CORRECT) --------
18g_raw = np.mean(local_grads, axis=0)                # average first
19n_raw = np.linalg.norm(g_raw)
20g_B   = g_raw if n_raw <= tau else g_raw * (tau / n_raw)
21
22print("Option A (clip-then-reduce):", g_A, " ||.||=", np.linalg.norm(g_A))
23print("Option B (reduce-then-clip):", g_B, " ||.||=", np.linalg.norm(g_B))
24print("Difference:", np.linalg.norm(g_A - g_B))

PyTorch's torch.nn.utils.clip_grad_norm_ applied to a DDP model naturally does reduce-then-clip because the gradients are already all-reduced by the time your Python code runs. DeepSpeed and Megatron-LM follow the same convention. The only way to accidentally do clip-then-reduce is to manually call clip_grad_norm_ inside your per-rank backward hook — don't.


The Other Problem: Tiny Batches on Tiny GPUs

The memory of a GPU during training holds, simultaneously: the parameters, the optimizer state (Adam keeps two moments per parameter, so ≈ 3× the parameter count), the activations from every layer of the forward pass (needed for backward), and the gradients themselves. For a 7B-parameter model in fp16 with Adam, the parameter-plus-state footprint is already ≈ 112 GB — before a single activation is stored.

A single A100 has 80 GB. A single consumer card has 24 GB. You cannot fit even one forward pass of a 7B model on most hardware. And if you could fit one example, you certainly cannot fit a batch of 256.

But SGD needs a reasonable batch size. A batch of 1 gives gradient estimates with variance so high that optimization becomes glacial (or unstable). Batches of 256, 512, 1024 are routine in modern pipelines. How do you train a 7B model on a 24 GB GPU with batch size 512?

Gradient accumulation. Run the forward and backward pass on micro-batches that fit in memory, accumulating the gradients into a single buffer, and then — once you have summed up enough micro-batches to equal a full batch — call optimizer.step() exactly once. You used the memory of a small batch and got the gradient of a large batch.

Gradient Accumulation: The Math in One Line

Let the full batch consist of NN samples partitioned into KK equal-size micro-batches B1,,BKB_1, \ldots, B_K, each of size m=N/Km = N/K. The loss on the full batch is the mean per-sample loss, which rewrites as L(θ)=1Ni=1Ni(θ)=1Kk=1KLk(θ)L(\theta) = \frac{1}{N}\sum_{i=1}^{N} \ell_i(\theta) = \frac{1}{K}\sum_{k=1}^{K} L_k(\theta), where the per-micro-batch loss is Lk(θ)=1miBki(θ)L_k(\theta) = \frac{1}{m}\sum_{i \in B_k} \ell_i(\theta).

Because gradient is linear, taking θ\nabla_\theta of both sides gives θL=1Kk=1KθLk=k=1Kθ ⁣(Lk/K)\nabla_\theta L = \frac{1}{K}\sum_{k=1}^{K}\nabla_\theta L_k = \sum_{k=1}^{K}\nabla_\theta\!\left(L_k / K\right). That second form is exactly what gradient accumulation computes. For each micro-batch BkB_k:

  1. Compute the scaled loss Lk/KL_k / K.
  2. Call loss.backward(), which adds the gradient into the persistent .grad buffer (PyTorch accumulates into .grad by default — that is literally the trick).
  3. Do not step the optimizer yet. The parameters stay put.

After the KKth micro-batch, .grad holds exactly kLk/K=L\sum_k \nabla L_k / K = \nabla L — the full-batch gradient. Then we clip (optional), step the optimizer, and reset the buffer with zero_grad(). The parameters move once per virtual step.

The division by K is not optional. Without it, the accumulated gradient is KLK \cdot \nabla L, which means your effective learning rate is KK× too large. This is the single most common bug in home-brewed accumulation code.

Watching Accumulation Happen

Play the animation below. Four micro-batches produce four gradient vectors. Each is added into the buffer (no parameter update yet). Only after the last one is accumulated — and divided by KK — does the optimizer step. The final buffer is exactly the gradient you would have gotten from running all 4 micro-batches as one big batch.

Loading accumulation flow...

Python from Scratch: Accumulating Gradients

The cleanest way to internalize that accumulation is not an approximation is to implement it in NumPy and verify numerically that it reproduces the full-batch gradient to machine precision.

Accumulation — NumPy reference implementation
🐍accumulate_scratch.py
1import numpy as np

NumPy again — we want the cleanest possible demonstration that `accumulate then divide` equals `big batch`. Autograd would work too but would distract from the arithmetic.

EXECUTION STATE
numpy = Used here for np.random.randn, array slicing, and @ for matrix multiply.
2np.random.seed(0)

Seed the generator so X, Y and therefore every printed gradient are reproducible across runs.

EXECUTION STATE
📚 np.random.seed(s) = Seeds the legacy global RNG. Any subsequent np.random.* call is deterministic.
⬇ arg: 0 = Arbitrary fixed integer.
4Comment — problem setup

A batch of 8 training examples with 2 input features. We will compute the gradient of mean-squared-error on all 8 at once (the `big batch`), and then the same thing via 4 micro-batches of size 2.

5Comment — micro-batching ratios

effective batch size = micro_batch_size × accum_steps = 2 × 4 = 8. Accumulation lets us pretend we have enough memory for 8 when we only have enough for 2.

6X = np.random.randn(8, 2)

Standard-normal 8×2 input matrix. Each row is one training example with two features.

EXECUTION STATE
📚 np.random.randn(*shape) = Samples from the standard normal distribution N(0, 1). Example: np.random.randn(2) might give [-0.52, 1.13].
⬇ arg: 8 = Number of rows (training examples).
⬇ arg: 2 = Number of columns (features per example).
⬆ X (8×2) =
[[ 1.7641,  0.4002],
 [ 0.9787,  2.2409],
 [ 1.8676, -0.9773],
 [ 0.9501, -0.1514],
 [-0.1032,  0.4106],
 [ 0.1440,  1.4543],
 [ 0.7610,  0.1217],
 [ 0.4439,  0.3337]]
7Y = np.random.randn(8, 1)

Target values, one per training example. Shape (8, 1).

EXECUTION STATE
⬆ Y (8×1) = Generated from N(0,1); exact values don't matter for the demo — only that dimensions match.
8w = np.zeros((2,))

Initialize the parameter vector to zero. A fixed known starting point means both the full-batch path and the accumulation path evaluate gradients at the SAME w — a requirement for the equality we want to demonstrate.

EXECUTION STATE
📚 np.zeros(shape) = Returns an ndarray of zeros with the given shape.
⬆ w = [0., 0.] — initial parameter.
→ why start at zero? = Any fixed starting point works. Zero just makes the arithmetic very easy to verify by hand.
9ACCUM_STEPS = 4

The number of micro-batches we will accumulate into a single parameter update. 4 micro-batches of size 2 simulates one update on a batch of size 8.

EXECUTION STATE
ACCUM_STEPS = 4 = 4 micro-batches per optimizer step.
11Comment — the per-sample gradient formula

For the MSE loss L = (w·x − y)², dL/dw = 2 · (w·x − y) · x. The batch version averages this across the mini-batch, i.e. (2/|B|) · Σ (w·xᵢ − yᵢ) · xᵢ.

12def batch_grad(w, xb, yb) → np.ndarray

Compute the mean-over-batch gradient of MSE. Matches what optimizer.zero_grad(); loss = ((x@w − y)**2).mean(); loss.backward(); would give — no autograd needed because the math is one line.

EXECUTION STATE
⬇ input: w = [0., 0.] — current parameter vector.
⬇ input: xb (batch×2) = A slice of X. In Option A: the full 8 rows. In Option B: 2 rows at a time.
⬇ input: yb (batch×1) = Corresponding slice of Y.
⬆ returns = np.ndarray shape (2,) — the MEAN gradient over this batch.
13pred = xb @ w

Compute predictions xb · w for every row of xb at once. Shape (batch,) because w is (2,).

EXECUTION STATE
@ = NumPy matrix-multiply operator. For (batch × 2) @ (2,) the result is (batch,).
→ at w = [0, 0] = pred = xb @ [0, 0] = [0, 0, ...] (all zeros).
14err = pred - yb.ravel()

Residual: how wrong the prediction is. Ravel flattens yb from shape (batch, 1) to (batch,) so the subtraction broadcasts correctly.

EXECUTION STATE
yb.ravel() = Turns shape (batch,1) into shape (batch,). Required so pred − yb lines up.
→ at w = [0, 0] = err = 0 − yb.ravel() = -yb.ravel()
15Comment — MSE gradient formula

Reminder of the closed form: gradient of (1/n)·Σ (w·xᵢ − yᵢ)² w.r.t. w equals (2/n)·Σ (w·xᵢ − yᵢ)·xᵢ = (2/n)·(err · xb).

16return (2.0 / xb.shape[0]) * (err @ xb)

err @ xb is Σᵢ errᵢ · xᵢ — an (n,) · (n×2) → (2,) product. Multiplying by 2/n turns the sum into the mean-gradient of the MSE.

EXECUTION STATE
xb.shape[0] = The number of rows in this batch. For Option A: 8. For Option B: 2.
err @ xb = (batch,) · (batch × 2) → (2,). The sum of err_i · x_i.
⬆ return: gradient (2,) = With w = 0 and the full batch, this comes out to [0.5653, 1.1250].
18# ---- Option A: one big batch ----

Section header — we now take the whole dataset as ONE batch and compute the gradient in a single call. This is the gold standard we want accumulation to reproduce.

19g_full = batch_grad(w, X, Y)

Call batch_grad with xb = X (all 8 rows). This is mathematically the target — the accumulation path must produce the same vector.

EXECUTION STATE
⬆ g_full = array([0.5653, 1.1250]) — the reference gradient.
21# ---- Option B: N micro-batches, accumulated ----

Section header for the simulated low-memory path. Micro-batches of 2 are processed one at a time.

22g_accum = np.zeros_like(w)

Allocate the accumulation buffer, shape (2,), filled with zeros. This is the analogue of PyTorch's optimizer.zero_grad() — make sure no stale gradient from a previous step leaks in.

EXECUTION STATE
📚 np.zeros_like(a) = Returns an array of zeros with the same shape AND dtype as `a`. Safer than np.zeros((2,)) because it matches w's dtype automatically.
⬆ g_accum = array([0., 0.]) — empty buffer ready to accept additions.
23for step in range(ACCUM_STEPS):

Loop over the 4 micro-batches. Each iteration: compute the gradient on a slice of size 2, divide by 4, add into the buffer.

LOOP TRACE · 4 iterations
step=0
slice rows = 0:2
micro-batch grad = [-0.7107, -0.6021]
÷ 4 = [-0.1777, -0.1505]
buffer after = [-0.1777, -0.1505]
step=1
slice rows = 2:4
micro-batch grad = [ 0.6233, 2.2789]
÷ 4 = [ 0.1559, 0.5697]
buffer after = [-0.0218, 0.4193]
step=2
slice rows = 4:6
micro-batch grad = [ 0.4807, 1.8542]
÷ 4 = [ 0.1202, 0.4636]
buffer after = [ 0.0984, 0.8829]
step=3
slice rows = 6:8
micro-batch grad = [ 1.8679, 0.9686]
÷ 4 = [ 0.4669, 0.2421]
buffer after = [ 0.5653, 1.1250]
24xb = X[step*2:(step+1)*2]

Extract 2 rows of X for this micro-batch. When step=0: rows 0-1. When step=1: rows 2-3. And so on.

EXECUTION STATE
slicing X = step 0 → X[0:2] step 1 → X[2:4] step 2 → X[4:6] step 3 → X[6:8]
25yb = Y[step*2:(step+1)*2]

The matching target rows. Identical slicing on Y.

26# IMPORTANT: divide by ACCUM_STEPS so the SUM equals the MEAN

The single most-often-forgotten line of any accumulation implementation. Without this division the effective learning rate is N× too large and training diverges on the first spike.

27g_accum += batch_grad(w, xb, yb) / ACCUM_STEPS

Compute the mean-gradient of this micro-batch, shrink it by 1/N, and add it into the running buffer. After the loop, g_accum is the mean of micro-batch means — which, for equal-sized micro-batches, is the mean over the whole batch.

EXECUTION STATE
batch_grad(w, xb, yb) = The mean gradient for this size-2 micro-batch.
/ ACCUM_STEPS = Divide by 4. Required because we are SUMMING four micro-batch means to simulate one big mean.
g_accum += = In-place addition into the running buffer. Equivalent to g_accum = g_accum + (...).
→ algebraic identity = g_accum = (1/N) Σ_k mean(micro_k) = mean(all samples) = g_full — provided each micro-batch has the same size.
29print("full-batch grad :", g_full)

Print the reference gradient.

EXECUTION STATE
⬆ printed = full-batch grad : [0.56530176 1.12504373]
30print("accumulated grad :", g_accum)

Print the gradient produced by the accumulation loop. Should match g_full.

EXECUTION STATE
⬆ printed = accumulated grad : [0.56530176 1.12504373]
31print("max abs difference :", np.max(np.abs(g_full - g_accum)))

Sanity check: the max element-wise difference should be zero (or machine-epsilon small). If this ever prints a non-tiny number, your division by N is missing or your micro-batches have unequal sizes.

EXECUTION STATE
⬆ printed = max abs difference : 0.0
→ what this proves = Accumulation with division by N is mathematically IDENTICAL to running the whole batch at once — not an approximation. The GPU memory cost, however, is ~N× smaller.
5 lines without explanation
1import numpy as np
2np.random.seed(0)
3
4# Toy: linear regression y = w · x, 8 examples, 2 features.
5# Full-batch size = 8, micro-batch size = 2, accum_steps = 4.
6X = np.random.randn(8, 2)
7Y = np.random.randn(8, 1)
8w = np.zeros((2,))          # the parameter we update
9ACCUM_STEPS = 4
10
11# Per-sample gradient of (w·x − y)² w.r.t. w is 2·(w·x − y)·x
12def batch_grad(w, xb, yb):
13    pred = xb @ w                       # shape (batch,)
14    err = pred - yb.ravel()             # shape (batch,)
15    # Mean-over-batch gradient: (2 / |batch|) · Σ err · x
16    return (2.0 / xb.shape[0]) * (err @ xb)
17
18# ---- Option A: one big batch ----
19g_full = batch_grad(w, X, Y)
20
21# ---- Option B: N micro-batches, accumulated ----
22g_accum = np.zeros_like(w)
23for step in range(ACCUM_STEPS):
24    xb = X[step*2:(step+1)*2]
25    yb = Y[step*2:(step+1)*2]
26    # IMPORTANT: divide by ACCUM_STEPS so the SUM equals the MEAN
27    g_accum += batch_grad(w, xb, yb) / ACCUM_STEPS
28
29print("full-batch grad    :", g_full)
30print("accumulated grad   :", g_accum)
31print("max abs difference :", np.max(np.abs(g_full - g_accum)))

The last line prints 0.0. The accumulated gradient equals the full-batch gradient exactly, not approximately. This is the mathematical guarantee that makes the trick usable in production: no one has to worry that micro-batching changes the optimization dynamics. It does not.


PyTorch Idiom: loss.backward() Without stepping

In PyTorch the trick is even cleaner because loss.backward() already accumulates into .grad. The only thing you have to do is call it KK times (without stepping) and remember to divide the loss by KK each time. Here it is combined with clipping — the complete production recipe:

Accumulation + clipping — a full one-step training loop
🐍accumulate_and_clip.py
1import torch

Core PyTorch — tensors and autograd.

EXECUTION STATE
torch = Used for Tensors, optim, nn.utils.clip_grad_norm_.
2import torch.nn as nn

Provides nn.Linear. We alias it `nn` by convention.

EXECUTION STATE
nn.Linear = The one learnable module we train in this demo.
4torch.manual_seed(0)

Seed the RNG so nn.Linear init, torch.randn(X), torch.randn(Y) are reproducible.

EXECUTION STATE
📚 torch.manual_seed(s) = Fixes the CPU RNG; every subsequent randn/normal call is deterministic.
⬇ arg: 0 = Arbitrary fixed seed.
6model = nn.Linear(2, 1, bias=False)

A tiny linear regression model: 2 input features → 1 scalar output. Enough to illustrate accumulation + clipping.

EXECUTION STATE
📚 nn.Linear(in, out, bias) = Learnable W of shape (out, in). Forward: y = x @ W.T (+ b).
⬇ in_features = 2 = X has 2 features per row.
⬇ out_features = 1 = Scalar output.
⬇ bias = False = Keeps .grad to shape (1, 2) only — easier to print.
7optim = torch.optim.SGD(model.parameters(), lr=0.1)

Stock SGD optimizer. It will perform w ← w − lr · w.grad when we call optim.step(). Crucially, optim.step() looks at .grad — not at the individual micro-batch losses. That is exactly why accumulation works.

EXECUTION STATE
📚 torch.optim.SGD(params, lr) = Stochastic gradient descent. Each step: for p in params: p -= lr * p.grad.
⬇ arg 1: model.parameters() = Iterable over tensors that need updates.
⬇ arg 2: lr = 0.1 = Learning rate. Large enough that a single step actually moves the weights noticeably.
8X = torch.randn(8, 2)

An 8-example, 2-feature input matrix. We'll stream it through in 4 micro-batches of 2.

EXECUTION STATE
📚 torch.randn(*shape) = Samples from N(0, 1). Returns a float32 tensor by default.
⬆ X = Tensor shape (8, 2).
9Y = torch.randn(8, 1)

Matching targets, one per example.

EXECUTION STATE
⬆ Y = Tensor shape (8, 1).
11ACCUM_STEPS = 4 — 4 micro-batches of 2 = effective batch 8

The whole point: effective_batch = micro_batch × ACCUM_STEPS. To pretend we can fit batch=8 when we can only fit 2, we do 4 passes and sum their scaled gradients.

EXECUTION STATE
ACCUM_STEPS = 4.
12MAX_NORM = 1.0 — gradient clipping threshold

The tau we will pass to clip_grad_norm_. Clipping happens AFTER all micro-batches have contributed so it acts on the full-batch gradient, not on each micro-batch individually.

EXECUTION STATE
MAX_NORM = 1.0.
14Comment — one virtual optimizer step

From the optimizer's point of view we still perform exactly ONE step. It is the `inside` of that step that gets split into 4 micro-batch forward/backwards.

15optim.zero_grad()

Reset every .grad to zero. CRITICAL: omitting this means the gradient from the previous optimizer step would be added to this step's gradient — an accidental second level of accumulation that would ruin training.

EXECUTION STATE
📚 optim.zero_grad() = Equivalent to `for p in params: if p.grad is not None: p.grad.zero_()`. Clears the accumulator.
→ when to call = Once per VIRTUAL optimizer step, BEFORE the micro-batch loop. Not once per micro-batch.
16for step in range(ACCUM_STEPS):

Loop through the 4 micro-batches. Each iteration adds to model.weight.grad without ever updating model.weight.

LOOP TRACE · 4 iterations
step=0
slice rows = 0:2
micro-loss/ACCUM_STEPS = ~ mean_loss_micro_0 / 4
weight.grad after backward = running sum after step 0
step=1
slice rows = 2:4
weight.grad after backward = running sum after step 1
step=2
slice rows = 4:6
weight.grad after backward = running sum after step 2
step=3
slice rows = 6:8
weight.grad after backward = FINAL sum — equals full-batch gradient
17xb = X[step*2:(step+1)*2]

Slice 2 rows of X for this micro-batch.

EXECUTION STATE
step 0 = X[0:2]
step 1 = X[2:4]
step 2 = X[4:6]
step 3 = X[6:8]
18yb = Y[step*2:(step+1)*2]

Matching target slice.

19yhat = model(xb)

Forward pass on this micro-batch. Autograd records ops so that backward() can compute dL/dW.

EXECUTION STATE
model(xb) = Equivalent to xb @ model.weight.T. Shape (2, 1).
20# Divide the loss by ACCUM_STEPS so that summed .grad = mean .grad.

Comment marking the single most-often-forgotten line of any accumulation recipe. If you omit the division, the effective learning rate is multiplied by ACCUM_STEPS.

21loss = ((yhat - yb) ** 2).mean() / ACCUM_STEPS

Mean squared error on THIS micro-batch, divided by N. Because .backward() accumulates into .grad, summing over N micro-batches of (loss/N) is exactly the same as evaluating one big MSE over all samples.

EXECUTION STATE
((yhat - yb) ** 2).mean() = MSE on the current 2 samples.
/ ACCUM_STEPS = The key rescaling. Without this, the accumulated gradient is N × too big.
→ equivalence = (1/N) Σₖ mean(micro_k) = mean(all samples) — provided every micro-batch has the same size.
22loss.backward() — ACCUMULATES into .grad

Runs backprop on the (scaled) micro-batch loss. PyTorch's default behavior is that .grad is ADDED to, not replaced. That is literally the machinery gradient accumulation relies on.

EXECUTION STATE
📚 loss.backward() = Computes dloss/dparam for every Parameter in the graph and ADDS it to .grad.
→ why additive? = PyTorch's default: accumulate. This was originally designed so you could backprop different loss heads independently. Accumulation repurposes the same behavior to simulate a bigger batch.
→ therefore = After the loop, model.weight.grad = Σₖ dloss_k/dW = dloss_total/dW. No extra code needed.
24Comment — clip first, then step

The order matters. Clipping must happen AFTER the loop (so the full gradient has been accumulated) and BEFORE optim.step() (so the step actually uses the clipped gradient).

25total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_NORM)

Same clipping call as before, but now applied to the ACCUMULATED gradient. This is the correct place to clip — clipping per-micro-batch would be less faithful to the full-batch behavior.

EXECUTION STATE
📚 clip_grad_norm_ = Rescales every p.grad in place so that the global L2 norm ≤ MAX_NORM.
⬇ arg 1: model.parameters() = All learnable tensors.
⬇ arg 2: MAX_NORM = 1.0 = Threshold tau.
⬆ total_norm = Pre-clip global norm. Log this — it is the main health signal of training.
26optim.step()

Finally apply the update: w ← w − lr · w.grad. By this point .grad is the full-batch gradient (thanks to accumulation), possibly rescaled by clipping (thanks to the previous line). The weights change EXACTLY ONCE per virtual step.

EXECUTION STATE
📚 optim.step() = Walks params and applies the optimizer's update rule using the current .grad.
→ when to call = Once per virtual optimizer step — after the accumulation loop AND after clipping. Never inside the micro-batch loop.
27print("pre-clip total_norm =", total_norm.item())

Log the raw gradient norm. In a long training run you would plot this every step to detect spikes.

28print("updated weights =", model.weight.data)

Show the weights after the (virtual) step. With lr=0.1, these will be small offsets from the random init — the single update equivalent to one full-batch SGD step.

5 lines without explanation
1import torch
2import torch.nn as nn
3
4torch.manual_seed(0)
5
6model = nn.Linear(2, 1, bias=False)
7optim = torch.optim.SGD(model.parameters(), lr=0.1)
8X = torch.randn(8, 2)
9Y = torch.randn(8, 1)
10
11ACCUM_STEPS = 4              # 4 micro-batches of 2 = effective batch 8
12MAX_NORM = 1.0               # gradient clipping threshold
13
14# One "virtual" optimizer step = ACCUM_STEPS micro-batch forward/backwards.
15optim.zero_grad()
16for step in range(ACCUM_STEPS):
17    xb = X[step*2:(step+1)*2]
18    yb = Y[step*2:(step+1)*2]
19    yhat = model(xb)
20    # Divide the loss by ACCUM_STEPS so that summed .grad = mean .grad.
21    loss = ((yhat - yb) ** 2).mean() / ACCUM_STEPS
22    loss.backward()          # ACCUMULATES into .grad — does NOT overwrite
23
24# Now .grad holds the full-batch gradient. Clip, then step.
25total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_NORM)
26optim.step()
27print("pre-clip total_norm =", total_norm.item())
28print("updated weights     =", model.weight.data)
Every sophisticated training framework uses exactly this pattern. Hugging Face Trainer, PyTorch Lightning, DeepSpeed, Megatron-LM, and FairScale all expose an gradient_accumulation_steps config option that implements precisely the loop above — with optional checkpointing, optional mixed precision, and optional distributed all-reduce at the end. The core algorithm is unchanged.

Clipping + Accumulation in One Training Loop

Both tricks compose without interference. The canonical ordering is:

  1. zero_grad() — clear the buffer once, at the start of the virtual step.
  2. For each micro-batch: forward, compute loss / accum_steps, backward() (accumulates into.grad). Do not step yet.
  3. After the loop: clip_grad_norm_. This clips the full-batch gradient, which is what we wanted.
  4. optimizer.step(). One parameter update per virtual step.
Clip after the full accumulation, not per micro-batch. Clipping each micro-batch gradient separately changes the direction of the full-batch update — the rescaling is no longer uniform. For SGD, clipping at the end preserves the mathematical equivalence to a big batch.

The SimpleTable below compares what happens in each scheme when the same 8 samples are processed:

SchemeForward passesGPU peak mem# .grad buffersOptimizer steps
Big batch (8)1 (batch=8)high11
Micro-batch (batch=2), no accumulation4 (batch=2)low4 overwrites4
Micro-batch (batch=2) + accumulation (K=4)4 (batch=2)low1 accumulated1

Row 3 is the magic row: low memory per pass, but the optimizer behaves exactly as if you had used batch size 8. Train a 70B model on 8 GPUs as if it were on a machine with 8× the memory.


How This Scales: Transformers, Flash Attention, and LLMs

Both techniques are load-bearing in every frontier-scale training run. The connection is worth making explicit because it explains why recent architecture papers focus so obsessively on memory.

1. Transformer training — clipping is the default, not a fallback

The original Attention Is All You Need recipe explicitly uses norm clipping; the BERT, GPT-2, GPT-3, LLaMA, Mistral, and Qwen tech reports all cite some variant (τ = 1.0 is overwhelmingly common). The reason is numerical fragility: softmax attention can produce gradients that spike when a few logits dominate (early in training, before LayerNorm has tamed the distribution). Without norm clipping, a single outlier step turns every subsequent forward pass into NaNs.

2. Flash Attention and the memory game

Flash Attention (Dao et al., 2022; v2 and v3 refinements through 2024) rewrote the attention kernel so that the full N × N attention matrix is never materialized in GPU HBM — it is streamed through on-chip SRAM in tiles. The memory saved is quadratic in sequence length. But the total memory you saved becomes meaningful only if you can also fit bigger sequences or batches. That is exactly where gradient accumulation matters: the memory you free from attention can now be spent on a bigger effective batch, reached one micro-batch at a time. Flash Attention and gradient accumulation are complementary levers on the same per-step memory × per-step compute frontier.

3. Multi-head attention and LayerNorm — interacting with clipping

In multi-head attention each head has its own Q,K,VQ, K, V projections, and LayerNorm rescales activations layer-by-layer. Both of these locally bound activation magnitudes, which indirectly bounds the per-layer gradient Jacobian norms. Norm clipping then acts as a global safety net on top of these local controls. You cannot remove clipping just because you have LayerNorm — a single bad data example or a single fp16 overflow still produces a spike that global clipping catches.

4. Positional encodings (RoPE / ALiBi) and long-context training

Long-context models are extreme memory customers. The standard recipe is Flash Attention (to keep attention memory linear in sequence length) plus gradient accumulation (to reach a stable effective batch size despite the tiny micro-batch count that fits in memory at 100k-token sequences) plus norm clipping (because RoPE/ALiBi positional schemes can introduce numerical anomalies at extrapolated positions). All three are needed together.

5. KV-cache and accumulation at inference time

The KV-cache optimization is about inference, not training: during autoregressive decoding, Key and Value tensors for previous tokens are cached so attention need not re-project them. But the training-time analog — recomputing activations on the backward pass to save memory instead of storing them (gradient checkpointing) — composes with accumulation in the same way. Both trade compute for memory; stacking them stretches the memory budget even further.

6. Scaling laws and effective batch size

Chinchilla-era scaling studies (Hoffmann et al., 2022) emphasize that effective batch size is a first-class hyperparameter. To hit the recommended tokens-per-step counts on hardware that cannot fit them in a single forward pass, gradient accumulation is mandatory. The relationship is clean: effective batch=micro batch×accum steps×DP world size\text{effective batch} = \text{micro batch} \times \text{accum steps} \times \text{DP world size}. If you are GPU-rich, you increase world size; if you are GPU-poor, you increase accum steps. The math does not care which.


Tradeoffs and Common Pitfalls

What clipping does and does not fix

  • Does fix: occasional gradient spikes, fp16 overflow artifacts, one-bad-batch disasters, loss explosions during the first few thousand steps of training.
  • Does not fix: a badly chosen learning rate (if clipping triggers every step, lr is too high), bad initialization (tune init, not τ), vanishing gradients (opposite problem — clipping doesn't help when gradients are already too small).

Pitfalls with accumulation

  • Forgetting to divide by K. Effective lr silently becomes K× too large. Symptom: loss diverges as soon as accumulation is turned on.
  • zero_grad() in the wrong place. Call it before the micro-batch loop, not inside. Inside, you would overwrite the accumulated gradient and do no accumulation at all.
  • BatchNorm with accumulation is subtly broken. BN computes statistics over the micro-batch, not the full batch, so its effective statistics are noisier than you think. Prefer LayerNorm/GroupNorm/RMSNorm in pipelines that use accumulation — which is nearly all of modern transformer training.
  • Dropout and data augmentation use independent randomness per micro-batch. That is usually fine (it is like running a bigger batch with higher variance in its stochastic layers), but reproducibility requires care with seeding.

Clipping + LR warmup — tune them together

Gradient clipping and learning-rate warmup are both solutions to the same underlying problem: early-training instability. A random-init network's gradients are often chaotic for the first few hundred to few thousand steps. Warmup ramps the LR from 0\approx 0 up to its target over (typically) 2000200010,00010{,}000 steps, smoothing the transition. Clipping bounds each step's magnitude as a hard safety net.

The GPT-3 paper (Brown et al. 2020, Section 2.3) uses both: a 375M375\text{M}-token linear warmup to LR=6×105\text{LR} = 6 \times 10^{-5}, combined with global norm clipping at τ=1.0\tau = 1.0. Similarly, LLaMA-1 (Touvron et al. 2023, Table 2) uses 20002000-step warmup plus τ=1.0\tau = 1.0. The rule: set warmup to reduce the frequency of clipping events; leave clipping in place as insurance. If clipping fires on more than 5%\sim 5\% of steps after warmup completes, something is miscalibrated — usually the peak LR is too high or weight initialization is off.

When clipping hurts

If τ is set too small, every gradient gets clipped, and the optimizer effectively runs at learning rate ητ/g2\eta \cdot \tau / \|g\|_2 — much smaller than you think. This makes training slow but not catastrophic; the cure is to raise τ or (more commonly) to lower the raw learning rate so clipping triggers less often. Log g2\|g\|_2.


Key Takeaways

  • Clipping bounds the STEP; accumulation bounds the MEMORY. Two orthogonal tricks that solve two separate problems.
  • Norm clipping preserves direction, value clipping does not. Prefer clip_grad_norm_ essentially always.
  • Accumulation is mathematically exact. Dividing the micro-batch loss by  K\;K and calling backward() KK times yields the full-batch gradient bit-for-bit.
  • Clip AFTER the accumulation loop, not per micro-batch. The contract is that clipping acts on the effective full-batch gradient.
  • Log the pre-clip norm every step. It is the single most informative scalar you can track to diagnose training instability.
  • Both tricks compose with Flash Attention, mixed precision, and checkpointing. Every frontier-scale training run uses them together. Tiny tricks, enormous leverage.
The bigger lesson. At scale, optimization becomes a memory and numerical-stability problem, not a gradient-direction problem. Clipping protects against numerical catastrophe; accumulation amortizes memory. Together they turn a GPU with just enough room for one example into a GPU that can train the next trillion-parameter model.

References. Pascanu, Mikolov, Bengio (2013), On the difficulty of training recurrent neural networks — introduced norm clipping. Ott et al. (2018), Scaling neural machine translation — popularized gradient accumulation in large-scale transformer training. Dao et al. (2022), Flash Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness. PyTorch documentation: torch.nn.utils.clip_grad_norm_, torch.nn.utils.clip_grad_value_, and the accumulation pattern in the PyTorch examples repository. Hugging Face Trainer gradient_accumulation_steps configuration.

Loading comments...