Chapter 18
14 min read
Section 74 of 121

Minimum Floor and Renormalization

The GABA Algorithm

The Pressure-Relief Valve

Industrial boilers have a small spring-loaded valve on top. 99% of the time it does nothing. The boiler operates far below the pressure threshold and the valve is invisible. But on the rare day the controller mis-reads a sensor and pressure climbs toward the rupture limit, the valve flicks open and dumps steam — preventing a catastrophic failure that no controller logic could recover from in time. The valve is cheap, simple, and always armed.

GABA needs the same component. The closed form λi=gj/(gi+gj)\lambda^*_i = g_j / (g_i + g_j) is mathematically beautiful but operationally brittle: if ghealthg_{\text{health}} ever drops to zero (perfect classification on a batch), λrul\lambda_{\text{rul}} snaps to 0 and the regression head stops receiving any gradient. The optimiser silently abandons one task. The floor is the pressure-relief valve.

Paper Eq. 6: λi=max(λ^i,λmin)/jmax(λ^j,λmin)\lambda^*_i = \max(\hat{\lambda}_i, \lambda_{\min}) / \sum_{j} \max(\hat{\lambda}_j, \lambda_{\min}) with λmin=0.05\lambda_{\min} = 0.05. The paper (main.tex:387) calls this an ‘anti-windup mechanism ensuring no task is fully suppressed’. Two operations: floor at λmin\lambda_{\min}, then renormalise to the simplex.

What Happens Without A Floor

Run a 1,000-step training trace with the §17.3 closed form and §18.2 EMA but NO floor. Three pathologies emerge, any of which silently degrade training:

  • Catastrophic suppression. A streak of batches where ghealthg_{\text{health}} is unusually small (e.g. all 64 samples are easy ‘Normal’ class) drives λ^rul\hat{\lambda}_{\text{rul}} near zero. The next 100+ steps see no RUL gradient. When the easy streak ends, the regression head has regressed.
  • Asymmetric recovery. Once λ^rul\hat{\lambda}_{\text{rul}} is near zero, the EMA can only climb back at rate (1−β) per step. With β=0.99, recovering from λ^rul=0.0001\hat{\lambda}_{\text{rul}} = 0.0001 to 0.010.01 takes ~460 steps. That is a permanent cost paid in the loss landscape.
  • Noise amplification. λ\lambda^* near zero is in a regime where the closed form ghealth/(grul+ghealth)g_{\text{health}} / (g_{\text{rul}} + g_{\text{health}}) has very high relative sensitivity (§17.3 sensitivity analysis: derivative ~0.2 when g_health is small). Near-zero weights amplify per-batch fluctuations and increase the effective variance the EMA has to fight.

The floor short-circuits all three failure modes by guaranteeing that every task always has at least λmin/(1+λmin)λmin\lambda_{\min}/(1+\lambda_{\min}) \approx \lambda_{\min} weight, regardless of the EMA value. No streak of bad batches can fully suppress a task.

The Floor Formula (Paper Eq. 6)

The two-line transformation:

(floor)ci=max(λ^i,λmin)(renormalise)λi=ci/j=1Kcj\text{(floor)} \quad c_i = \max(\hat{\lambda}_i, \lambda_{\min}) \qquad \text{(renormalise)} \quad \lambda^*_i = c_i \Big/ \sum_{j=1}^{K} c_j

Floor takes any weight below λmin\lambda_{\min} and lifts it to λmin\lambda_{\min}; weights above the floor pass through unchanged. After flooring, the sum can exceed 1 (clamping only INCREASES values), so the renormalisation step divides by the new sum to put the result back on the simplex. Together the two operations preserve the property iλi=1\sum_i \lambda^*_i = 1 while bounding λi\lambda^*_i away from zero.

Anti-Windup: A Control-Theoretic Lens

In classical PID control, ‘windup’ happens when an integrator accumulates error during a period when the actuator is saturated. By the time the actuator un-saturates, the integrator holds a huge accumulated value that takes seconds to discharge — producing large overshoot and oscillation. The fix is anti-windup: clamp the integrator to a bounded range so the accumulator can never blow up.

