GPS On A Bumpy Road
Your phone's GPS receives a noisy position estimate every second. If the navigation app showed every raw sample, the blue dot would jitter wildly — half a metre forward, a metre sideways, back. Useless for driving. Instead, the app applies a smoother: each new sample contributes a small fraction to the displayed position; most of what you see is the running average of the recent past. The dot still tracks — it just stops twitching.
GABA has the same problem. The closed form is exact, but the inputs are measured ON ONE MINI-BATCH each step. Mini-batch gradients are noisy estimates of the true expected gradient; their L2 norms inherit that noise. Plug noisy into the closed form and you get a noisy that fluctuates 5× per step on a 500×-imbalanced problem. Use those raw weights and the optimiser veers every batch.
Why The Raw Per-Batch Lambda Cannot Be Used
Three sources of noise enter per step:
- Sampling noise. The mini-batch is a random subsample of the dataset. Its gradient is an unbiased estimate of the full-data gradient, but with variance that scales as 1/batch_size.
- Per-condition variance (multi-condition data). On C-MAPSS FD002 / FD004, six different operating conditions in the same batch produce gradients of different magnitudes that average together. Batch composition matters.
- Loss-curvature interactions. Near sharp local minima, small parameter changes produce big gradient changes. Two consecutive batches can land on very different points of the loss surface.
On a typical FD002 training step the realised has standard deviation of order 0.001 around its mean of order 0.002. That is a 50% per-step coefficient of variation. Without smoothing, the optimiser would receive different effective loss weights every step, defeating the convergence guarantee that comes from running gradient descent with a stable objective.
The EMA Update (Paper Eq. 5)
Paper main.tex:347 specifies:
Because , the output stays a convex combination of the inputs and therefore lives on the same simplex as the inputs — and for all . Initial value per the paper algorithm (uniform start).
EMA As A First-Order IIR Low-Pass Filter
Rewrite the update as a discrete-time linear system:
Take the Z-transform: the transfer function is . This is a textbook first-order Infinite-Impulse-Response (IIR) low-pass filter with a single pole at . The frequency response magnitude is:
At (DC, the long-run mean): . Sustained inputs pass through unattenuated — the EMA tracks the true mean. At higher frequencies (per-batch noise), falls off. The 3 dB cutoff for is at:
i.e. a noise period of 625 steps. Per-batch noise (period 1) sees attenuation of more than 30 dB; coherent trends spanning 1000+ steps pass through almost unchanged. The paper's is the result of choosing this cutoff to lie below the slowest meaningful training-time variation.
Time Constant, Half-Life, And Settling
Three equivalent ways to think about how fast the EMA responds:
| Quantity | Formula | β = 0.99 value | Interpretation |
|---|---|---|---|
| Time constant τ | 1 / (1 − β) | 100 steps | Step input absorbed to (1 − 1/e) ≈ 63.2% after τ steps |
| Half-life | ln 2 / − ln β | 68.97 steps | Step input absorbed to 50% (matches τ for β near 1) |
| 95% settling (3τ rule) | 3 · τ | 300 steps | Engineering rule of thumb for a 1st-order LTI system |
| 3 dB cutoff | −ln β / (2π) | 0.0016 cycles/step | Period: 625 steps. Noise faster than this is suppressed |
| Effective averaging window | τ (≈ 1/(1−β)) | ~100 steps | Loosely: each EMA value averages the last τ raw measurements |
The paper (main.tex:387) writes ‘the EMA with serves as a first-order IIR low-pass filter (time constant ~100 steps) that smooths stochastic gradient noise, preventing oscillation’. That is the same statement in plain English.
Variance Reduction Theorem
For an EMA driven by i.i.d. zero-mean noise, the output variance is provably reduced by a closed-form factor. Let with i.i.d. At steady state:
For : theoretical std reduction =. On the realistic synthetic data we generate in the Python demo below, the empirical reduction is — slightly more than theory because the per-batch noise has small positive serial correlation that EMA can exploit further. The takeaway: every 10× multiplier on buys you a std reduction.
Interactive: Beta Sweep And Step Response
Drag β from 0 (no smoothing) to 0.999 (heavy smoothing). The top panel shows a 600-step run with per-batch noise; the bottom panel shows the response to a single step input from 0.5 to 0.998. Increasing β TIGHTENS the trace at the cost of SLOWER tracking.
Try this. Set : the blue trace coincides with the grey raw trace (no memory ⇒ no smoothing). Set : the trace is rock-steady but the bottom panel shows the EMA hasn't even reached 50% after 600 steps (τ = 1000). Paper's sits in the sweet spot — fast enough to track within 300 steps, smooth enough to remove per-batch jitter.
Python: EMA From Scratch
Implement ema_step and ema_run in pure NumPy, generate a 1,000-step synthetic sequence with realistic 500× imbalance, and verify the AR(1) variance-reduction theorem numerically.
PyTorch: register_buffer And The Detach Trick
The paper's actual EMA lives in grace/core/gaba.py as part of the GABALoss class. The pattern: store EMA-smoothed weights as a non-learnable BUFFER (not a Parameter), and call .detach() on every update to prevent autograd-history accumulation. We extract the EMA portion into a standalone class for clarity.
.detach() is the most common GABA bug. A new PyTorch user instantiates GABA, runs 50 steps, watches GPU memory grow linearly, and crashes with OOM around step 1000. The cause is forgetting self.ema_weights = ema.detach(). Without detach, every step appends to the autograd graph through the EMA buffer; backward() then walks all the way back to step 0 every time. The paper's code has .detach() explicitly at grace/core/gaba.py:131; it is the line that makes GABA fixed-memory.EMA In Other Fields
| Field | EMA appears as | Typical β | What it stabilises |
|---|---|---|---|
| Predictive maintenance (this paper) | GABA stabiliser (paper eq. 5) | 0.99 | Per-task loss weights |
| Optimiser internals | Adam first-moment (β₁), second-moment (β₂) | β₁ = 0.9, β₂ = 0.999 | Per-parameter step direction & scale |
| Computer vision | Batch-norm running mean / running var | 0.99 — 0.999 (1 − momentum) | Test-time normalisation statistics |
| Self-supervised learning (BYOL, MoCo) | Target-network EMA | 0.99 — 0.9999 | Slow-moving teacher network |
| Reinforcement learning | Target Q-network update (Polyak averaging) | 0.995 — 0.999 | Bootstrapping target stability |
| Finance | Exponential moving average of price | 0.94 — 0.99 | Price trend, volatility (RiskMetrics) |
| Audio compressors | Attack / release envelope | Hardware time constant | Smoothed signal level |
| Sensor fusion (GPS smoothing) | Position low-pass | Tuned to road conditions | Displayed location |
In every row, the same recursion plays the role of ‘memory with controlled forgetting’. GABA simply names the variables and applies the recursion to multi-task weights.
Three Pitfalls That Break EMA Stabilisation
.detach(). Memory grows linearly with step count. After ~1,000 steps you OOM. Every line that writes back to the EMA buffer must call .detach() on the new value. Paper's grace/core/gaba.py:131 does this; mirror that pattern.torch.ones(n_tasks) / n_tasks) so the EMA starts at a sensible value and tracks down to the correct value over the first ~τ steps.step_count must be a buffer, not a Python int. The warmup gate (§18.4) checks before applying the closed form. Resuming training from a checkpoint at step 5,000 must NOT re-enter warmup because the checkpoint forgot the counter. Storing step_count as a buffer ensures it is saved by state_dict() and restored by load_state_dict(). The PyTorch demo above verifies this round-trip.Takeaway
- Raw per-batch λ is too noisy to use directly. ~50% per-step coefficient of variation on FD002. The optimiser would receive a different effective objective every step.
- EMA with paper β = 0.99 is paper eq. 5. A first-order IIR low-pass filter with single pole at . Convex combination ⇒ output stays on the simplex.
- Time constant τ = 1/(1−β) = 100 steps. Half-life ≈ 69 steps; 95% settling at 3τ = 300 steps; 3 dB cutoff at period 625 steps.
- Variance reduction is closed form. For i.i.d. noise: std-ratio = at β = 0.99. Empirically 19× on realistic (slightly correlated) gradient noise.
register_buffer+.detach()is the implementation contract. Buffer for persistence, detach for fixed-memory autograd. Forgetting either breaks the algorithm.- The same recursion appears everywhere. Adam moments, BatchNorm running stats, BYOL target, Polyak-averaged Q-targets, RiskMetrics volatility, GPS smoothing. GABA just instantiates it for multi-task weights.