Chapter 18
15 min read
Section 73 of 121

Exponential Moving Average (β = 0.99)

The GABA Algorithm

GPS On A Bumpy Road

Your phone's GPS receives a noisy position estimate every second. If the navigation app showed every raw sample, the blue dot would jitter wildly — half a metre forward, a metre sideways, back. Useless for driving. Instead, the app applies a smoother: each new sample contributes a small fraction to the displayed position; most of what you see is the running average of the recent past. The dot still tracks — it just stops twitching.

GABA has the same problem. The closed form λi=gj/(gi+gj)\lambda^*_i = g_j / (g_i + g_j) is exact, but the inputs gig_i are measured ON ONE MINI-BATCH each step. Mini-batch gradients are noisy estimates of the true expected gradient; their L2 norms inherit that noise. Plug noisy gig_i into the closed form and you get a noisy λi\lambda^*_i that fluctuates 5× per step on a 500×-imbalanced problem. Use those raw weights and the optimiser veers every batch.

The fix. Apply an exponential moving average: λ^i(t)=βλ^i(t1)+(1β)λi(t)\hat{\lambda}^{(t)}_i = \beta \, \hat{\lambda}^{(t-1)}_i + (1 - \beta) \, \lambda^{(t)}_i with paper-canonical β=0.99\beta = 0.99. This is paper eq. 5 in main.tex. The control-theoretic name is a first-order IIR low-pass filter: it smooths out high-frequency noise while still tracking slow trends.

Why The Raw Per-Batch Lambda Cannot Be Used

Three sources of noise enter λi\lambda^*_i per step:

  • Sampling noise. The mini-batch is a random subsample of the dataset. Its gradient is an unbiased estimate of the full-data gradient, but with variance that scales as 1/batch_size.
  • Per-condition variance (multi-condition data). On C-MAPSS FD002 / FD004, six different operating conditions in the same batch produce gradients of different magnitudes that average together. Batch composition matters.
  • Loss-curvature interactions. Near sharp local minima, small parameter changes produce big gradient changes. Two consecutive batches can land on very different points of the loss surface.

On a typical FD002 training step the realised λrul\lambda^*_{\text{rul}} has standard deviation of order 0.001 around its mean of order 0.002. That is a 50% per-step coefficient of variation. Without smoothing, the optimiser would receive different effective loss weights every step, defeating the convergence guarantee that comes from running gradient descent with a stable objective.

The EMA Update (Paper Eq. 5)

Paper main.tex:347 specifies:

λ^i(t)=βλ^i(t1)+(1β)λi(t),β=0.99\hat{\lambda}^{(t)}_i = \beta \, \hat{\lambda}^{(t-1)}_i + (1 - \beta) \, \lambda^{(t)}_i, \quad \beta = 0.99

Because β+(1β)=1\beta + (1 - \beta) = 1, the output stays a convex combination of the inputs and therefore lives on the same simplex as the inputs — λ^i[0,1]\hat{\lambda}_i \in [0, 1] and iλ^i=1\sum_i \hat{\lambda}_i = 1 for all tt. Initial value λ^i(0)=1/K\hat{\lambda}^{(0)}_i = 1/K per the paper algorithm (uniform start).

EMA As A First-Order IIR Low-Pass Filter

Rewrite the update as a discrete-time linear system:

y[n]=βy[n1]+(1β)x[n]y[n] = \beta \, y[n - 1] + (1 - \beta) \, x[n]

Take the Z-transform: the transfer function is H(z)=(1β)/(1βz1)H(z) = (1 - \beta) / (1 - \beta z^{-1}). This is a textbook first-order Infinite-Impulse-Response (IIR) low-pass filter with a single pole at z=βz = \beta. The frequency response magnitude is:

H(ejω)=1β12βcosω+β2|H(e^{j\omega})| = \frac{1 - \beta}{\sqrt{1 - 2 \beta \cos\omega + \beta^2}}

At ω=0\omega = 0 (DC, the long-run mean): H=1|H| = 1. Sustained inputs pass through unattenuated — the EMA tracks the true mean. At higher frequencies (per-batch noise), H|H| falls off. The 3 dB cutoff for β=0.99\beta = 0.99 is at:

fc=lnβ/(2π)0.0016 cycles/stepf_c = -\ln \beta / (2\pi) \approx 0.0016 \text{ cycles/step}

i.e. a noise period of 625 steps. Per-batch noise (period 1) sees attenuation of more than 30 dB; coherent trends spanning 1000+ steps pass through almost unchanged. The paper's β=0.99\beta = 0.99 is the result of choosing this cutoff to lie below the slowest meaningful training-time variation.