GABA's EMA λ^i\hat{\lambda}_i is functionally an integrator. Without a floor, it can drift toward zero on a long bad streak. The floor is the anti-windup clamp: the integrator state is forbidden from leaving [λmin,1(K1)λmin][\lambda_{\min}, 1 - (K-1) \lambda_{\min}]. The paper (main.tex:387, 688) explicitly draws this analogy: ‘The floor λmin=0.05\lambda_{\min} = 0.05 acts as an anti-windup mechanism ensuring no task is fully suppressed.’

The Bounded-Weight Guarantee

Algebraically, the floor + renormalisation guarantees a closed-form bound on the output. For K=2:

λi[λmin1+λmin,  11+λmin]\lambda^*_i \in \left[ \frac{\lambda_{\min}}{1 + \lambda_{\min}}, \; \frac{1}{1 + \lambda_{\min}} \right]

For λmin=0.05\lambda_{\min} = 0.05 the bound evaluates to [0.04762,0.95238][0.04762, 0.95238] — within rounding of the paper's loose statement [λmin,1λmin]=[0.05,0.95][\lambda_{\min}, 1 - \lambda_{\min}] = [0.05, 0.95]. The Python demo below verifies this saturation point exactly. The general K-task bound is:

λi[λmin1+(K1)λmin,  11+(K1)λmin]\lambda^*_i \in \left[ \frac{\lambda_{\min}}{1 + (K - 1) \lambda_{\min}}, \; \frac{1}{1 + (K - 1) \lambda_{\min}} \right]

This is the property GradNorm cannot guarantee. Recall §17.4: GradNorm's learnable weights wiw_i can take any real value, including negative, after one aggressive aux-loss SGD step. GABA's weights are provably bounded in [0.04762,0.95238][0.04762, 0.95238] at every step, on every dataset, with every random seed. That bound is what the paper calls in main.tex:387 a ‘stability property absent from loss-based approaches’.

Why λmin=0.05\lambda_{\min} = 0.05

The paper's value of 0.05 is empirical. Three considerations drove the choice:

ConstraintImplication for λ_minPaper's pick
Each task should ALWAYS contribute non-triviallyλ_min > 00.05 ensures ≥ ~5% of the loss budget per task
The floor must NOT distort the closed form when imbalance is realλ_min should be small relative to typical λ-spread0.05 ≪ 0.5; only kicks in for extreme imbalance
Hyperparameter robustness across datasetsResult must be insensitive to exact valuePaper §5.8 ablation: results stable for λ_min in [0.02, 0.10]

The paper's hyperparameter robustness section (paper §5.8) reports that GABA results are statistically indistinguishable for λmin[0.02,0.10]\lambda_{\min} \in [0.02, 0.10] across the C-MAPSS subsets, so 0.05 is comfortably in the robust regime.

Interactive: Watch The Floor Engage

Drag the raw EMA slider from 0 to 1. The middle bars (clamped) light up amber when the floor is active. The right bars (renormalised) show the final λ\lambda^*. The bottom panel plots the input-output transfer curve so you can see the ‘knee’ at λmin\lambda_{\min} explicitly.

Loading floor + renormalisation visualizer…
Try this. Drag raw λ_rul to 0. Without the floor, output would also be 0. With the paper's λ_min = 0.05, output is pinned at 0.04762 — the analytic lower bound. Now move λ_min to 0 and watch the bottom curve become y = x: no anti-windup, no bound, full task suppression possible. Move λ_min to 0.5 and the curve collapses to a horizontal line at 0.5: the floor is so aggressive that GABA degenerates to uniform weighting, defeating the closed form.

Python: Three Scenarios From Scratch

Implement floor_and_renorm in pure NumPy and exercise it on three representative inputs: paper-realistic 500× imbalance, pathological near-zero, and balanced no-op. The third confirms the graceful-degradation property: the floor does nothing when nothing needs fixing.

Paper Eq. 6 — three scenarios + bounded-weight guarantee
🐍floor_and_renorm.py
1docstring

Module docstring: implements paper equation 6 verbatim and tests three illustrative scenarios. The point is to make the bounded-weight guarantee concrete on three realistic input regimes.

3import numpy as np

NumPy supplies the ndarray, np.array, np.maximum (the element-wise floor), and np.allclose (numeric equality test).

EXECUTION STATE
📚 numpy = Numerical computing library. Used for ndarray, np.maximum, np.allclose.
6def floor_and_renorm(ema, lam_min) → np.ndarray

