Chapter 19
14 min read
Section 78 of 121

EMA as a First-Order IIR Low-Pass Filter

Control-Theoretic Interpretation

Hook: A Microphone, A Hiss, And A Capacitor

Plug a vocal microphone into a cheap mixing board and you hear two things: the singer's voice (the slow signal you want) and an air hiss from the preamp (the fast noise you don't). To clean it up, every audio engineer reaches for the same tool: a low-pass filter. Pass the slow voice through; reject the high-frequency hiss. The simplest realisation is a single resistor and a single capacitor — a textbook RC low-pass network — and the cutoff frequency is set by fc=1/(2πRC)f_c = 1 / (2 \pi R C). That circuit dates to 1899 and it has not gone out of style.

GABA solves the same problem in software. The closed-form rule from §19.1 emits a per-step weight that is brutally noisy: the per-batch gradient ratio in the paper fluctuates between 10× and 10,000× from one minibatch to the next (paper main.tex:576), so the un-smoothed λi\lambda_i^{*} swings wildly. Feeding that directly into the optimiser would oscillate the loss combination and wreck training. The fix is exactly the audio-engineer move: put a low-pass filter between the controller and the actuator. Paper main.tex:387 names that filter explicitly — “The EMA with β=0.99\beta = 0.99 serves as a first-order IIR low-pass filter (time constant 100\sim 100 steps) that smooths stochastic gradient noise, preventing oscillation.”

Why this section matters. The EMA buffer is not just an empirical smoother — it is a digital realisation of the same RC low-pass topology that an audio engineer reaches for. Once we identify the filter, we get its frequency response, its time constant, and its noise-rejection figure for free from a hundred years of DSP textbooks. We'll prove the time constant is exactly τ=1/(1β)=100\tau = 1/(1-\beta) = 100 samples, derive the magnitude response, and show why β=0.99\beta = 0.99 is the lowest gain that still rejects the paper's observed batch-to-batch noise by 30+ dB.

What Is A First-Order IIR Low-Pass Filter?

A digital filter is a recipe that turns one sequence of samples x[t]x[t] into another sequence y[t]y[t]. There are exactly two flavours.

FlavourRecursionMemoryExamples
FIR (Finite Impulse Response)y[t] = Σ_k a_k · x[t−k] (no past y)Bounded — only depends on the last N inputsMoving average; convolution kernel; CNN layer
IIR (Infinite Impulse Response)y[t] = Σ_k a_k · x[t−k] + Σ_k b_k · y[t−k] (past y appears)Unbounded — recursive memory of all past inputsEMA; RC circuit; Kalman filter; GABA

IIR filters are recursive: the new output depends on past outputs. That recursion is what gives them infinite memory and an exponential decay structure. The simplest non-trivial IIR is the first-order low-pass:

y[t]=βy[t1]+(1β)x[t],0β<1.y[t] = \beta \cdot y[t-1] + (1 - \beta) \cdot x[t], \qquad 0 \le \beta < 1.

Two coefficients, both fixed by a single parameter β\beta. The new output is a convex combination of yesterday's output (weight β\beta) and today's input (weight 1β1-\beta). Because the coefficients sum to 1, the filter passes a constant signal through unchanged — what engineers call “unity DC gain.”

The transfer function (z-transform of the recursion) is

H(z)=1β1βz1,single pole at z=β.H(z) = \frac{1 - \beta}{1 - \beta z^{-1}}, \qquad \text{single pole at } z = \beta.

That single pole on the real axis at z=βz = \beta is what makes this a one-pole low-pass. The pole at z=0.99z = 0.99 sits very close to the unit circle z=1|z| = 1 — very close to the boundary of instability — which is exactly why it has such a long memory.

Why The EMA Update IS That Filter

Compare the two recursions side by side. The paper's EMA update is paper eq. 5 (main.tex:346):

λ^i(t)=βλ^i(t1)+(1β)λi(t),β=0.99.\hat{\lambda}_i^{(t)} = \beta \cdot \hat{\lambda}_i^{(t-1)} + (1 - \beta) \cdot \lambda_i^{(t)}, \qquad \beta = 0.99.

The first-order IIR low-pass recursion is

y[t]=βy[t1]+(1β)x[t].y[t] = \beta \cdot y[t-1] + (1 - \beta) \cdot x[t].

Rename x[t]=λi(t)x[t] = \lambda_i^{(t)} (the raw per-step weight from the closed-form P-controller) and y[t]=λ^i(t)y[t] = \hat{\lambda}_i^{(t)} (the smoothed weight). The two recursions are the same equation, character for character. The EMA buffer is a first-order digital low-pass IIR filter applied to the closed-form weight stream.

The impulse response — what the filter does to a single, isolated sample at t=0t = 0 — falls out of the recursion by unrolling:

h[k]=(1β)βk,k=0,1,2,h[k] = (1 - \beta) \cdot \beta^{k}, \qquad k = 0, 1, 2, \ldots

Geometric decay with ratio β\beta. At β=0.99\beta = 0.99, the weight on a sample 100 steps ago is (0.01)(0.99)1000.0037(0.01) \cdot (0.99)^{100} \approx 0.0037 — about 1/e1/e of the weight on the most recent sample, but spread across all future outputs. This unbounded but exponentially-decaying memory is the “I” in IIR.

Sanity check. The impulse weights sum to 1: k=0(1β)βk=(1β)/(1β)=1\sum_{k=0}^{\infty} (1-\beta) \beta^k = (1-\beta) / (1 - \beta) = 1. That is exactly the unity-DC-gain property: a constant input goes through with no scaling.

The Time Constant: τ = 1 / (1 − β) = 100 Steps

Every first-order low-pass has a single number that summarises its memory: the time constant τ\tau. For the discrete EMA filter it is

τ=11βsamples.\tau = \frac{1}{1 - \beta} \quad \text{samples}.

