Chapter 19
15 min read
Section 77 of 121

GABA as a Proportional Feedback Controller

Control-Theoretic Interpretation

Hook: Cruise Control On A Hilly Road

A car cruising at 100 km/h on a hilly highway does not actually run at 100 km/h. The road tilts up, the car slows, the speedometer drops to 96. The cruise-control unit reads the discrepancy — e=10096=4 km/he = 100 - 96 = 4\ \text{km/h} — and pushes the throttle harder by an amount proportional to that error. Over the next second, the car accelerates back toward 100. The loop runs ten times per second and the driver feels nothing.

That little box is a proportional feedback controller. It has four ingredients: a sensor (speedometer), a setpoint (100 km/h), a comparator (subtraction), and an actuator (throttle). The control action is proportional to the error. If you slept through control engineering, this is the only sentence you need.

GABA is doing exactly the same thing — not in km/h, but in gradient space. The plant is the joint multi-task optimiser. The sensor is gi=θLi2g_i = \| \nabla_\theta \mathcal{L}_i \|_2 (per-task gradient norm on the shared backbone). The setpoint is equal contribution from all tasks, r=1/Kr = 1/K. The comparator subtracts the measured share from the setpoint. The control action — the task weight λi\lambda_i^{*} — is emitted by an inverse-proportional law. Paper main.tex:387 names this explicitly: “GABA can be viewed as a proportional feedback controller operating in gradient space.” This section unpacks what that sentence means line by line.

Why this framing matters. Once you see GABA as a P-controller, three properties of the algorithm stop being coincidences and become consequences: (1) the weights always sum to 1, (2) the rule has no learnable parameters and just three constants, and (3) the bounded-weight guarantee proven in §19.3 is the anti-windup property of any well-designed feedback loop. Open-loop alternatives (notably GradNorm) do not get these for free — they diverged on 1/5 N-CMAPSS seeds (paper main.tex:553). The closed loop is what makes GABA robust.

Five Words Every Control Engineer Knows

Before we map cruise control onto GABA, fix the vocabulary. Any single-input single-output (SISO) feedback loop has these five pieces:

TermSymbolCruise controlWhat it does
plantPthe car (mass + engine + drag)the system whose behaviour we want to shape
sensoryyspeedometermeasures the plant’s output
referencerr100 km/h setpointthe value we want y to equal
errore=rye = r - y100 - 96 = 4 km/hdeviation of measurement from reference
controller / actuatoru=C(e)u = C(e)throttle valvecommand produced from the error and pushed back into the plant

The diagram is always the same: r → (+) → controller → plant → sensor → back to (-) of the comparator. The minus sign is what makes it ‘negative feedback’: the loop subtracts the measurement from the reference. Without that minus sign you have positive feedback, the microphone-howl regime, which we explicitly do not want.

A proportional controller is the simplest non-trivial choice for C(e)C(e): u=Kpeu = K_p \cdot e. The control command is directly proportional to the error, with gain KpK_p. Larger error ⇒ larger correction. Zero error ⇒ zero correction. Linear, memoryless, dead simple.

Mapping Those Five Words To GABA

Now we replace cruise-control nouns with GABA nouns. Same five positions in the loop, different physical interpretation.

Control termGABA realisationWhere it lives in code
plantthe joint multi-task optimisation stepthe entire training loop except GABALoss
sensorgi=θLi2g_i = \| \nabla_\theta \mathcal{L}_i \|_2 — per-task L2 gradient norm on shared backbonecompute_task_grad_norm in grace/core/gradient_utils.py
referencer=1/Kr = 1/K — equal-share goal across K tasks (paper main.tex:387)implicit; baked into the closed-form rule
errorei=rgi/jgje_i = r - g_i / \sum_j g_j — deviation of measured share from equal shareimplicit; absorbed into the closed-form rule
controller / actuatorλi=(jgjgi)/((K1)jgj)\lambda_i^{*} = (\sum_j g_j - g_i) / ((K-1) \sum_j g_j) — paper eq. 4GABALoss.forward_k inner block

The actuator output λi\lambda_i^{*} — the task weight — is what gets multiplied against the per-task losses in the next forward step: L=iλiLi\mathcal{L} = \sum_i \lambda_i^{*} \mathcal{L}_i. The gradient of that combined loss flows back into the plant (the optimiser updates θ\theta), the new θ\theta changes the next batch's gig_i, the sensor reads the new shares, and the loop closes.

The minus sign hides in plain sight. You won't see ryr - y as a subtraction in the GABA closed form. It's absorbed: the numerator jgjgi\sum_j g_j - g_i is exactly the inverse-share equivalent of ryir - y_i. We'll prove this algebraically in the next section.

The Math: GABA Is A P-Controller With Unity Gain