Two lines of meaningful code that implement the entire floor + renormalise step. Takes the EMA-smoothed weights from §18.2 and returns the bounded, simplex-valued lambda* used by the optimiser.

EXECUTION STATE
⬇ input: ema = ndarray (K,). EMA-smoothed task weights from §18.2. Sums to 1, but individual entries can be arbitrarily small or close to 1.
⬇ input: lam_min = Float. The minimum allowed weight per task. Paper sets 0.05.
⬆ returns = ndarray (K,). The floored + renormalised weights. Bounded on the simplex, always sums to 1.
7docstring

Records the formula and bound. The bound for K=2 is approximately [lam_min, 1-lam_min] — paper main.tex:387, 688.

13clamped = np.maximum(ema, lam_min)

Element-wise floor: any ema entry below lam_min is replaced by lam_min; entries above are passed through unchanged.

EXECUTION STATE
📚 np.maximum(a, b) = Element-wise max of two arrays (or array + scalar). Returns max(a_i, b_i) per element. NOT to be confused with np.max which reduces an array.
⬇ arg 1: ema = The EMA-smoothed weights (any value in [0, 1]).
⬇ arg 2: lam_min = Scalar. Broadcasting lifts it to compare against every element of ema.
→ maximum vs max = np.maximum(ema, 0.05) returns a (K,) array with each element floored at 0.05. np.max(ema) returns the SINGLE largest value (a scalar). Different functions, different jobs.
→ example = ema = [0.002, 0.998]: clamped = [max(0.002, 0.05), max(0.998, 0.05)] = [0.05, 0.998]. The first entry was floored; the second was left alone.
clamped = Each entry now ≥ lam_min. Sum can exceed 1 because clamping can only INCREASE values, never decrease them.
14return clamped / clamped.sum()

Renormalise to the simplex. Element-wise divide by the new sum. The output is non-negative and sums to 1 by construction.

EXECUTION STATE
📚 .sum() = ndarray reduction: sum every element, return scalar.
/ (broadcast) = Element-wise division: ndarray / scalar divides every element by the same scalar.
→ why renormalise? = After clamping, sum(clamped) > 1 (or = 1 if no clamping engaged). Without dividing, the ‘loss combination’ would scale up. Dividing puts the weights back on the simplex so the combined-loss scale stays consistent across steps.
⬆ return = ndarray (K,). On the simplex; bounded in [lam_min/(1+lam_min*(K-1)), 1/(1+lam_min*(K-1))].
18lam_min = 0.05

Paper canonical floor (main.tex:354). Empirically robust across all four C-MAPSS subsets; the paper's hyperparameter robustness ablation (§5.8) shows results are insensitive to this within [0.02, 0.10].

EXECUTION STATE
lam_min = 0.05 = Float. ~5% minimum weight per task. Forbids any task from being weighted below 5% even if its gradient is very large relative to peers.
22ema_a = np.array([0.002, 0.998])

Realistic EMA after a long training run on FD002 with 500x gradient imbalance. The EMA has settled near the closed-form λ_rul = g_health / (g_rul + g_health) = 0.01 / (5.0 + 0.01) ≈ 0.002.

EXECUTION STATE
📚 np.array(list) = Build an ndarray from a Python list. dtype inferred (float64).
ema_a = ndarray (2,) = [0.002, 0.998]. RUL weight is 0.2%, health weight is 99.8%.
23out_a = floor_and_renorm(ema_a, lam_min)

Apply the floor + renorm. The first entry hits the floor; the second is unchanged.

EXECUTION STATE
intermediate clamp = [max(0.002, 0.05), max(0.998, 0.05)] = [0.05, 0.998]
intermediate sum = 0.05 + 0.998 = 1.048
out_a = [0.05/1.048, 0.998/1.048] = [0.04771, 0.95229]
→ reading = λ_rul lifted from 0.002 → 0.04771 (24x increase). λ_health pulled down from 0.998 → 0.95229 (small drop). The drop on health is needed because the simplex constraint sum=1 must be re-imposed.
24print A.raw

Pretty-print scenario A.

EXECUTION STATE
Output = A. raw = [0.002 0.998]
25print A.clamp

Show the intermediate clamped value.

EXECUTION STATE
Output = clamp = [0.05 0.998]
26print A.out

Final output and sanity-check the simplex constraint.