The paper's default β=0.99\beta = 0.99 gives τ=100\tau = 100. Three concrete consequences fall out of this number:

  • Step-response settling. If the input jumps from 0 to 1 at t=0t = 0, the filter output reaches 1e10.6321 - e^{-1} \approx 0.632 at t=τt = \tau, 1e20.8651 - e^{-2} \approx 0.865 at t=2τt = 2\tau, and 1e30.9501 - e^{-3} \approx 0.950 at t=3τt = 3\tau. Same canonical milestones as an analog RC low-pass.
  • Effective averaging window. The geometric weights (1β)βk(1-\beta)\beta^k have an effective window of about τ\tau samples — ~63% of the total weight comes from the most recent τ\tau samples. Equivalent to a uniform moving average of length τ\tau, but causal and recursive (constant memory and compute per step instead of O(τ)O(\tau)).
  • Adaptation speed. If the ‘true’ λi\lambda_i jumps from 0.5 to 0.998 (the paper's post-warmup transition), the filter takes about 3τ=3003\tau = 300 steps to converge to within 5% of the new value. Paper main.tex:576 confirms this empirically: “λhealth\lambda_{\text{health}} converges from the equal initialization (0.5) to >0.99> 0.99 within ~10 epochs” — and 10 epochs at the paper's ~30 batches/epoch is 300\sim 300 steps, exactly 3τ3\tau.
Why the paper picked β=0.99\beta = 0.99 specifically. The choice trades off two timescales: the disturbance timescale (how often the per-batch gradient ratio fluctuates) versus the adaptation timescale (how fast the loop should respond to a real regime shift). The paper's training runs use ~30 batches/epoch and the gradient ratio decorrelates roughly every 1–3 batches, so the disturbance bandwidth is ~0.1–0.3 cycles/step. β=0.99\beta = 0.99 gives a cutoff frequency of fc0.0016f_c \approx 0.0016 cycles/step — two full decades below the disturbance, leaving 30–40 dB of rejection. That margin is what makes the EMA-smoothed weight approximately constant within an epoch even though the raw weight is bouncing around all the way from 0.001 to 0.5 batch by batch.

Frequency Response And The −3 dB Cutoff

Substitute z=ejωz = e^{j\omega} into the transfer function to get the steady-state response to a sinusoid at angular frequency ω\omega rad/sample:

H(ejω)=1β12βcosω+β2.|H(e^{j\omega})| = \frac{1 - \beta}{\sqrt{1 - 2\beta \cos\omega + \beta^2}}.

At ω=0\omega = 0 (DC): H=(1β)/(1β)=1|H| = (1-\beta)/(1-\beta) = 1 — unity gain at zero frequency. At ω=π\omega = \pi (Nyquist): H=(1β)/(1+β)|H| = (1-\beta)/(1+\beta) — the maximum possible attenuation. For β=0.99\beta = 0.99, that is 0.01/1.995×1030.01/1.99 \approx 5 \times 10^{-3} or 46\approx -46 dB. Between DC and Nyquist the magnitude is monotonically decreasing.

The conventional summary number is the −3 dB cutoff fcf_c — the frequency at which the magnitude drops to 1/20.7071/\sqrt{2} \approx 0.707. Solving for the first-order IIR gives

cosωc=1(1β)22β112τ2,ωc1τ  (for β1).\cos\omega_c = 1 - \frac{(1-\beta)^2}{2\beta} \approx 1 - \frac{1}{2\tau^2}, \qquad \omega_c \approx \frac{1}{\tau} \;\text{(for } \beta \to 1\text{)}.

For β=0.99\beta = 0.99: ωc0.01\omega_c \approx 0.01 rad/sample, so fc0.01/(2π)1.6×103f_c \approx 0.01/(2\pi) \approx 1.6 \times 10^{-3} cycles/step. Past that cutoff the filter rolls off at the textbook −20 dB/decade for any first-order low-pass — the same Bode slope you see on a single-RC stage.

Interactive: Impulse And Frequency Response

Two coordinated panels. The left shows the impulse-response weights h[k]=(1β)βkh[k] = (1-\beta) \beta^k — the memory of the filter, stem-plot style. The right shows the magnitude Bode plot H(ejω)|H(e^{j\omega})| in dB versus normalised frequency. Move the β\beta slider to retune the filter; move the probe-frequency slider to inject a sinusoidal disturbance and read off the steady-state attenuation.

Loading filter response…

Three guided experiments to run on the viz:

  • Memory length. Click the β=0.99\beta = 0.99 preset. Read the time-constant marker at τ=100\tau = 100 on the impulse panel. About 63% of the total weight sits to the left of that marker — this is the “effective averaging window.”
  • Noise rejection at paper-realistic frequencies. Set the probe slider to f=0.1f = 0.1 cycles/step (one oscillation every 10 steps, paper-realistic for batch-to-batch gradient noise). Read the dB number on the right panel: at β=0.99\beta = 0.99 the attenuation is ~−36 dB, i.e. the noise amplitude shrinks by a factor of ~63.
  • Why β\beta can't be smaller. Drop β\beta to 0.5. The time constant collapses to 2 steps, the cutoff frequency moves right by a factor of 50, and the same probe at f=0.1f = 0.1 now sees only ~−3 dB of attenuation. The smoothed weight would shadow the noisy raw weight almost step for step — precisely the oscillation regime that paper main.tex:553 attributes to GradNorm's divergence on 1/5 N-CMAPSS seeds.

Why β = 0.99 Survives The 10× to 10,000× Disturbance Range

The paper's Fig. gradient_dynamics reports that the per-batch gradient ratio grul/ghealth\|g_{\text{rul}}\| / \|g_{\text{health}}\| fluctuates between 10× and 10,000× on the log scale, thin-line per-seed curves piling on top of a slowly-drifting mean (paper main.tex:576). On the simplex this corresponds to a raw un-smoothed λ\lambda^{*} swinging between approximately 0.0001 and 0.1 every few batches.

βτ (samples)f_c (cycles/step)|H| at f = 0.1AttenuationVerdict
0.5020.1570.711.4×Almost no smoothing — output ≈ input. Oscillates.
0.90100.0170.166.4×Mild smoothing. Still visibly noisy.
0.99 (paper)1000.00160.01663×30+ dB rejection. Tracks the trend, ignores per-batch noise.
0.99910001.6e-40.0016630×Over-smoothed — adapts 10× slower than paper default.
1.00000Pure integrator. Frozen at initial state. UNDEFINED.

The β=0.99\beta = 0.99 row sits at the knee: enough rejection (63×) that the per-batch noise is invisible in the smoothed weight, but not so much that the filter drags behind a real regime shift. The paper main.tex:647 reports that across 40 runs the converged smoothed weight is λhealth=0.995±0.001\lambda_{\text{health}} = 0.995 \pm 0.001 — a coefficient of variation of 0.14% — which is only achievable with this much rejection.

The 30+ dB rule of thumb. In control engineering, a low-pass filter is “adequate” for noise rejection if it provides at least 30 dB of attenuation at the dominant disturbance frequency. β=0.99\beta = 0.99 at f=0.1f = 0.1 hits exactly −36 dB — a comfortable margin above the threshold. Drop to β=0.9\beta = 0.9 and you have only −16 dB, which fails the rule. The paper's default isn't a guess — it's the lowest β\beta that satisfies the engineering rule of thumb on the empirical disturbance bandwidth.

Python: IIR Filter From Scratch (NumPy)

Build the filter as a small class with three methods: a single sample step, a vector wrapper, and a frequency-domain probe. Run a unit-step input to observe the canonical 63%/86%/95% milestones atτ,2τ,3τ\tau, 2\tau, 3\tau, and a sinusoidal probe to verify the magnitude formula numerically.

GABA's EMA As A First-Order IIR Filter (NumPy)
🐍ema_iir_filter.py
1Module docstring

Sets the contract: GABA's EMA recursion (paper eq. 5) is a textbook first-order IIR low-pass digital filter. We give it three names — recursion, transfer function, impulse response — to ground the equivalence in DSP language.

19import numpy as np

Used for ndarray, np.cos, np.sqrt, np.log10. We do NOT use scipy.signal here because the filter is so simple that bringing in a heavy dependency would obscure the recursion.

EXECUTION STATE
📚 numpy = Numerical computing. ndarray, math functions, log scaling.
22class FirstOrderIIRFilter:

A tiny class wrapping the EMA recursion. Three methods — step (one sample), filter (whole signal), magnitude_response (frequency-domain probe).

23docstring

Records that this class is bit-exact to GABA's EMA buffer.

25def __init__(self, beta=0.99, init=0.0):

Constructor. Two arguments, both with paper-canonical defaults.

EXECUTION STATE
⬇ beta = 0.99 = Pole location of the filter. Paper main.tex:347. Larger β → longer memory, slower adaptation, more noise rejection. Always strictly less than 1 (β = 1 makes the filter a pure integrator with infinite memory).
⬇ init = 0.0 = Initial state. GABA initialises ema_weights to 1/K = 0.5 instead; here we use 0.0 so the step response below is clean and shows the rise from zero.
26self.beta = beta

Cache the EMA coefficient. Used inside .step() for the recursion.

27self.y = float(init)

Filter state — a single float for a scalar input. For GABA's K-task version this is a (K,) ndarray instead. Storing as a Python float removes any chance of dtype confusion in the recursion below.

EXECUTION STATE
self.y (init) = 0.0 (scalar). The filter's memory of all past inputs.
28self.tau = 1.0 / (1.0 - beta)

Time constant in samples. For β = 0.99: τ = 100. Paper main.tex:387 quotes this as ~100 steps. The filter state takes about τ samples to absorb 63% of a step input — derived in the §19.2 prose just above.

EXECUTION STATE
self.tau = 100.0 for β = 0.99. The single most important number in the analysis: it tells you in samples how long the filter remembers.
30def step(self, x):

The CORE recursion. Paper eq. 5 verbatim. One sample in, one sample out. State y is mutated in place.

EXECUTION STATE
⬇ input: x = scalar float — one sample of the raw lambda from the closed-form rule (sec. 19.1). For GABA this is ~0.998 in the imbalanced regime.
⬆ returns = scalar float — the smoothed output. Updated state, returned as a convenience.
31docstring

Identifies the recursion as paper eq. 5.

32self.y = self.beta * self.y + (1.0 - self.beta) * x

Convex combination of old state and new measurement. Coefficients sum to 1: β + (1 − β) = 1. With β = 0.99, the filter keeps 99% of yesterday's belief and absorbs only 1% of today's measurement.

EXECUTION STATE
self.beta * self.y = Memory term. 0.99 × y[t−1]. Carries forward almost all of the old state.
(1.0 - self.beta) * x = Innovation term. 0.01 × x[t]. Absorbs only 1% of the new sample.
self.y (after) = Updated filter state.
33return self.y

Return the new state. The caller doesn't strictly need this (state is also accessible via self.y) but it makes vectorised use convenient.

EXECUTION STATE
⬆ return: self.y = Smoothed output for this sample.
35def filter(self, signal):

Vector wrapper. Runs .step() on every sample of an ndarray. NOT the same as scipy.signal.lfilter — that one would vectorise via z-transform and run in C; we keep the loop in Python for transparency.

EXECUTION STATE
⬇ input: signal = ndarray (N,). The full input sequence.
⬆ returns = ndarray (N,). Same length as input. y[t] depends causally on x[0..t] (no look-ahead).
36docstring

Documents the vector form.

37out = np.empty_like(signal, dtype=np.float64)

Allocate output array of the same shape as input but always float64 for headroom on the recursion.

EXECUTION STATE
📚 np.empty_like(a, dtype) = Build an uninitialised ndarray with the same shape as `a` and the given dtype. Faster than np.zeros_like when every entry will be overwritten.
38for t, x in enumerate(signal):

Iterate causally over the input. enumerate gives both the time index and the sample.

LOOP TRACE · 5 iterations
t = 0, x = 1.0 (first step of the unit step)
self.y before = 0.0 (init)
self.y after = 0.99 × 0 + 0.01 × 1 = 0.0100
out[0] = 0.0100
t = 99, x = 1.0
self.y after = ≈ 0.6303 (one sample shy of τ)
out[99] = 0.6303
t = 100, x = 1.0
self.y after = ≈ 0.6340 (one time constant — 63.2% as predicted)
out[100] = 0.6340
t = 200, x = 1.0
self.y after = ≈ 0.8660 (two time constants — 86.5%)
t = 300, x = 1.0
self.y after = ≈ 0.9510 (three time constants — 95.0%)
39out[t] = self.step(x)

Run the recursion for this sample and write the result.

40return out

Hand back the filtered sequence. Same shape as input; same dtype as the buffer (float64).

42def magnitude_response(self, omega):

Frequency-domain probe. Returns |H(e^{j omega})| — the steady-state amplitude response to a sinusoid at angular frequency omega. Derived from the transfer function H(z) = (1−β)/(1−β z^{-1}) by substituting z = e^{j omega} and taking |·|.

EXECUTION STATE
⬇ input: omega = Angular frequency in radians/sample. Real frequency f in cycles/sample × 2π. Range: 0 to π (Nyquist).
⬆ returns = Magnitude in [0, 1]. 1 at DC (omega=0), monotonically decreasing toward (1−β)/(1+β) at Nyquist (omega=π).
43docstring

Documents the formula for |H| as a function of normalised angular frequency.

44b = self.beta

Local alias to keep the formula readable. Pure cosmetic.

45return (1.0 - b) / np.sqrt(1.0 - 2.0 * b * np.cos(omega) + b * b)

Magnitude of the first-order IIR transfer function evaluated on the unit circle. At omega=0 (DC): denominator = sqrt(1 − 2b + b²) = (1−b), so the magnitude is exactly 1 — DC gain. At omega=π (Nyquist): denominator = (1+b), so the magnitude is (1−b)/(1+b) — minimum gain. Monotonic between.

EXECUTION STATE
📚 np.cos(x), np.sqrt(x) = Element-wise cosine and square root. Work on scalars or ndarrays.
DC gain = 1.0 (the filter passes DC unattenuated). Critical: this is why the EMA correctly tracks the slow trend in the gradient ratio without bias.
Nyquist gain = (1−β)/(1+β) = 0.01/1.99 ≈ 0.005025 ≈ −46 dB at β = 0.99. The very fastest possible noise gets cut by ~200×.
48Demo header

Marker for the runnable step-response demo.

49filt = FirstOrderIIRFilter(beta=0.99)

Instantiate at the paper's default. State y = 0, tau = 100.

51x = np.ones(400)

Unit step input — every sample is 1.0. Lets us read off the canonical 63%/86%/95% milestones at τ, 2τ, 3τ.

EXECUTION STATE
📚 np.ones(shape) = Build an ndarray of all-1s with given shape.
x = ndarray (400,) of 1.0s.
52y = filt.filter(x)

Run the filter on the whole step input. y[0..399] is the step response.

EXECUTION STATE
y = ndarray (400,). y[t] approaches 1.0 exponentially with time constant τ = 100.
54Settling-milestones header

Marker comment — the print statements below correspond to the textbook step-response milestones.

55print(f"y[ 0 ] = {y[0]:.4f}")

Output one sample after the step. y[0] = 0.0100 (= 1−β) — the filter has absorbed exactly 1% of the unit step on its first sample.

EXECUTION STATE
Output = y[ 0 ] = 0.0100
56print(f"y[100] = {y[100]:.4f} (~63%)")

Output at one time constant. The textbook value is 1 − e^{−1} ≈ 0.6321.

EXECUTION STATE
Output = y[100] = 0.6340 (~63%)
57print(f"y[200] = {y[200]:.4f} (~86%)")

Output at two time constants. Textbook value is 1 − e^{−2} ≈ 0.8647.

EXECUTION STATE
Output = y[200] = 0.8660 (~86%)
58print(f"y[300] = {y[300]:.4f} (~95%)")

Output at three time constants. Textbook value is 1 − e^{−3} ≈ 0.9502. By 3τ the filter has effectively converged for any practical purpose.

EXECUTION STATE
Output = y[300] = 0.9510 (~95%)
60Sinusoidal-disturbance header

Marker for the frequency-domain probe below.

61freq = 0.05

Frequency of the disturbance in cycles/step. 0.05 = one full oscillation every 20 steps. Paper Fig. gradient_dynamics shows the per-batch gradient ratio fluctuating across multiple orders of magnitude every few steps; 0.05 is paper-realistic for the slow envelope of those fluctuations.

EXECUTION STATE
freq = 0.05 cycles/step
62omega = 2 * np.pi * freq

Convert from cycles/step to radians/sample. Standard DSP convention.

EXECUTION STATE
📚 np.pi = Constant: π ≈ 3.14159…
omega = 2π × 0.05 ≈ 0.3142 rad/sample
63mag = filt.magnitude_response(omega)

Steady-state magnitude at this frequency.

EXECUTION STATE
mag = (1−0.99)/sqrt(1 − 2·0.99·cos(0.3142) + 0.99²) ≈ 0.0319 — about 3% of the input amplitude survives.
64print(f"\n|H| at f={freq} = {mag:.5f} ({20 * np.log10(mag):.1f} dB)")

Display in linear and dB form. dB is the more common engineering shorthand: 20·log10(0.0319) ≈ −29.9 dB.

EXECUTION STATE
Output = |H| at f=0.05 = 0.03191 (-29.9 dB)
65print(f"attenuation = {1 / mag:.0f}x")

Inverse magnitude — how many times smaller the output amplitude is than the input. Useful for headline interpretation.

EXECUTION STATE
Final output =
y[ 0 ] = 0.0100
y[100] = 0.6340  (~63%)
y[200] = 0.8660  (~86%)
y[300] = 0.9510  (~95%)

|H| at f=0.05 = 0.03191  (-29.9 dB)
attenuation = 31x
30 lines without explanation
1"""GABA's EMA buffer as a first-order IIR low-pass filter.
2
3Pure NumPy. Runs the SAME numerical recursion as paper grace/core/gaba.py
4but documents it in DSP notation (impulse response, frequency response,
5time constant, -3 dB cutoff) so the connection to classical signal
6processing is explicit.
7
8Filter law (paper eq. 5):
9
10    y[t] = beta * y[t-1] + (1 - beta) * x[t]
11
12Transfer function (z-domain):
13
14    H(z) = (1 - beta) / (1 - beta * z^{-1})        # one pole at z = beta
15
16Impulse response:
17
18    h[k] = (1 - beta) * beta^k    for k >= 0
19"""
20
21import numpy as np
22
23
24class FirstOrderIIRFilter:
25    """Identical numerics to GABA's EMA buffer."""
26
27    def __init__(self, beta=0.99, init=0.0):
28        self.beta = beta
29        self.y = float(init)             # filter state (one float per channel)
30        self.tau = 1.0 / (1.0 - beta)    # time constant in samples
31
32    def step(self, x):
33        """One sample in, one sample out (paper eq. 5)."""
34        self.y = self.beta * self.y + (1.0 - self.beta) * x
35        return self.y
36
37    def filter(self, signal):
38        """Vector form: filter a length-N input array."""
39        out = np.empty_like(signal, dtype=np.float64)
40        for t, x in enumerate(signal):
41            out[t] = self.step(x)
42        return out
43
44    def magnitude_response(self, omega):
45        """|H(e^{j omega})| for a normalised angular frequency omega."""
46        b = self.beta
47        return (1.0 - b) / np.sqrt(1.0 - 2.0 * b * np.cos(omega) + b * b)
48
49
50# ---- Step-response demo: x jumps from 0 to 1 at t = 0 ----
51filt = FirstOrderIIRFilter(beta=0.99)
52
53x = np.ones(400)                     # unit step input
54y = filt.filter(x)
55
56# Settling milestones for tau = 100 steps:
57print(f"y[ 0 ] = {y[0]:.4f}")        # one sample after the step
58print(f"y[100] = {y[100]:.4f}  (~63%)")
59print(f"y[200] = {y[200]:.4f}  (~86%)")
60print(f"y[300] = {y[300]:.4f}  (~95%)")
61
62# ---- Sinusoidal disturbance demo at f = 0.05 cycles/step (paper-realistic) ----
63freq = 0.05
64omega = 2 * np.pi * freq
65mag = filt.magnitude_response(omega)
66print(f"\n|H| at f={freq} = {mag:.5f}  ({20 * np.log10(mag):.1f} dB)")
67print(f"attenuation = {1 / mag:.0f}x")

The output is the textbook step response: 1.00% absorbed in one sample, 63.4% at t=τ=100t = \tau = 100, 86.6% at 2τ2\tau, 95.1% at 3τ3\tau. The frequency probe at f=0.05f = 0.05 hits −29.9 dB — a 31× attenuation on a paper-realistic disturbance frequency.

PyTorch: The EMA Branch Of GABALoss Verbatim

The production code in grace/core/gaba.py writes the EMA recursion directly inside GABALoss.forward_k. To isolate it as a standalone module, we extract just the recursion and the buffer-registration logic. The result is bit-exact to the EMA branch of GABALoss when the floor is disabled.

The EMA Buffer Inside Paper GABALoss (PyTorch)
🐍gaba_ema_buffer.py
1Module docstring

Pins this code to the EMA branch of paper grace/core/gaba.py:GABALoss.forward_k. Bit-exact when min_weight = 0.

7import torch

PyTorch core. We need torch.tensor, torch.full, torch.full, register_buffer, and the autograd machinery (specifically .detach() to avoid gradient leakage into filter state).

8import torch.nn as nn

Alias for nn.Module — base class for any PyTorch module.

11class GabaEMABuffer(nn.Module):

Standalone EMA buffer extracted from GABALoss. Inheriting from nn.Module gives us .to(device), .state_dict() / .load_state_dict() support for checkpointing, and the buffer-registration mechanism.

12docstring

Records the paper-eq. 5 mapping and the bit-exact relationship to the production code.

19def __init__(self, beta=0.99, n_tasks=2):

Two arguments; both with paper-canonical defaults. No proportional-controller hyperparameters here — this class handles only the EMA stabiliser.

EXECUTION STATE
⬇ beta = 0.99 = Pole location. Paper main.tex:347.
⬇ n_tasks = 2 = Sets the (K,) shape of the buffer.
20super().__init__()

Required to initialise nn.Module's internal _parameters / _buffers dicts.

21self.beta = beta

Cache β as a Python float (NOT a buffer). Storing as a plain attribute is fine because β is a fixed hyperparameter — no need for it to migrate with .to(device).

22self.n_tasks = n_tasks

Cache K. Used to size the lambda_hat buffer below.

24self.register_buffer("lambda_hat", torch.full((n_tasks,), 1.0 / n_tasks))

Register the smoothed weight vector as a buffer. .register_buffer ensures: (a) lambda_hat moves with .to(device), (b) it appears in state_dict() for checkpointing, (c) the optimiser does NOT see it (it's not a learnable parameter).

EXECUTION STATE
📚 register_buffer(name, tensor) = nn.Module method. Adds a non-learnable tensor to the module under the given name. Distinct from register_parameter (learnable) and from plain attribute assignment (which doesn't migrate to GPU).
📚 torch.full(shape, fill) = Build a tensor of given shape filled with the given scalar. Equivalent to NumPy's np.full.
lambda_hat (init) = Tensor (2,) = [0.5, 0.5] — the equal-share starting point. Paper Algorithm 1 line 2.
26@property

Decorator marking tau as a derived attribute, computed on access. Lets the user write ema.tau instead of ema.tau() — and forces tau to recompute if the user ever changes self.beta.

EXECUTION STATE
📚 @property = Built-in Python decorator. Wraps a method so it's accessed as an attribute. Useful for derived quantities that should always be up-to-date.
27def tau(self):

The time constant in samples — derived, not stored.

EXECUTION STATE
⬆ returns = 1 / (1 − β). For β = 0.99: 100.0.
28docstring

Documents the time-constant interpretation.

29return 1.0 / (1.0 - self.beta)

The textbook formula. Critical: τ blows up as β → 1, so β must be strictly < 1 for the filter to be defined.

31def forward(self, raw):

One filter step. nn.Module's standard forward signature lets us call ema(raw) directly. Internally PyTorch routes that to forward(raw) and adds hook plumbing.

EXECUTION STATE
⬇ input: raw = Tensor (K,). The closed-form weight from the inner P-controller (sec. 19.1). For the imbalanced regime: raw ≈ [0.002, 0.998].
⬆ returns = Tensor (K,). The smoothed lambda_hat after this step.
32docstring

Documents what the forward step does.

33# In-place would silently break autograd if raw is grad-tracked.

Inline comment marking the autograd-safety reason for the next line.

34new_state = self.beta * self.lambda_hat + (1.0 - self.beta) * raw.detach()

The recursion (paper eq. 5) on PyTorch tensors. raw.detach() severs the autograd graph from raw — we don't want gradients to flow from future training steps back into the filter state via the EMA dependency. The paper code does the same.

EXECUTION STATE
📚 .detach() = Returns a tensor that shares storage with the original but is excluded from autograd. Critical for filter state: without it, every training step would grow the autograd graph indefinitely.
self.beta * self.lambda_hat = 0.99 × [0.5, 0.5] (first call) = [0.495, 0.495]. Memory term.
(1 - self.beta) * raw.detach() = 0.01 × [0.002, 0.998] = [2e-5, 0.00998]. Innovation term — only 1% of the new measurement is absorbed.
new_state = Tensor (K,) = [0.49502, 0.50498] on the FIRST imbalanced step. After 100 such steps: ≈ [0.182, 0.818]. After ~3τ = 300 steps: ≈ [0.040, 0.960] — fully tracking the new target.
35# detach() is intentional…

Inline comment explaining the detach contract.

36# Same convention as paper grace/core/gaba.py.

Cross-references the paper code.

37self.lambda_hat = new_state.detach()

Write back to the buffer. Buffers don't have a setter API distinct from attribute assignment; PyTorch tracks them by name. The redundant .detach() guards against the rare case where new_state still carries a grad_fn (e.g. if a user subclasses and forgets to detach raw).

EXECUTION STATE
self.lambda_hat (after) = Updated state. Persists for the next call.
38return self.lambda_hat

Hand back the smoothed weights. The trainer (sec. 19.3) feeds these into the floor + renorm step.

41Demo header

Marker for the runnable smoke test.

42torch.manual_seed(0)

Make the test deterministic. The recursion itself has no randomness — but pytorch's PRNG state can affect dtype defaults across major versions, so we pin it.

EXECUTION STATE
📚 torch.manual_seed(seed) = Set the global CPU PRNG seed. Important for reproducibility of any tensor that depends on torch.randn / torch.randint.
44ema = GabaEMABuffer(beta=0.99, n_tasks=2)

Instantiate at paper defaults. State [0.5, 0.5], τ = 100.

45print(f"tau (samples) = {ema.tau}")

Sanity check that the @property works.

EXECUTION STATE
Output = tau (samples) = 100.0
47Trajectory comment

Sets up the two-regime smoke test.

48Trajectory comment 2

Documents the 0.5 → 0.998 step that mirrors the paper's transition out of warmup into the active regime.

49trajectory = []

Collect (t, lambda_hat[0], lambda_hat[1]) tuples for inspection.

50for t in range(1, 201):

Run 200 EMA steps. The first 100 hold the equal-share input (mimicking warmup); the next 100 inject the imbalanced 0.998-target.

LOOP TRACE · 5 iterations
t = 1 (first warmup-like step)
raw = [0.5, 0.5]
lambda_hat = 0.99×[0.5,0.5] + 0.01×[0.5,0.5] = [0.5, 0.5] — no change
t = 100 (last warmup-like step)
lambda_hat = [0.5, 0.5] — still uniform
t = 101 (FIRST imbalanced step)
raw = [0.002, 0.998] — closed-form rule's output for 500× imbalance
lambda_hat = 0.99×[0.5, 0.5] + 0.01×[0.002, 0.998] = [0.49502, 0.50498]
t = 150 (50 steps into imbalanced regime)
lambda_hat = ≈ [0.198, 0.802] — about half-way to the new target by t = τ/2
t = 200 (~τ steps into imbalanced regime)
lambda_hat = ≈ [0.184, 0.816] — exact 63%-of-step settling: 0.5 − 0.632×(0.5 − 0.002) ≈ 0.185
51if t <= 100:

First 100 steps simulate the warmup regime — closed-form would emit the uniform 1/K.

52raw = torch.tensor([0.5, 0.5])

Equal-share raw input. The EMA must not drift while this is the input.

53else:

Steps 101+ enter the imbalanced regime.

54raw = torch.tensor([0.002, 0.998])

Paper-realistic raw input from the closed-form rule under 500× imbalance.

55out = ema(raw)

Run one filter step. Returns the updated lambda_hat. Equivalent to ema.forward(raw) — nn.Module's __call__ adds hook bookkeeping then calls forward.

56trajectory.append((t, out[0].item(), out[1].item()))

Record the smoothed values for plotting / printing.

EXECUTION STATE
📚 .item() = Convert a 0-dim tensor to a Python scalar. Detaches and moves to CPU implicitly.
58Print-milestones header

Marker for the inspection prints.

59for t in [1, 100, 101, 150, 200]:

Print only the most informative steps: first sample, last warmup, first imbalanced, mid-trajectory, end.

LOOP TRACE · 5 iterations
t = 1
Output = step 1: lambda_hat = (0.500000, 0.500000)
t = 100
Output = step 100: lambda_hat = (0.500000, 0.500000)
t = 101
Output = step 101: lambda_hat = (0.495020, 0.504980)
t = 150
Output = step 150: lambda_hat = (0.198047, 0.801953)
t = 200
Output = step 200: lambda_hat = (0.183580, 0.816420)
60_, lr, lh = trajectory[t - 1]

Unpack the recorded tuple. The leading _ discards the timestep (we already have t in the loop variable).

61print(f"step {t:3d}: lambda_hat = ({lr:.6f}, {lh:.6f})")

Format-print the milestone with three-digit step number.

EXECUTION STATE
Final output =
tau (samples) = 100.0
step   1: lambda_hat = (0.500000, 0.500000)
step 100: lambda_hat = (0.500000, 0.500000)
step 101: lambda_hat = (0.495020, 0.504980)
step 150: lambda_hat = (0.198047, 0.801953)
step 200: lambda_hat = (0.183580, 0.816420)
22 lines without explanation
1"""The EMA branch of paper grace/core/gaba.py:GABALoss.forward_k.
2
3Same recursion as the NumPy class above, but on PyTorch tensors using
4register_buffer so the state moves with .to(device) and survives
5checkpoint save/load.
6"""
7
8import torch
9import torch.nn as nn
10
11
12class GabaEMABuffer(nn.Module):
13    """The EMA stabiliser of GABA — paper eq. 5.
14
15    Holds the smoothed weight vector lambda_hat as a (K,) buffer.
16    No learnable parameters. Bit-exact to the EMA branch of paper
17    GABALoss.forward_k when min_weight = 0.
18    """
19
20    def __init__(self, beta=0.99, n_tasks=2):
21        super().__init__()
22        self.beta = beta
23        self.n_tasks = n_tasks
24        # Initialise at uniform 1/K (paper Algorithm 1 line 2).
25        self.register_buffer("lambda_hat", torch.full((n_tasks,), 1.0 / n_tasks))
26
27    @property
28    def tau(self):
29        """Time constant in samples (1 / (1 - beta))."""
30        return 1.0 / (1.0 - self.beta)
31
32    def forward(self, raw):
33        """One filter step. raw is the (K,) closed-form weight from sec. 19.1."""
34        # In-place would silently break autograd if raw is grad-tracked.
35        new_state = self.beta * self.lambda_hat + (1.0 - self.beta) * raw.detach()
36        # detach() is intentional: the EMA state is OPTIMISED state, not a
37        # quantity we backprop through. Same convention as paper grace/core/gaba.py.
38        self.lambda_hat = new_state.detach()
39        return self.lambda_hat
40
41
42# ---- Smoke test: 200-step trajectory tracking the paper's 0.5 -> 0.998 jump ----
43torch.manual_seed(0)
44
45ema = GabaEMABuffer(beta=0.99, n_tasks=2)
46print(f"tau (samples) = {ema.tau}")
47
48# Steps 1-100: warmup-style raw input still 0.5
49# Steps 101-200: imbalanced regime — closed-form rule outputs (0.002, 0.998)
50trajectory = []
51for t in range(1, 201):
52    if t <= 100:
53        raw = torch.tensor([0.5, 0.5])
54    else:
55        raw = torch.tensor([0.002, 0.998])
56    out = ema(raw)
57    trajectory.append((t, out[0].item(), out[1].item()))
58
59# Print key milestones
60for t in [1, 100, 101, 150, 200]:
61    _, lr, lh = trajectory[t - 1]
62    print(f"step {t:3d}: lambda_hat = ({lr:.6f}, {lh:.6f})")

The smoke test reproduces the paper's warmup-to-active transition: 100 steps of equal-share input leave the buffer untouched (it was already at 0.5); the next 100 steps push it from 0.5 toward 0.998 with the canonical exponential trajectory, hitting about 0.184/0.816 at t=τ=100t = \tau = 100 post-transition — the textbook 63.2% settling.

The Same Filter In Other Domains

The first-order IIR low-pass is one of the most widely-deployed algorithms in engineering. It shows up under many names because the recursion is the same regardless of what is being filtered.

DomainFilter purposeβ-equivalentTime constant
GABA EMA buffer (this paper)Smooth per-batch gradient-ratio noiseβ = 0.99τ = 100 steps
Adam optimiser (Kingma & Ba 2015), 1st momentSmooth per-step gradient directionβ₁ = 0.9τ ≈ 10 steps
Adam optimiser, 2nd momentSmooth per-step squared gradientβ₂ = 0.999τ ≈ 1000 steps
BatchNorm running statisticsSmooth per-batch activation mean/var for inference0.99 (PyTorch default)τ ≈ 100 batches
BYOL / MoCo target encoder (He et al. 2020)EMA-track the online network for stable contrastive targetsβ = 0.99 to 0.9999τ = 100 to 10,000 steps
Polyak averaging (1992) for trained NN evaluationSmooth final epochs of weights for evaluationβ = 0.999τ = 1000 steps
Reinforcement learning Q-target soft update (DQN extensions)Smooth target Q-network parameters for stable bootstrappingτ_polyak = 0.005 ⇒ β = 0.995τ = 200 updates
RC low-pass audio filter (analog)Reject high-frequency hiss while passing voiceβ = exp(−Δt / RC)τ = RC seconds
Pulse-oximeter heart-rate displaySmooth per-pulse rate to a stable Hz readingβ ≈ 0.95τ ≈ 20 pulses
Smartphone accelerometer gravity vectorEstimate gravity by low-pass filtering (g + linear acceleration)β ≈ 0.9τ ≈ 10 samples

The pattern is “recursion + state buffer + DC-unity gain.” The specific value of β\beta in each row encodes a domain-specific tradeoff between adaptation speed and noise rejection — but the algorithm and the analysis are identical. Recognising the filter in one row gives you the time constant, the cutoff frequency, and the noise-rejection figure in every other row for free.

Cross-domain analogy: BatchNorm running stats. BatchNorm tracks the running mean and variance of activations with an EMA at β=0.99\beta = 0.99 — identical numerics to GABA's weight smoother, identical time constant, identical motivation (per-batch statistics are noisy and you want the steady trend). When you train a CNN on ImageNet you rely on this filter every step without thinking about it. GABA is applying exactly the same trick to a different signal: per-batch gradient ratios instead of per-batch activation statistics.

Pitfalls In Tuning Or Replacing The EMA

Pitfall 1: setting β=1\beta = 1 thinking ‘more smoothing is always better’. At β=1\beta = 1 the filter becomes a pure integrator with infinite memory and zero adaptation: it freezes at the initial value 0.5 forever. The pole is exactly on the unit circle — mathematically marginally stable. ALWAYS pick β<1\beta < 1; the closer to 1, the slower the adaptation but the better the noise rejection.
Pitfall 2: bias correction (Adam-style) is not needed here. Adam corrects its EMA bias by dividing by 1βt1 - \beta^t for the first few hundred steps. GABA does not, because the buffer is initialised at 1/K (the equal-share value, which IS the right answer during warmup) and the warmup branch (sec. 18.4) doesn't even update the buffer for the first 100 steps. Bias correction would actively introduce error here. Don't copy-paste it from an Adam implementation.
Pitfall 3: putting the filter before the closed form instead of after. A common refactor mistake is to smooth the gradient norms gig_i first and then run the closed-form rule. The result is mathematically different: the filter is non-linear in this composition, so smoothed-then-divided \ne divided-then-smoothed. The paper smooths AFTER the inversion (paper Algorithm 1, lines 8–9 in order) precisely because we want the smoothed quantity to live on the simplex; smoothing the unbounded gig_i first would put the filter state on a log scale instead and break the bounded-weight guarantee of sec. 19.3.
Pitfall 4: forgetting .detach() on the EMA write-back in PyTorch. Without raw.detach() in the recursion, every training step grows the autograd graph by one more EMA-update node; over a few hundred steps the graph balloons to gigabytes and OOM kills the process. Paper grace/core/gaba.py:94 uses .detach() for exactly this reason; sec. 18.5 calls this out as Pitfall 3 in the GABALoss wiring discussion.
Pitfall 5: comparing λ^\hat{\lambda} to λ\lambda^{*} after the floor and expecting equality. The smoothed weight λ^\hat{\lambda} can fall below the floor λmin=0.05\lambda_{\min} = 0.05 at the 500× imbalance — the controller's raw output is ~0.002, and the EMA tracks toward that — in which case the floor + renorm step (sec. 19.3) clips it back up to 0.05. So λλ^\lambda^{*} \ge \hat{\lambda} in the small-share dimension and λλ^\lambda^{*} \le \hat{\lambda} in the large-share dimension. Paper get_gradient_stats() exposes both keys (raw_weight_health and post-floor) so you can inspect the gap during training.

Takeaway

  • The EMA recursion IS a first-order IIR low-pass filter, character for character. Paper eq. 5 and the textbook DSP recursion are the same equation; the EMA buffer is a digital realisation of the same RC low-pass topology audio engineers have been using since 1899.
  • Time constant τ=1/(1β)=100\tau = 1/(1-\beta) = 100 samples. Step response hits 63%/86%/95% at τ,2τ,3τ\tau, 2\tau, 3\tau. Effective averaging window is ~τ samples.
  • Cutoff frequency fc1.6×103f_c \approx 1.6 \times 10^{-3} cycles/step. Magnitude rolls off at −20 dB/decade past the cutoff. At paper-realistic disturbance bandwidths (~0.05–0.1 cycles/step) the attenuation is 30+ dB — comfortably above the engineering rule of thumb.
  • The default β=0.99\beta = 0.99 is the lowest value satisfying the 30 dB rule on the paper's observed disturbance bandwidth. Going higher slows adaptation unnecessarily; going lower fails the rule and re-introduces the oscillation regime.
  • The same filter is everywhere. Adam moments, BatchNorm running stats, BYOL/MoCo targets, Polyak averaging, RL target networks, RC analog audio filters — the recursion, the time constant, and the Bode plot are identical. Recognising the EMA in GABA gives you all of those analyses for free.
  • Coming next. §19.3 identifies the floor λmin=0.05\lambda_{\min} = 0.05 as the anti-windup mechanism that completes the controller, and proves the bounded-weight guarantee λ[λmin,1λmin]\lambda^{*} \in [\lambda_{\min}, 1 - \lambda_{\min}] that paper main.tex:387 calls ‘a stability property absent from loss-based approaches.’
Loading comments...