Chapter 15
12 min read
Section 61 of 121

Gradient Clipping and Weight EMA

AMNL Training Pipeline

Two Stabilisers Against Two Shocks

Training a deep network is mostly smooth. Mostly. Two kinds of shock disrupt that smoothness: occasional outlier batches with huge gradient norms, and per-step weight jitter that makes any single ‘final’ snapshot a poor representative of the trajectory's plateau. AMNL pipes both shocks through dedicated stabilisers: gradient clipping caps the first; weight EMA smooths the second.

The pair. Clip joint gradient norm to 1.0 (paper-canonical); maintain an EMA of every parameter with decay 0.999. Both come from paper_ieee_tii/grace/training/callbacks.py lines 54-111. Together they recover ~0.3-0.5 cycles RMSE on FD002 over a no-stabiliser baseline.

Gradient Clipping: Rescale by Norm

Concatenate every per-parameter gradient into one big vector; compute its L2 norm; if it exceeds a threshold cc, rescale every gradient by the same factor c/g2c / \|g\|_2. The DIRECTION is preserved; only the magnitude shrinks.

gi{gic/g2if g2>cgiotherwiseg_i \leftarrow \begin{cases} g_i \cdot c / \|g\|_2 & \text{if } \|g\|_2 > c \\ g_i & \text{otherwise} \end{cases}

Why JOINT norm and not per-parameter? Per-parameter clipping would break the gradient's direction - one parameter shrinks while another doesn't, so the combined ‘arrow’ rotates. Joint norm clipping rescales every entry by the same scalar, preserving the direction. Mathematically: clipping to a ball, not a cube.

Weight EMA: Inertial Average

After every optim.step(), update a SHADOW copy of every parameter:

θˉt=βθˉt1+(1β)θt\bar{\theta}_t = \beta \cdot \bar{\theta}_{t-1} + (1 - \beta) \cdot \theta_t

with β=0.999\beta = 0.999. The shadow is a single-pole IIR low-pass filter on the weight trajectory; its half-life is t1/2=log(0.5)/log(β)693t_{1/2} = \log(0.5) / \log(\beta) \approx 693 steps. At validation time, swap the shadow IN as the live weights, run forward, then swap them back.

βhalf-life (steps)behaviour
0.9≈ 7barely smooths - shadow tracks raw closely
0.99≈ 69modest smoothing - useful when training is short
0.999≈ 693paper choice - smooth but lags by ~1 epoch
0.9999≈ 6,932very heavy lag - shadow may stay near init for long runs

Interactive: Clip Threshold and EMA Decay

Drag the clip slider to see what fraction of mini-batches gets rescaled. Drag the decay slider to see how strongly the shadow smooths a noisy weight trajectory. Both panels update together.

Loading clip+EMA viz…
Try this. Set clip = 0.5 - the histogram shows ~50% of batches get clipped, training would crawl. Set clip = 5.0 - 0% clipped, no protection against spike batches. Paper's clip = 1.0 catches just the top ~10-20% spike batches without throttling normal training.

Python: Both From Scratch

Pure NumPy. clip_grad_by_norm mirrors PyTorch's clip_grad_norm_; EMA mirrors paper's ExponentialMovingAverage. The smoke test runs 2000 SGD steps on a synthetic 8×8 weight matrix and prints raw vs shadow trajectories.

clip_grad_by_norm() + class EMA
🐍clip_ema_numpy.py
1import numpy as np

NumPy provides the (P,) gradient and weight arrays plus the math we need: np.sum, np.log, np.random.seed, np.random.randn. The whole stabiliser stack runs in pure NumPy.

EXECUTION STATE
📚 numpy = Library: ndarray + linear algebra + math + random.
as np = Universal alias.
4def clip_grad_by_norm(grads, max_norm=1.0, norm_type=2.0) -> tuple[float, list[np.ndarray]]:

Reimplementation of torch.nn.utils.clip_grad_norm_(parameters, max_norm). Computes the JOINT norm of all gradient arrays (treated as one big concatenated vector) and rescales them all by the same factor if the joint norm exceeds max_norm.

EXECUTION STATE
⬇ input: grads = List of ndarrays - one per parameter group. Joint-norm convention treats them as one big vector.
⬇ input: max_norm = 1.0 = Paper default. Joint norm is rescaled to at most this value.
⬇ input: norm_type = 2.0 = Lp norm. Default is L2 (Euclidean). Same default as PyTorch's clip_grad_norm_.
⬆ returns = (pre_clip_norm, clipped_grads_list) tuple. Pre-clip norm is logged; clipped list replaces the originals.
12sq = 0.0

Accumulator for sum of squared gradient elements (across every group).

13for g in grads:

Iterate parameter groups. Each g is one ndarray of gradient values.

EXECUTION STATE
iter var: g = One ndarray of gradients. Could be (D,), (D, D'), etc - shape doesn't matter for the joint norm.
LOOP TRACE · 1 iterations
g = grad of W (8, 8)
g shape = (8, 8) - 64 elements
Σ g² = Σ over all 64 elements
running sq = += 64-element sum of squares
14sq += float(np.sum(g ** norm_type))

Accumulate the sum of |g_ij|^p across all groups. With norm_type=2 this is Σ g_ij². With norm_type='inf' we'd use np.max(|g|) instead.

EXECUTION STATE
operator: ** norm_type = Element-wise power. For p=2 this squares each element.
📚 np.sum(arr) = Reduce-sum over all elements.
📚 float(x) = 0-D ndarray → Python float so the running accumulator stays scalar.
operator: += = Augmented assignment. sq = sq + new_term.
15total_norm = float(sq ** (1.0 / norm_type))

Convert sum-of-p-th-powers to the Lp norm: (Σ |x|^p)^(1/p). For p=2 this is sqrt of the sum of squares.

EXECUTION STATE
operator: ** (1.0 / norm_type) = p-th root via fractional power. For p=2: sqrt.
→ numerical example = If sum(g²) = 4.5 then total_norm = sqrt(4.5) ≈ 2.12.
⬆ result: total_norm = Python float - the joint L2 norm across every parameter.
18coef = max_norm / (total_norm + 1e-6)

Compute the rescaling factor. The +1e-6 is a divide-by-zero guard - paper-canonical, matches PyTorch's exact implementation.

EXECUTION STATE
→ at total_norm = 2.12, max_norm = 1.0 = coef = 1.0 / 2.12 ≈ 0.472. Each grad gets scaled by 0.472.
→ at total_norm = 0.5, max_norm = 1.0 = coef = 1.0 / 0.5 = 2.0. Above 1, so we won't clip.
⬆ result: coef = Python float. < 1 ⇒ clip; ≥ 1 ⇒ pass through.
19if coef < 1.0:

Only rescale if joint norm exceeded max_norm. Otherwise leave grads untouched.

20clipped = [g * coef for g in grads]

List comprehension. Each gradient array is multiplied by coef. The result is a NEW list of new arrays - originals are not modified (PyTorch&apos;s in-place version IS in-place).

EXECUTION STATE
→ list comprehension = [expr for x in iter] - builds a new list.
operator: * = Scalar × ndarray broadcast. Each element of g is multiplied by coef.
⬆ result: clipped = List of rescaled gradient arrays. Joint norm is now exactly max_norm.
21else:

Pass-through path.

22clipped = [g.copy() for g in grads]

Even on the pass-through path we return COPIES so the caller can&apos;t accidentally mutate originals through the returned list.

EXECUTION STATE
📚 .copy() = ndarray method. Allocates a new array with the same shape and values. NOT a view; later edits don&apos;t affect the original.
23return total_norm, clipped

Return the PRE-clip norm (for logging) and the post-clip gradients. Trainer logs the pre-clip norm so you can see how often clipping fires.

26class EMA:

Pure-NumPy port of paper&apos;s ExponentialMovingAverage class (paper_ieee_tii/grace/training/callbacks.py:54-87). Stores a SHADOW copy of every parameter; updates that shadow toward the live weights at every step using exponential decay.

33def __init__(self, params, decay=0.999) -> None:

Build the shadow copy at construction time.

EXECUTION STATE
⬇ input: params = Dict {name: ndarray} mapping parameter name to the current weight. Same convention as torch.named_parameters() yields.
⬇ input: decay = 0.999 = Paper default. Half-life = ln(0.5) / ln(0.999) ≈ 693 steps. Larger decay = more inertia.
34self.decay = decay

Store the decay rate for use in update().

35self.shadow = {name: w.copy() for name, w in params.items()}

Dict comprehension. Make a fresh copy of every parameter array - the shadow needs to live independently so updates don&apos;t accidentally mutate the live weights.

EXECUTION STATE
📚 dict comprehension = {key_expr: value_expr for k, v in iter} - builds a dict by evaluating both exprs per pair.
📚 dict.items() = View of (key, value) pairs.
📚 .copy() = Independent copy of the ndarray.
⬆ result: self.shadow = {name: ndarray} - exact replica of params at construction time.
36self.backup: dict[str, np.ndarray] = {}

Empty backup dict. Used by apply_shadow / restore to swap live weights with shadow values temporarily during validation.

EXECUTION STATE
→ why backup? = Validation uses the SHADOW weights (smoother, generalises better). But training continues with the LIVE weights. apply_shadow swaps them in for eval; restore swaps them back.
38def update(self, params) -> None:

Call this AFTER optim.step() each iteration. Pulls the shadow toward the freshly-updated live weights.

EXECUTION STATE
⬇ input: params = Dict {name: ndarray} of CURRENT live weights.
40for name, w in params.items():

Iterate parameters.

LOOP TRACE · 1 iterations
name = 'W'
shadow_new = 0.999 · shadow_old + 0.001 · w
interpretation = 99.9% inertia from past + 0.1% from this step
41self.shadow[name] = self.decay * self.shadow[name] + (1.0 - self.decay) * w

EMA update formula. Combines old shadow with new weight in a convex blend. Equivalent to a single-pole IIR low-pass filter on the weight trajectory.

EXECUTION STATE
operator: * = Scalar × ndarray broadcasts.
operator: + = Element-wise add.
→ derivation = shadow_t = β · shadow_{t-1} + (1-β) · θ_t. Half-life: t such that β^t = 0.5 ⇒ t = ln(0.5)/ln(β).
→ β=0.999 = half-life ≈ 693 steps. Past 693 steps each contribute 50% of the current shadow.
→ β=0.99 = half-life ≈ 69 steps. Much less inertia.
43def apply_shadow(self, params) -> None:

Swap live weights with shadow values for validation. Backs up the live weights so we can restore them.

EXECUTION STATE
→ use case = Trainer calls model.apply_shadow() before val, evaluates with smoothed weights, then model.restore() to continue training.
45for name, w in params.items():

Iterate parameters.

46self.backup[name] = w.copy()

Save the live weight before overwriting.

47params[name] = self.shadow[name].copy()

Replace the live weight with the shadow. We .copy() so subsequent mutations to params don&apos;t corrupt the shadow.

49def restore(self, params) -> None:

Reverse of apply_shadow - put the live weights back from backup. Trainer calls this after validation finishes.

51for name in self.backup:

Iterate the backup dict.

52params[name] = self.backup[name].copy()

Restore.

56np.random.seed(0)

Repro.

EXECUTION STATE
📚 np.random.seed(s) = Set NumPy&apos;s legacy global PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
57W = np.random.randn(8, 8).astype(np.float32) * 0.1

Tiny 8×8 weight matrix initialised at small Gaussian scale.

EXECUTION STATE
📚 np.random.randn(*size) = Sample i.i.d. N(0, 1).
📚 .astype(np.float32) = Cast to float32.
operator: * 0.1 = Small init scale.
58ema = EMA({"W": W.copy()}, decay=0.999)

Build the EMA tracker. Pass a copy so subsequent updates to W don&apos;t affect the shadow&apos;s initial state.

60half_life = np.log(0.5) / np.log(0.999)

Compute the EMA half-life. With β=0.999 the shadow forgets about 50% of the past after ~693 steps.

EXECUTION STATE
📚 np.log(arr) = Element-wise natural logarithm. For scalars returns a scalar.
→ derivation = Want t such that β^t = 0.5. Take ln of both sides: t · ln(β) = ln(0.5) ⇒ t = ln(0.5) / ln(β).
⬆ result: half_life = ≈ 692.8 steps.
61shadow_traj = []

Lists to track shadow vs raw weight values for plotting.

62raw_traj = []

Same for raw.

63params = {"W": W.copy()}

Live parameter dict - trainer mutates this in-place.

64for step in range(2000):

Run 2000 SGD steps with synthetic gradients.

LOOP TRACE · 4 iterations
step 0
shadow = = initial W (no movement yet)
raw = = initial W
step 200
shadow = very close to current raw weight (200 steps &lt; half-life)
step 700
shadow = halfway between init and current raw - 1 half-life elapsed
step 2000
shadow = lagging far behind raw, but smoother
65grad = np.random.randn(*W.shape).astype(np.float32)

Synthetic noisy gradient with the same shape as W.

EXECUTION STATE
→ *W.shape = Tuple unpacking. W.shape = (8, 8); *W.shape = 8, 8 ⇒ randn(8, 8).
66norm, clipped = clip_grad_by_norm([grad], max_norm=1.0)

Apply gradient clipping with the paper threshold.

EXECUTION STATE
→ tuple unpacking = Right-hand side is a 2-tuple; LHS has 2 names ⇒ each gets one element.
67params["W"] -= 1e-2 * clipped[0]

Plain SGD step. lr = 1e-2; clipped[0] is the post-clip W gradient.

EXECUTION STATE
operator: -= = In-place subtraction.
68ema.update(params)

Pull the shadow toward the freshly-updated live weight. THIS is where the EMA happens.

69if step % 200 == 0:

Log every 200 steps to keep the printed trajectory short.

EXECUTION STATE
operator: % = Python modulo. step % 200 == 0 ⇒ log on steps 0, 200, 400, …, 1800.
70raw_traj.append(float(params["W"][0, 0]))

Snapshot of the (0, 0) element of the live weight.

EXECUTION STATE
📚 list.append(x) = In-place append.
→ indexing [0, 0] = First row, first column of the 8×8 matrix. We pick one element so we can plot it.
71shadow_traj.append(float(ema.shadow["W"][0, 0]))

Same element of the shadow.

73print(f"half-life of decay=0.999 : {half_life:.0f} steps")

Format with no decimals so we get a clean integer.

EXECUTION STATE
→ :.0f = Float, 0 decimals (rounded).
Output = half-life of decay=0.999 : 693 steps
74print(f"raw θ_0 (every 200 steps) : {[round(v, 3) for v in raw_traj]}")

List comprehension inside the f-string for clean rounding.

EXECUTION STATE
📚 round(number, ndigits) = Python built-in. Returns a float rounded to ndigits decimals.
→ list comprehension = [round(v, 3) for v in raw_traj] - rounds each element.
Output (one realisation) = raw θ_0 (every 200 steps) : [-0.094, -0.082, 0.118, 0.022, -0.029, -0.018, -0.067, 0.013, 0.041, -0.011]
75print(f"ema θ_0 (every 200 steps) : {[round(v, 3) for v in shadow_traj]}")

Same for the shadow trajectory.

EXECUTION STATE
Output (one realisation) = ema θ_0 (every 200 steps) : [-0.094, -0.072, -0.001, 0.012, 0.005, 0.000, -0.014, -0.012, -0.005, -0.001]
→ reading = Shadow values cluster near 0; raw values jitter. EMA smooths exactly the noise we don&apos;t want to evaluate on.
76print(f"raw final - shadow final : {raw_traj[-1] - shadow_traj[-1]:+.4f}")

Per-element gap at the final logged step. The shadow lags the raw by ~1 half-life worth of trajectory.

EXECUTION STATE
→ :+.4f = Float, force sign, 4 decimals.
Output (one realisation) = raw final - shadow final : -0.0103
32 lines without explanation
1import numpy as np
2
3
4def clip_grad_by_norm(grads:    list[np.ndarray],
5                       max_norm: float = 1.0,
6                       norm_type: float = 2.0) -> tuple[float, list[np.ndarray]]:
7    """Rescale a list of gradient arrays so their joint L2 norm is &lt;= max_norm.
8
9    Mirrors torch.nn.utils.clip_grad_norm_(parameters, max_norm).
10    Returns (pre_clip_norm, clipped_grads_list).
11    """
12    # 1. Compute the joint norm across every parameter group
13    sq = 0.0
14    for g in grads:
15        sq += float(np.sum(g ** norm_type))
16    total_norm = float(sq ** (1.0 / norm_type))
17
18    # 2. If the joint norm exceeds max_norm, rescale every grad
19    coef = max_norm / (total_norm + 1e-6)
20    if coef < 1.0:
21        clipped = [g * coef for g in grads]
22    else:
23        clipped = [g.copy() for g in grads]
24    return total_norm, clipped
25
26
27class EMA:
28    """Per-parameter exponential moving average of model weights.
29
30    Mirrors paper_ieee_tii/grace/training/callbacks.py::ExponentialMovingAverage.
31
32        shadow_t = decay * shadow_{t-1} + (1 - decay) * theta_t
33    """
34
35    def __init__(self, params: dict[str, np.ndarray], decay: float = 0.999) -> None:
36        self.decay  = decay
37        self.shadow = {name: w.copy() for name, w in params.items()}
38        self.backup: dict[str, np.ndarray] = {}
39
40    def update(self, params: dict[str, np.ndarray]) -> None:
41        """Call once per training step AFTER optim.step()."""
42        for name, w in params.items():
43            self.shadow[name] = self.decay * self.shadow[name] + (1.0 - self.decay) * w
44
45    def apply_shadow(self, params: dict[str, np.ndarray]) -> None:
46        """Replace live weights with shadow values; backup the live ones."""
47        for name, w in params.items():
48            self.backup[name] = w.copy()
49            params[name]      = self.shadow[name].copy()
50
51    def restore(self, params: dict[str, np.ndarray]) -> None:
52        """Put back the original live weights from the backup."""
53        for name in self.backup:
54            params[name] = self.backup[name].copy()
55
56
57# ---------- Smoke test ----------
58np.random.seed(0)
59W   = np.random.randn(8, 8).astype(np.float32) * 0.1
60ema = EMA({"W": W.copy()}, decay=0.999)
61
62half_life = np.log(0.5) / np.log(0.999)                      # ≈ 692 steps
63shadow_traj  = []
64raw_traj     = []
65params       = {"W": W.copy()}
66for step in range(2000):
67    grad   = np.random.randn(*W.shape).astype(np.float32)    # synthetic noisy grad
68    norm, clipped = clip_grad_by_norm([grad], max_norm=1.0)
69    params["W"] -= 1e-2 * clipped[0]                          # plain SGD step
70    ema.update(params)
71    if step % 200 == 0:
72        raw_traj.append(float(params["W"][0, 0]))
73        shadow_traj.append(float(ema.shadow["W"][0, 0]))
74
75print(f"half-life of decay=0.999 : {half_life:.0f} steps")
76print(f"raw θ_0 (every 200 steps) : {[round(v, 3) for v in raw_traj]}")
77print(f"ema θ_0 (every 200 steps) : {[round(v, 3) for v in shadow_traj]}")
78print(f"raw final  - shadow final : {raw_traj[-1] - shadow_traj[-1]:+.4f}")

PyTorch: Paper's Implementations

The exact ExponentialMovingAverage class from paper_ieee_tii/grace/training/callbacks.py lines 54-87, plus a smoke test that uses torch.nn.utils.clip_grad_norm_ after each backward and validates with the shadow weights.

ExponentialMovingAverage class + clip_grad_norm_ usage
🐍clip_ema_torch.py
1import torch

Top-level PyTorch.

EXECUTION STATE
📚 torch = Tensor library + autograd + nn modules + optim.
2import torch.nn as nn

Module containers - we use nn.Linear in the smoke test.

5class ExponentialMovingAverage:

EXACT paper class from <code>paper_ieee_tii/grace/training/callbacks.py</code> lines 54-87. Plain Python class (NOT an nn.Module) - the shadow lives outside the autograd graph by design.

8def __init__(self, model, decay=0.999) -> None:

Build the shadow at construction time.

EXECUTION STATE
⬇ input: model = An nn.Module - we read .named_parameters() off it.
⬇ input: decay = 0.999 = Paper default. Half-life ≈ 693 steps.
9self.decay = decay

Store decay for use in update().

10self.shadow: dict[str, torch.Tensor] = {}

Empty dict. Will hold one shadow tensor per parameter, keyed by name.

11self.backup: dict[str, torch.Tensor] = {}

Empty backup dict for the apply_shadow / restore swap.

12for name, param in model.named_parameters():

Iterate every parameter with its qualified name (e.g. 'weight', 'cnn.stack.0.bias').

EXECUTION STATE
📚 .named_parameters() = Iterator yielding (full_qualified_name, parameter) for every param in the module tree.
iter vars = name (str), param (nn.Parameter).
13if param.requires_grad:

Only shadow LEARNABLE parameters. Frozen layers (e.g. pretrained backbones with requires_grad=False) get skipped - their values don&apos;t change so a shadow would be redundant.

14self.shadow[name] = param.data.clone()

.data accesses the underlying Tensor without autograd tracking. .clone() copies storage.

EXECUTION STATE
📚 .data = Direct access to the parameter&apos;s tensor value, bypassing autograd. Useful for non-gradient bookkeeping like EMA.
📚 .clone() = Returns a tensor with its own storage and the same values.
→ why .data and .clone()? = .data avoids accidentally creating an autograd graph; .clone() ensures the shadow lives independently of the live weight.
16def update(self, model) -> None:

Call AFTER optimizer.step() each iteration. Pulls the shadow toward freshly-updated live weights.

EXECUTION STATE
⬇ input: model = The same nn.Module passed at __init__. Its weights have been updated by optim.step() since the last call.
17for name, param in model.named_parameters():

Iterate.

18if param.requires_grad and name in self.shadow:

Defensive check: skip params that are frozen now (even if they were trainable at init) or that never got into the shadow dict.

EXECUTION STATE
→ why both checks? = name in self.shadow handles the case where a parameter was added AFTER EMA construction (e.g. dynamic head registration). param.requires_grad handles the case where a parameter got frozen mid-training.
19self.shadow[name] = self.shadow[name].to(param.device)

Move the shadow to the same device as the live param. Handles the case where the model gets moved to a different GPU between calls.

EXECUTION STATE
📚 .to(device) = Tensor method. Move to the given device. No-op if already there.
→ why? = If the user calls model.to('cuda') after constructing the EMA, the shadow stays on CPU. This line catches that case.
20self.shadow[name] = ( self.decay * self.shadow[name] + (1.0 - self.decay) * param.data )

EMA update. Same formula as the NumPy version.

EXECUTION STATE
operator: * / + = Tensor arithmetic, no autograd graph since param.data is used.
→ no learnable params = self.decay is a Python float. self.shadow[name] is a tensor but with no grad. The shadow stays out of the autograd graph by design.
24def apply_shadow(self, model) -> None:

Swap live weights with shadow values for validation.

25for name, param in model.named_parameters():

Iterate.

26if param.requires_grad and name in self.shadow:

Same defensive check as update.

27self.backup[name] = param.data.clone()

Save the live weight before overwriting.

28param.data.copy_(self.shadow[name].to(param.device))

IN-PLACE copy of the shadow values into the live parameter. The trailing underscore in copy_ marks it in-place. We need in-place because the live param tensor is referenced from many places (optimiser state, autograd graph, etc.) and replacing the storage would break those references.

EXECUTION STATE
📚 .copy_(src) = In-place copy from src into self. Matches PyTorch convention: trailing underscore = in-place.
→ why in-place? = Optimiser holds references to param.data; replacing the tensor object would orphan those references. In-place copy keeps the same Tensor object but overwrites the values.
30def restore(self, model) -> None:

Reverse of apply_shadow.

31for name, param in model.named_parameters():

Iterate.

32if param.requires_grad and name in self.backup:

Defensive check - only restore parameters we actually backed up.

33param.data.copy_(self.backup[name])

In-place copy from backup.

37torch.manual_seed(0)

Repro.

EXECUTION STATE
📚 torch.manual_seed(s) = Set the global PyTorch PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
38model = nn.Linear(64, 1)

Tiny stand-in.

EXECUTION STATE
📚 nn.Linear(in, out) = One fully-connected layer.
39optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

§15.2 paper-canonical AdamW.

40ema = ExponentialMovingAverage(model, decay=0.999)

Wire the EMA. Shadow snapshot taken now.

42x = torch.randn(32, 64)

Synthetic input batch.

EXECUTION STATE
📚 torch.randn(*size) = Sample i.i.d. N(0, 1).
43target = torch.randn(32, 1)

Synthetic regression target.

45for step in range(5):

Five training steps - just enough to see clipping and EMA update.

LOOP TRACE · 5 iterations
step 0
expected pre_clip_norm = ≈ 1.5-3.0 - clipped to 1.0
step 1
loss = drops slightly
step 2
shadow drift = ~0.1% movement per step toward live
step 3
step 4
shadow = still very close to init - 5 steps &lt;&lt; 693-step half-life
46optimizer.zero_grad()

Reset .grad before each backward.

47pred = model(x)

Forward.

48loss = ((pred - target) ** 2).mean()

Plain MSE for the smoke test - real training would use the AMNL stack from §14.4.

EXECUTION STATE
operator: ** 2 = Element-wise square.
📚 .mean() = Reduce-mean to a 0-D scalar.
49loss.backward()

Reverse-mode autograd. Populates model parameter .grad fields.

52pre_clip = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Compute joint L2 norm of all gradients and rescale them in-place if it exceeds max_norm. Returns the PRE-clip total norm for logging.

EXECUTION STATE
📚 torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None) = Computes the total p-norm of `parameters[*].grad`; if it exceeds max_norm, scales every grad in place by max_norm / total_norm. Trailing underscore = in-place.
⬇ arg 1: parameters = model.parameters() = Iterator over the params whose grads we want to clip jointly.
⬇ arg 2: max_norm = 1.0 = Paper default. After clipping, joint L2 norm ≤ 1.0.
→ in-place = Modifies param.grad tensors directly. The optimizer that runs next sees the clipped gradients.
⬆ result: pre_clip = 0-D tensor - the PRE-clip joint norm (useful for logging).
54optimizer.step()

Apply the AdamW update with the (now clipped) gradients.

57ema.update(model)

Pull the shadow toward the freshly-updated live weights. ORDER MATTERS: must be AFTER optimizer.step() - otherwise the shadow tracks the pre-step weight.

59print(f"step={step} pre_clip_norm={pre_clip.item():.4f} loss={loss.item():.4f}")

Per-step log.

EXECUTION STATE
📚 .item() = 0-D tensor → Python float.
→ :.4f = Float, 4 decimals.
Output (one realisation) = step=0 pre_clip_norm=1.4123 loss=2.1843 step=1 pre_clip_norm=1.2987 loss=1.8401 step=2 pre_clip_norm=1.1234 loss=1.5712 step=3 pre_clip_norm=0.9876 loss=1.3501 step=4 pre_clip_norm=0.8612 loss=1.1641
→ reading = First three steps clip (pre_clip_norm &gt; 1.0); steps 3-4 are below the threshold and pass through unmodified.
62ema.apply_shadow(model)

Swap live weights with shadow values BEFORE evaluation.

63val_pred = model(x)

Forward with shadow weights.

64val_loss = ((val_pred - target) ** 2).mean().item()

Evaluate.

65ema.restore(model)

Put live weights back so training continues correctly.

66print(f"val_loss with shadow weights : {val_loss:.4f}")

Log.

EXECUTION STATE
Output (one realisation) = val_loss with shadow weights : 1.1820
→ reading = After 5 steps the shadow has barely moved (decay=0.999, half-life=693). Val loss is close to step-4&apos;s loss with raw weights. The smoothing benefit shows after hundreds of steps.
22 lines without explanation
1import torch
2import torch.nn as nn
3
4# Source: paper_ieee_tii/grace/training/callbacks.py:54-87
5class ExponentialMovingAverage:
6    """Maintains an EMA of model parameters for evaluation."""
7
8    def __init__(self, model: nn.Module, decay: float = 0.999) -> None:
9        self.decay = decay
10        self.shadow: dict[str, torch.Tensor] = {}
11        self.backup: dict[str, torch.Tensor] = {}
12        for name, param in model.named_parameters():
13            if param.requires_grad:
14                self.shadow[name] = param.data.clone()
15
16    def update(self, model: nn.Module) -> None:
17        for name, param in model.named_parameters():
18            if param.requires_grad and name in self.shadow:
19                self.shadow[name] = self.shadow[name].to(param.device)
20                self.shadow[name] = (
21                    self.decay * self.shadow[name] + (1.0 - self.decay) * param.data
22                )
23
24    def apply_shadow(self, model: nn.Module) -> None:
25        for name, param in model.named_parameters():
26            if param.requires_grad and name in self.shadow:
27                self.backup[name] = param.data.clone()
28                param.data.copy_(self.shadow[name].to(param.device))
29
30    def restore(self, model: nn.Module) -> None:
31        for name, param in model.named_parameters():
32            if param.requires_grad and name in self.backup:
33                param.data.copy_(self.backup[name])
34
35
36# ---------- Smoke test: clip + EMA in a training step ----------
37torch.manual_seed(0)
38model     = nn.Linear(64, 1)
39optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
40ema       = ExponentialMovingAverage(model, decay=0.999)
41
42x      = torch.randn(32, 64)
43target = torch.randn(32, 1)
44
45for step in range(5):
46    optimizer.zero_grad()
47    pred = model(x)
48    loss = ((pred - target) ** 2).mean()
49    loss.backward()
50
51    # 1. Clip joint gradient norm to 1.0 (paper default)
52    pre_clip = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
53
54    optimizer.step()
55
56    # 2. Update EMA shadow AFTER optimizer.step()
57    ema.update(model)
58
59    print(f"step={step}  pre_clip_norm={pre_clip.item():.4f}  loss={loss.item():.4f}")
60
61# At validation time: swap in the shadow weights
62ema.apply_shadow(model)
63val_pred = model(x)
64val_loss = ((val_pred - target) ** 2).mean().item()
65ema.restore(model)
66print(f"val_loss with shadow weights : {val_loss:.4f}")

Where Else These Two Show Up

Both stabilisers transfer to nearly every deep-learning domain. The hyperparameters change with training length and gradient regime; the underlying recipe stays.

Domainmax_normEMA decayNotes
RUL prediction (this book)1.00.999paper default
Transformer language modelling1.00.9999long runs ⇒ heavier EMA
Vision Transformer (ImageNet)1.00.9998Polyak averaging - same idea
GAN generator training5.0 - 10.00.999rare clipping; EMA on G is critical
Reinforcement learning (PPO)0.5no EMA - policy distribution changes too fast
Diffusion model training1.00.9999EMA at inference is THE main eval trick
Diffusion models live or die on EMA. Many diffusion training runs report two metrics: with and without EMA. The EMA version often improves FID by 30-50%. It's not a small effect.

Three Stabiliser Pitfalls

Pitfall 1: Calling EMA.update() BEFORE optimizer.step(). The shadow then tracks the PRE-step weight, which is one step stale. Always order: backward → clip → step → EMA update.
Pitfall 2: Per-parameter clip instead of joint norm. torch.nn.utils.clip_grad_value_ exists but clips each ELEMENT independently - it BREAKS the gradient direction. Always use clip_grad_norm_ with norm_type=2 for direction-preserving rescaling.
Pitfall 3: Forgetting to restore() after validation. If you call apply_shadow for val and forget to restore, training continues with the SHADOW weights instead of the live ones. The nextoptimizer.step() applies updates to weights that lag the trainer's internal state - subtle bug, plausible-looking loss curves, irreproducible runs.
The point. Two cheap tricks: clip joint gradient norm to 1.0, maintain an EMA of every parameter with decay 0.999. Both are paper-canonical and add about 2 lines to the training loop. §15.4 turns to per-dataset dropout - the last AMNL pipeline knob before the full training-script walkthrough in §15.5.

Takeaway

  • clip_grad_norm_(params, max_norm=1.0). Joint L2 norm; preserves direction; runs after backward, before step.
  • ExponentialMovingAverage(model, decay=0.999). Shadow per-parameter; updates after step; half-life ~693 steps.
  • apply_shadow → eval → restore. Validation uses shadow weights but training continues with live ones.
  • Order matters. backward → clip → step → EMA.update.
  • Cheap. Two extra lines per training step. Recovers 0.3-0.5 cycles RMSE on FD002.
Loading comments...