EXECUTION STATE
Output = out = [0.047710, 0.952290] sum = 1.000000
30ema_b = np.array([1e-8, 1.0 - 1e-8])

Pathological case: imagine the EMA has converged to a near-zero weight for RUL (e.g. after a very long training run with a freak large g_rul / tiny g_health draw). Without a floor, RUL would be effectively turned off.

EXECUTION STATE
ema_b = ndarray (2,) = [1e-8, ~1.0]. RUL is one part in 100 million.
→ why test this? = Without the floor, λ_rul · L_rul ≈ 1e-8 · 100 = 1e-6. The optimiser sees no RUL signal at all and the regression head stops learning. The floor prevents this.
31out_b = floor_and_renorm(ema_b, lam_min)

Apply floor + renorm. The pathological tiny weight is rescued to ~0.0476.

EXECUTION STATE
intermediate clamp = [max(1e-8, 0.05), max(1-1e-8, 0.05)] = [0.05, ~1.0]
intermediate sum = ≈ 1.05
out_b = [0.04762, 0.95238]
→ key observation = out_b is nearly identical to out_a (0.04771 vs 0.04762) even though raw inputs differ by 7 orders of magnitude. This is the ‘saturating’ behaviour of the floor: once raw_min < lam_min, the output flat-lines at lam_min/(1+lam_min) ≈ 0.04762.
32print B.raw

Pretty-print scenario B.

EXECUTION STATE
Output = (blank) B. raw = [1.00e-08, 1.000000]
33print B.clamp

Clamped intermediate.

EXECUTION STATE
Output = clamp = [0.05 0.99999999]
34print B.out

Output. Note the saturation effect.

EXECUTION STATE
Output = out = [0.047619, 0.952381]
38ema_c = np.array([0.5, 0.5])

Balanced case: gradient norms are equal, EMA is at uniform 0.5/0.5. The floor should be a no-op.

EXECUTION STATE
ema_c = ndarray (2,) = [0.5, 0.5]. Both weights well above lam_min = 0.05.
39out_c = floor_and_renorm(ema_c, lam_min)

Apply floor + renorm. Neither entry is below the floor, so clamping is a no-op; sum is already 1, so renormalisation is also a no-op.

EXECUTION STATE
intermediate clamp = [max(0.5, 0.05), max(0.5, 0.05)] = [0.5, 0.5]
intermediate sum = 1.0
out_c = [0.5, 0.5]
→ reading = Identity transformation. The floor is invisible when nothing is in danger of being suppressed. This is the ‘graceful degradation’ property: the stabiliser does nothing harmful when not needed.
40print C.raw

Pretty-print scenario C.

EXECUTION STATE
Output = (blank) C. raw = [0.5 0.5]
41print C.clamp

Clamped intermediate.

EXECUTION STATE
Output = clamp = [0.5 0.5]
42print C.out

Output equals input.

EXECUTION STATE
Output = out = [0.500000, 0.500000]
43print unchanged?

np.allclose checks element-wise near-equality with a default tolerance.

EXECUTION STATE
📚 np.allclose(a, b) = Returns True if |a − b| ≤ atol + rtol·|b| element-wise. Default atol=1e-8, rtol=1e-5.
Output = unchanged? True
47lower = lam_min / (1 + lam_min)

Closed-form lower bound for K=2. Worst case occurs when the OTHER weight is at its maximum 1.0; clamped+renormalised that gives lam_min / (1 + lam_min).

EXECUTION STATE
lower = 0.05 / 1.05 = 0.04762
→ derivation = When ema = [ε, 1−ε] with ε → 0: clamp = [lam_min, 1−ε ≈ 1]; sum ≈ 1 + lam_min; out_min = lam_min / (1 + lam_min).
48upper = 1.0 / (1 + lam_min)

Closed-form upper bound. Symmetric to the lower bound; achieved by the LARGE weight when the small weight is being clamped up to lam_min.

EXECUTION STATE
upper = 1.0 / 1.05 = 0.95238
49print lower bound

Print the analytic lower bound.

EXECUTION STATE
Output = (blank) K=2 bounds: lower = lam_min / (1 + lam_min) = 0.047619
50print upper bound

Print the analytic upper bound.

EXECUTION STATE
Output = upper = 1 / (1 + lam_min) = 0.952381
51print paper claim