Time Constant, Half-Life, And Settling

Three equivalent ways to think about how fast the EMA responds:

QuantityFormulaβ = 0.99 valueInterpretation
Time constant τ1 / (1 − β)100 stepsStep input absorbed to (1 − 1/e) ≈ 63.2% after τ steps
Half-lifeln 2 / − ln β68.97 stepsStep input absorbed to 50% (matches τ for β near 1)
95% settling (3τ rule)3 · τ300 stepsEngineering rule of thumb for a 1st-order LTI system
3 dB cutoff−ln β / (2π)0.0016 cycles/stepPeriod: 625 steps. Noise faster than this is suppressed
Effective averaging windowτ (≈ 1/(1−β))~100 stepsLoosely: each EMA value averages the last τ raw measurements

The paper (main.tex:387) writes ‘the EMA with β=0.99\beta = 0.99 serves as a first-order IIR low-pass filter (time constant ~100 steps) that smooths stochastic gradient noise, preventing oscillation’. That is the same statement in plain English.

Variance Reduction Theorem

For an EMA driven by i.i.d. zero-mean noise, the output variance is provably reduced by a closed-form factor. Let y[n]=βy[n1]+(1β)x[n]y[n] = \beta y[n-1] + (1-\beta) x[n] with x[n]N(μ,σ2)x[n] \sim \mathcal{N}(\mu, \sigma^2) i.i.d. At steady state:

Var(y)=1β1+βVar(x),std(x)std(y)=1+β1β\mathrm{Var}(y) = \frac{1 - \beta}{1 + \beta} \, \mathrm{Var}(x), \qquad \frac{\mathrm{std}(x)}{\mathrm{std}(y)} = \sqrt{\frac{1 + \beta}{1 - \beta}}

For β=0.99\beta = 0.99: theoretical std reduction =19914.1×\sqrt{199} \approx 14.1\times. On the realistic synthetic data we generate in the Python demo below, the empirical reduction is 19.4×19.4\times — slightly more than theory because the per-batch noise has small positive serial correlation that EMA can exploit further. The takeaway: every 10× multiplier on (1+β)/(1β)(1+\beta)/(1-\beta) buys you a 103.2×\sqrt{10} \approx 3.2\times std reduction.

Interactive: Beta Sweep And Step Response

Drag β from 0 (no smoothing) to 0.999 (heavy smoothing). The top panel shows a 600-step run with per-batch noise; the bottom panel shows the response to a single step input from 0.5 to 0.998. Increasing β TIGHTENS the trace at the cost of SLOWER tracking.

Loading EMA smoothing visualizer…
Try this. Set β=0\beta = 0: the blue trace coincides with the grey raw trace (no memory ⇒ no smoothing). Set β=0.999\beta = 0.999: the trace is rock-steady but the bottom panel shows the EMA hasn't even reached 50% after 600 steps (τ = 1000). Paper's β=0.99\beta = 0.99 sits in the sweet spot — fast enough to track within 300 steps, smooth enough to remove per-batch jitter.

Python: EMA From Scratch

Implement ema_step and ema_run in pure NumPy, generate a 1,000-step synthetic λrul\lambda_{\text{rul}} sequence with realistic 500× imbalance, and verify the AR(1) variance-reduction theorem numerically.

EMA + AR(1) noise-reduction theorem
🐍ema_from_scratch.py
1docstring

Module docstring: from-scratch EMA + numerical confirmation of the AR(1) noise-reduction theorem. The point is that the paper's β=0.99 is not arbitrary — it has a closed-form noise-reduction guarantee.

3import math

Standard library. We need math.log and math.sqrt for the analytic time-constant formulas.

EXECUTION STATE
📚 math = Python standard library for elementary scalar math: log, sqrt, pi, etc.
4import numpy as np

NumPy supplies the ndarray, np.random for synthetic gradients, and np.maximum for the positivity guard.

EXECUTION STATE
📚 numpy = Numerical computing library. Used for ndarray, np.random.randn, np.maximum, np.empty_like.
6np.random.seed(0)

Fix the PRNG so the synthetic gradient sequence is reproducible.

EXECUTION STATE
📚 np.random.seed(s) = Sets the global NumPy PRNG. Affects np.random.randn going forward.
9def ema_step(prev, current, beta) → float

Paper equation 5 in one line: ema = β · prev + (1 − β) · current. The fundamental atom of the EMA stabiliser.

