Chapter 17
15 min read
Section 69 of 121

Why Inverse-Proportional Weights Work

Inverse-Gradient Balancing: The Idea

The Dosing Problem

A clinician is treating a patient with two simultaneous drugs. Drug A is 500× more potent than Drug B on a per-milligram basis. If she gives 50 mg of each, Drug A overwhelms the body and Drug B is pharmacologically invisible. The fair prescription is not equal mass — it is equal physiological effect. To get there she scales each dose inversely by potency: very little Drug A, much more Drug B.

Multi-task learning has the same shape. The two losses are the drugs; the shared backbone is the patient. Their ‘potency’ on the backbone is the gradient norm gi\|g_i\|. The effective dose — what the parameters actually feel — is λigi\lambda_i \cdot \|g_i\|. Equal loss weights guarantee that the high-potency task dominates every step. Equal effects require inverse weighting. That is the entire principle behind GABA.

The principle in one line. Balance λigi\lambda_i \cdot \|g_i\|, not λi\lambda_i. Balancing weights is cosmetic; balancing effective contributions is what changes what the backbone learns.

The Effective-Contribution Principle

Recall the gradient-descent update on the shared backbone parameters θ\theta for a multi-task loss L=iλiLi\mathcal{L} = \sum_i \lambda_i \, \mathcal{L}_i:

θθηiλiθLi\theta \leftarrow \theta - \eta \sum_i \lambda_i \, \nabla_\theta \mathcal{L}_i

Each task contributes a vector λiθLi\lambda_i \, \nabla_\theta \mathcal{L}_i to the update. Its magnitude is λigi\lambda_i \, \|g_i\| where gi=θLig_i = \nabla_\theta \mathcal{L}_i. We name this scalar the effective contribution of task ii:

ci    λigic_i \;\equiv\; \lambda_i \cdot \|g_i\|

It is the unit-free measurement of how hard task ii is pulling on the backbone. Every weighting scheme defines a vector (c1,,cK)(c_1, \ldots, c_K); only one scheme makes the entries equal.

QuantitySymbolWhat it measures
Per-task lossL_iHow wrong task i currently is
Per-task gradient norm||g_i||How hard task i would pull if given full weight
Per-task weightlambda_iHow much we let it pull (scaling knob)
Effective contributionc_i = lambda_i ||g_i||How hard task i ACTUALLY pulls in this step — the only thing the optimiser sees

Why Inverse-Proportional Is The Unique Solution

For two tasks, demand crul=chealthc_{\text{rul}} = c_{\text{health}} with λrul+λhealth=1\lambda_{\text{rul}} + \lambda_{\text{health}} = 1. That is two equations in two unknowns. Substitute λhealth=1λrul\lambda_{\text{health}} = 1 - \lambda_{\text{rul}} into the balance condition:

λrulgrul=(1λrul)ghealth\lambda_{\text{rul}} \cdot g_{\text{rul}} = (1 - \lambda_{\text{rul}}) \cdot g_{\text{health}}

Solve for λrul\lambda_{\text{rul}}:

λrul=ghealthgrul+ghealth\lambda_{\text{rul}} = \frac{g_{\text{health}}}{g_{\text{rul}} + g_{\text{health}}}

That is exactly the K=2 GABA formula. It is the unique solution: any other weighting violates either the simplex constraint or the equal-contribution constraint. The next section (§17.3) re-derives the same expression from a Lagrangian and shows the closed form survives small perturbations.

The intuition is symmetric. “Each task contributes equally” means λi=λj\lambda_i = \lambda_j — balance of permission. “Each task contributes equally to what is learned” means λigi=λjgj\lambda_i \|g_i\| = \lambda_j \|g_j\| — balance of effect. They differ by the gradient ratio. With a 500× ratio they differ by 500×.

Plugging realistic FD002 numbers (paper §12.3\S 12.3: grul=5.0g_{\text{rul}} = 5.0, ghealth=0.01g_{\text{health}} = 0.01):

Schemeλ_rulλ_healthc_rul = λ_rul·g_rulc_health = λ_health·g_healthImbalance c_max / c_min
Uniform (Fixed Baseline, paper §3.5)0.50000.50002.5000000.005000500.00x
Sqrt-inverse (softer rule)0.04280.95720.2140350.00957222.36x
GABA (inverse-proportional)0.00200.99800.0099800.0099801.00x

Uniform preserves the 500× gradient imbalance in the contributions — doing nothing useful. Sqrt-inverse leaves a 22× residual. Only the exact inverse rule collapses the imbalance to 1.00×.

Interactive: Where The Curves Cross

Plot two functions of λrul\lambda_{\text{rul}}: the increasing line λgrul\lambda \cdot g_{\text{rul}} (in blue) and the decreasing line (1λ)ghealth(1-\lambda) \cdot g_{\text{health}} (in green). They cross at exactly one point, which is the GABA λ\lambda^*. Drag the red λ slider; only at the crossing point are the two contribution bars equal.

Loading equal-contribution visualizer…
Try this. Set grul=ghealthg_{\text{rul}} = g_{\text{health}} (move both sliders to 1.0). The two curves coincide and any λ\lambda balances them — GABA degenerates to vanilla 0.5/0.5 when nothing needs fixing. Now set the ratio to 500× (the paper default). The crossing point snaps far to the left, and only the bottom GABA card has balanced ✓.

Python: Three Schemes, Side By Side

Implement uniform, sqrt-inverse, and GABA in pure NumPy and print their effective contributions on the realistic 500× example. The point of the table is not to show the λ values — it is to show that only GABA equalises cic_i.

Three weighting schemes vs. effective contributions
🐍effective_contributions.py
1import numpy as np

NumPy supplies ndarray (a fast C-backed N-dimensional array) and vectorised math. Every weighting scheme below is one or two array operations — no Python loops over tasks. Vectorisation matters here because in production you weight K tasks across millions of mini-batches; per-element Python loops would be ~100× slower.

EXECUTION STATE
📚 numpy = Numerical computing library. Provides ndarray, broadcasting (mixing scalar↔vector↔matrix shapes automatically), and universal functions (np.sqrt, np.full_like, np.array). Math runs as compiled C, not interpreted Python. Example: np.array([1,2,3]) * 2 → array([2,4,6]) — no loop.
as np = Standard community alias. Lets us write np.array() instead of numpy.array(). Universal in scientific Python.
4def gaba_weights(grad_norms) → np.ndarray

The canonical K-task GABA formula (paper eq. 4). Takes per-task gradient norms and returns lambdas where the task with the SMALLER gradient receives the LARGER weight. The lambdas always sum to exactly 1.0 (probability simplex).