The paper's [lam_min, 1-lam_min] is the loose statement; the tight bound for K=2 has the (1+lam_min) denominator. Both agree to 3 decimal places at lam_min = 0.05.

EXECUTION STATE
Final output =
A. raw   = [0.002 0.998]
   clamp = [0.05  0.998]
   out   = [0.047710, 0.952290]   sum = 1.000000

B. raw   = [1.00e-08, 1.000000]
   clamp = [0.05       0.99999999]
   out   = [0.047619, 0.952381]

C. raw   = [0.5 0.5]
   clamp = [0.5 0.5]
   out   = [0.500000, 0.500000]
   unchanged? True

K=2 bounds:  lower = lam_min / (1 + lam_min)   = 0.047619
             upper = 1     / (1 + lam_min)     = 0.952381
             paper claim: in [lam_min, 1-lam_min] = [0.05, 0.95]  (approx)
24 lines without explanation
1"""Floor + renormalisation: paper equation 6 from scratch."""
2
3import numpy as np
4
5
6def floor_and_renorm(ema: np.ndarray, lam_min: float) -> np.ndarray:
7    """Paper eq. 6: clamp at floor, then renormalise to the simplex.
8
9        lambda_i* = max(ema_i, lam_min) / sum_j max(ema_j, lam_min)
10
11    Guarantees lambda_i* in [lam_min/(1+lam_min*(K-1)), 1/(1+lam_min*(K-1))]
12    for K tasks; for K=2 this is approximately [lam_min, 1-lam_min].
13    """
14    clamped = np.maximum(ema, lam_min)
15    return clamped / clamped.sum()
16
17
18# ---------- Paper canonical floor ----------
19lam_min = 0.05
20
21
22# ---------- Scenario A: paper-realistic 500x imbalance ----------
23ema_a = np.array([0.002, 0.998])
24out_a = floor_and_renorm(ema_a, lam_min)
25print(f"A. raw   = {ema_a}")
26print(f"   clamp = {np.maximum(ema_a, lam_min)}")
27print(f"   out   = [{out_a[0]:.6f}, {out_a[1]:.6f}]   sum = {out_a.sum():.6f}")
28
29
30# ---------- Scenario B: extreme imbalance (g_health -> 0) ----------
31ema_b = np.array([1e-8, 1.0 - 1e-8])
32out_b = floor_and_renorm(ema_b, lam_min)
33print(f"\nB. raw   = [{ema_b[0]:.2e}, {ema_b[1]:.6f}]")
34print(f"   clamp = {np.maximum(ema_b, lam_min)}")
35print(f"   out   = [{out_b[0]:.6f}, {out_b[1]:.6f}]")
36
37
38# ---------- Scenario C: balanced (no clamping needed) ----------
39ema_c = np.array([0.5, 0.5])
40out_c = floor_and_renorm(ema_c, lam_min)
41print(f"\nC. raw   = {ema_c}")
42print(f"   clamp = {np.maximum(ema_c, lam_min)}")
43print(f"   out   = [{out_c[0]:.6f}, {out_c[1]:.6f}]")
44print(f"   unchanged? {np.allclose(ema_c, out_c)}")
45
46
47# ---------- Bounded-weight guarantee (K=2 closed form) ----------
48lower = lam_min / (1 + lam_min)
49upper = 1.0     / (1 + lam_min)
50print(f"\nK=2 bounds:  lower = lam_min / (1 + lam_min)   = {lower:.6f}")
51print(f"             upper = 1     / (1 + lam_min)     = {upper:.6f}")
52print(f"             paper claim: in [lam_min, 1-lam_min] = [0.05, 0.95]  (approx)")

PyTorch: clamp + renormalise (Paper Code)

The paper code in grace/core/gaba.py:134-135 is two lines: weights = ema_w.clamp(min=self.min_weight) followed by weights = weights / weights.sum(). We reproduce both verbatim and confirm that the boundary case λ^=[0,1]\hat{\lambda} = [0, 1] produces exactly the analytic bound [0.04762,0.95238][0.04762, 0.95238].

Paper code: clamp + renorm with autograd verification
🐍gaba_floor_torch.py
1docstring

Module docstring. The two-line implementation below is the EXACT paper code from grace/core/gaba.py:134-135 — the entire floor + renorm step.

3import torch

Core PyTorch.

EXECUTION STATE
📚 torch = Tensor library with autograd. We use .clamp, .sum, and tensor arithmetic.
6# Paper canonical hyperparameters