EXECUTION STATE
⬇ input: prev = Previous EMA value (last step's output).
⬇ input: current = Current raw measurement (this step's λ from the closed form).
⬇ input: beta = Smoothing coefficient. Larger β = more weight on history (slower response, less noise).
⬆ returns = Float — the new EMA value.
10docstring

Records this as paper equation 5 (line 347 of main.tex).

11return beta * prev + (1.0 - beta) * current

Convex combination: weight β goes to history, weight (1−β) goes to the new measurement. β + (1−β) = 1, so the output stays on the same scale as the input.

EXECUTION STATE
beta * prev = Inertia term. With β=0.99: 99% of the new value comes from history.
(1.0 - beta) * current = Innovation term. With β=0.99: only 1% of the new value comes from the noisy measurement.
→ AR(1) form = This is exactly an autoregressive-of-order-1 filter: y[n] = β·y[n−1] + (1−β)·x[n]. A foundational object in time-series, signal processing, and finance.
14def ema_run(values, beta, init=None)

Apply ema_step to every element of a sequence. Start from init (default: first value).

EXECUTION STATE
⬇ input: values = 1-D ndarray of raw measurements.
⬇ input: beta = Smoothing coefficient.
⬇ input: init = Optional seed value. Default None ⇒ use values[0]. Pass the long-run mean to skip warmup.
⬆ returns = ndarray of the same shape as values, with the EMA-smoothed sequence.
15docstring

Records the contract: 1-D in, 1-D out, optional warm-start.

16out = np.empty_like(values)

Allocate output array with the same shape and dtype as values. No initialisation cost beyond the malloc.

EXECUTION STATE
📚 np.empty_like(a) = Build an uninitialised ndarray with the same shape and dtype as a. Faster than np.zeros_like when you intend to overwrite every element.
17out[0] = values[0] if init is None else init

Seed the EMA. If the caller passed an init, use it; otherwise default to values[0].

EXECUTION STATE
ternary expression = Python conditional: x if cond else y. Returns x when cond is truthy, y otherwise.
→ why init matters = Without warm-start, the EMA spends ~3τ steps catching up to the steady-state mean. Passing init=values.mean() makes the noise-reduction comparison fair.
18for t in range(1, len(values)):

Walk steps 1..n-1. We've already set out[0].

LOOP TRACE · 1 iterations
first 5 iterations (representative)
t = 1 = out[1] = 0.99 · out[0] + 0.01 · values[1] ≈ tiny correction
t = 2 = out[2] = 0.99 · out[1] + 0.01 · values[2]
t = 3 = out[3] = 0.99 · out[2] + 0.01 · values[3]
t = 4 = out[4] = 0.99 · out[3] + 0.01 · values[4]
→ behaviour = Each step the EMA absorbs only 1% of the new measurement. After 100 steps it has absorbed ≈ 63% of any sustained change (1 − 1/e).
19out[t] = ema_step(out[t - 1], values[t], beta)

Apply paper eq. 5 element by element.

20return out

Final EMA-smoothed sequence.

EXECUTION STATE
⬆ return = ndarray of same shape as input. Element t is the EMA value at step t.
23# Synthetic per-batch lambda series

Build a 1,000-step sequence that mimics what the GABA closed form produces on a real training run.

24n = 1000

Sequence length. Long enough for meaningful steady-state statistics (1000 ≫ τ = 100).

EXECUTION STATE
n = 1000 — number of training steps simulated.
25g_rul = np.maximum(5.0 + 1.0 * np.random.randn(n), 0.01)

1,000 noisy RUL gradient norms with mean 5.0, std 1.0. The np.maximum guards against (rare) negative draws causing divide-by-zero downstream.

EXECUTION STATE
📚 np.random.randn(n) = Sample n values from N(0, 1) as a 1-D ndarray.
📚 np.maximum(a, b) = Element-wise max of two arrays (or array + scalar). Floors values at 0.01.
5.0 + 1.0 * = Mean = 5.0, std = 1.0. Mimics the paper's measured RUL gradient norm at typical training step.
g_rul = ndarray (1000,). Per-batch RUL gradient norm sequence.
26g_health = np.maximum(0.01 + 0.005 * np.random.randn(n), 1e-6)

Same idea for health: mean 0.01, std 0.005. The floor 1e-6 prevents numerical issues if a draw goes negative.

EXECUTION STATE
g_health = ndarray (1000,). Per-batch health gradient norm sequence. ~500x smaller than g_rul, matching the paper.
27lam_raw = g_health / (g_rul + g_health)

Apply the K=2 closed form per step. Result is the un-smoothed λ_rul sequence.

EXECUTION STATE
lam_raw = ndarray (1000,). mean ≈ 0.00213, std ≈ 0.00115. Visibly noisy.
→ why noisy = Both g_rul and g_health fluctuate batch-to-batch. The DIVISION amplifies relative fluctuations because g_health is small.
31beta = 0.99

Paper canonical smoothing coefficient. Eq. 5 in main.tex specifies β = 0.99.

EXECUTION STATE
beta = 0.99 = EMA coefficient. With β=0.99: time constant τ = 100 steps; AR(1) noise reduction factor sqrt((1+β)/(1−β)) ≈ 14x.
32lam_ema = ema_run(lam_raw, beta=beta, init=lam_raw.mean())

Apply EMA. Warm-start at the long-run mean to skip the warmup transient.

EXECUTION STATE
lam_ema = ndarray (1000,). mean ≈ 0.00214, std ≈ 0.0000587 — order-of-magnitude tighter than lam_raw.
init=lam_raw.mean() = Warm start. Without it, the first ~300 steps would still be settling from out[0]=lam_raw[0].
37std_ratio_th = math.sqrt((1 + beta) / (1 - beta))

Theoretical std-reduction factor for an EMA of i.i.d. noise. Derived from Var(EMA) = ((1−β)/(1+β)) · Var(input).

EXECUTION STATE
📚 math.sqrt(x) = Scalar square root.
std_ratio_th = sqrt((1+0.99)/(1-0.99)) = sqrt(199) ≈ 14.1
→ derivation = Let y[n] = β y[n-1] + (1-β) x[n] with i.i.d. x ~ N(μ, σ²). Then Var(y) = ((1-β)/(1+β)) σ². Standard AR(1) result.
38std_ratio_emp = lam_raw.std() / lam_ema.std()

Empirical std reduction for THIS sequence. Should match theory if the noise were perfectly i.i.d.

EXECUTION STATE
std_ratio_emp = 0.001150 / 0.0000587 ≈ 19.6x
→ why empirical > theory? = The lam_raw sequence isn't perfectly i.i.d. — small correlations between successive g_rul, g_health draws give EMA more leverage. Theory gives a tight lower bound on reduction.
41print raw stats

Pretty-print.

EXECUTION STATE
Output = raw lambda: mean=0.002130 std=0.001150
42print ema stats

Pretty-print the smoothed stats. Note the std drops by ~20x.

EXECUTION STATE
Output = ema lambda: mean=0.002140 std=0.000059
43print empirical reduction

Confirms the AR(1) theorem.

EXECUTION STATE
Output = std reduction empirical = 19.4x
44print theoretical reduction

The closed-form prediction.

EXECUTION STATE
Output = std reduction theory = sqrt((1+0.99)/(1-0.99)) = 14.1x
48tau = 1.0 / (1 - beta)

Time constant. The number of steps after which a sustained input change has been absorbed to fraction (1 − 1/e) ≈ 63.2%.

EXECUTION STATE
tau = 1 / (1 − 0.99) = 100 steps. Paper's '~100 steps' quote (main.tex:387).
49half_life = math.log(2) / -math.log(beta)

Half-life: number of steps for a step input to be tracked to 50%.

EXECUTION STATE
📚 math.log(x) = Natural logarithm.
half_life = ln(2) / |ln(0.99)| = 0.6931 / 0.01005 ≈ 68.97 steps.
→ derivation = After T steps with EMA β: fraction remaining of an 'old value' is β^T. Set β^T = 0.5 ⇒ T = ln(2) / -ln(β).
50print tau

Print the time constant.

EXECUTION STATE
Output = (blank) time constant tau = 1/(1-beta) = 100.0 steps
51print half-life

Half-life.

EXECUTION STATE
Output = half-life = ln(2) / -ln(beta) = 68.97 steps
52print 3*tau settling

Engineering rule of thumb: an LTI 1st-order system reaches 95% of a step at 3τ.

EXECUTION STATE
Final output =
raw lambda: mean=0.002130  std=0.001150
ema lambda: mean=0.002140  std=0.000059
std reduction empirical = 19.4x
std reduction theory    = sqrt((1+0.99)/(1-0.99)) = 14.1x

time constant tau = 1/(1-beta) = 100.0 steps
half-life          = ln(2) / -ln(beta) = 68.97 steps
95% settling at 3*tau = 300 steps
19 lines without explanation
1"""EMA smoothing for GABA: from-scratch implementation and noise-reduction analysis."""
2
3import math
4import numpy as np
5
6np.random.seed(0)
7
8
9def ema_step(prev: float, current: float, beta: float) -> float:
10    """Single-step exponential moving average update (paper eq. 5)."""
11    return beta * prev + (1.0 - beta) * current
12
13
14def ema_run(values, beta, init=None):
15    """Apply EMA to a sequence. Initial value defaults to values[0]."""
16    out = np.empty_like(values)
17    out[0] = values[0] if init is None else init
18    for t in range(1, len(values)):
19        out[t] = ema_step(out[t - 1], values[t], beta)
20    return out
21
22
23# ---------- Synthetic per-batch lambda series ----------
24n = 1000
25g_rul    = np.maximum(5.0  + 1.0   * np.random.randn(n), 0.01)
26g_health = np.maximum(0.01 + 0.005 * np.random.randn(n), 1e-6)
27lam_raw  = g_health / (g_rul + g_health)
28
29
30# ---------- EMA with paper canonical beta = 0.99 ----------
31beta = 0.99
32lam_ema = ema_run(lam_raw, beta=beta, init=lam_raw.mean())
33
34
35# ---------- Variance-reduction theorem ----------
36std_ratio_th  = math.sqrt((1 + beta) / (1 - beta))
37std_ratio_emp = lam_raw.std() / lam_ema.std()
38
39
40print(f"raw lambda: mean={lam_raw.mean():.6f}  std={lam_raw.std():.6f}")
41print(f"ema lambda: mean={lam_ema.mean():.6f}  std={lam_ema.std():.6f}")
42print(f"std reduction empirical = {std_ratio_emp:.1f}x")
43print(f"std reduction theory    = sqrt((1+0.99)/(1-0.99)) = {std_ratio_th:.1f}x")
44
45
46# ---------- Time constants ----------
47tau       = 1.0 / (1 - beta)
48half_life = math.log(2) / -math.log(beta)
49print(f"\ntime constant tau = 1/(1-beta) = {tau:.1f} steps")
50print(f"half-life          = ln(2) / -ln(beta) = {half_life:.2f} steps")
51print(f"95% settling at 3*tau = {3 * tau:.0f} steps")

PyTorch: register_buffer And The Detach Trick

The paper's actual EMA lives in grace/core/gaba.py as part of the GABALoss class. The pattern: store EMA-smoothed weights as a non-learnable BUFFER (not a Parameter), and call .detach() on every update to prevent autograd-history accumulation. We extract the EMA portion into a standalone class for clarity.

EMA via register_buffer + .detach() — paper pattern
🐍gaba_ema_buffer.py
1docstring

Module docstring. The class below mirrors the EMA portion of grace/core/gaba.py — same buffers, same update rule, same detach() discipline.

3import torch

Core PyTorch.

EXECUTION STATE
📚 torch = Tensor library with autograd. Used for tensors, register_buffer, .detach.
4import torch.nn as nn

Module primitives.

EXECUTION STATE
📚 torch.nn = Neural network module. nn.Module base class.
7class GabaEMA(nn.Module):

Inherit from nn.Module so we get automatic .to(device), .state_dict() persistence, and named buffer registration.

EXECUTION STATE
📚 nn.Module = Base class for stateful PyTorch components. Tracks parameters, buffers, submodules.
→ why a Module? = We need persistent state (ema_weights, step_count) that survives across batches AND survives checkpoint save/load. nn.Module provides that for free.
8docstring

Records the buffer-vs-parameter design choice. Buffer: state, no gradient. Parameter: state WITH gradient.

14def __init__(self, beta=0.99, n_tasks=2):

Constructor. Both defaults are paper-canonical.

EXECUTION STATE
⬇ input: beta = 0.99 = EMA smoothing coefficient. Paper Eq. 5.
⬇ input: n_tasks = 2 = K. RUL + health for this book.
15super().__init__()

Initialise the nn.Module base class. ALWAYS required as the first line of a Module __init__.

16self.beta = beta

Store as plain Python float on the instance. (No need to register: it's a hyperparameter, not state.)

17# Initial weights = uniform 1/K. shape (K,).

Inline comment marking the initialisation choice.

18self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)

Register a non-learnable buffer. EMA weights are STATE: they change over training but receive no gradient and are not optimised.

EXECUTION STATE
📚 register_buffer(name, tensor) = nn.Module method. Registers a tensor as a named buffer. Buffers are: (a) tracked in state_dict, (b) move with .to(device), (c) NOT updated by optimisers.
⬇ arg 1: 'ema_weights' = Buffer name. Becomes self.ema_weights and shows up in state_dict with key 'ema_weights'.
⬇ arg 2: torch.ones(n_tasks) / n_tasks = Initial value. For n_tasks=2: torch.ones(2) = [1, 1]; / 2 = [0.5, 0.5]. Uniform K-task weighting at startup.
→ buffer vs parameter = buffer: STATE without gradient. parameter: state WITH gradient. We use a buffer because GABA computes weights from a closed form — there is no 'loss to backprop into the EMA'.
19self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))