EXECUTION STATE
⬇ input: grad_norms (K,) — per-task L2 gradient norms = [5.00, 0.01]
→ grad_norms purpose = 1-D ndarray. Each entry is ‖∇_θ ℒ_i‖₂ — how hard task i would pull on the shared backbone if given full weight. Computed in §11 via torch.autograd.grad on the backbone parameters.
→ why we need this = GABA must SEE the imbalance before it can correct it. The function reads gradient magnitudes, then inverts them so the small task is amplified and the loud task attenuated.
⬆ returns: ndarray (K,) = Probability vector that sums to 1.0. Element i is λ_i ∝ (Σ_j g_j − g_i). Examples below show λ for [5.0, 0.01] and other inputs.
→ mini-example for K=3 = grad_norms = [1.0, 2.0, 7.0]; total = 10.0 λ = [(10−1)/(2·10), (10−2)/20, (10−7)/20] = [0.45, 0.40, 0.15] Note: the smallest gradient (1.0) gets the largest weight (0.45). ✓
5docstring: "Inverse-proportional task weights (paper eq. 4)."

Anchors the function to the canonical paper formula (eq. 4 in the main.tex preprint). Future readers can grep the docstring back to the math source. Docstrings are accessible at runtime via help(gaba_weights) — they document intent, not execution.

6K = grad_norms.shape[0]

Read the number of tasks from the input array. We use this to scale the denominator below.

EXECUTION STATE
📚 .shape = ndarray attribute (NOT a method — no parentheses). Returns a Python tuple of dimension sizes. For a 1-D array of length 2: (2,). For a 5×4 matrix: (5, 4). Example: np.zeros((3,4)).shape → (3, 4).
→ .shape[0] = First element of the shape tuple = size of axis 0. For grad_norms shape (2,): .shape[0] = 2.
⬇ grad_norms.shape = (2,) — a 1-element tuple containing 2
⬆ K = 2 (Python int). The number of tasks (RUL + health).
7total = grad_norms.sum() + 1e-12

Sum all per-task gradient norms into a scalar, plus a tiny floor for numerical safety. This becomes the denominator of the GABA formula.

EXECUTION STATE
📚 .sum() = ndarray reduction method. Without arguments, sums ALL elements into a single scalar regardless of shape. Example: np.array([[1,2],[3,4]]).sum() → 10. With axis=0 it sums along rows; with axis=-1 along the last axis.
⬇ grad_norms = [5.00, 0.01]
→ .sum() result = 5.00 + 0.01 = 5.0100
+ 1e-12 = Numerical safety floor (≈ 0.000000000001). Prevents divide-by-zero in the next line if EVERY task has zero gradient (rare — would only happen if both losses are at a perfect minimum on the current batch). 1e-12 is far below any real gradient magnitude, so it does not bias real computations.
⬆ total = 5.010000000001 ≈ 5.0100
8return (total - grad_norms) / ((K - 1) * total)

The K-task GABA formula. (total − grad_norms) computes the SUM-OF-OTHERS for each task (the total minus its own contribution). Dividing by (K−1)·total normalises so the resulting weights sum to exactly 1.0 (probability simplex).

EXECUTION STATE
(total − grad_norms) — sum of OTHER tasks' gradients = scalar − vector → broadcasting kicks in. NumPy expands the scalar to match the vector shape, then subtracts element-wise. For total=5.01, grad_norms=[5.00, 0.01]: [5.01−5.00, 5.01−0.01] = [0.01, 5.00].
→ result = [0.0100, 5.0000]
→ key insight = Notice the swap: the SMALL gradient (health=0.01) now has the LARGE numerator (5.00), and vice versa. That's the inversion that gives GABA its name.
(K − 1) * total — normaliser = (2 − 1) · 5.01 = 5.01. Hand-picked so that summing the K weights gives Σ(total − g_i) / ((K−1)·total) = (K·total − total) / ((K−1)·total) = (K−1)/(K−1) = 1. ✓
→ division = Element-wise vector ÷ scalar (broadcasting again). [0.01, 5.00] / 5.01 = [0.001996, 0.998004].
⬆ return: lambdas (2,) = [0.0020, 0.9980]
→ verify sum = 0.0020 + 0.9980 = 1.0000 ✓ (probability simplex)
11def uniform_weights(grad_norms) → np.ndarray

The naive 0.5/0.5 baseline. Ignores gradient magnitudes entirely — every task gets the same weight. This is what the paper calls the Fixed Baseline (§3.5) and what most multi-task tutorials default to.

EXECUTION STATE
⬇ input: grad_norms (K,) = [5.00, 0.01]
→ input purpose = Same shape as gaba_weights for API uniformity, BUT the values are never read — only .shape is used. We could pass np.array([0, 0]) and get the same answer.
⬆ returns: ndarray (K,) = [0.5, 0.5] — uniform 1/K weights for K=2.
→ why included = It's the control group. Showing GABA's c_i values next to uniform's makes the 500× → 1× collapse undeniable.
12docstring: "Standard equal weighting - ignores gradient magnitudes."

Records EXACTLY why this is the wrong scheme: it ignores the very thing GABA exists to correct (the gradient imbalance). Documenting failure modes in the docstring keeps future maintainers from 'simplifying' the code by removing GABA.

13return np.full_like(grad_norms, 1.0 / len(grad_norms))

Build a same-shape, same-dtype array filled with the constant 1/K. One line replaces ~3 lines of np.full(...) with explicit shape/dtype.

EXECUTION STATE
📚 np.full_like(a, fill_value) = Build a new ndarray with the SAME shape and dtype as the template array `a`, filled with fill_value. Saves writing np.full(grad_norms.shape, fill_value, dtype=grad_norms.dtype). Example: np.full_like(np.array([1.0, 2.0, 3.0]), 0.5) → array([0.5, 0.5, 0.5]).
⬇ arg 1: grad_norms (template) = Shape (2,), dtype float64. Used ONLY as a shape/dtype template — values are not read.
⬇ arg 2: 1.0 / len(grad_norms) — fill value = 1.0 / 2 = 0.5. (len() returns the size of the first axis for a 1-D array.)
→ why 1.0 / len(...) not 1/2 hardcoded? = Generic: this function works for K=2, K=3, K=10 unchanged. K=3 → fill = 1/3 = 0.3333.
⬆ return = [0.5, 0.5] — both weights equal 1/K.
16def sqrt_inverse_weights(grad_norms) → np.ndarray

A 'softer' inverse rule used in some imbalanced-classification literature: λ_i ∝ 1/√g_i instead of GABA's λ_i ∝ 1/g_i. We include it to show that proximity to the inverse rule is not enough — only the EXACT linear inverse equalises c_i. Anything weaker leaves residual imbalance.

EXECUTION STATE
⬇ input: grad_norms (K,) = [5.00, 0.01]
⬆ returns: ndarray (K,) = [0.0428, 0.9572] for the [5.0, 0.01] case. Sum = 1.0.
→ why include this baseline? = Without it, the comparison would be 'naive vs paper' — too easy a contrast. Sqrt-inverse is in the right family but the wrong exponent. Showing it leaves 22× residual proves the LINEAR inverse is doing real work, not just any monotone-decreasing function.
17docstring: "Square-root inverse - a 'softer' inverse rule."

Documents that this is intentionally weaker than GABA. Important context — without this note a future maintainer might assume the sqrt was a bug and 'fix' it.