Section header marking lam_min as a hyperparameter.

7LAM_MIN = 0.05

Module-level constant. Paper main.tex:354 sets λ_min = 0.05.

EXECUTION STATE
LAM_MIN = 0.05. Convention: ALL_CAPS for module-level constants.
10def gaba_floor_and_renorm(ema_w, min_weight=LAM_MIN) → torch.Tensor

Two-line function. Mirrors the paper's code at grace/core/gaba.py:134-135 line for line.

EXECUTION STATE
⬇ input: ema_w = Tensor (K,). EMA-smoothed task weights from the §18.2 update.
⬇ input: min_weight = Float. The floor. Default = paper canonical 0.05.
⬆ returns = Tensor (K,). Floored + renormalised weights. On the simplex, bounded in [LAM_MIN/(1+LAM_MIN), 1/(1+LAM_MIN)] for K=2.
11docstring

Records the exact two lines copied from the paper code.

16weights = ema_w.clamp(min=min_weight)

Element-wise clamp at the floor. Paper code line 134.

EXECUTION STATE
📚 .clamp(min, max) = Tensor method. Element-wise clip: clamp(x, m) returns max(x, m) if only min is given. Equivalent to torch.clamp(tensor, min=m).
⬇ arg: min=min_weight = Lower bound. Any element below this gets replaced by this value. (No max= argument ⇒ unbounded above.)
→ vs np.maximum = PyTorch's clamp(min=v) is the tensor-method equivalent of np.maximum(arr, v). Same semantics, different syntax.
→ autograd note = clamp is differentiable everywhere except at the threshold. Gradient is 1 above the floor, 0 at/below. PyTorch implements this as a sub-gradient.
17weights = weights / weights.sum()

Renormalise to the simplex. Paper code line 135.

EXECUTION STATE
📚 .sum() = Tensor reduction. Sums all elements to a 0-dim tensor.
/ = Element-wise division (broadcasts a 0-dim tensor against a (K,) tensor).
18return weights

Return the bounded simplex-valued tensor.

EXECUTION STATE
⬆ return = Tensor (K,). Sums to 1, bounded.
22ema_a = torch.tensor([0.002, 0.998])

Paper-realistic 500x imbalance after EMA convergence.

EXECUTION STATE
📚 torch.tensor(list) = Build a tensor from a Python list. Default dtype = float32.
ema_a = Tensor (2,) = [0.002, 0.998].
23ema_b = torch.tensor([1e-8, 1.0 - 1e-8])

Pathological corner case.

EXECUTION STATE
ema_b = Tensor (2,) = [1e-8, ~1.0]. RUL effectively zero.
24ema_c = torch.tensor([0.5, 0.5])

Balanced case where the floor is a no-op.

26print A

Apply the function and print.

EXECUTION STATE
Output = A. raw = [0.002, 0.998] -> out = [0.04770992398262024, 0.9522900581359863]
27print B

B is the saturation case.

EXECUTION STATE
Output = B. raw = [1e-08, 0.9999999...] -> out = [0.047619... , 0.952381...]
28print C

C is the no-op case.

EXECUTION STATE
Output = C. raw = [0.5, 0.5] -> out = [0.5, 0.5]
32ema_extreme = torch.tensor([0.0, 1.0])

Boundary: zero on one side, one on the other. The maximally suppressed input.

EXECUTION STATE
ema_extreme = Tensor (2,) = [0.0, 1.0]. RUL fully suppressed before the floor.
33out_extreme = gaba_floor_and_renorm(ema_extreme)

Apply the floor. Output should be exactly [LAM_MIN/(1+LAM_MIN), 1/(1+LAM_MIN)].

EXECUTION STATE
out_extreme = Tensor (2,) = [0.04762, 0.95238]. Hits the analytic bound exactly.
34print extreme

Pretty-print the actual output.

EXECUTION STATE
Output = (blank) at ema=[0, 1]: out = [0.04761904925107956, 0.9523810148239136]
35print theoretical bound

Compare the empirical to the analytic bound. They agree to float32 precision.

EXECUTION STATE
Output = theoretical bound: [0.047619, 0.952381]
39ema_grad = torch.tensor([0.002, 0.998], requires_grad=True)

Set up an autograd test. Track gradients on the input so we can backprop through the clamp.