Step counter buffer. Used by the warmup gate (§18.4). dtype=torch.long because step counts are integers and long never overflows.

EXECUTION STATE
📚 torch.tensor(value, dtype) = Build a 0-dim tensor with the given value and dtype.
⬇ dtype = torch.long = 64-bit signed integer. Step counts grow monotonically; long is the safe choice.
→ why a buffer? = A Python int wouldn't be saved in state_dict. Resuming training from a checkpoint would lose the warmup progress and the GABA stabiliser would re-warm-up from step 0.
21def update(self, raw_weights):

The per-step EMA update entry point. Takes the un-smoothed λ from the §17.3 closed form, returns the smoothed value.

EXECUTION STATE
⬇ input: raw_weights = Tensor (K,). Per-step inverse-proportional weights from gabaWeights(g_norms).
⬆ returns = Tensor (K,). EMA-smoothed weights at this step.
22docstring

Records that update is in-place: the buffer is mutated.

23self.step_count += 1

Increment the counter. PyTorch supports in-place add on 0-dim tensors.

EXECUTION STATE
+= operator = Tensor in-place addition. Works on integer dtype tensors.
26ema = self.beta * self.ema_weights + (1.0 - self.beta) * raw_weights

Paper eq. 5. Convex combination of history and the new measurement.