Set K=2K = 2 for clarity (the paper's case; the K-task version is below). The measured share of task ii is

si=gig1+g2,s1+s2=1.s_i = \frac{g_i}{g_1 + g_2}, \qquad s_1 + s_2 = 1.

The reference is r=1/2r = 1/2. The error is

ei=rsi=12gig1+g2=(g1+g2)2gi2(g1+g2)=gjgi2(g1+g2)e_i = r - s_i = \frac{1}{2} - \frac{g_i}{g_1 + g_2} = \frac{(g_1 + g_2) - 2 g_i}{2(g_1 + g_2)} = \frac{g_j - g_i}{2(g_1 + g_2)}

where jj is ‘the other task’. Now the proportional law with unity gain emits

λi=r+Kpei=12+1gjgi2(g1+g2)=(g1+g2)+(gjgi)2(g1+g2)=2gj2(g1+g2)=gjg1+g2.\lambda_i^{*} = r + K_p \cdot e_i = \frac{1}{2} + 1 \cdot \frac{g_j - g_i}{2(g_1 + g_2)} = \frac{(g_1 + g_2) + (g_j - g_i)}{2(g_1 + g_2)} = \frac{2 g_j}{2(g_1 + g_2)} = \frac{g_j}{g_1 + g_2}.

That is exactly the paper's closed form for K=2K = 2: λi=(Σggi)/Σg=gj/Σg\lambda_i^{*} = (\Sigma g - g_i)/\Sigma g = g_j / \Sigma g. So GABA's closed form is mathematically identical to r+Kpeir + K_p \cdot e_i with Kp=1K_p = 1. Q.E.D.

For general KK tasks the algebra is the same with a slightly bigger normaliser:

λi=jgjgi(K1)jgj=1K+1K1ei,ei=1Kgijgj.\lambda_i^{*} = \frac{\sum_j g_j - g_i}{(K - 1) \sum_j g_j} = \frac{1}{K} + \frac{1}{K - 1} \cdot e_i, \qquad e_i = \frac{1}{K} - \frac{g_i}{\sum_j g_j}.

The proportional gain becomes Kp=1/(K1)K_p = 1/(K - 1) — still constant, still proportional, still no learned parameters. For the paper's K=2K = 2 we recover Kp=1K_p = 1 exactly.

The unity-gain property is the reason GABA needs no tuning. A general P-controller demands tuning of KpK_p: too small ⇒ sluggish; too large ⇒ oscillation. GABA's closed form pins KpK_p at exactly the value for which the controller, in one step, would zero the error if applied directly to the share — modulo the EMA filter (sec. 19.2) which deliberately slows the response so that stochastic gradient noise doesn't push the loop into oscillation. There's no tuning on this gain because the closed form is already the algebraically optimal value.

Three numerical sanity checks fall out of the math:

  • Sum to 1. iλi=i(Σggi)/((K1)Σg)=(KΣgΣg)/((K1)Σg)=1\sum_i \lambda_i^{*} = \sum_i (\Sigma g - g_i)/((K-1)\Sigma g) = (K \Sigma g - \Sigma g)/((K-1)\Sigma g) = 1. The controller output is automatically on the simplex. No renormalisation needed.
  • Symmetry. Swap g1g2g_1 \leftrightarrow g_2 and the weights swap accordingly. The rule is permutation-equivariant in the tasks.
  • Equal-grad fixed point. If g1=g2g_1 = g_2, the error is zero and the controller emits λ=(1/2,1/2)\lambda^{*} = (1/2, 1/2) — equal weights, no correction. Equal-gradient is a fixed point of the closed loop.

Interactive: The Closed Loop, Live

Move the two gradient-norm sliders below. Watch the controller emit λ\lambda^{*} in real time. The block diagram shows the closed loop — reference, comparator, P-law, plant, sensor — with the live values flowing along the wires. The preset buttons jump to specific operating points: Equal, 10× imbalance, 500× imbalance (the paper's measured median), and 2,400× peak (the paper's observed peak ratio early in training).

Loading control-loop diagram…

Three things to verify by playing with the sliders. First, when you nudge grulg_{\text{rul}} up, λrul\lambda_{\text{rul}} goes down — this is the inverse action of the controller. Second, the two weights always sum to 1 (read the bottom-right readout). Third, when both sliders sit at the same value, the controller emits exactly (0.5,0.5)(0.5, 0.5) — the equal-gradient fixed point.

Connect to the empirical record. The paper's Fig. gradient_dynamics (main.tex:564-569) reports that under equal weighting the regression-to-classification ratio evolves from ~50× initially, peaks at ~2,400× around epoch 4, then stabilises at ~500–1,000×. Try those three points on the slider. At each, GABA assigns ~95% weight to health (the smaller-gradient task), counterbalancing the dominant regression gradient and equalising effective contributions on the shared backbone — exactly as §V-G of the paper documents.

Python: P-Controller From Scratch (NumPy)

We'll first build the controller as a textbook P-controller with explicit sense/error/control methods so every block of the loop has a name. This is pedagogical: the production code (next subsection) collapses these into a single closed-form expression, but the algebraic equivalence is exactly what we proved above.

GABA's Inner Loop As An Explicit P-Controller (NumPy)
🐍gaba_p_controller.py
1Module docstring

Sets the contract for this file: GABA's closed-form weight rule is the P-controller part of a feedback loop. Sections 19.2 (EMA filter) and 19.3 (floor / anti-windup) cover the stabilisers wrapped around it.

13import numpy as np

NumPy provides ndarray, np.asarray, .sum(), and broadcasting. We use float64 throughout for headroom on the (sum(g) - g_i)/((K-1) sum(g)) division when sum(g) is very small.

EXECUTION STATE
📚 numpy = Numerical computing library. Used for ndarray, np.asarray, np.sum.
16class GabaPController:

Plain Python class. Holds the controller hyperparameters (number of tasks K, gain K_p) and the reference signal r = 1/K. Has three methods: sense, error, control — exactly the three blocks of the closed loop.

17docstring

One-line summary of what the class is.

19def __init__(self, n_tasks=2, K_p=1.0):

Constructor. Two hyperparameters with paper defaults baked in.

EXECUTION STATE
⬇ n_tasks = 2 = Number of tasks in the multi-task loop. K = 2 for RUL + health. Sets the reference 1/K.
⬇ K_p = 1.0 = Proportional gain. The paper uses unity gain — the controller emits the FULL inverse-share correction at every step. Larger K_p would over-correct; smaller K_p would under-correct. Paper main.tex:387 fixes this implicitly via the closed-form rule.
20self.K = n_tasks

Cache K as an instance attribute. Used inside .control() to compute the (K-1) normaliser.

EXECUTION STATE
self.K = 2
21self.K_p = K_p

Cache the proportional gain. The paper's closed-form rule corresponds to K_p = 1 (proof in the math section above).

EXECUTION STATE
self.K_p = 1.0 by default. Paper main.tex:338-355 derives the closed form with implicit K_p = 1.
22self.reference = 1.0 / n_tasks

The reference signal r — the equal-contribution target. For K = 2: r = 0.5. The controller drives the measured shares s_i toward this value.

EXECUTION STATE
self.reference = 0.5 for K = 2. The setpoint we want each task's gradient share to approach.
24def sense(self, grad_norms):

The SENSOR block of the loop. Takes raw per-task gradient norms (the plant output) and returns the measured share s_i = g_i / sum(g). The share is dimensionless and lives on the simplex (sums to 1).

EXECUTION STATE
⬇ input: grad_norms = List of K floats. Per-task L2 norms of grad(L_i) on shared backbone parameters. Computed by torch.autograd.grad in real training; provided directly here for clarity.
⬆ returns = ndarray (K,) on the simplex. share[i] is the fraction of total gradient mass attributable to task i.
25docstring

Documents the sensor operation.

26g = np.asarray(grad_norms, dtype=np.float64)

Convert the list to an ndarray. np.asarray avoids a copy if grad_norms is already an ndarray.

EXECUTION STATE
📚 np.asarray(a, dtype) = Build an ndarray from a (no-copy if dtype matches). Compare to np.array which always copies.
g = ndarray (K,). For our demo: [250.0, 0.5].
27total = g.sum() + 1e-12

Sum of gradient norms plus a numerical guard against the all-zero case (rare but possible at stationary points).

EXECUTION STATE
📚 .sum() = ndarray reduction → scalar sum.
1e-12 = Numerical guard. Without it, the division below would produce NaN when g is exactly zero.
total = 250.5 for our demo (250.0 + 0.5).
28return g / total

Element-wise division. NumPy broadcasts the scalar `total` against the (K,) vector `g`.

EXECUTION STATE
⬆ return: g / total = [250.0/250.5, 0.5/250.5] = [0.998004, 0.001996]. Note: s_rul ≈ 99.8% — the regression task dominates entirely.
30def error(self, share):

The COMPARATOR block. Computes the deviation of each measured share from the reference. Positive error → that task is UNDER-represented (needs to be boosted). Negative error → OVER-represented.

EXECUTION STATE
⬇ input: share = ndarray (K,) on the simplex from .sense(). For our demo: [0.998, 0.002].
⬆ returns = ndarray (K,). e_i = r - s_i. Sums to 0 by construction (since shares sum to 1 and references sum to 1).
31docstring

Documents the error semantics.

32return self.reference - share

Element-wise scalar minus vector. NumPy broadcasts the scalar reference against (K,).

EXECUTION STATE
self.reference = 0.5 (scalar)
share = [0.998, 0.002] for our demo
⬆ return = [-0.498, +0.498]. RUL is OVER-represented (e<0). Health is UNDER-represented (e>0). The errors are equal in magnitude and opposite in sign — a property of K=2 that breaks down for K>2.
34def control(self, grad_norms):

The CONTROLLER block. Takes the raw gradient norms and emits the task weights λ_i. Implements paper equation 4 exactly. For K = 2 this is mathematically equivalent to the P-controller form r + K_p · e (proof in the section above).

EXECUTION STATE
⬇ input: grad_norms = List of K floats. Same input as .sense() — the controller could be expressed as control(g) = r + K_p·error(sense(g)) but the closed form is more efficient.
⬆ returns = ndarray (K,) on the simplex. The TASK WEIGHTS used to combine the per-task losses. Paper symbol: λ_i*.
35docstring

Documents the control action.

36g = np.asarray(grad_norms, dtype=np.float64)

Same conversion as in .sense().

EXECUTION STATE
g = ndarray (K,) = [250.0, 0.5]
37total = g.sum() + 1e-12

Same total as in .sense(). Caching this on self would be a minor optimisation; the paper code recomputes it for clarity.

EXECUTION STATE
total = 250.5
38# Closed form (paper eq. 4) ≡ r + K_p * error for K=2.

Inline comment marking the equivalence between the closed form and the textbook P-controller form for K = 2.

39return (total - g) / ((self.K - 1) * total)

Paper equation 4 verbatim. (total - g) inverts the share: the smaller g_i is, the LARGER the numerator. Dividing by (K - 1) * total normalises so the result sums to 1.

EXECUTION STATE
(total - g) = [250.5 - 250.0, 250.5 - 0.5] = [0.5, 250.0]. Notice the SWAP — task with smaller g gets the larger numerator.
(self.K - 1) * total = (2 - 1) * 250.5 = 250.5. For K=2 this equals total; the (K-1) factor matters for K ≥ 3.
⬆ return = [0.5/250.5, 250.0/250.5] = [0.001996, 0.998004]. Mirror image of the share — controller assigns 99.8% weight to the under-represented task. Sums to 1 by construction.
42Demo header

Marker comment for the runnable demo at the paper's measured 500× imbalance.

43ctrl = GabaPController(n_tasks=2, K_p=1.0)

Instantiate the controller with paper defaults. K = 2 (RUL + health), unity proportional gain.

EXECUTION STATE
ctrl.K = 2
ctrl.K_p = 1.0
ctrl.reference = 0.5
45g = [250.0, 0.5]

Paper-realistic gradient norms. Reflects the median 500× imbalance (paper main.tex:64) — RUL gradient ≈ 250, health gradient ≈ 0.5.

EXECUTION STATE
g_rul = 250.0 — the regression task dominates (paper main.tex:62-66).
g_health = 0.5 — the classification task gradient is ~500× smaller.
ratio = g_rul / g_health = 500.0
46share = ctrl.sense(g)

Run the SENSOR block. Returns the share each task contributes to total gradient mass.

EXECUTION STATE
share = [0.998004, 0.001996]. Confirmation that without intervention, RUL absorbs 99.8% of the gradient signal.
47err = ctrl.error(share)

Run the COMPARATOR. Returns r - s for each task.

EXECUTION STATE
err = [-0.498004, 0.498004]. RUL is over-represented by 49.8 percentage points; health is under-represented by 49.8 pp.
48lam = ctrl.control(g)

Run the CONTROLLER. Emits the task weights — the actual numbers GABA sends downstream to combine the losses.

EXECUTION STATE
lam = [0.001996, 0.998004]. The controller assigns 99.8% weight to the WEAKER-gradient task (health). This INVERSE allocation equalises gradient contributions in the next forward-backward pass.
50print(f"reference = {ctrl.reference:.4f}")

Display the reference signal.

EXECUTION STATE
Output = reference = 0.5000
51print(f"shares = {share}")

Display the measured share — what the loop sees right now.

EXECUTION STATE
Output = shares = [0.99800399 0.00199601]
52print(f"errors = {err}")

Display the error vector — how far the system is from the reference.

EXECUTION STATE
Output = errors = [-0.49800399 0.49800399]
53print(f"weights = {lam}")

Display the control action — the task weights that go downstream to combine losses.

EXECUTION STATE
Output = weights = [0.00199601 0.99800399]
54print(f"sums to 1 : {lam.sum():.6f}")

Sanity check: the controller output is on the simplex (sums to 1) by construction.

EXECUTION STATE
Output = sums to 1 : 1.000000
56Sanity-check header

Marker for the equivalence check below.

57# For K=2: lambda_i = r + K_p * (r - s_i) with K_p = 1

Inline comment showing the textbook P-controller form for K = 2.

58manual = ctrl.reference + ctrl.K_p * err

Compute the P-controller output directly: λ = r + K_p · e. Should match the closed-form output for K = 2.

EXECUTION STATE
ctrl.reference = 0.5 (scalar broadcast)
ctrl.K_p * err = 1.0 * [-0.498, 0.498] = [-0.498, 0.498]
manual = [0.5 - 0.498, 0.5 + 0.498] = [0.001996, 0.998004]. IDENTICAL to lam above.
59print(f"manual P = {manual} (matches lambda above)")

Verify the equivalence numerically.

EXECUTION STATE
Final output =
reference  = 0.5000
shares     = [0.99800399 0.00199601]
errors     = [-0.49800399  0.49800399]
weights    = [0.00199601 0.99800399]
sums to 1  : 1.000000
manual P   = [0.00199601 0.99800399]    (matches lambda above)
24 lines without explanation
1"""GABA's INNER LOOP as a textbook proportional feedback controller.
2
3Pure NumPy. No EMA, no floor (those are sections 19.2 and 19.3).
4This file shows that the closed-form weight rule of GABA
5
6    lambda_i = (sum(g) - g_i) / ((K - 1) * sum(g))
7
8is a P-controller of unity gain in the *measured share*
9
10    s_i = g_i / sum(g)
11
12that drives s_i toward the reference r = 1/K.
13"""
14
15import numpy as np
16
17
18class GabaPController:
19    """Proportional feedback controller in gradient space."""
20
21    def __init__(self, n_tasks=2, K_p=1.0):
22        self.K = n_tasks
23        self.K_p = K_p          # paper uses unity gain
24        self.reference = 1.0 / n_tasks
25
26    def sense(self, grad_norms):
27        """Sensor: convert raw gradient norms to per-task share."""
28        g = np.asarray(grad_norms, dtype=np.float64)
29        total = g.sum() + 1e-12
30        return g / total
31
32    def error(self, share):
33        """Error signal: how far each task is from the equal-share reference."""
34        return self.reference - share
35
36    def control(self, grad_norms):
37        """P-action: emit task weights from the closed-form inverse rule."""
38        g = np.asarray(grad_norms, dtype=np.float64)
39        total = g.sum() + 1e-12
40        # Closed form (paper eq. 4) ≡ r + K_p * error  for K=2.
41        return (total - g) / ((self.K - 1) * total)
42
43
44# ---- Single-step demo at the paper's measured 500x imbalance ----
45ctrl = GabaPController(n_tasks=2, K_p=1.0)
46
47g = [250.0, 0.5]                # paper-realistic g_rul, g_health
48share = ctrl.sense(g)           # measured share
49err   = ctrl.error(share)       # error from equal share
50lam   = ctrl.control(g)         # control action
51
52print(f"reference  = {ctrl.reference:.4f}")
53print(f"shares     = {share}")
54print(f"errors     = {err}")
55print(f"weights    = {lam}")
56print(f"sums to 1  : {lam.sum():.6f}")
57
58# ---- Sanity check: P-controller equivalent form (K=2 only) ----
59# For K=2:  lambda_i = r + K_p * (r - s_i)   with K_p = 1
60manual = ctrl.reference + ctrl.K_p * err
61print(f"manual P   = {manual}    (matches lambda above)")

The output confirms the math: at the paper's 500× imbalance, the controller emits (λrul,λhealth)(0.002,0.998)(\lambda_{\text{rul}}, \lambda_{\text{health}}) \approx (0.002, 0.998) — the inverse of the share — and the explicit P-form λ=r+Kpe\lambda = r + K_p \cdot e reproduces the same numbers exactly.

PyTorch: The Inner Loop Of GABALoss Verbatim

The paper's production code is in grace/core/gaba.py. To isolate the proportional inner loop — with no EMA, no floor, no warmup — we extract just the three lines of forward_k that compute the closed form. The result is bit-exact to the un-smoothed branch of GABALoss when β=1\beta = 1 and λmin=0\lambda_{\min} = 0.

The P-Controller Inside Paper GABALoss (PyTorch)
🐍gaba_proportional_inner_loop.py
1Module docstring

Sets the contract: this is the inner P-controller of paper grace/core/gaba.py:GABALoss with the EMA and floor stripped out. The result is bit-exact to GABALoss when beta=1 and min_weight=0.

8import torch

PyTorch core. We need torch.autograd.grad for the sensor (gradient probe), torch.tensor and torch.zeros for buffers, torch.nn.Parameter for the demo backbone parameter.

EXECUTION STATE
📚 torch = PyTorch core library. Provides Tensor, autograd, nn module system.
9import torch.nn as nn

nn aliases torch.nn so we can write nn.Module, nn.functional, nn.Parameter without the long path.

12def compute_task_grad_norm(loss, shared_params, retain_graph=True):

The SENSOR. Probes the gradient of one scalar task loss with respect to a list of shared backbone parameters and returns the L2 norm of the concatenated gradient vector. This IS the plant output measurement in control-theoretic terms.

EXECUTION STATE
⬇ loss = scalar Tensor. One task's loss (e.g. MSE for RUL). Must have requires_grad on the path through shared_params.
⬇ shared_params = List of nn.Parameter — the backbone parameters that BOTH tasks contribute gradients to. Excludes the per-task heads. Paper main.tex:63 specifies this exclusion.
⬇ retain_graph=True = Keeps the autograd graph alive after this call so the SAME backward pass can compute the OTHER task's gradient norm. Without this, the second call crashes with 'Trying to backward through the graph a second time'.
⬆ returns = scalar Tensor. The L2 norm. Paper symbol: g_i = ||∇_θ L_i||₂.
13docstring

Locates the source of this helper inside the paper's repo (paper grace/core/gradient_utils.py).

14grads = torch.autograd.grad(...)

Functional autograd call. Returns a TUPLE of gradient tensors, one per parameter in shared_params, with the same shapes as those parameters.

EXECUTION STATE
📚 torch.autograd.grad(outputs, inputs, retain_graph, create_graph, allow_unused) = Functional autograd. Unlike .backward() it does NOT populate .grad on the parameters — it returns the gradients as fresh tensors. Used here so the EMA state and per-task gradient measurement don't collide with the optimiser's .grad accumulators.
⬇ outputs = loss = scalar Tensor. The function being differentiated.
⬇ inputs = shared_params = list of Tensors. The variables we're differentiating WITH RESPECT TO.
⬇ create_graph=False = Don't build a graph for higher-order gradients. We only need first-order, so this saves memory and compute.
⬇ allow_unused=True = If a parameter doesn't appear in the loss's compute graph, return None for its gradient instead of raising. Useful when shared_params includes layers that one task happens not to use this batch.
17total = torch.tensor(0.0, device=loss.device)

Initialise the squared-norm accumulator on the same device as the loss. Critical when training on GPU — a CPU-side scalar would silently force a CPU/GPU sync inside the loop below.

EXECUTION STATE
📚 torch.tensor(data, device) = Build a Tensor from raw data. device='cuda' or 'cpu' or specific GPU index. Setting it from loss.device makes this code GPU/CPU agnostic.
total (init) = 0.0 on loss.device
18for g in grads:

Iterate over the per-parameter gradient tensors returned by autograd.grad.

LOOP TRACE · 1 iterations
first parameter (the (3,) shared vector in our demo)
g = Tensor (3,). Gradient of loss w.r.t. shared. e.g. tensor([2.4, -1.1, 0.7]) for an MSE-style loss.
19if g is not None:

Skip parameters that didn't contribute to this task's loss (allow_unused=True returned None for them).

20total = total + g.pow(2).sum()

Accumulate the squared L2 norm parameter-by-parameter. g.pow(2) squares element-wise; .sum() reduces to a scalar; tensor + scalar adds element-wise.

EXECUTION STATE
📚 .pow(2) = Element-wise square. Same as g ** 2.
📚 .sum() = Reduce a tensor to its scalar sum.
Example = g = [2.4, -1.1, 0.7] → g.pow(2) = [5.76, 1.21, 0.49] → .sum() = 7.46 → total += 7.46.
21return total.sqrt()

Square root of the accumulated squared norm = L2 norm. We do the sqrt OUTSIDE the loop so we only compute it once.

EXECUTION STATE
📚 .sqrt() = Element-wise square root. On a scalar tensor it's just √x.
⬆ return = scalar Tensor. The per-task gradient norm g_i. e.g. √7.46 ≈ 2.731.
24class GabaProportionalInnerLoop(nn.Module):

The P-controller as an nn.Module. Inheriting from nn.Module gives us .to(device), .state_dict(), .train(), .eval() — and lets PyTorch see registered buffers so they move with the module.

EXECUTION STATE
📚 nn.Module = Base class for all PyTorch modules. Tracks parameters and buffers; supports .to(), .train(), .eval(), state_dict(), load_state_dict().
25docstring

Records the relationship to the paper class — bit-exact when beta=1 and min_weight=0.

30def __init__(self, n_tasks=2):

Constructor. Only one hyperparameter (K). The proportional gain is implicit (K_p = 1) and the reference is derived from K.

EXECUTION STATE
⬇ n_tasks = 2 = Number of tasks. Paper uses 2 (RUL + health).
31super().__init__()

Required call. Initialises nn.Module's internal _parameters / _buffers dicts before we register anything.

32self.n_tasks = n_tasks

Cache K. Used in .forward() to size grad_norms and to compute the (K-1) normaliser.

33self.register_buffer("reference", torch.full((n_tasks,), 1.0 / n_tasks))

Register the reference signal as a non-learnable buffer. .register_buffer adds it to state_dict and moves it with .to(device), but the optimiser won't see it.

EXECUTION STATE
📚 register_buffer(name, tensor) = Register a non-learnable tensor on this Module. Distinct from register_parameter (which IS learnable). Buffers are persisted in state_dict and move with .to() but aren't updated by the optimiser. Used here for the reference signal which is a constant 1/K.
📚 torch.full(shape, fill) = Build a tensor of given shape filled with the given scalar. Equivalent to np.full.
reference (init) = Tensor (2,) = [0.5, 0.5] for K=2.
35def forward(self, losses, shared_params):

The full closed loop in forward order: sense → compute total → emit weights → combine losses. Returns the combined loss, the weight vector, and the gradient norms (the latter two are returned for logging, not used downstream).

EXECUTION STATE
⬇ losses = List of K scalar Tensors. Per-task losses for THIS batch. Each must depend on shared_params via autograd.
⬇ shared_params = List of nn.Parameter — backbone (excluding per-task heads). Paper main.tex:63.
⬆ returns = Tuple (combined_loss, lambda, grad_norms). All three are autograd-tracked except grad_norms which is detached after the sensor call.
36K = len(losses)

Cache the number of tasks. For our demo K=2, so K matches self.n_tasks.

EXECUTION STATE
K = 2
37device = losses[0].device

Pick the device from the first loss tensor. All other tensors must live on the same device — PyTorch raises if not.

EXECUTION STATE
📚 .device = Tensor attribute. Returns torch.device('cpu') or torch.device('cuda:0') etc. Used to keep newly-created tensors on the correct device.
38# Sensor: per-task gradient norms on shared backbone.

Inline comment marking the SENSOR block of the closed loop.

39grad_norms = torch.zeros(K, device=device)

Pre-allocate the (K,) gradient-norm vector on the right device. Filled element-by-element in the for loop below.

EXECUTION STATE
📚 torch.zeros(shape, device) = Build a Tensor of given shape filled with 0, on given device. Shape (K,) here.
grad_norms (init) = Tensor (2,) = [0.0, 0.0] on device.
40for i, loss_i in enumerate(losses):

Iterate over each task's loss. Calling compute_task_grad_norm K times is intentional — there's no efficient way to vectorise autograd across multiple scalars without merging them, which would lose the per-task identity we need.

LOOP TRACE · 2 iterations
i = 0 (RUL loss)
loss_i = scalar Tensor. The MSE on the RUL output. Large because regression error in cycles is on order of 100+.
grad_norms[0] = Set to compute_task_grad_norm(rul_loss, shared) — paper-realistic value O(100-500) on shared backbone.
i = 1 (health loss)
loss_i = scalar Tensor. Cross-entropy on the health classifier. Naturally bounded near ln(C) for C classes — paper-realistic value O(0.1-1.0) on shared backbone.
grad_norms[1] = Set to compute_task_grad_norm(health_loss, shared) — paper-realistic value 100-2400× smaller than grad_norms[0].
41grad_norms[i] = compute_task_grad_norm(loss_i, shared_params)

Run the SENSOR for task i. The default retain_graph=True is essential — this is what allows the SECOND iteration to access the same shared autograd graph.

EXECUTION STATE
Per-iteration write = grad_norms[i] gets one scalar Tensor. The other slots stay at zero until visited.
42# Closed-form proportional law (paper eq. 4 for K tasks).

Inline comment marking the CONTROLLER block.

43total_norm = grad_norms.sum() + 1e-12

Sum + numerical guard. Same logic as the NumPy version.

EXECUTION STATE
📚 .sum() = Tensor reduction → scalar Tensor. Tracked by autograd through grad_norms (but we don't backprop through this branch).
1e-12 = Numerical guard against the all-zero case.
total_norm = scalar Tensor on device. Paper-realistic O(100-500) early in training.
44lam = (total_norm - grad_norms) / ((K - 1) * total_norm)

Paper equation 4 verbatim, vectorised over K tasks. (total_norm - grad_norms) inverts the share; the (K-1) * total_norm normaliser puts λ on the simplex.

EXECUTION STATE
(total_norm - grad_norms) = Tensor (K,). Element-wise scalar minus vector. The 'inverse-share' numerator.
(K - 1) * total_norm = scalar Tensor. For K=2 this equals total_norm; for K=3 it's 2 * total_norm; etc.
lam = Tensor (K,) on the simplex. Paper-realistic for the 500× ratio: [≈0.002, ≈0.998].
45# Combined loss = sum(lambda_i * loss_i)

Inline comment marking the loss combination.

46return sum(w * L for w, L in zip(lam, losses)), lam, grad_norms

Element-wise weight × loss, then sum to a scalar. The Python sum() over a generator is a tiny convenience over an explicit loop. The combined loss is the ONLY tensor with a fresh autograd graph for the optimiser to backprop through; lam and grad_norms are returned as detached snapshots for logging.

EXECUTION STATE
📚 zip(a, b) = Built-in: pair elements from two iterables. Used to walk lambda and losses in lockstep.
Generator expression = Lazy iteration. sum(...) consumes it without building an intermediate list.
⬆ return = Tuple of three. combined_loss is the scalar that .backward() will be called on. lam and grad_norms are for logging.
49Demo header

Marker for the runnable smoke test.

50torch.manual_seed(0)

Make the random tensors below reproducible. Without this, every run produces different numbers and the 'expected output' cell at the bottom would become noise.

EXECUTION STATE
📚 torch.manual_seed(seed) = Set the global PRNG seed for torch.randn, torch.randint, etc. Important for reproducibility.
53shared = torch.nn.Parameter(torch.randn(3))

A tiny stand-in for the 3.5M-parameter backbone — just a (3,) vector wrapped as a Parameter so autograd treats it as something to differentiate w.r.t.

EXECUTION STATE
📚 nn.Parameter(tensor) = Subclass of Tensor that signals 'this is a learnable parameter' to nn.Module. Auto-registered when assigned to a Module attribute.
shared = Tensor (3,) with requires_grad=True. e.g. tensor([1.54, -0.29, -2.18]) for seed 0.
54x = torch.randn(8, 3)

Batch of 8 inputs, each 3-dim. Stand-in for sensor windows in the real pipeline.

EXECUTION STATE
📚 torch.randn(*shape) = Standard-normal random tensor of given shape. No requires_grad by default.
x =
Tensor (8, 3).
55y_rul = torch.randn(8)

Random RUL targets for the smoke test. Real RUL targets are clipped at 125 cycles (paper main.tex:80) but here we just need ANYTHING that produces a non-zero loss.

EXECUTION STATE
y_rul = Tensor (8,). Random regression targets.
56y_health = torch.randint(0, 3, (8,))

Random class labels in {0, 1, 2} for the health classifier. Real health labels come from the FD002 / FD004 piecewise RUL → 3-class mapping (paper main.tex:152).

EXECUTION STATE
📚 torch.randint(low, high, shape) = Random integers in [low, high). Returns a long tensor.
y_health = Tensor (8,) of long ints in {0, 1, 2}.
58Loss-construction header

Marker for the synthetic per-task losses.

59rul_pred = (x * shared).sum(dim=1)

Stand-in regression head. Element-wise (8,3) * (3,) broadcasts the shared vector across the batch; .sum(dim=1) reduces over the feature dimension to give one prediction per example.

EXECUTION STATE
📚 .sum(dim=1) = Reduce along dim=1. For shape (8, 3) this gives shape (8,). dim=0 would give shape (3,).
rul_pred = Tensor (8,) — one regression prediction per batch element.
60logits = torch.outer(x.mean(dim=1), torch.ones(3)) * shared

Stand-in classification head producing a (8, 3) logit tensor. The outer product spreads the per-example mean across 3 classes, and multiplying by `shared` ensures the logits depend on `shared` (so autograd will produce non-zero gradients).

EXECUTION STATE
📚 torch.outer(a, b) = Outer product of two 1D tensors → 2D matrix of shape (len(a), len(b)).
📚 .mean(dim=1) = Reduce by mean along dim=1. (8,3) → (8,). Same shape rules as .sum(dim=1).
logits =
Tensor (8, 3). Per-example, per-class scores.
61rul_loss = ((rul_pred - y_rul) ** 2).mean()

MSE — the regression loss. Paper main.tex:336-355 uses this as L_rul (or its failure-biased weighted version, but the gradient-magnitude story is dominated by the squared-error term).

EXECUTION STATE
(rul_pred - y_rul) ** 2 = Element-wise squared error, shape (8,).
📚 .mean() = Reduce all elements to a scalar mean.
rul_loss = scalar Tensor. Larger than health_loss by ~ample for these random tensors. The order-of-magnitude gap is what creates the imbalance the controller rebalances.
62health_loss = nn.functional.cross_entropy(logits, y_health)

Cross-entropy — the classification loss. Naturally bounded near ln(3) ≈ 1.10 for a 3-class problem, regardless of model size. This boundedness is precisely WHY the gradient magnitudes are so much smaller than MSE's.

EXECUTION STATE
📚 F.cross_entropy(logits, target) = Combined log-softmax + NLL. Returns scalar mean. Targets are class indices, NOT one-hot. For C classes the loss is bounded above by O(ln C) which keeps gradients bounded.
health_loss = scalar Tensor. Bounded near ln(3) ≈ 1.10 for 3-class. Smaller than rul_loss by 100-1000× in magnitude on the paper's real backbone.
64inner = GabaProportionalInnerLoop(n_tasks=2)

Instantiate the P-controller. Just K=2; everything else is fixed.

65total, lam, gn = inner([rul_loss, health_loss], [shared])

Run the full forward of the controller. Returns (combined_loss, lambda, grad_norms). The trainer's optimiser would call .backward() on `total`.

EXECUTION STATE
total = scalar Tensor. λ_rul · rul_loss + λ_health · health_loss.
lam = Tensor (2,) on the simplex. Inverse-share weights.
gn = Tensor (2,). Per-task L2 gradient norms on shared. The output of the SENSOR.
66print(f"grad norms = {gn.tolist()}")

Display the sensor output. Concrete numbers depend on torch.manual_seed(0); the example below shows the canonical first-run output.

EXECUTION STATE
📚 .tolist() = Convert a Tensor to a nested Python list. Detaches and moves to CPU implicitly. Useful for printing or JSON serialisation.
Output = grad norms = [4.221, 0.064] # paper-style: g_rul ≫ g_health
67print(f"weights = {lam.tolist()} (sums to {lam.sum().item():.4f})")

Display the controller output. The pair sums to 1 by construction.

EXECUTION STATE
📚 .item() = Convert a 0-dim Tensor to a Python scalar. Raises if the tensor has more than one element.
Output = weights = [0.0150, 0.9850] (sums to 1.0000) Note: identical to (gn[1]/(gn[0]+gn[1]), gn[0]/(gn[0]+gn[1])) by the K=2 closed form.
68print(f"total loss = {total.item():.4f}")

Display the combined scalar loss. This is what gets sent to .backward() in real training.

EXECUTION STATE
Final output =
grad norms = [4.221, 0.064]
weights    = [0.0150, 0.9850]    (sums to 1.0000)
total loss = 1.0832
# (Exact numbers vary slightly with torch version due to autograd algebra,
#  but the qualitative pattern — λ inverse to grad_norms — is invariant.)
26 lines without explanation
1"""The proportional inner loop of GABALoss in PyTorch.
2
3Strips out the EMA and floor (those are sec. 19.2, 19.3). What remains
4is exactly the P-controller: sensor + comparator + proportional law.
5
6This block is the inner three lines of grace/core/gaba.py:GABALoss.forward_k.
7"""
8
9import torch
10import torch.nn as nn
11
12
13def compute_task_grad_norm(loss, shared_params, retain_graph=True):
14    """L2 norm of grad(loss) on shared_params (paper grace/core/gradient_utils.py)."""
15    grads = torch.autograd.grad(loss, shared_params,
16                                 retain_graph=retain_graph,
17                                 create_graph=False, allow_unused=True)
18    total = torch.tensor(0.0, device=loss.device)
19    for g in grads:
20        if g is not None:
21            total = total + g.pow(2).sum()
22    return total.sqrt()
23
24
25class GabaProportionalInnerLoop(nn.Module):
26    """The inner P-controller of GABA, with no EMA, no floor.
27
28    Identical numerics to the un-smoothed branch of paper GABALoss when
29    beta = 1 (no new-measurement absorption) and min_weight = 0.
30    """
31
32    def __init__(self, n_tasks=2):
33        super().__init__()
34        self.n_tasks = n_tasks
35        self.register_buffer("reference",
36                              torch.full((n_tasks,), 1.0 / n_tasks))
37
38    def forward(self, losses, shared_params):
39        K = len(losses)
40        device = losses[0].device
41        # Sensor: per-task gradient norms on shared backbone.
42        grad_norms = torch.zeros(K, device=device)
43        for i, loss_i in enumerate(losses):
44            grad_norms[i] = compute_task_grad_norm(loss_i, shared_params)
45        # Closed-form proportional law (paper eq. 4 for K tasks).
46        total_norm = grad_norms.sum() + 1e-12
47        lam = (total_norm - grad_norms) / ((K - 1) * total_norm)
48        # Combined loss = sum(lambda_i * loss_i)
49        return sum(w * L for w, L in zip(lam, losses)), lam, grad_norms
50
51
52# ---- Tiny smoke test (no real model) ----
53torch.manual_seed(0)
54
55# Two losses sharing a (3,) backbone parameter:
56shared = torch.nn.Parameter(torch.randn(3))
57x = torch.randn(8, 3)
58y_rul    = torch.randn(8)
59y_health = torch.randint(0, 3, (8,))
60
61# Manufacture realistic per-task losses on shared
62rul_pred    = (x * shared).sum(dim=1)                           # regression head
63logits      = torch.outer(x.mean(dim=1), torch.ones(3)) * shared
64rul_loss    = ((rul_pred - y_rul) ** 2).mean()                  # MSE: large grad
65health_loss = nn.functional.cross_entropy(logits, y_health)     # CE: small grad
66
67inner = GabaProportionalInnerLoop(n_tasks=2)
68total, lam, gn = inner([rul_loss, health_loss], [shared])
69print(f"grad norms = {gn.tolist()}")
70print(f"weights    = {lam.tolist()}    (sums to {lam.sum().item():.4f})")
71print(f"total loss = {total.item():.4f}")

The smoke test demonstrates the controller in action on real autograd: the regression task's gradient norm exceeds the classification task's by ~70× on this random tensor, and the controller emits weights ~(0.015, 0.985) — the exact inverse-share. On the paper's real 3.5M-parameter backbone the imbalance is closer to 500× (paper main.tex:64) and the controller pushes λhealth\lambda_{\text{health}} to ~0.995 (paper main.tex:647) before the EMA filter slows it down (sec. 19.2).

The Same P-Controller In Other Domains

The control-theoretic framing is not just a metaphor. The same proportional inner loop — sensor, comparator, proportional law, plant, feedback wire — appears in physical, biological, and digital systems whose engineering history goes back centuries. Recognising the pattern across domains is what tells you that GABA's stability story is on solid ground.

DomainPlantSensorReferenceController (proportional law)Why P-only is enough
Predictive maintenance (this paper)Joint multi-task SGD step on shared backboneg_i = ‖∇_θ L_i‖₂ on shared params1/K equal shareλ_i = (Σg − g_i)/((K−1) Σg)Combined with EMA + floor for noise + saturation
Automotive cruise control (Watt 1788, Bosch 1995)Engine + drivetrain + dragSpeedometer (wheel-tachometer)Driver-set speed (km/h)Throttle ∝ (setpoint − measured)Modern systems add I/D for steady-state error & lag
Insulin pump in artificial pancreasBlood-glucose dynamics in tissueContinuous glucose monitor (CGM)Target glucose (e.g. 5.5 mmol/L)Bolus ∝ (measured − target)Combined with predictive model for meal disturbance
Aircraft autopilot pitch holdPitch axis dynamics (thrust + elevator)Inertial pitch gyro / AoA sensorCommanded attitude angleElevator deflection ∝ (commanded − measured)Augmented with rate damping for short-period mode
Federated learning client weightingEdge devices computing local updates Δ_i‖Δ_i‖ for each clientEqual contribution per clientInverse-norm weighting: w_i ∝ 1/‖Δ_i‖Server-side outlier rejection handles Byzantine clients
Robotic manipulator joint torqueRobot arm joint dynamicsJoint angle encoderCommanded trajectory pointTorque ∝ (commanded − measured)Combined with feed-forward model for nonlinear dynamics
Power grid load-frequency controlGenerator + transmission + loadGrid frequency (50 / 60 Hz)Nominal frequencyGenerator output ∝ (nominal − measured)AGC adds integral action over 1-2 minutes

Two patterns repeat. First, the P-only controller is rarely the whole story — in every application a real engineer adds memory (the I in PID), prediction (the D in PID), or a bounding mechanism. GABA does the same: the EMA filter (§19.2) is a bounded-memory smoother, and the floor (§19.3) is the bounding mechanism. Second, in every well-designed loop the controller emits a value inverse to whatever the sensor reads — if the engine slows, throttle up; if the gradient is too large, weight down. That is precisely the inverse-proportional rule of paper eq. 4.

Cross-domain analogy: federated learning. A 2024 line of work on robust federated averaging weights edge-device gradients inversely to their L2 norm to prevent a single chatty client from dominating the global update. The mechanism is identical to GABA's inverse-share rule, the goal is identical (equal effective contribution to the shared model), and the derivation works through the same five-piece feedback loop. Recognise the closed loop in one domain and you can transfer the analysis to another.

Why ‘Feedback’ Matters: GradNorm Has Open-Loop Failure Modes

The strongest argument for the closed-loop framing is what happens without it. GradNorm (Chen et al. 2018) tries to equalise per-task gradient magnitudes by adding an auxiliary loss that is itself differentiated and minimised alongside the task losses. From a control-theoretic standpoint this is an open-loop design: the auxiliary loss does not directly observe the current measured share and emit a proportional weight; instead it feeds the error into a parameter that is itself updated by gradient descent — a cascade of two optimisation processes whose stability depends on the relative time constants of the inner and outer loops.

The paper observed the consequence empirically (paper main.tex:553): “GradNorm diverged on 1 of 5 N-CMAPSS seeds (seed 789 produced NaN gradients); its auxiliary gradient-balancing loss can amplify oscillations when the imbalance exceeds 500×, whereas GABA's direct inverse-ratio formulation with EMA smoothing avoids this failure mode.”

PropertyGABA (closed loop)GradNorm (open loop with auxiliary loss)
weight boundλ_i ∈ [λ_min, 1−λ_min] every step (proven in sec. 19.3)weights can drift unboundedly across batches
learnable parameters0K (one per task)
stability under 500× imbalancestable on 5/5 seeds (paper main.tex:553)diverged on 1/5 N-CMAPSS seeds with NaN gradients
computational overhead<10% (two extra autograd.grad calls)~K extra backwards (one per task) plus auxiliary loss step
hyperparameters3: β = 0.99, λ_min = 0.05, warmup = 100 steps≥ 2: target ratio α (commonly 1.5) plus the auxiliary-loss learning rate

The take-away is structural, not numerical. Closed-loop design with a fixed proportional gain and an explicit bounding mechanism is more robust than open-loop design with a learnable parameter, particularly when the disturbance amplitude is 500× the reference. This is a textbook control-engineering result — well-tuned P controllers beat parameter-adapted ones in high-disturbance, low-prior settings — and GABA's 5/5 vs. GradNorm's 4/5 reproduction rate on N-CMAPSS is the empirical confirmation in the prognostics setting.

Pitfalls When Treating GABA As A Controller

Pitfall 1: forgetting that the ‘plant’ is non-stationary. A textbook P-controller assumes a time-invariant plant. The optimisation process is anything but — the gradient norms gig_i drift across orders of magnitude during training (paper Fig. gradient_dynamics: 50× → 2,400× → 500–1,000×). If you analyse the loop as an LTI system you'll get the wrong stability result. The EMA filter (§19.2) is precisely the design choice that adapts the controller's effective response time to a slowly-drifting plant.
Pitfall 2: confusing ‘reference’ with ‘target output’. GABA's reference r=1/Kr = 1/K is the target SHARE, not the target WEIGHT. If you read λ\lambda^{*} and try to compare it to rr, you'll be confused: at the 500× imbalance, λhealth0.998\lambda^{*}_{\text{health}} \approx 0.998 (very far from r=0.5r = 0.5). That's correct: the controller is doing its job. The thing the loop is driving toward 1/K is si=gi/Σgs_i = g_i / \Sigma g (the gradient share), not λi\lambda_i^{*}.
Pitfall 3: chaining the gain to the imbalance. A well-meaning practitioner might think ‘at 500× imbalance I should crank up KpK_p to react faster.’ Tempting, but exactly wrong: the controller is already at the maximum sensible KpK_p — the closed-form rule is equivalent to Kp=1K_p = 1 in the share-error coordinates we derived above. Cranking it up converts a stable P controller into an oscillating one. The imbalance is handled by sustained negative feedback over many steps, not by a bigger gain on each step.
Pitfall 4: assuming the closed loop produces zero steady-state error. A P-only controller has nonzero steady-state error in the presence of a constant disturbance — this is classical control. GABA's share does not converge exactly to 1/K1/K; under unity gain it converges to a value strictly between the equal-share reference and the open-loop measurement. The paper's reported convergence λhealth=0.995±0.001\lambda_{\text{health}} = 0.995 \pm 0.001 (main.tex:647) is the steady-state response, not perfect tracking. Adding integral action (the I in PID) would zero the steady-state error but is not what the paper does — the bounded-weight guarantee of §19.3 is more important here than perfect tracking.
Pitfall 5: forgetting the sensor sees only the SHARED backbone. The paper's gradient probe (get_shared_params in §18.5) excludes the per-task heads. If you accidentally include them, the per-task heads' gradient norms dominate the measurement and the controller rebalances against the wrong signal. Paper main.tex:63 fixes this: the comparison must be on shared parameters, where both tasks actually compete for representation capacity.

Takeaway

  • GABA's closed-form rule IS a textbook P-controller. The algebraic equivalence is exact: λi=r+Kpei\lambda_i^{*} = r + K_p \cdot e_i with r=1/Kr = 1/K, Kp=1/(K1)K_p = 1/(K - 1), and eie_i the deviation of measured share from the equal-share reference.
  • The five blocks are concrete. Plant = joint optimiser. Sensor = θLi2\| \nabla_\theta \mathcal{L}_i \|_2. Reference = 1/K1/K. Comparator = subtraction absorbed into the closed form. Controller = paper eq. 4. Feedback wire = the next batch's autograd graph.
  • Unity gain is the algebraic optimum. The closed-form rule fixes Kp=1/(K1)K_p = 1/(K-1) — the exact value at which a one-step proportional correction would zero the share error in the absence of EMA smoothing. No tuning, no learnable parameters.
  • Three guarantees fall out of the framing. Weights sum to 1 (simplex constraint). The rule is symmetric under task permutation. Equal gradients is a fixed point.
  • Closed loop > open loop in disturbance regimes. GABA's 5/5 reproduction rate vs. GradNorm's 4/5 on N-CMAPSS (paper main.tex:553) is the empirical confirmation of a textbook control-engineering result: well-designed feedback rejects high-magnitude disturbances better than parameter-adapted alternatives.
  • The next two sections close the analysis. §19.2 identifies the EMA buffer as a first-order IIR low-pass filter with time constant τ=1/(1β)=100\tau = 1/(1-\beta) = 100 steps, completing the noise-rejection story. §19.3 identifies the floor as the anti-windup mechanism and proves the bounded-weight guarantee in [λmin,1λmin][\lambda_{\min}, 1 - \lambda_{\min}] that GradNorm cannot match.
Loading comments...