EXECUTION STATE
requires_grad=True = Track gradients. Required so backward() through any function of ema_grad can write to ema_grad.grad.
40out_grad = gaba_floor_and_renorm(ema_grad)

Apply the function. The output is differentiable in the input (sub-gradient at the clamp threshold).

EXECUTION STATE
out_grad = Tensor (2,) with requires_grad=True (because ema_grad does).
41out_grad.sum().backward()

Differentiate sum(out) with respect to ema_grad. Sum is identically 1 by construction (renormalisation), so the gradient should be exactly zero — a useful sanity check.

EXECUTION STATE
📚 .sum().backward() = Compose: scalar = sum, then backprop. Writes gradients to all leaf tensors.
→ why test grad of sum? = sum(weights) ≡ 1 after renormalisation, so d(sum)/d(any input) ≡ 0. If the gradient is non-zero, the renormalisation is broken (e.g. division by something that isn't the sum).
42print sum-grad header

Format the autograd readout.

EXECUTION STATE
Output = (blank) grad through clamp at ema=[0.002, 0.998]:
43print grad value

Read off the grad attribute.

EXECUTION STATE
Output = d(sum out) / d ema = [0.0, 0.0]
44print why this should be zero

Final commentary.

EXECUTION STATE
Final output =
A. raw = [0.002, 0.998] -> out = [0.04770992398262024, 0.9522900581359863]
B. raw = [1e-08, 0.9999999...] -> out = [0.047619... , 0.952381...]
C. raw = [0.5, 0.5] -> out = [0.5, 0.5]

at ema=[0, 1]: out = [0.04761904925107956, 0.9523810148239136]
theoretical bound: [0.047619, 0.952381]

grad through clamp at ema=[0.002, 0.998]:
  d(sum out) / d ema = [0.0, 0.0]
  (sum out is identically 1, so grad should be 0)
19 lines without explanation
1"""Paper code: ema_w.clamp(min=lam_min) / sum (grace/core/gaba.py:134)."""
2
3import torch
4
5
6# ---------- Paper canonical hyperparameters ----------
7LAM_MIN = 0.05
8
9
10def gaba_floor_and_renorm(ema_w: torch.Tensor, min_weight: float = LAM_MIN) -> torch.Tensor:
11    """Floor + renormalise (paper grace/core/gaba.py lines 134-135).
12
13    weights = ema_w.clamp(min=self.min_weight)
14    weights = weights / weights.sum()
15    """
16    weights = ema_w.clamp(min=min_weight)
17    weights = weights / weights.sum()
18    return weights
19
20
21# ---------- Three test cases ----------
22ema_a = torch.tensor([0.002, 0.998])
23ema_b = torch.tensor([1e-8, 1.0 - 1e-8])
24ema_c = torch.tensor([0.5, 0.5])
25
26print(f"A. raw = {ema_a.tolist()} -> out = {gaba_floor_and_renorm(ema_a).tolist()}")
27print(f"B. raw = {ema_b.tolist()} -> out = {gaba_floor_and_renorm(ema_b).tolist()}")
28print(f"C. raw = {ema_c.tolist()} -> out = {gaba_floor_and_renorm(ema_c).tolist()}")
29
30
31# ---------- Confirm the bound at the saturated regime ----------
32ema_extreme = torch.tensor([0.0, 1.0])
33out_extreme = gaba_floor_and_renorm(ema_extreme)
34print(f"\nat ema=[0, 1]: out = {out_extreme.tolist()}")
35print(f"theoretical bound: [{LAM_MIN / (1 + LAM_MIN):.6f}, {1.0 / (1 + LAM_MIN):.6f}]")
36
37
38# ---------- Differentiability of the clamp (forward only) ----------
39ema_grad = torch.tensor([0.002, 0.998], requires_grad=True)
40out_grad = gaba_floor_and_renorm(ema_grad)
41out_grad.sum().backward()
42print(f"\ngrad through clamp at ema=[0.002, 0.998]:")
43print(f"  d(sum out) / d ema = {ema_grad.grad.tolist()}")
44print(f"  (sum out is identically 1, so grad should be 0)")

Floors In Other Fields

FieldFloor mechanismWhy it's needed
Predictive maintenance (this paper)λ_min = 0.05 on per-task weightPrevent task suppression on adversarial gradient streaks
Optimisation: gradient clippingmax_norm = 1.0 on gradient L2 normPrevent exploding gradients in RNNs / Transformers
RL: epsilon-greedy explorationε ∈ [0.01, 0.1] floor on exploration probabilityPrevent the agent from getting stuck in a sub-optimal policy
NLP: label smoothing1 − ε on the true class, ε/(K−1) on others (ε ≈ 0.1)Prevent the model from over-confidently predicting one class
Recommender systemsMinimum exposure (e.g. 0.5%) per itemPrevent rich-get-richer feedback loops
Control theory: PID anti-windupSaturation limits on integrator statePrevent integrator wind-up during actuator saturation
Audio: noise gate thresholdBelow-threshold signals attenuated, above pass throughSuppress hum during silence, pass through during speech
Finance: portfolio diversificationMin weight per asset (e.g. 1%)Prevent the optimiser from putting everything into a few assets

The same recipe — floor + renormalise — appears in many fields under many names. It is the canonical way to add stability to an unconstrained adaptive controller.

Pitfalls In The Floor + Renormalise Step

Pitfall 1: Floor too aggressive (e.g. λmin=0.4\lambda_{\min} = 0.4 on K=2). The bound becomes [0.286,0.714][0.286, 0.714] — half the dynamic range of GABA. The closed form becomes nearly redundant: most of the time the floor is engaged on at least one task and the output is dragged toward uniform 0.5/0.5. You've unintentionally re-implemented Fixed Baseline.
Pitfall 2: Floor too small (e.g. λmin=106\lambda_{\min} = 10^{-6}). The bound becomes [106,1][\sim 10^{-6}, \sim 1]. Anti- windup is essentially disabled; the optimiser can spend 500 consecutive steps with λrul106\lambda_{\text{rul}} \sim 10^{-6} and the regression head stops learning. Recovery rate from this regime is (1β)102(1-\beta) \sim 10^{-2} per step which means hundreds of steps to recover.
Pitfall 3: Forgetting the renormalisation. If you clamp without dividing, the output sums to >1> 1. The combined loss scale increases, the effective learning rate inflates, and Adam's second-moment estimator gets a moving denominator. Convergence behaviour shifts in subtle ways that look like a bug elsewhere. The paper's code always pairs .clamp with / .sum() on the next line precisely to prevent this.
Pitfall 4: Applying the floor BEFORE the EMA. Floors must operate on the SMOOTHED weights, not the raw closed-form weights. If you floor the raw λi\lambda^*_i and then EMA the floored value, the floor applies on every step and creates persistent low-frequency bias in the EMA. Paper-canonical order: closed form → EMA → floor → renormalise.
Why the paper explicitly says ‘bounded at every step’. Compare with GradNorm (§17.4): its learnable wiw_i is bounded only at convergence of the auxiliary loss (which it may or may not reach). GABA's bound is per-step, deterministic, and provable from the formula. On the paper's 1/5 N-CMAPSS divergence of GradNorm (seed 789), GABA on the same seeds and same data ran to completion — entirely because of this single algebraic guarantee.

Takeaway

  • Paper Eq. 6 is two lines. weights = ema_w.clamp(min=lam_min); weights = weights / weights.sum(). That is the entire stabiliser.
  • The floor is anti-windup. Without it, a streak of bad batches can drive a task weight to ~0 and the EMA takes hundreds of steps to recover. With it, every task always carries at least λmin/(1+λmin)\lambda_{\min}/(1 + \lambda_{\min}) weight.
  • Bounded-weight guarantee. For K=2: λ[0.04762,0.95238]\lambda^* \in [0.04762, 0.95238] per step, deterministic. This is the property GradNorm cannot match.
  • Paper picks λmin=0.05\lambda_{\min} = 0.05. Robust across [0.02,0.10][0.02, 0.10] per the §5.8 ablation. Small enough not to distort the closed form; large enough to provide real anti-windup.
  • Graceful no-op when not needed. If all λ^iλmin\hat{\lambda}_i \geq \lambda_{\min}, floor is identity and renormalisation divides by 1 — the entire stabiliser becomes invisible.
  • Same recipe everywhere. Gradient clipping, ε-greedy floors, label smoothing, PID anti-windup, audio gating, portfolio diversification: all instances of ‘floor + renormalise to add stability to an adaptive controller’.
Loading comments...