18inv = 1.0 / np.sqrt(grad_norms + 1e-12)

Compute the per-task softened inverse weights, before normalisation. Square-root first (smooths the gradient ratio), then reciprocal (inverts).

EXECUTION STATE
📚 np.sqrt(x) = Element-wise square root universal function (ufunc). Vectorised C loop under the hood. Example: np.sqrt([4.0, 9.0, 16.0]) → array([2., 3., 4.]).
⬇ arg: grad_norms + 1e-12 = [5.0 + 1e-12, 0.01 + 1e-12] ≈ [5.0000, 0.0100]
→ why + 1e-12? = Numerical guard. If a gradient ever lands at exactly 0.0, sqrt(0)=0, then 1/0 = inf. The 1e-12 floor keeps the reciprocal finite without measurably affecting real values.
→ np.sqrt result = [√5.00, √0.01] = [2.2361, 0.1000]
1.0 / (element-wise reciprocal) = Broadcasting scalar over vector. Returns 1/each element.
→ reciprocal = [1/2.2361, 1/0.1000] = [0.4472, 10.0000]
⬆ inv = [0.4472, 10.0000]
→ ratio so far = 10.0 / 0.4472 = 22.36 — already softer than the 500× original.
19return inv / inv.sum()

Normalise so the weights live on the probability simplex (sum to 1.0). Without this step the weights would not be comparable across schemes.

EXECUTION STATE
📚 .sum() = Reduces a vector to a scalar by summing all elements. Same as np.sum(arr).
inv.sum() = 0.4472 + 10.0000 = 10.4472
inv / inv.sum() (vector ÷ scalar) = Broadcasting again. [0.4472/10.4472, 10.0000/10.4472].
⬆ return = [0.0428, 0.9572]
→ vs GABA on same input = GABA → [0.0020, 0.9980]; sqrt-inv → [0.0428, 0.9572]. Sqrt-inv is closer to uniform: λ_rul is 21× larger than GABA's. The 'softening' costs us — the residual c imbalance is 22× instead of 1×.
22# ---------- Realistic FD002 numbers from section 12.3 ----------

A header comment anchoring the demo to the paper's measured 500× imbalance. These are not contrived numbers — they are the actual mean per-batch gradient norms recorded during the FD002 training runs. Source: paper_ieee_tii/data_analysis/.

23g = np.array([5.0, 0.01])

Pack the two realistic gradient norms (RUL regression head, health classification head) into a length-2 ndarray.

EXECUTION STATE
📚 np.array(object) = Constructor — converts a Python list/tuple/iterable into an ndarray. Default dtype is inferred (float64 for floats here). Example: np.array([1, 2, 3]) → ndarray shape (3,) dtype int64.
⬇ arg: [5.0, 0.01] = Python list of two floats. NumPy infers dtype=float64.
→ 5.0 = g_rul = Mean L2 gradient norm of the RUL regression loss on the shared backbone. Large because MSE on RUL targets in [0, 125] produces big residuals early in training.
→ 0.01 = g_health = Mean L2 gradient norm of the health-classification loss on the shared backbone. Small because cross-entropy on a 3-class problem saturates near ln(3) ≈ 1.10 quickly, leaving small residual gradients.
⬆ g (2,) = [5.00, 0.01]
→ ratio = g_rul / g_health = 5.0 / 0.01 = 500.0 ← the imbalance GABA must correct
25schemes = [('uniform', uniform_weights), ('sqrt_inv', sqrt_inverse_weights), ('gaba', gaba_weights)]

List of (display_name, weighting_function) tuples. The for-loop below iterates over them, applying each function to the same g and printing a comparison row. Keeps the comparison code DRY: change the schemes list to add or remove a scheme; the loop adapts automatically.

EXECUTION STATE
schemes = Python list of 3 tuples. len(schemes) = 3.
→ element 0 = ('uniform', <function uniform_weights>) — string + function reference (not function call).
→ element 1 = ('sqrt_inv', <function sqrt_inverse_weights>)
→ element 2 = ('gaba', <function gaba_weights>)
→ why functions, not results? = Storing references lets us defer the call to inside the loop. We call fn(g) per iteration, so each scheme processes the SAME g.
26print(f"{'scheme':<10} | lambda_rul lambda_health | c_rul c_health | imbalance")

Print the table header. The f-string format spec :<10 means left-aligned in a 10-character-wide field — keeps columns visually lined up.

EXECUTION STATE
📚 f-string format spec :<10 = Inside an f-string, {value:<10} pads `value` to width 10 with left alignment. :>10 = right-aligned. :^10 = centred. :>10.4f = right-aligned, width 10, 4 decimals (used in the data rows below).
Output = scheme | lambda_rul lambda_health | c_rul c_health | imbalance
27print("-" * 78)

Visual separator. Python's str * int operator repeats a string n times.

EXECUTION STATE
📚 str * int = Python string repetition. "-" * 78 builds a 78-character dash string. Example: "ab" * 3 = "ababab".
Output = ------------------------------------------------------------------------------
28for name, fn in schemes:

Iterate the three schemes. Tuple-unpacking destructures each (name, fn) tuple into two named variables per iteration.

LOOP TRACE · 3 iterations
iter 1: name='uniform', fn=uniform_weights
lam = uniform_weights(g) = [0.5000, 0.5000]
c = lam * g = [0.5·5.0, 0.5·0.01] = [2.500000, 0.005000]
ratio = c.max() / c.min() = 2.500 / 0.005 = 500.00x — IDENTICAL to the original gradient ratio. Uniform did literally nothing to balance contributions; the imbalance passed through unchanged.
iter 2: name='sqrt_inv', fn=sqrt_inverse_weights
lam = sqrt_inverse_weights(g) = [0.0428, 0.9572]
c = lam * g = [0.0428·5.0, 0.9572·0.01] = [0.214035, 0.009572]
ratio = 0.214035 / 0.009572 = 22.36x — much better than uniform but the loud task still dominates by a factor of 22.
iter 3: name='gaba', fn=gaba_weights
lam = gaba_weights(g) = [0.0020, 0.9980]
c = lam * g = [0.0020·5.0, 0.9980·0.01] = [0.009980, 0.009980]
ratio = 0.009980 / 0.009980 = 1.00x — EXACTLY balanced. This is why inverse-proportional works: c_i collapses to a single scalar regardless of the input ratio.
29lam = fn(g)

Apply the current scheme's weighting function to the shared gradient-norm vector. Returns the lambda vector for that scheme.

EXECUTION STATE
fn(g) — dynamic dispatch = fn is the function reference unpacked from the tuple. Calling fn(g) invokes whichever weighting strategy this iteration is testing.
⬆ lam (2,) = Per-task weights summing to 1.0. Specific values depend on which scheme is current — see the iterations card on line 28.
30c = lam * g

Element-wise multiply: c_i = λ_i · g_i. THIS is the effective contribution — the actual magnitude of task i's pull on the backbone after weighting. The whole chapter is about this one line.