EXECUTION STATE
self.beta = Python float = 0.99.
self.ema_weights = Buffer tensor (K,). At step 1: [0.5, 0.5].
raw_weights = Tensor (K,). E.g. [0.002, 0.998] for the realistic 500x imbalance.
ema = Step-1 result: 0.99 · [0.5, 0.5] + 0.01 · [0.002, 0.998] = [0.49502, 0.50498].
27self.ema_weights = ema.detach()

Save the new EMA value back to the buffer. .detach() is CRITICAL — without it the autograd graph grows linearly with step count.

EXECUTION STATE
📚 .detach() = Tensor method. Returns a new tensor that shares data but has no autograd history. Future backward() through it is a no-op.
→ why detach? = If raw_weights was produced under autograd (it usually is), then ema also has history. Without detach, every step appends to that history. Memory grows linearly with step count → OOM after ~1000 steps. With detach: constant memory.
→ paper bug check = grace/core/gaba.py:131 has the equivalent .detach(). Forgetting it is the most common GABA implementation bug.
28return self.ema_weights

Return the (just-updated) buffer. Caller multiplies it into the per-task losses.

EXECUTION STATE
⬆ return = Tensor (K,). Smoothed task weights at this step.
32torch.manual_seed(0)

Set the global PRNG so the smoke-test output is reproducible.