EXECUTION STATE
* (element-wise) = NumPy element-wise multiplication between two same-shape arrays. For shape (K,) × (K,) → (K,). Loop-free under the hood.
⬇ lam = [λ_rul, λ_health] — depends on scheme.
⬇ g = [5.00, 0.01]
⬆ c (2,) = [c_rul, c_health] = [λ_rul·5.0, λ_health·0.01]. The vector GABA equalises.
→ why this matters = If c[0] == c[1], the two tasks pull on the backbone with equal effective force. That is the goal — and only GABA achieves it.
31ratio = c.max() / (c.min() + 1e-12)

Imbalance metric. 1.0× means perfectly balanced; large numbers mean the bigger contribution dominates the smaller. This is the figure of merit for the comparison table.

EXECUTION STATE
📚 .max() = ndarray reduction — returns the maximum element as a scalar. Example: np.array([3, 1, 4]).max() → 4.
📚 .min() = ndarray reduction — returns the minimum element as a scalar. Example: np.array([3, 1, 4]).min() → 1.
+ 1e-12 = Numerical guard so .min() = 0 doesn't divide-by-zero. Same trick as line 7.
⬆ ratio = Scalar. The figure of merit: how many times bigger the loud contribution is than the quiet one. 500× = uniform; 22× = sqrt-inverse; 1× = GABA.
32print(f"{name:<10} | {lam[0]:>10.4f} {lam[1]:>13.4f} | {c[0]:>9.6f} {c[1]:>10.6f} | {ratio:>8.2f}x")

Print one comparison row per scheme. Each format spec aligns its column. :>10.4f = right-aligned, width 10, 4 decimal places.

EXECUTION STATE
Format specs used = :<10 = name left-aligned 10 wide; :>10.4f = lam right-aligned 10 wide 4dp; :>9.6f = c right-aligned 9 wide 6dp; :>8.2f = ratio right-aligned 8 wide 2dp.
Final printed table (all 3 iterations) =
scheme     | lambda_rul lambda_health |    c_rul   c_health | imbalance
------------------------------------------------------------------------------
uniform    |     0.5000        0.5000 |  2.500000   0.005000 |   500.00x
sqrt_inv   |     0.0428        0.9572 |  0.214035   0.009572 |    22.36x
gaba       |     0.0020        0.9980 |  0.009980   0.009980 |     1.00x
→ reading the table = Look at the c_rul and c_health columns. Uniform: 500× apart. Sqrt-inv: 22× apart. GABA: bit-identical. Only GABA achieves the equal-contribution principle this section derives.
9 lines without explanation
1import numpy as np
2
3
4def gaba_weights(grad_norms: np.ndarray) -> np.ndarray:
5    """Inverse-proportional task weights (paper eq. 4)."""
6    K = grad_norms.shape[0]
7    total = grad_norms.sum() + 1e-12
8    return (total - grad_norms) / ((K - 1) * total)
9
10
11def uniform_weights(grad_norms: np.ndarray) -> np.ndarray:
12    """Standard equal weighting - ignores gradient magnitudes."""
13    return np.full_like(grad_norms, 1.0 / len(grad_norms))
14
15
16def sqrt_inverse_weights(grad_norms: np.ndarray) -> np.ndarray:
17    """Square-root inverse - a 'softer' inverse rule."""
18    inv = 1.0 / np.sqrt(grad_norms + 1e-12)
19    return inv / inv.sum()
20
21
22# ---------- Realistic FD002 numbers from section 12.3 ----------
23g = np.array([5.0, 0.01])
24
25schemes = [("uniform", uniform_weights), ("sqrt_inv", sqrt_inverse_weights), ("gaba", gaba_weights)]
26print(f"{'scheme':<10} | lambda_rul lambda_health |    c_rul   c_health | imbalance")
27print("-" * 78)
28for name, fn in schemes:
29    lam = fn(g)
30    c = lam * g
31    ratio = c.max() / (c.min() + 1e-12)
32    print(f"{name:<10} | {lam[0]:>10.4f} {lam[1]:>13.4f} | {c[0]:>9.6f} {c[1]:>10.6f} | {ratio:>8.2f}x")

PyTorch: Verification On Real Autograd Gradients

The hand-picked NumPy demo is an existence proof. The empirical demo is a real autograd run on a tiny shared backbone. We compute grul\|g_{\text{rul}}\| and ghealth\|g_{\text{health}}\| with torch.autograd.grad, apply the K=2 GABA closed form, and assert crul=chealthc_{\text{rul}} = c_{\text{health}} to 10610^{-6}. The assertion passes on every seed, every batch, every backbone — because the equality is algebraic, not approximate.

GABA equalises effective contributions — proven by autograd
🐍verify_equal_contributions.py
1import torch

Core PyTorch module. Provides the Tensor class (an N-dim numerical array on CPU/GPU) and the autograd engine that records every operation so we can differentiate. We need autograd here because the whole point of this script is to verify the equal-contribution identity on REAL backbone gradients computed by the framework, not made-up scalars.

EXECUTION STATE
📚 torch = Tensor library + autograd. Provides torch.randn (sample N(0,1)), torch.rand (sample U[0,1)), torch.randint (uniform ints), torch.autograd.grad (functional differentiation), torch.allclose (numerical equality check).
→ why a framework? = Computing gradients of a 480-param backbone by hand would take pages of chain-rule. Autograd does it in one call.
2import torch.nn as nn

PyTorch's neural-network submodule. Provides high-level layer primitives so we don't have to allocate weight tensors and write y = x @ W.T + b by hand. We use nn.Linear for the backbone and both heads.

EXECUTION STATE
📚 torch.nn = Layer primitives + loss functions. Exports nn.Linear (fully-connected), nn.Conv2d, nn.LSTM, nn.functional (functional API for losses like cross_entropy).
as nn = Universal alias in the PyTorch community. Lets us write nn.Linear(...) instead of torch.nn.Linear(...).
4torch.manual_seed(0)

Set the global PyTorch PRNG seed so the experiment is reproducible. Without this, every run produces different gradient norms — and the printed numbers in this section would not be stable across executions.

EXECUTION STATE
📚 torch.manual_seed(seed) = Sets the global PyTorch PRNG. Affects torch.randn, torch.rand, torch.randint, and the random initialisation of nn.Linear weights going forward. Returns a torch.Generator.
⬇ arg: seed = 0 = Any int works; 0 is the de-facto default for tutorials. Same seed + same code = same numbers, every machine.
→ what it determines = The init weights of backbone, rul_head, hp_head; the random batch x; the rul_target; the hp_target. ALL stochastic state below is fixed by this one call.
7backbone = nn.Linear(14, 32)

Tiny shared backbone — a single fully-connected layer. 14 sensor channels in (matching §5's C-MAPSS 14-sensor subset) → 32 hidden features out. The two heads will both read from these 32 features, which is what makes them share parameters.

EXECUTION STATE
📚 nn.Linear(in_features, out_features, bias=True) = Fully-connected layer. Stores a learnable weight matrix W of shape (out, in) and bias b of shape (out,). Forward pass: y = x @ W.T + b. Example: nn.Linear(14, 32) creates W shape (32, 14) — 448 weights — and b shape (32,) — 32 biases — for 480 total params.
⬇ arg 1: in_features = 14 = Input dimensionality. Matches §5's recommended C-MAPSS 14-sensor subset (after dropping low-information channels). Sets the number of COLUMNS in W.
⬇ arg 2: out_features = 32 = Output dimensionality. Small but non-trivial — large enough that the gradient norms differ realistically across the two heads. Sets the number of ROWS in W.
⬆ backbone = nn.Linear module. backbone.weight has shape (32, 14); backbone.bias has shape (32,). Total 480 learnable params — exactly what shared_grad_norm() will differentiate against.
8rul_head = nn.Linear(32, 1)

RUL regression head. Maps the 32 shared features to a single scalar — the predicted Remaining Useful Life — which we will train with mean-squared-error against rul_target.

EXECUTION STATE
⬇ arg 1: in_features = 32 = Must match the backbone's out_features (32). Otherwise the matmul backbone(x) @ rul_head.weight.T fails.
⬇ arg 2: out_features = 1 = Single scalar — the predicted RUL value. Trained with MSE on a real-valued target in [0, 125].
⬆ rul_head = nn.Linear(32, 1). 32 weights + 1 bias = 33 params. NOT shared — each head has its own.
9hp_head = nn.Linear(32, 3)

Health-state classification head. Maps the 32 shared features to 3 logits (Normal / Degrading / Critical), which cross_entropy will turn into class probabilities and a scalar loss.

EXECUTION STATE
⬇ arg 1: in_features = 32 = Same as rul_head: must match backbone output.
⬇ arg 2: out_features = 3 = One logit per health class. Cross-entropy internally applies log-softmax then negative log-likelihood.
⬆ hp_head = nn.Linear(32, 3). 96 weights + 3 biases = 99 params.
→ why two heads? = Multi-task learning: one shared trunk learns generic sensor features; specialised heads decode them per task. The whole reason GABA exists is to balance how strongly each head pulls on the trunk.
12x = torch.randn(64, 14)

Synthetic mini-batch. 64 samples, each a 14-dim vector — stand-in for one batch of normalised sensor windows from C-MAPSS.

EXECUTION STATE
📚 torch.randn(*size) = Sample a tensor from N(0, 1) (mean 0, variance 1). Variadic in size: torch.randn(64, 14) → shape (64, 14). Example: torch.randn(3) → tensor of 3 standard-normal floats.
⬇ arg: 64, 14 = Shape spec. Each positional argument is one dimension. (64, 14) = 64 rows × 14 cols.
→ 64 = batch size = Number of independent samples per gradient step. Typical mini-batch size.
→ 14 = feature dim = Matches backbone's in_features. Each sample is one 14-D sensor reading.
⬆ x = Tensor (64, 14), dtype float32. Standard-normal random — fine here because we only care about gradient MAGNITUDES, not loss MEANINGS.
13rul_target = torch.rand(64, 1) * 125.0

Random RUL targets uniformly in [0, 125] — the paper's piecewise-linear RUL cap. We multiply by 125 because torch.rand samples from [0, 1).

EXECUTION STATE
📚 torch.rand(*size) = Uniform in [0, 1). DIFFERENT from torch.randn (normal). Example: torch.rand(2) → tensor([0.41, 0.83]).
⬇ arg: 64, 1 = Shape (64, 1) — column vector matching rul_head's output shape (so MSE works without broadcasting).
* 125.0 = Scale [0, 1) → [0, 125). Picked because 125 is the §5 RUL clip — lifetimes longer than 125 cycles are clamped to 125 in the paper.
⬆ rul_target = Tensor (64, 1), dtype float32. Synthetic regression labels in [0, 125).
14hp_target = torch.randint(0, 3, (64,))

Random int64 health labels — class indices in {0, 1, 2}. cross_entropy expects a 1-D int64 target tensor (NOT one-hot encoded), so we sample integers, not floats.

EXECUTION STATE
📚 torch.randint(low, high, size) = Uniform integer tensor with values in [low, high). Note: high is EXCLUSIVE. Default dtype = int64 (torch.long), exactly what cross_entropy expects. Example: torch.randint(0, 3, (5,)) → tensor([1, 2, 0, 2, 1]).
⬇ arg 1: low = 0 = Inclusive lower bound for class indices.
⬇ arg 2: high = 3 = Exclusive upper bound. Combined with low=0 → values in {0, 1, 2} = our 3 health classes.
⬇ arg 3: size = (64,) = Tuple — produces a 1-D tensor of length 64. Note the trailing comma: (64,) is a tuple, (64) is just the int 64.
⬆ hp_target = Tensor (64,), dtype int64. Class indices for the 3 health states.
17feat = backbone(x)

Forward pass through the shared backbone. (64, 14) input → (64, 32) features. Both heads read from `feat`, so any gradient flowing back from either head's loss reaches the same backbone parameters — that's what 'parameter sharing' means in PyTorch.

EXECUTION STATE
backbone(x) — calls __call__ which calls forward() = Internally: feat = x @ backbone.weight.T + backbone.bias. backbone.weight shape (32,14), .T → (14,32). x (64,14) @ W.T (14,32) = (64,32). + bias (32,) broadcasts.
⬆ feat = Tensor (64, 32) with requires_grad=True (autograd is recording). Shared hidden representation. KEY: both rul_loss and health_loss will flow gradients back through this tensor to the SAME backbone params.
18rul_loss = ((rul_head(feat) - rul_target) ** 2).mean()

Mean-squared-error loss on RUL. (predicted − target) ², averaged over the batch. One line of standard PyTorch — no nn.MSELoss object needed.

EXECUTION STATE
rul_head(feat) — predicted RULs = Tensor (64, 1). Forward through the regression head: feat @ rul_head.weight.T + rul_head.bias.
− rul_target = Element-wise subtraction (shapes match: (64,1) − (64,1)).
** 2 — element-wise square = Operator overload: ** delegates to .pow(2). Squaring produces strictly non-negative residuals.
📚 .mean() = Tensor reduction — averages ALL elements (64 here) into a 0-dim scalar tensor. With axis= argument, would reduce along an axis instead.
⬆ rul_loss = 0-dim tensor ≈ 5318. Large because predictions are near 0 (random init) but targets are O(60).
19health_loss = nn.functional.cross_entropy(hp_head(feat), hp_target)

Standard 3-class cross-entropy loss. Internally combines log_softmax over the 3 logits and negative-log-likelihood against the int64 target.

EXECUTION STATE
📚 nn.functional.cross_entropy(input, target, reduction='mean') = PyTorch loss function. Inputs: logits shape (B, C); int64 target shape (B,). Returns 0-dim scalar by default. Combines log_softmax + nll_loss in one numerically stable kernel.
⬇ arg 1: hp_head(feat) — logits = Tensor (64, 3). Raw scores BEFORE softmax. cross_entropy applies softmax internally — DON'T do it yourself or you'll double-softmax.
⬇ arg 2: hp_target — class indices = Tensor (64,) int64. Each element is the correct class index in {0, 1, 2}. NOT one-hot encoded.
⬆ health_loss = 0-dim tensor ≈ 1.10. ≈ ln(3) because random logits give ~uniform probabilities → −log(1/3) = ln 3 ≈ 1.0986.
22def shared_grad_norm(loss, params) → torch.Tensor

Helper that returns the L2 norm of a loss's gradient on a list of parameters — WITHOUT writing into the parameters' .grad attribute. This non-mutating measurement is crucial: GABA needs to read both task gradients before deciding the lambdas, and we don't want measurement to side-effect the eventual loss.backward() step.

EXECUTION STATE
⬇ input: loss — scalar to differentiate = 0-dim tensor. Either rul_loss or health_loss in our case.
⬇ input: params — what to differentiate w.r.t. = List of nn.Parameter objects (typically the shared backbone parameters). Order doesn't matter for the L2 norm.
⬆ returns: 0-dim tensor = ‖∇_params loss‖₂ = sqrt(Σ_p ‖g_p‖²) — the L2 norm of the concatenated per-parameter gradient. The single scalar GABA needs.
→ why 'shared'? = We always pass the SHARED backbone params, not head params. GABA balances tasks at the point where they collide (the trunk), not at the heads where they're already independent.
23docstring: "L2 norm of loss's gradient on the given parameters."

Documents intent. Critical implicit detail: this function does NOT mutate p.grad. Future maintainers seeing autograd.grad below sometimes assume it side-effects .grad like loss.backward() does — the docstring is the first warning that it doesn't.

24grads = torch.autograd.grad(loss, params, retain_graph=True, allow_unused=True)

Functional differentiation. Returns a tuple of gradient tensors, one per param, WITHOUT writing into p.grad. This lets us call grad() once for rul_loss and once for health_loss, then later call backward() on the weighted combination — all on the same forward graph.

EXECUTION STATE
📚 torch.autograd.grad(outputs, inputs, retain_graph=False, allow_unused=False) = Functional differentiation API. Returns ∂outputs/∂inputs as a tuple of tensors of shapes matching `inputs`. Pure-functional — does NOT mutate .grad. Contrast with loss.backward() which DOES accumulate into .grad.
⬇ arg 1: outputs = loss = Scalar tensor (or sum-reducible) to differentiate.
⬇ arg 2: inputs = params = List of leaf tensors with requires_grad=True. We get gradients of `loss` w.r.t. each.
⬇ arg 3: retain_graph = True = Keep the autograd computation graph alive AFTER this call. Default is False — the graph is freed and a second grad()/backward() would error. We set True because we will call grad() again for the OTHER task's loss on the same graph.
→ mini-example: retain_graph = loss = x.sum(); torch.autograd.grad(loss, [x]) # OK torch.autograd.grad(loss, [x]) # RuntimeError: graph already freed! # With retain_graph=True the second call works.
⬇ arg 4: allow_unused = True = If a parameter does NOT affect `loss` (e.g. hp_head params have no effect on rul_loss), return None for that entry instead of raising. Important for shared-backbone setups where each task's loss only flows through its own head, not the other head's params.
⬆ grads = Tuple of length 2 (one entry per param in `shared`). Each entry is a Tensor of the same shape as the param, OR None (filtered out below).
25sq = sum((g.detach().norm() ** 2 for g in grads if g is not None))

Sum of squared per-parameter L2 norms. Uses a generator expression (lazy iterator) so we don't allocate an intermediate list. The 'if' filter discards Nones from allow_unused.

EXECUTION STATE
for g in grads = Iterates over the gradient tuple. Each g is either a Tensor (same shape as the param) or None.
if g is not None = Skip None entries returned by allow_unused for params that didn't participate in the loss.
📚 .detach() = Tensor method. Returns a new tensor sharing storage but DETACHED from the autograd graph (requires_grad=False). We don't want the norm computation itself to be differentiable — autograd shouldn't track gradient-of-gradient here.
📚 .norm(p=2) = L2 norm by default (Frobenius for matrices). For a tensor T: returns sqrt(Σ T_ij²) as a 0-dim tensor. Example: torch.tensor([3., 4.]).norm() → 5.0.
** 2 = Square the per-param L2 norm. So we have Σ_p ‖g_p‖² — building the squared total norm before taking the global sqrt.
📚 sum(iterable) = Python builtin. With an iterable of tensors, performs Tensor + Tensor for each, returning a 0-dim tensor with the running sum.
⬆ sq = 0-dim tensor. = Σ_p ‖g_p‖² over the 2 backbone params (weight + bias). Always non-negative.
26return sq.sqrt()

Take the global square root. sqrt(Σ_p ‖g_p‖²) is exactly the L2 norm of the CONCATENATED gradient vector — the single number GABA uses as ‖g_i‖.

EXECUTION STATE
📚 .sqrt() = Tensor element-wise square root method. On a 0-dim tensor, returns a 0-dim tensor. Example: torch.tensor(25.0).sqrt() → tensor(5.0).
⬆ return = 0-dim tensor. ‖∇_params loss‖₂ — the per-task gradient norm GABA reads.
→ identity = sqrt(Σ ‖g_p‖²) = ‖concat(g_p)‖₂. The L2 norm of a concatenated vector equals the sqrt of the sum of squared sub-norms.
29shared = list(backbone.parameters())

Materialise the backbone's parameter iterator into a Python list. We need to consume it TWICE (once per task) below, but iterators only walk forward once — converting to a list lets us re-iterate.

EXECUTION STATE
📚 nn.Module.parameters() = Generator method. Yields every nn.Parameter under this module recursively. Returns an iterator, NOT a list — you can iterate once.
📚 list(iterable) = Force materialisation of an iterable into a Python list. After this we can iterate `shared` as many times as we want.
⬆ shared = List of 2 nn.Parameter: [0] backbone.weight — Tensor shape (32, 14), 448 params [1] backbone.bias — Tensor shape (32,), 32 params Total: 480 learnable params.
30g_rul = shared_grad_norm(rul_loss, shared)

Compute the L2 norm of the RUL loss's gradient on the shared backbone parameters. This is the FIRST of two grad() calls on the same forward graph — that's why retain_graph=True.

EXECUTION STATE
⬇ arg 1: rul_loss = 0-dim tensor ≈ 5318. The MSE scalar.
⬇ arg 2: shared = List of 2 backbone params — what we differentiate against.
⬆ g_rul = 0-dim tensor ≈ 76.5335. Real measured ‖∇_θ rul_loss‖₂. Larger than §12.3's idealised 5.0 only because this toy backbone has more parameters and an O(60) target scale.
31g_health = shared_grad_norm(health_loss, shared)

SECOND grad() call, on the same forward graph (allowed because the first call used retain_graph=True). Returns the L2 norm of the health loss's gradient on the shared backbone — the OTHER quantity GABA needs.

EXECUTION STATE
⬇ arg 1: health_loss = 0-dim tensor ≈ 1.10. The cross-entropy scalar.
⬇ arg 2: shared = Same list as before — both tasks share these params.
⬆ g_health = 0-dim tensor ≈ 0.3665. Two orders of magnitude smaller than g_rul. The §12.3 imbalance reproduces in miniature even on this toy setup.
→ empirical ratio = g_rul / g_health = 76.5335 / 0.3665 ≈ 208.8x — the exact imbalance GABA must correct on this seed.
34S = g_rul + g_health

Sum of the two gradient norms. K=2 normaliser — divides through both lambdas so they sum to 1.0 (probability simplex).

EXECUTION STATE
+ on 0-dim tensors = Returns a new 0-dim tensor.
⬆ S = 0-dim tensor ≈ 76.9000.
35lam_rul = g_health / S

K=2 closed-form GABA: λ_rul = g_health / (g_rul + g_health). The OTHER task's gradient norm sits in the numerator — that's the inversion that lets the loud task get the small weight.

EXECUTION STATE
g_health = ≈ 0.3665
S = ≈ 76.9000
/ on 0-dim tensors = Element-wise division. Returns a 0-dim tensor.
⬆ lam_rul = 0-dim tensor ≈ 0.004766. ~0.48% of the loss combination goes to RUL.
36lam_health = g_rul / S

Mirror equation: λ_health = g_rul / (g_rul + g_health). Same swap — the LOUD task's norm in the numerator means the QUIET task gets the big weight.

EXECUTION STATE
⬆ lam_health = 0-dim tensor ≈ 0.995234. ~99.52% goes to health.
→ simplex sanity check = lam_rul + lam_health = 0.004766 + 0.995234 = 1.000000 ✓ — by construction, since both numerators sum to S itself.
39c_rul = lam_rul * g_rul

Effective contribution of the RUL task to the backbone update. = λ_rul · ‖g_rul‖. The §17.2 'c' quantity in real autograd numbers.

EXECUTION STATE
lam_rul * g_rul = (g_health/S) · g_rul = g_rul · g_health / S
→ numeric = 0.004766 · 76.5335 ≈ 0.364791
⬆ c_rul = 0-dim tensor ≈ 0.364791.
40c_health = lam_health * g_health

Effective contribution of the health task. = λ_health · ‖g_health‖.

EXECUTION STATE
lam_health * g_health = (g_rul/S) · g_health = g_rul · g_health / S
→ numeric = 0.995234 · 0.3665 ≈ 0.364791
⬆ c_health = 0-dim tensor ≈ 0.364791. IDENTICAL to c_rul (modulo float32 noise).
→ why identical? = Algebraically both reduce to g_rul · g_health / S. The closed form FORCES equality on every batch, every backbone, every seed — not approximately, exactly.
42print(f"g_rul = {g_rul.item():.4f}")

Pretty-print the RUL gradient norm. f-string interpolation with format spec :.4f → 4 decimal places.

EXECUTION STATE
📚 .item() = Tensor method. Converts a 0-dim tensor to a plain Python scalar (float / int). Required because tensors don't directly format with %f or .4f — you need the underlying Python number.
→ format spec :.4f = Floating-point, 4 digits after the decimal. Example: f"{3.141592:.4f}" → '3.1416'.
Output = g_rul = 76.5335
43print(f"g_health = {g_health.item():.4f}")

Same pattern for the health gradient norm.

EXECUTION STATE
Output = g_health = 0.3665
44print(f"ratio = {(g_rul / g_health).item():.1f}x")

Empirical gradient ratio for this seed/batch. Computed inline as g_rul / g_health then converted to a Python float for printing.

EXECUTION STATE
(g_rul / g_health).item() = Compute tensor division → 0-dim tensor → .item() → Python float.
Output = ratio = 208.8x
→ meaning = The RUL gradient is 208.8× larger than the health gradient on this batch. Without GABA, the optimiser would barely 'feel' the health loss.
45print(f"lambda_rul = {lam_rul.item():.6f}")

Per-task RUL weight. 6 decimals to make the 0.004766 visible (with 4dp it would round to 0.0048).

EXECUTION STATE
Output = lambda_rul = 0.004766
46print(f"lambda_health = {lam_health.item():.6f}")

Per-task health weight. ~99.5% → most of the loss combination.

EXECUTION STATE
Output = lambda_health = 0.995234
47print(f"c_rul = {c_rul.item():.6f}")

RUL effective contribution. The first half of the equality we're proving.

EXECUTION STATE
Output = c_rul = 0.364791
48print(f"c_health = {c_health.item():.6f}")

Health effective contribution. The second half. Side-by-side with c_rul above, the equality is visually obvious.

EXECUTION STATE
Output = c_health = 0.364791
49print(f"|c_rul - c_health| = {(c_rul - c_health).abs().item():.2e}")

Numerical-precision sanity check. The difference is single-precision round-off, NOT a real gap.

EXECUTION STATE
📚 .abs() = Tensor element-wise absolute value. On a 0-dim tensor returns a 0-dim non-negative tensor.
→ format spec :.2e = Scientific notation, 2 decimals. Example: f"{0.0000000298:.2e}" → '2.98e-08'.
Output = |c_rul - c_health| = 2.98e-08
→ meaning = ≈ 10⁻⁸ — comfortably inside float32 epsilon (≈ 1.19e-7). The two contributions are EQUAL to numerical precision; the residual is machine noise, not algorithmic error.
51assert torch.allclose(c_rul, c_health, atol=1e-6), "GABA must equalise contributions"

Hard-fail if the equality breaks. Anchors the 'why it works' claim in code, not just prose. Runs every time the script executes — so the entire chapter's central identity is regression-tested by the act of reading it.

EXECUTION STATE
📚 torch.allclose(input, other, rtol=1e-5, atol=1e-8) = Returns True if |input − other| ≤ atol + rtol·|other| element-wise. The tolerant equality check for floating-point. Example: torch.allclose(torch.tensor(1.0), torch.tensor(1.0000001)) → True.
⬇ arg 1: c_rul = First operand. 0-dim tensor.
⬇ arg 2: c_health = Second operand. 0-dim tensor.
⬇ arg 3: atol = 1e-6 = Absolute tolerance. Comfortably above float32 noise (~1e-7) but several orders of magnitude below ANY real imbalance (the smallest meaningful c-difference would be at least ~1e-3 here).
assert <expr>, <msg> = Python statement. If <expr> is False, raise AssertionError with <msg>. If True, no-op.
→ if this assertion ever fails = Two scenarios: (1) you implemented GABA wrong (typo in the closed form), or (2) you used the K-task formula for K>2 — that formula equalises a DIFFERENT quantity (sum-of-others), not c_i pairwise. See §17.3 for K>2 details.
52print("PASS: GABA equalises effective contributions on real autograd gradients.")

Confirmation banner. Reaches stdout only if the assert above passed.

EXECUTION STATE
Final stdout (all 9 prints) =
g_rul         = 76.5335
g_health      = 0.3665
ratio         = 208.8x
lambda_rul    = 0.004766
lambda_health = 0.995234
c_rul         = 0.364791
c_health      = 0.364791
|c_rul - c_health| = 2.98e-08
PASS: GABA equalises effective contributions on real autograd gradients.
17 lines without explanation
1import torch
2import torch.nn as nn
3
4torch.manual_seed(0)
5
6# Tiny shared backbone + two task heads (mimics the §11 architecture).
7backbone = nn.Linear(14, 32)
8rul_head = nn.Linear(32, 1)
9hp_head  = nn.Linear(32, 3)
10
11# One realistic batch.
12x          = torch.randn(64, 14)
13rul_target = torch.rand(64, 1) * 125.0
14hp_target  = torch.randint(0, 3, (64,))
15
16# Forward through shared backbone, then per-task heads.
17feat        = backbone(x)
18rul_loss    = ((rul_head(feat) - rul_target) ** 2).mean()
19health_loss = nn.functional.cross_entropy(hp_head(feat), hp_target)
20
21
22def shared_grad_norm(loss: torch.Tensor, params: list) -> torch.Tensor:
23    """L2 norm of loss's gradient on the given parameters."""
24    grads = torch.autograd.grad(loss, params, retain_graph=True, allow_unused=True)
25    sq    = sum((g.detach().norm() ** 2 for g in grads if g is not None))
26    return sq.sqrt()
27
28
29shared   = list(backbone.parameters())
30g_rul    = shared_grad_norm(rul_loss,    shared)
31g_health = shared_grad_norm(health_loss, shared)
32
33# K=2 inverse-proportional weights.
34S          = g_rul + g_health
35lam_rul    = g_health / S
36lam_health = g_rul    / S
37
38# Effective contributions: lambda_i * ||g_i||
39c_rul    = lam_rul    * g_rul
40c_health = lam_health * g_health
41
42print(f"g_rul         = {g_rul.item():.4f}")
43print(f"g_health      = {g_health.item():.4f}")
44print(f"ratio         = {(g_rul / g_health).item():.1f}x")
45print(f"lambda_rul    = {lam_rul.item():.6f}")
46print(f"lambda_health = {lam_health.item():.6f}")
47print(f"c_rul         = {c_rul.item():.6f}")
48print(f"c_health      = {c_health.item():.6f}")
49print(f"|c_rul - c_health| = {(c_rul - c_health).abs().item():.2e}")
50
51assert torch.allclose(c_rul, c_health, atol=1e-6), "GABA must equalise contributions"
52print("PASS: GABA equalises effective contributions on real autograd gradients.")
The algebraic identity. For K=2: crul=λrulgrul=ghealthgrul+ghealthgrul=grulghealthgrul+ghealthc_{\text{rul}} = \lambda_{\text{rul}} \, g_{\text{rul}} = \frac{g_{\text{health}}}{g_{\text{rul}} + g_{\text{health}}} \cdot g_{\text{rul}} = \frac{g_{\text{rul}} \, g_{\text{health}}}{g_{\text{rul}} + g_{\text{health}}}. By symmetry chealthc_{\text{health}} equals the same expression. They are identically equal — not approximately, not on average, but on every single mini-batch.

The Same Principle In Other Domains

“Equalise effective effect, not nominal weight” recurs whenever a single mechanism aggregates heterogeneous contributors:

DomainMechanismWhat &lsquo;effective contribution&rsquo; isInverse-rule analogue
Pharmacology (the hook)Combination drug therapydose × potencyInverse-potency dosing
Portfolio risk (Markowitz)Multi-asset allocationweight × asset volatilityRisk parity (1/σ_i weights)
Federated learning (FedAvg)Average client gradientsclient weight × ||local update||Inverse-update-norm aggregation
Reinforcement learning (DQN with multi-reward)Sum reward componentsweight × reward gradientInverse-gradient reward shaping
Sound mixing (mastering)Sum tracks at the busfader × stem loudnessLUFS-normalising fader
Climate-model ensembles (CMIP6)Weight per model in multi-model meanweight × model varianceInverse-variance weighting
Object detection (Detectron2)Sum bbox-regression + classification + objectnessweight × loss-component gradGABA / Inv-grad weighting

In each row, the ‘loud contributor’ is the one with the larger natural magnitude. Naive equal weighting lets it monopolise the mechanism; inverse-proportional weighting equalises effect. The mathematics is identical to K=2 GABA — just renamed for the domain.

Three Misconceptions About ‘Why It Works’

Misconception 1: ‘GABA scales DOWN the loud task.’ No — it scales DOWN the loud task's weight, but the loud task's gradient is unchanged. The optimiser still sees the full θLrul\nabla_\theta \mathcal{L}_{\text{rul}} signal; it just receives much less of it per step. Think attenuation, not censorship.
Misconception 2: ‘Any inverse-shaped rule works.’ Square-root inverse, log-inverse, or rank-based inverse all push the right direction but leave a residual imbalance (22× for sqrt on the realistic case). Only linear inverse exactly equalises cic_i for K=2.
Misconception 3: ‘The K-task formula equalises cic_i for K>2.’ Only for K=2. The K-task generalisation in §17.1 equalises a different quantity (the sum-of-others). For K>2 see §17.3 for the closed form and the precise invariant that GABA preserves.
Why this matters for the paper's safety claim. Equalising cic_i is exactly what allows the health classifier's gradient to actually update the backbone. Under uniform weighting it is buried 500× below the regression gradient and the backbone effectively sees only the regressor — which is precisely the regime that produces the high-RMSE / high-NASA ‘accuracy-only’ failure mode in §13. GABA's −55% NASA improvement on FD002 is a direct consequence of this single algebraic identity.

Takeaway

  • Effective contribution is the right quantity to balance. ci=λigic_i = \lambda_i \, \|g_i\| is what the optimiser sees, not λi\lambda_i alone.
  • Inverse-proportional is unique for K=2. Among all weightings on the simplex, only λi1/gi\lambda_i \propto 1 / g_i satisfies ci=cjc_i = c_j for every batch and every backbone.
  • The identity is algebraic. crul=chealth=grulghealth/(grul+ghealth)c_{\text{rul}} = c_{\text{health}} = g_{\text{rul}} g_{\text{health}} / (g_{\text{rul}} + g_{\text{health}}). The PyTorch assertion passes to single-precision epsilon, on any seed.
  • Sqrt / log / rank-based inverses do NOT balance. They reduce imbalance but leave 22× or more residual. Linear inverse is the only exact rule.
  • The principle generalises. Risk parity, inverse-variance ensemble weighting, LUFS-normalised mixing, and federated-averaging all instantiate the same equal-effect identity under different names.
  • This is why GABA fixes the safety failure. The 500× gradient gap is exactly what was burying the health signal; equalising cic_i unburies it, and the −55% NASA improvement on FD002 is the empirical witness.
Loading comments...