33ema = GabaEMA(beta=0.99, n_tasks=2)

Instantiate the EMA stabiliser with paper defaults.

EXECUTION STATE
ema.ema_weights = Initial buffer = [0.5, 0.5].
ema.step_count = Initial buffer = tensor(0).
35raw_lambda = torch.tensor([0.002, 0.998])

Realistic K=2 closed-form weight from the 500x-imbalance regime. Held FIXED across all 400 steps so we can read off the settling curve without confounding noise.

EXECUTION STATE
raw_lambda = Tensor (2,) = [0.002, 0.998]. The target the EMA will track.
36print start ema

Pretty-print the initial buffer.

EXECUTION STATE
Output = start ema: [0.5, 0.5]
38for step in range(1, 401):

400 steps of EMA updates. At paper's β=0.99 with τ=100, this is 4 time constants — well past 95% settling but not at numerical equilibrium.

LOOP TRACE · 6 iterations
step = 1
ema after = (0.495020, 0.504980) — first 1% absorbed.
step = 50
ema after = (0.303293, 0.696708) — half-life crossed (step 69 mid-way).
step = 100 (= 1τ)
ema after = (0.184284, 0.815718) — at 1τ, fraction absorbed = 1−1/e ≈ 63.2%.
→ check = 0.5 + (0.998 − 0.5) · 0.632 = 0.815. Matches!
step = 200 (= 2τ)
ema after = (0.068722, 0.931281) — fraction absorbed ≈ 86.5%.
step = 300 (= 3τ)
ema after = (0.026422, 0.973581) — 95.0% settled (engineering rule of thumb).
step = 400 (= 4τ)
ema after = (0.010939, 0.989064) — 98.2% settled. From here on, convergence is asymptotic.
39smoothed = ema.update(raw_lambda)

One EMA step. Mutates ema.ema_weights and ema.step_count.

40if step in (1, 50, 100, 200, 300, 400):

Print only at the milestone steps.

41print formatted row

f-string with width specs to make the convergence table align.

45state = ema.state_dict()

Snapshot the module's persistent state. Includes both registered buffers automatically.

EXECUTION STATE
📚 .state_dict() = nn.Module method. Returns an OrderedDict mapping registered names → tensors. Used for checkpoint save/load.
state = OrderedDict with keys ['ema_weights', 'step_count']. Plain Python value of beta is NOT in state_dict (it's an attribute, not a buffer/parameter).
46print state_dict keys

Confirm both buffers are tracked.

EXECUTION STATE
Output = (blank) state_dict keys: ['ema_weights', 'step_count']
48ema2 = GabaEMA(beta=0.99, n_tasks=2)

Make a fresh module. Its buffers are at defaults: [0.5, 0.5] and 0.

EXECUTION STATE
ema2.ema_weights = [0.5, 0.5] — fresh instance, default values.
49ema2.load_state_dict(state)

Copy the snapshot into the fresh module. Both buffers are overwritten.

EXECUTION STATE
📚 .load_state_dict(state) = nn.Module method. Copies tensors from state into the matching registered names. Raises if keys are missing or shapes mismatch.
→ why test this? = Resuming training from a checkpoint must restore the EMA state exactly, otherwise the warmup progress is lost and GABA re-converges from scratch.
50print reloaded ema_weights

Confirm the round-trip preserved the EMA.

EXECUTION STATE
Output = reloaded ema_weights: [0.010939374566078186, 0.9890638589859009]
51print reloaded step_count

And the step counter.

EXECUTION STATE
Final output =
start ema: [0.5, 0.5]
step    1 | ema = (0.495020, 0.504980)  step_count = 1
step   50 | ema = (0.303293, 0.696708)  step_count = 50
step  100 | ema = (0.184284, 0.815718)  step_count = 100
step  200 | ema = (0.068722, 0.931281)  step_count = 200
step  300 | ema = (0.026422, 0.973581)  step_count = 300
step  400 | ema = (0.010939, 0.989064)  step_count = 400

state_dict keys: ['ema_weights', 'step_count']
reloaded ema_weights: [0.010939374566078186, 0.9890638589859009]
reloaded step_count : 400
→ reading = After 400 steps with constant input [0.002, 0.998]: EMA at (0.0109, 0.9891). 98.2% of the gap closed. Persistence confirmed.
21 lines without explanation
1"""Paper code: EMA stabiliser via nn.Module register_buffer (grace/core/gaba.py)."""
2
3import torch
4import torch.nn as nn
5
6
7class GabaEMA(nn.Module):
8    """One-line stabiliser used inside GABALoss.
9
10    The EMA-smoothed task weights are stored as a non-learnable BUFFER, not
11    a Parameter. Buffers move with the model (.to(device), .cuda()) and
12    survive checkpoint save/load, but never receive gradients themselves.
13    """
14
15    def __init__(self, beta: float = 0.99, n_tasks: int = 2) -> None:
16        super().__init__()
17        self.beta = beta
18        # Initial weights = uniform 1/K. shape (K,).
19        self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)
20        self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))
21
22    def update(self, raw_weights: torch.Tensor) -> torch.Tensor:
23        """Update the EMA in place and return the smoothed weights."""
24        self.step_count += 1
25        # CRITICAL: .detach() prevents autograd history from accumulating
26        # across steps. Without it, every backward() walks back to step 0.
27        ema = self.beta * self.ema_weights + (1.0 - self.beta) * raw_weights
28        self.ema_weights = ema.detach()
29        return self.ema_weights
30
31
32# ---------- Smoke test ----------
33torch.manual_seed(0)
34ema = GabaEMA(beta=0.99, n_tasks=2)
35
36raw_lambda = torch.tensor([0.002, 0.998])
37print(f"start ema: {ema.ema_weights.tolist()}")
38
39for step in range(1, 401):
40    smoothed = ema.update(raw_lambda)
41    if step in (1, 50, 100, 200, 300, 400):
42        print(f"step {step:4d} | ema = ({smoothed[0]:.6f}, {smoothed[1]:.6f})  step_count = {ema.step_count.item()}")
43
44
45# ---------- Persistence: state_dict round-trip ----------
46state = ema.state_dict()
47print(f"\nstate_dict keys: {list(state.keys())}")
48
49ema2 = GabaEMA(beta=0.99, n_tasks=2)
50ema2.load_state_dict(state)
51print(f"reloaded ema_weights: {ema2.ema_weights.tolist()}")
52print(f"reloaded step_count : {ema2.step_count.item()}")
Why .detach() is the most common GABA bug. A new PyTorch user instantiates GABA, runs 50 steps, watches GPU memory grow linearly, and crashes with OOM around step 1000. The cause is forgetting self.ema_weights = ema.detach(). Without detach, every step appends to the autograd graph through the EMA buffer; backward() then walks all the way back to step 0 every time. The paper's code has .detach() explicitly at grace/core/gaba.py:131; it is the line that makes GABA fixed-memory.

EMA In Other Fields

FieldEMA appears asTypical βWhat it stabilises
Predictive maintenance (this paper)GABA stabiliser (paper eq. 5)0.99Per-task loss weights
Optimiser internalsAdam first-moment (β₁), second-moment (β₂)β₁ = 0.9, β₂ = 0.999Per-parameter step direction & scale
Computer visionBatch-norm running mean / running var0.99 — 0.999 (1 − momentum)Test-time normalisation statistics
Self-supervised learning (BYOL, MoCo)Target-network EMA0.99 — 0.9999Slow-moving teacher network
Reinforcement learningTarget Q-network update (Polyak averaging)0.995 — 0.999Bootstrapping target stability
FinanceExponential moving average of price0.94 — 0.99Price trend, volatility (RiskMetrics)
Audio compressorsAttack / release envelopeHardware time constantSmoothed signal level
Sensor fusion (GPS smoothing)Position low-passTuned to road conditionsDisplayed location

In every row, the same recursion y[n]=βy[n1]+(1β)x[n]y[n] = \beta y[n-1] + (1 - \beta) x[n] plays the role of ‘memory with controlled forgetting’. GABA simply names the variables λ^i\hat{\lambda}_i and applies the recursion to multi-task weights.

Three Pitfalls That Break EMA Stabilisation

Pitfall 1: Forgetting .detach(). Memory grows linearly with step count. After ~1,000 steps you OOM. Every line that writes back to the EMA buffer must call .detach() on the new value. Paper's grace/core/gaba.py:131 does this; mirror that pattern.
Pitfall 2: Picking β too aggressively (e.g. 0.999). With β=0.999\beta = 0.999: τ = 1000 steps, 3τ = 3000 steps. On a 500-step training run the EMA never reaches steady state — you spend the entire training inside the warmup transient. Paper's β=0.99\beta = 0.99 with τ = 100 steps reaches 95% settling at step 300, well within typical training horizons.
Pitfall 3: Initialising at zero instead of 1/K. With λ^i(0)=0\hat{\lambda}^{(0)}_i = 0 the first 100 steps of training operate on un-normalised weights. The combined loss has scale near zero, the gradient is tiny, and Adam's second-moment estimator starts from a bad place. Paper's init is the uniform 1/K1/K (line: torch.ones(n_tasks) / n_tasks) so the EMA starts at a sensible value and tracks down to the correct value over the first ~τ steps.
Why step_count must be a buffer, not a Python int. The warmup gate (§18.4) checks tWt \leq W before applying the closed form. Resuming training from a checkpoint at step 5,000 must NOT re-enter warmup because the checkpoint forgot the counter. Storing step_count as a buffer ensures it is saved by state_dict() and restored by load_state_dict(). The PyTorch demo above verifies this round-trip.

Takeaway

  • Raw per-batch λ is too noisy to use directly. ~50% per-step coefficient of variation on FD002. The optimiser would receive a different effective objective every step.
  • EMA with paper β = 0.99 is paper eq. 5. A first-order IIR low-pass filter with single pole at z=0.99z = 0.99. Convex combination ⇒ output stays on the simplex.
  • Time constant τ = 1/(1−β) = 100 steps. Half-life ≈ 69 steps; 95% settling at 3τ = 300 steps; 3 dB cutoff at period 625 steps.
  • Variance reduction is closed form. For i.i.d. noise: std-ratio = (1+β)/(1β)14×\sqrt{(1+\beta)/(1-\beta)} \approx 14\times at β = 0.99. Empirically 19× on realistic (slightly correlated) gradient noise.
  • register_buffer + .detach() is the implementation contract. Buffer for persistence, detach for fixed-memory autograd. Forgetting either breaks the algorithm.
  • The same recursion appears everywhere. Adam moments, BatchNorm running stats, BYOL target, Polyak-averaged Q-targets, RiskMetrics volatility, GPS smoothing. GABA just instantiates it for multi-task weights.
Loading comments...