Walk into a modern car's ABS controller and you find, before any control logic, a wheel-speed sensor sampled many times per second. The controller can't modulate brake pressure unless it MEASURES wheel speed first. Throw the sensor away and the algorithm is just an opinion.
GABA has the same anatomy. The closed form λi∗=gj/(gi+gj) is the controller. Its sensor is the per-task gradient norm gi=∥∇θsLi∥2 on the shared backbone parameters θs. Every training step starts with a measurement, then applies the rule. This section is about the measurement: how to compute gi exactly, cheaply, and without contaminating the rest of the optimisation.
The headline. The paper's production utility compute_task_grad_norm in grace/core/gradient_utils.py does this in 12 lines: torch.autograd.grad with create_graph=False, sum of squared per-parameter norms, then sqrt. Same algorithm, different scale: a 480-parameter toy backbone in this section reproduces the paper's 500× imbalance.
What ‘Per-Step’ Means In GABA
The GABA algorithm (paper Algorithm 1) is a control loop running once per training step. The four stages, in order:
Measure. Compute gi=∥∇θsLi∥2 for each task. THIS SECTION.
Compute. Apply the closed form λi=(S−gi)/((K−1)S) where S=∑jgj (§17.3 derivation).
Stabilise. EMA-smooth (§18.2), apply the floor (§18.3), maybe pass through warmup (§18.4).
Combine and step. Form L=∑iλi∗Li and let the main optimiser take one step.
Stages 1 and 4 both involve gradients on the same model, but they are NOT the same computation. Stage 1 needs the SCALAR gi per task; the actual gradient tensors get thrown away. Stage 4 needs the actual gradient of the combined loss applied to all parameters (including heads). The rest of this section makes Stage 1 explicit.
Selecting Shared Parameters
The first decision is: which parameters does θs actually contain? GABA balances task gradients on the shared backbone, not on every parameter in the model. Including the task heads is wrong: each head's gradient norm is large for its own task and zero for the others, so head parameters artificially inflate one side of the imbalance and deflate the other.
The paper's utility get_shared_params(model, head_names=("rul_head", "health_head")) in grace/core/gradient_utils.py:13 walks model.named_parameters() and excludes any parameter whose name contains a configured head substring. This is robust to PyTorch wrapping (EMA, DataParallel) because named_parameters() emits dotted-path names that retain the head substring.
Why substring matching, not type matching? Type-matching (e.g. ‘exclude all nn.Linear layers’) would also catch the inner backbone layers. Substring on the registered name is a stronger contract because the model author already chose those names to mark roles. The user can pass a different head_names tuple if their architecture differs.
Why The L2 Norm (Not L1 Or L∞)
Three norm choices satisfy ‘a positive scalar that increases with magnitude’:
Norm
Formula
Behaviour
Used by
L1 (Manhattan)
Σ_p Σ_i |g_p,i|
Sum of absolute values. Linear in each element.
Lasso regression, sparsity penalties
L2 (Euclidean) — paper choice
sqrt(Σ_p Σ_i g_p,i^2)
Rotation-invariant, dominated by largest elements.
GABA, GradNorm, Adam normalisation, gradient clipping
L∞ (max)
max_p,i |g_p,i|
Single largest element. Insensitive to mass distribution.
Adversarial robustness budgets
The L2 norm is chosen for three reasons:
Rotation invariance. If you rotate the parameter basis (e.g. SVD reparametrisation), ∥g∥2 does not change. L1 and L∞ do change — they are basis-dependent. Multi-task balance should not be basis-dependent.
Standard practice in optimisation. PyTorch's torch.nn.utils.clip_grad_norm_, Adam's second-moment normaliser, and almost every gradient- based regulariser default to L2. Using the same norm downstream avoids subtle interactions.
Smooth gradient. The L2 norm is differentiable everywhere except at zero (a measure- zero point). L1 has a kink at every coordinate axis; L∞ has many kinks. GABA itself is one-shot so this matters less than for GradNorm, which back-propagates through the norm.
Sum-Of-Squares Equals Concat-Then-Norm
Real backbones store gradients as a list of per-parameter tensors of different shapes (one weight matrix here, one bias vector there). The L2 norm of the full gradient vector g∈RD with D=∑pdim(θp) could be computed by concatenating everything into one big 1-D vector and calling np.linalg.norm:
∥g∥2=∑p∑igp,i2=∑p∥gp∥22
The right-hand side is the paper's implementation choice: accumulate squared per-parameter norms, then take the square root once at the end. The two are numerically identical (we verify to 4 decimals in the Python demo). The sum-of-squares version has two practical advantages:
No materialisation. Concatenation would allocate a fresh D float buffer. For the paper's 3.5M-parameter backbone that is 14 MB of temporary memory per gradient norm. The sum-of-squares version reuses each parameter gradient in place.
Pipelines with the autograd output.torch.autograd.grad returns a tuple of per-parameter gradient tensors already — the sum-of-squares loop walks them directly without first concatenating.
One Forward Pass, K Backward Calls
For K=2 tasks, GABA needs TWO gradient norms per step: grul and ghealth. Naively that is two forward passes — expensive. The standard PyTorch idiom is one forward pass and two torch.autograd.grad calls with retain_graph=True:
The first autograd.grad(rul_loss, ..., retain_graph=True) computes ∇Lrul AND keeps the autograd graph alive.
The second autograd.grad(health_loss, ..., retain_graph=True) re-uses the same forward graph to compute ∇Lhealth. Because the graph is still alive, no second forward pass is needed.
After both calls, the trainer typically calls combined_loss.backward() to actually update the weights. This third pass also re-uses the same graph; on the third call, PyTorch finally frees it.
Without retain_graph=True, the second call crashes. PyTorch frees the autograd graph immediately after the first backward unless told otherwise. The error you would see is: RuntimeError: Trying to backward through the graph a second time. This is a common bug when retrofitting GABA onto an existing trainer.
Interactive: Which Parameters Count?
Toggle individual parameters in or out of the gradient-norm aggregation. The default mirrors the paper's get_shared_params: backbone IN, heads OUT. Watch how including a head parameter inflates one side's norm and changes the resulting GABA λ.
Loading shared-parameter selector visualizer…
Try this. Start with the paper default, then toggle rul_head.weight on. The RUL norm jumps from ∼110 to ∼180 because that one head parameter contributes a 145.2-unit gradient by itself. The health norm doesn't change (the rul_head is detached from the health loss). Now λrul DROPS — you accidentally told GABA that the RUL task was even more dominant than it really is on the shared backbone, which is wrong.
Python: Aggregating Per-Parameter Norms From Scratch
Build both versions of the L2 norm in pure NumPy: the paper's sum-of-squares loop and the reference concat-then-norm. Run them on a 4-tensor synthetic backbone and verify they produce identical answers to 4-decimal precision.
Two equivalent ways to aggregate the L2 norm
🐍grad_norm_sum_of_squares.py
Explanation(46)
Code(60)
1docstring
Module docstring stating the central identity of this section: aggregating per-parameter squared norms and taking sqrt gives the same answer as flattening every gradient into one vector and taking its L2 norm — but the per-parameter version saves memory.
3import numpy as np
NumPy supplies the ndarray, np.random for synthetic gradients, np.concatenate, and np.linalg.norm.
EXECUTION STATE
📚 numpy = Numerical computing library. We use ndarray, np.random.randn, np.concatenate, np.linalg.norm.
5np.random.seed(1)
Fix the PRNG so the synthetic gradients are reproducible. Without this, every run produces different per-parameter norms.
EXECUTION STATE
📚 np.random.seed(s) = Sets the global NumPy PRNG. Affects np.random.randn going forward.
8# A miniature shared backbone
Synthetic gradient list. Mimics the paper's 4-tensor backbone: two weight matrices (W1, W2) and two bias vectors (b1, b2).
9shared_grads_rul = [...]
List of per-parameter gradient tensors for the RUL task. SAME shape as the parameters they were computed from.
EXECUTION STATE
shared_grads_rul = Python list of 4 ndarrays. Total elements: 12 + 3 + 6 + 2 = 23 scalar gradients.
10np.random.randn(3, 4) * 5.0
Standard-normal (3, 4) ndarray scaled by 5.0. Mimics the magnitude of a real RUL gradient on shared backbone weights.
EXECUTION STATE
📚 np.random.randn(*shape) = Sample shape from N(0, 1). Returns a fresh ndarray of that shape.
Same 4 shapes as shared_grads_rul, but scaled by 0.01 instead of 5.0. Mimics the 500x-smaller cross-entropy gradient norm.
EXECUTION STATE
shared_grads_health = Python list of 4 ndarrays. Same shapes as shared_grads_rul.
→ ratio of scales = 5.0 / 0.01 = 500x. Matches the paper's measured imbalance.
16np.random.randn(3, 4) * 0.01
Same shape as RUL's W1 gradient, but scaled small.
17np.random.randn(3) * 0.01
Health bias gradient #1.
18np.random.randn(2, 3) * 0.01
Health weight gradient #2.
19np.random.randn(2) * 0.01
Health bias gradient #2.
23def grad_norm_sum_of_squares(grads) → float
Aggregates per-parameter L2 norms via the squared-sum identity: ||cat(g1,g2,...)||_2^2 = sum_p ||g_p||_2^2. Avoids materialising a single concatenated tensor.
EXECUTION STATE
⬇ input: grads = List of ndarrays of any shape. Per-parameter gradient tensors.
⬆ returns = Float — the L2 norm of the concatenated gradient vector.
24docstring
Records that this is the paper's implementation pattern (matches grace/core/gradient_utils.py:67).
25sq = 0.0
Accumulator for the running sum of squared gradients. Starts at 0.0, will hold sum_p ||g_p||^2 at the end of the loop.
EXECUTION STATE
sq = Float scalar. Initialised to 0.0.
26for g in grads:
Iterate the parameter list. Each g is one per-parameter gradient ndarray. The function is called TWICE in this script (line 46 with shared_grads_rul, line 48 with shared_grads_health), so the loop runs 4+4 = 8 times total. Both sweeps shown below.
context = Triggered by line 48. The accumulator sq is reinitialised to 0.0 (fresh function frame). Loop body runs 4 more times over the 4 HEALTH gradient tensors (each element ~ N(0, 1) × 0.01 — 500× smaller).
iter 0 (health): W1 (3, 4)
g = ndarray (3, 4) — health gradient of W1. Same 12 standard-normal draws as RUL but scaled by 0.01 instead of 5.0 — i.e. (0.01/5.0)² = 4×10⁻⁶ smaller squared norm.
(g**2).sum() = = 0.0005048291
‖g_p‖₂ = sqrt(0.0005048291) = 0.022468
sq after = 0.0 + 0.0005048291 = 0.0005048291
iter 1 (health): b1 (3,)
g = ndarray (3,) — health gradient of b1.
(g**2).sum() = = 0.0001303494
‖g_p‖₂ = sqrt(0.0001303494) = 0.011417
sq after = 0.0005048291 + 0.0001303494 = 0.0006351785
iter 2 (health): W2 (2, 3)
g = ndarray (2, 3) — health gradient of W2.
(g**2).sum() = = 0.0007552907
‖g_p‖₂ = sqrt(0.0007552907) = 0.027483
sq after = 0.0006351785 + 0.0007552907 = 0.0013904692
iter 3 (health): b2 (2,)
g = ndarray (2,) — health gradient of b2.
(g**2).sum() = = 0.0000408345
‖g_p‖₂ = sqrt(0.0000408345) = 0.006390
sq after = 0.0013904692 + 0.0000408345 = 0.0014313037
→ next: line 28 = np.sqrt(0.0014313037) = 0.037833 returned as g_health_sos.
── after both calls ──
g_rul_sos = 26.401602
g_health_sos = 0.037833
ratio = 26.401602 / 0.037833 ≈ 698× per-norm imbalance for this seed.
27sq += (g ** 2).sum()
Accumulate the sum of squared elements of this parameter's gradient.
EXECUTION STATE
📚 (g ** 2) = ndarray element-wise square. Returns a new ndarray of the same shape with each element squared.
📚 .sum() = ndarray reduction. Sums every element to a scalar.
→ why squared? = L2 norm is sqrt(sum of squares). We accumulate squares now, sqrt at the end.
28return np.sqrt(sq)
Final square root recovers the L2 norm: sqrt(sum_p ||g_p||^2) = ||cat(g_p)||_2.
EXECUTION STATE
📚 np.sqrt(x) = Element-wise square root. On a Python float returns a numpy scalar.
⬆ return = Float ≈ 26.4016 for shared_grads_rul. Single L2 norm of the concatenated 23-element gradient.
31def grad_norm_via_concat(grads) → float
Reference implementation: explicitly concatenate all gradients into a single 1-D vector and take its L2 norm. Numerically identical to the sum-of-squares version.
EXECUTION STATE
⬇ input: grads = Same list as before.
⬆ returns = Float — same answer as grad_norm_sum_of_squares but materialises a temporary concat array.
32docstring
Records that this is the ‘naive’ reference: same answer, more memory.
33flat = np.concatenate([g.reshape(-1) for g in grads])
Build a single 1-D vector containing every gradient element. .reshape(-1) flattens any-shape ndarray into 1-D; np.concatenate stitches them end-to-end.
EXECUTION STATE
📚 .reshape(-1) = ndarray method: -1 in reshape means ‘infer this dimension’. With a single -1 ⇒ flatten to 1-D.
📚 np.concatenate(seq) = Stack a sequence of arrays end-to-end along axis 0 (default). For 1-D input: returns a 1-D vector with total length = sum of lengths.
flat = 1-D ndarray of length 12 + 3 + 6 + 2 = 23.
→ memory cost = Allocates a fresh 23-element buffer. For real backbones with millions of params this is millions of floats temporarily duplicated.
34return np.linalg.norm(flat, ord=2)
L2 norm of the concatenated vector. ord=2 is the default but we make it explicit.
EXECUTION STATE
📚 np.linalg.norm(x, ord) = Compute matrix or vector norm. For a 1-D vector with ord=2: sqrt(sum of squares).
⬇ ord = 2 = Euclidean norm. Other options: ord=1 (L1), ord=np.inf (Linf / max-abs).
⬆ return = Float ≈ 26.4016 — IDENTICAL to grad_norm_sum_of_squares.
38names = ['W1', 'b1', 'W2', 'b2']
Display labels for the per-parameter table.
EXECUTION STATE
names = Python list of 4 strings. Index-aligned with shared_grads_rul / _health.
→ used downstream by = trainer step: combined = lambda_rul * rul_loss + lambda_health * health_loss; combined.backward(); optimizer.step(). lambda values become scalar coefficients; their gradients do NOT flow (treated as constants).
21📐 Toy example used throughout this trace
One small numerical setup that every iteration card below refers to. Read this card first; it makes the rest of the walkthrough hand-traceable.
EXECUTION STATE
shared_grads_rul (4 tensors, total 23 scalars) =
W1 grad: shape (3, 4) — each element ~ N(0, 1) × 5.0
b1 grad: shape (3,) — each element ~ N(0, 1) × 5.0
W2 grad: shape (2, 3) — each element ~ N(0, 1) × 5.0
b2 grad: shape (2,) — each element ~ N(0, 1) × 5.0
shared_grads_health (same 4 shapes) = Each element ~ N(0, 1) × 0.01.
This is 500× smaller per-element than the RUL gradient — mimics the paper's measured imbalance.
29📊 Variable trace — sq accumulator across both grad_norm_sum_of_squares calls
Step-by-step evolution of the squared-norm accumulator as the for-loop on line 26 walks the four parameter tensors. The function is called twice (RUL on line 46, health on line 48); the accumulator is fresh in each call. The final sqrt on line 28 produces the value fed to the GABA closed form.
35⚠️ Edge cases for the per-parameter L2 aggregation
Failure modes you must handle when wiring this aggregator into a real trainer. The reference implementation (numpy + paper PyTorch) is robust to most of these — but only because it makes deliberate choices.
EXECUTION STATE
Empty grads list = sum-of-squares: returns sqrt(0.0) = 0.0. No error, but downstream divide-by-zero when computing S = sum(g_i). Guard the trainer: if S == 0, fall back to uniform λ_i = 1/K.
Single-element grads list = Works correctly — accumulator visits one tensor, sqrt returns that tensor's own L2 norm. No special case.
NaN in any gradient = (g**2).sum() propagates NaN; final sqrt is NaN. Guard with np.isfinite(g).all() before adding. Common with mixed-precision training.
+inf in any gradient = (g**2) overflows to +inf, sq becomes +inf, sqrt(inf) = inf. λ rule still works (one λ → 1, the other → 0) but the optimizer step blows up. Clip gradients before computing norm.
Mixed dtypes (float16 + float32) = NumPy upcasts silently. PyTorch raises if devices/dtypes mismatch. Best practice: cast every g to float32 before squaring.
Gradient is None (PyTorch w/ allow_unused=True) = (None ** 2) raises TypeError in NumPy. The torch reference handles this on line 48: `if g is not None`. Mirror that guard if you ever swap NumPy for a real autograd output.
GPU tensors (PyTorch) = Accumulator must be on the same device as the gradients. The torch reference uses `torch.tensor(0.0, device=loss.device)` on line 46 — never `0.0` (float, CPU).
44🐛 Debug version — instrumented with prints
Drop-in replacement that prints every accumulation step and the final norm. Use this once when wiring GABA into a new codebase to verify the per-parameter contributions match what you expect.
EXECUTION STATE
Instrumented function =
def grad_norm_sum_of_squares_debug(grads, label='grads'):
print(f'\n=== {label} ===')
sq = 0.0
for i, g in enumerate(grads):
contribution = (g ** 2).sum()
sq += contribution
print(f' param {i}: shape={str(g.shape):<8} '
f'||g_p||={np.sqrt(contribution):>10.6f} '
f'sq_running={sq:>12.6f}')
norm = np.sqrt(sq)
print(f' FINAL ||g||_2 = {norm:.6f}')
return norm
# Run on the same toy example
g_rul = grad_norm_sum_of_squares_debug(shared_grads_rul, 'RUL')
g_health = grad_norm_sum_of_squares_debug(shared_grads_health, 'health')
Copy this whole block into a fresh .py file and run with `python file.py`. Verifies sum-of-squares equals concat-then-norm to within 1e-12.
EXECUTION STATE
Standalone script =
import numpy as np
np.random.seed(1)
grads = [np.random.randn(3, 4) * 5.0,
np.random.randn(3) * 5.0,
np.random.randn(2, 3) * 5.0,
np.random.randn(2) * 5.0]
# Method A — paper's sum of squares
sq = sum((g ** 2).sum() for g in grads)
norm_a = np.sqrt(sq)
# Method B — concat then L2 norm
flat = np.concatenate([g.reshape(-1) for g in grads])
norm_b = np.linalg.norm(flat, ord=2)
print(f'Method A (sum of squares): {norm_a:.10f}')
print(f'Method B (concat-then-norm): {norm_b:.10f}')
print(f'identical to 1e-12: {np.isclose(norm_a, norm_b, atol=1e-12)}')
Expected stdout =
Method A (sum of squares): 26.4015038258
Method B (concat-then-norm): 26.4015038258
identical to 1e-12: True
57✅ In one sentence
The whole script distilled.
EXECUTION STATE
This script proves = Adding up per-parameter squared L2 norms and taking the square root produces the SAME number as flattening every gradient into one big vector and taking its L2 norm — bit-for-bit, with zero extra memory.
Why it matters = GABA's stage 1 sensor (||g_i||) is computed every training step. On the paper's 3.5 M-param backbone, the concat-then-norm path would allocate ~14 MB of scratch per task per step. The sum-of-squares identity gives the same answer for free.
Now the production version — line-for-line from paper_ieee_tii/grace/core/gradient_utils.py. A toy 14→32→16 shared backbone with two heads; one forward pass; two gradient norms via torch.autograd.grad with create_graph=False and retain_graph=True; finally the K=2 closed form.
Paper code: get_shared_params + compute_task_grad_norm
🐍compute_task_grad_norm.py
Explanation(60)
Code(74)
1docstring
Module docstring. The two functions below are line-for-line copies of grace/core/gradient_utils.py — the paper's production utilities.
3import torch
Core PyTorch.
EXECUTION STATE
📚 torch = Tensor library with autograd. Used for tensors, torch.autograd.grad, torch.tensor, torch.randn.
4import torch.nn as nn
Layer primitives.
EXECUTION STATE
📚 torch.nn = Neural-network module. nn.Module base class, nn.Linear layers.
6torch.manual_seed(0)
Fix the PRNG so the smoke-test output is reproducible.
9class TinyBackbone(nn.Module):
A two-layer shared backbone: 14 → 32 → 16. Mimics the paper's CNN-BiLSTM-Attention stack but small enough to keep the demo readable.
EXECUTION STATE
📚 nn.Module = Base class for stateful PyTorch components. Subclasses override forward() and register parameters via attribute assignment.
10def __init__(self):
Backbone constructor.
11super().__init__()
Initialise the nn.Module base class. Required.
12self.fc1 = nn.Linear(14, 32)
First fully-connected layer. Total params = 14 * 32 + 32 = 480.
EXECUTION STATE
📚 nn.Linear(in, out) = Stores W (out × in) and b (out). Forward: y = x @ W.T + b.
→ 14 inputs = Matches the paper's 14-sensor C-MAPSS input.
→ 32 hidden = Hidden width.
13self.fc2 = nn.Linear(32, 16)
Second layer. 32 * 16 + 16 = 528 params.
15def forward(self, x):
Backbone forward pass.
EXECUTION STATE
⬇ input: x = Tensor (batch, 14).
⬆ returns = Tensor (batch, 16) — shared features for both heads.
16return self.fc2(torch.relu(self.fc1(x)))
Linear → ReLU → Linear. Composes the two layers with a non-linearity in between.
EXECUTION STATE
📚 torch.relu(x) = Element-wise max(0, x). Adds a non-linearity so the two layers can't be collapsed into one.
19class DualHead(nn.Module):
The full multi-task model: shared backbone + two task-specific heads. Mimics the paper's DualTaskModel architecture.
20def __init__(self):
DualHead constructor.
21super().__init__()
nn.Module init.
22self.backbone = TinyBackbone()
Embed the shared backbone as a submodule. Both heads will read from its output.
EXECUTION STATE
→ param naming = All backbone params will appear under named_parameters() with prefix ‘backbone.’ — that prefix is what get_shared_params() filters on.
23self.rul_head = nn.Linear(16, 1)
Regression head: 16 → 1 scalar (predicted RUL).
EXECUTION STATE
→ name 'rul_head' = This name is checked by get_shared_params() to EXCLUDE these params from the gradient-norm computation.
→ name 'health_head' = Also excluded by get_shared_params().
26def forward(self, x):
Full forward returns a tuple of (rul_pred, hp_logits).
EXECUTION STATE
⬆ returns = Tuple of two tensors: (B, 1) RUL, (B, 3) health logits.
27feat = self.backbone(x)
Run the shared backbone exactly once. Both heads will read from `feat`.
EXECUTION STATE
feat = Tensor (B, 16). Critical: BOTH losses depend on the SAME backbone parameters via this tensor — that is what makes them ‘shared-parameter’.
→ why one feat? = If we re-ran the backbone for each head, the two losses would not share gradients on the same forward graph — the whole multi-task setup falls apart.
Return only the parameters that belong to the shared backbone — those whose .name does NOT contain ‘rul_head’ or ‘health_head’. Line-for-line copy from grace/core/gradient_utils.py:13.
EXECUTION STATE
⬇ input: model = An nn.Module subclass instance. Walks .named_parameters() to find candidates.
⬇ input: head_names = Tuple of strings to EXCLUDE. Default matches the paper's DualTaskModel naming.
⬆ returns = List of nn.Parameter — the backbone-only subset.
32docstring
Records the function's purpose: filter named_parameters() down to the backbone-only subset.
33out = []
Empty accumulator for the filtered parameter list.
34for name, p in model.named_parameters():
Iterate every (name, parameter) pair in the model. Names look like ‘backbone.fc1.weight’, ‘rul_head.bias’, etc.
EXECUTION STATE
📚 .named_parameters() = nn.Module method. Yields (str, nn.Parameter) pairs for every parameter (recursive across submodules).
Skip frozen parameters. They cannot contribute to gradients, so including them would waste a None / zero entry per param.
EXECUTION STATE
📚 p.requires_grad = Tensor attribute. True if the parameter is being trained (autograd records ops on it), False if frozen (e.g. transfer learning, frozen embeddings).
🔀 branch — toy example = All 8 params in DualHead have requires_grad=True (default for nn.Linear). So `not p.requires_grad` is False on every iteration → this branch never taken in the toy run.
→ if True instead = `continue` runs and the parameter is skipped. Net effect: that param is omitted from the returned `out` list — exactly what we want for a frozen backbone (e.g. fine-tuning only heads).
36continue
Skip this iteration; move to the next parameter.
37if any(h in name for h in head_names):
Substring match on the head names. If ‘rul_head’ or ‘health_head’ appears anywhere in `name`, this parameter is a head and must be excluded from the shared-backbone gradient norm.
EXECUTION STATE
📚 any(iterable) = Python builtin. Returns True if any element of the iterable is truthy. Short-circuits on the first True.
h in name = Python substring test. Matches ‘rul_head.weight’, ‘rul_head.bias’, ‘ema.module.rul_head.weight’, etc.
→ why substring? = Robust to parent prefixes. If the model is wrapped (e.g. EMA, DataParallel), names look like ‘ema.module.rul_head.weight’ and a substring match still catches it. Equality match would silently miss them.
Functional autograd. Returns gradient tensors WITHOUT writing to p.grad. The flag combination is paper-canonical.
EXECUTION STATE
📚 torch.autograd.grad(outputs, inputs, ...) = Functional differentiation. Returns ∂outputs/∂inputs as a tuple. Unlike loss.backward(), it does not accumulate into .grad.
⬇ retain_graph=retain_graph = Default True. Keep the autograd graph after this call so the second task can compute its gradient on the SAME forward pass without recomputing.
⬇ create_graph=False = Do NOT track gradient-of-gradient. ~1x memory. Saves ~50% vs GradNorm's create_graph=True.
⬇ allow_unused=True = Tolerate parameters that don't appear in the autograd graph of `loss` (returns None for those entries instead of raising).
→ why allow_unused = Robust to architecture variants where some shared params are detached for one head only. Paper's DualTaskModel uses this defensively.
46total = torch.tensor(0.0, device=loss.device)
Accumulator. Built on the same device as the loss (CPU / GPU / MPS).
EXECUTION STATE
📚 torch.tensor(value, device) = Build a 0-dim tensor at the given value.
⬇ device=loss.device = Match the loss's device so the in-place addition below stays on the same hardware.
47for g in grads:
Iterate per-parameter gradient tensors. Each g matches the shape of the corresponding parameter. compute_task_grad_norm is called TWICE in this script (line 65 with rul_loss, line 66 with health_loss), so this loop runs 4+4 = 8 times total. All values measured under torch.manual_seed(0).
context = Triggered by line 66. SAME forward pass, SAME shared params, but ∂health_loss instead of ∂rul_loss. The cross-entropy gradient on the shared backbone is ~500× smaller per-element. total is reinitialised to 0.0 (fresh function frame).
iter 0 (health): g = ∂health_loss/∂fc1.weight (32, 14)
g.shape = (32, 14)
g.pow(2).sum() = = 0.008390
‖g_p‖₂ = sqrt(0.008390) = 0.09160
total after = 0.0 + 0.008390 = 0.008390
iter 1 (health): g = ∂health_loss/∂fc1.bias (32,)
g.shape = (32,)
g.pow(2).sum() = = 0.001306
‖g_p‖₂ = sqrt(0.001306) = 0.03614
total after = 0.008390 + 0.001306 = 0.009696
iter 2 (health): g = ∂health_loss/∂fc2.weight (16, 32)
g.shape = (16, 32)
g.pow(2).sum() = = 0.027842
‖g_p‖₂ = sqrt(0.027842) = 0.16686
total after = 0.009696 + 0.027842 = 0.037538
iter 3 (health): g = ∂health_loss/∂fc2.bias (16,)
g.shape = (16,)
g.pow(2).sum() = = 0.009293
‖g_p‖₂ = sqrt(0.009293) = 0.09640
total after = 0.037538 + 0.009293 = 0.046830
→ next: line 50 = total.sqrt() = sqrt(0.046830) = 0.2164 returned as g_health.
── after both calls ──
g_rul = tensor(114.4296)
g_health = tensor(0.2164)
ratio = 114.4296 / 0.2164 ≈ 528.8× — reproduces the paper's measured 500× imbalance.
48if g is not None:
Skip None entries left by allow_unused=True. Without this guard, `g.pow(2)` would raise AttributeError on a None value.
EXECUTION STATE
🔀 branch — toy example = All 4 shared backbone params appear in BOTH rul_loss and health_loss autograd graphs (they flow through `feat`). So g is never None on this run → branch always True.
→ if False instead = (branch False, i.e. g is None) — the loop body is skipped, that param contributes 0 to the squared sum. This is the correct behaviour for a parameter that doesn't appear in the loss's autograd graph.
→ when can g be None? = (a) Parameter has requires_grad=False — but get_shared_params already filters those. (b) Parameter is detached from the loss (e.g. one head doesn't read from it). (c) Architecture variants where some shared params route only through one head.
49total = total + g.pow(2).sum()
Accumulate the squared-norm contribution. Out-of-place add to keep autograd happy.
EXECUTION STATE
📚 .pow(n) = Tensor element-wise power. .pow(2) is element-wise square.
📚 .sum() = Tensor reduction. Sum every element.
→ why out-of-place? = In-place add (total += ...) sometimes breaks autograd's view tracking. Out-of-place creates a fresh tensor each step — safer and the cost is negligible at this scale.
50return total.sqrt()
Final L2 norm.
EXECUTION STATE
📚 .sqrt() = Tensor element-wise square root.
⬆ return = 0-dim tensor ≈ 114.4296 for the RUL task on this seed.
54model = DualHead()
Instantiate the multi-task model.
EXECUTION STATE
model = DualHead with 8 trainable parameters (4 backbone + 2 rul_head + 2 health_head).
55shared = get_shared_params(model)
Extract the backbone-only subset for gradient-norm computation.
EXECUTION STATE
shared = List of 4 nn.Parameter: backbone.fc1.weight (32×14), backbone.fc1.bias (32,), backbone.fc2.weight (16×32), backbone.fc2.bias (16,).
57x = torch.randn(64, 14)
Random batch.
EXECUTION STATE
📚 torch.randn(*size) = Sample from N(0, 1).
58rul_target = torch.rand(64, 1) * 125.0
Random RUL targets in [0, 125] (paper's RUL cap).
First gradient norm. retain_graph=True is critical — without it the next call would crash.
EXECUTION STATE
g_rul = 0-dim tensor ≈ 114.4296. Real measurement on this seed.
→ retain_graph=True = Required so the SAME forward pass survives for the second autograd.grad call below. Without it, PyTorch frees the graph after the first backward and the second call raises ‘Trying to backward through the graph a second time’.
Second gradient norm on the SAME forward. This is where retain_graph from line 65 pays off.
EXECUTION STATE
g_health = 0-dim tensor ≈ 0.2164. Same forward; different loss; different gradients.
68print ||g_rul||
Pretty-print.
EXECUTION STATE
Output = ||g_rul|| = 114.4296
69print ||g_health||
Pretty-print.
EXECUTION STATE
Output = ||g_health|| = 0.2164
70print ratio
The empirical gradient-magnitude ratio for this seed.
EXECUTION STATE
Output = ratio = 528.8x
→ reading = Reproduces the paper's ~500x imbalance figure on a tiny untrained backbone — the imbalance is structural (MSE scale vs CE bound), not architecture-specific.
72S = g_rul + g_health
K=2 normaliser.
EXECUTION STATE
S = 0-dim tensor ≈ 114.6460.
73print lambda_rul
Apply the §17.3 closed form.
EXECUTION STATE
Output = (blank line)
lambda_rul = 0.001888
74print lambda_health
Final result. The trainer would now form combined_loss = lambda_rul·rul_loss + lambda_health·health_loss and call combined_loss.backward() — that backward call closes the autograd graph the two grad() calls have been retaining since line 65.
→ used downstream by = combined_loss = lambda_rul.detach() * rul_loss + lambda_health.detach() * health_loss
combined_loss.backward() # frees the retained autograd graph
optimizer.step() # actually updates the weights
optimizer.zero_grad() # clear .grad for next step
7📐 Toy example used throughout this trace
One small concrete setup that every iteration card refers to. Read this card first; it makes the rest of the walkthrough hand-traceable. Tiny enough that you could re-run it in your head.
EXECUTION STATE
Model: DualHead = TinyBackbone (14 → 32 → 16) + rul_head (16 → 1) + health_head (16 → 3).
8 trainable parameters, 4 of which are 'shared' (backbone).
Expected losses = rul_loss ≈ O(5300) — random init MSE on a 0–125 target
health_loss ≈ ln 3 ≈ 1.10 — uniform softmax over 3 classes
Expected ||g_rul|| = ≈ 114.4296 (sqrt of sum of squared per-param grads on the shared backbone)
Expected ||g_health|| = ≈ 0.2164
Expected ratio = 528.8x — close to the paper's measured 500x median imbalance.
Expected lambdas = λ_rul = 0.001888
λ_health = 0.998112 (almost all weight goes to the underdog task)
41📊 Variable trace — `total` accumulator across both compute_task_grad_norm calls
Step-by-step evolution of the squared-norm accumulator inside the for-loop on line 47, for BOTH calls in this script (RUL on line 65, health on line 66). All values measured under torch.manual_seed(0) on the toy DualHead model.
Failure modes you must handle when wiring this into a real GABA trainer. Most of these are silent — they produce a number that looks plausible but is wrong.
EXECUTION STATE
loss is not a 0-dim tensor = torch.autograd.grad raises RuntimeError if loss has more than one element. Fix: pass `.mean()` or `.sum()` explicitly. Common when the user forgets the reduction on a per-sample loss.
shared_params is empty = torch.autograd.grad raises 'inputs must contain at least one tensor'. Caller bug — get_shared_params filtered everything out (e.g. wrong head_names). Always assert len(shared) > 0.
Some grad is None (allow_unused=True) = Handled correctly by line 48's `if g is not None` guard. Without that guard, `g.pow(2)` raises AttributeError. Common when one head's loss doesn't depend on a particular shared param.
All grads are None = total stays at 0.0; total.sqrt() = 0.0. Returns tensor(0.0). Downstream λ rule divides by S = 0 → NaN. Trainer should fall back to uniform λ_i = 1/K when S < epsilon.
retain_graph=False on the FIRST call = Forward graph freed after this call returns. The SECOND call (line 66) raises 'RuntimeError: Trying to backward through the graph a second time'. ALWAYS keep retain_graph=True for the K-1 first calls. Most common GABA-integration bug.
create_graph=True instead of False = Doubles memory (autograd records ops on the gradient computation itself). GABA does NOT need this — the closed form is not differentiated through. (GradNorm DOES need it; that's the §17.4 difference.)
loss is on GPU but accumulator on CPU = tensor(0.0) without device= lives on CPU. Adding a CUDA tensor to it raises 'Expected all tensors to be on the same device'. The fix on line 46 — `device=loss.device` — is mandatory.
Mixed precision (loss is fp16) = g.pow(2) overflows fp16 quickly (max ≈ 65 504, easily exceeded by 9 110 in our toy run). Cast to fp32 before squaring or wrap in autocast(disable=True).
Gradient contains NaN / inf = Propagates through total.sqrt(). Trainer should check torch.isfinite(g_rul) and torch.isfinite(g_health) before applying the closed form, and skip the step on NaN.
52🐛 Debug version — instrumented compute_task_grad_norm
Drop-in replacement that prints every per-parameter contribution. Use once when wiring GABA into a new codebase to verify that your shared-param filter and gradient norms match expectations.
EXECUTION STATE
Instrumented function =
def compute_task_grad_norm_debug(loss, shared_params, label='task',
retain_graph=True):
print(f'\n=== {label} ===')
grads = torch.autograd.grad(loss, shared_params,
retain_graph=retain_graph,
create_graph=False,
allow_unused=True)
total = torch.tensor(0.0, device=loss.device)
for i, (p, g) in enumerate(zip(shared_params, grads)):
if g is None:
print(f' [{i}] shape={tuple(p.shape)} grad=None (skipped)')
continue
contribution = g.pow(2).sum()
total = total + contribution
print(f' [{i}] shape={str(tuple(p.shape)):<10} '
f'||g_p||={contribution.sqrt().item():>10.4f} '
f'total_running={total.item():>12.4f}')
norm = total.sqrt()
print(f' FINAL ||g||_2 = {norm.item():.4f}')
return norm
# Replace the two calls on lines 65-66 with:
g_rul = compute_task_grad_norm_debug(rul_loss, shared, 'RUL', retain_graph=True)
g_health = compute_task_grad_norm_debug(health_loss, shared, 'health', retain_graph=True)
Smallest possible self-contained script that produces the paper's gradient-norm imbalance. Copy into a fresh .py file and run with `python file.py`.
EXECUTION STATE
Standalone script =
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
class DualHead(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(nn.Linear(14, 32), nn.ReLU(),
nn.Linear(32, 16))
self.rul_head = nn.Linear(16, 1)
self.health_head = nn.Linear(16, 3)
def forward(self, x):
f = self.backbone(x)
return self.rul_head(f), self.health_head(f)
def shared_params(model):
return [p for n, p in model.named_parameters()
if 'rul_head' not in n and 'health_head' not in n
and p.requires_grad]
def grad_norm(loss, params):
grads = torch.autograd.grad(loss, params,
retain_graph=True,
create_graph=False,
allow_unused=True)
total = torch.tensor(0.0, device=loss.device)
for g in grads:
if g is not None:
total = total + g.pow(2).sum()
return total.sqrt()
model = DualHead()
shared = shared_params(model)
x = torch.randn(64, 14)
rul_target = torch.rand(64, 1) * 125.0
hp_target = torch.randint(0, 3, (64,))
rul_pred, hp_logits = model(x)
rul_loss = ((rul_pred - rul_target) ** 2).mean()
health_loss = F.cross_entropy(hp_logits, hp_target)
g_rul = grad_norm(rul_loss, shared)
g_health = grad_norm(health_loss, shared)
S = g_rul + g_health
print(f'||g_rul|| = {g_rul.item():.4f}')
print(f'||g_health|| = {g_health.item():.4f}')
print(f'ratio = {(g_rul / g_health).item():.1f}x')
print(f'lambda_rul = {(g_health / S).item():.6f}')
print(f'lambda_health= {(g_rul / S).item():.6f}')
This script proves = One forward pass plus two torch.autograd.grad calls (with retain_graph=True, create_graph=False) gives both per-task gradient norms on the shared backbone — and a 480-param toy already reproduces the paper's ~500x imbalance.
Why it matters = GABA's stage 1 sensor must be cheap (every step) and exact (a closed form depends on it). torch.autograd.grad with create_graph=False is half the memory of GradNorm's create_graph=True; retain_graph=True is what lets two tasks share one forward pass; allow_unused=True is what makes the function work on architectures where some shared params route only through one head.
14 lines without explanation
1"""Paper code: per-task gradient norm on shared backbone (compute_task_grad_norm)."""23import torch
4import torch.nn as nn
56torch.manual_seed(0)789classTinyBackbone(nn.Module):10def__init__(self):11super().__init__()12 self.fc1 = nn.Linear(14,32)13 self.fc2 = nn.Linear(32,16)1415defforward(self, x):16return self.fc2(torch.relu(self.fc1(x)))171819classDualHead(nn.Module):20def__init__(self):21super().__init__()22 self.backbone = TinyBackbone()23 self.rul_head = nn.Linear(16,1)24 self.health_head = nn.Linear(16,3)2526defforward(self, x):27 feat = self.backbone(x)28return self.rul_head(feat), self.health_head(feat)293031defget_shared_params(model, head_names=("rul_head","health_head")):32"""Return parameters that belong to the shared backbone."""33 out =[]34for name, p in model.named_parameters():35ifnot p.requires_grad:36continue37ifany(h in name for h in head_names):38continue39 out.append(p)40return out
414243defcompute_task_grad_norm(loss, shared_params, retain_graph=True):44"""L2 norm of grad(loss) on shared_params. create_graph=False for speed."""45 grads = torch.autograd.grad(loss, shared_params, retain_graph=retain_graph, create_graph=False, allow_unused=True)46 total = torch.tensor(0.0, device=loss.device)47for g in grads:48if g isnotNone:49 total = total + g.pow(2).sum()50return total.sqrt()515253# ---------- One forward pass, two gradient norms ----------54model = DualHead()55shared = get_shared_params(model)5657x = torch.randn(64,14)58rul_target = torch.rand(64,1)*125.059hp_target = torch.randint(0,3,(64,))6061rul_pred, hp_logits = model(x)62rul_loss =((rul_pred - rul_target)**2).mean()63health_loss = nn.functional.cross_entropy(hp_logits, hp_target)6465g_rul = compute_task_grad_norm(rul_loss, shared, retain_graph=True)66g_health = compute_task_grad_norm(health_loss, shared, retain_graph=True)6768print(f"||g_rul|| = {g_rul.item():.4f}")69print(f"||g_health|| = {g_health.item():.4f}")70print(f"ratio = {(g_rul / g_health).item():.1f}x")7172S = g_rul + g_health
73print(f"\nlambda_rul = {(g_health / S).item():.6f}")74print(f"lambda_health = {(g_rul / S).item():.6f}")
The 528.8× ratio in the PyTorch output is not a coincidence. A 480-parameter tiny backbone with random init reproduces the paper's 500× imbalance because the asymmetry is structural: MSE gradients scale as O(Rmax2) with the regression target range (paper Rmax=125 cycles), while cross-entropy gradients are bounded by K for K=3 classes. The ratio is a property of the LOSSES and the TARGET RANGES, not the backbone size.
Measured On A Real Backbone
On the paper's actual CNN-BiLSTM-Attention backbone (3.5 M parameters), the same compute_task_grad_norm utility produces the empirical distribution that motivates GABA. Quoting the paper directly (paper main.tex:319):
“During joint training with standard MSE and cross-entropy losses, regression (RUL) gradients exceed classification (health) gradients on shared backbone parameters by ∼500× (median across n=4,120 epoch-level gradient samples from 20 training runs).”
Quantity
What it means
Where it comes from
n = 4,120 samples
Number of (epoch, parameter-block) pairs measured. 4,120 = 20 runs × 206 epochs avg.
paper main.tex:73
20 training runs
5 random seeds × 4 C-MAPSS subsets (FD001–FD004).
paper main.tex:319
~500× ratio
Median of g_rul / g_health across the 4,120 samples.
paper main.tex:48, 319
Peak ~2,400× (around epoch 4)
Transient maximum during training before the system settles to ~500–1,000×.
paper main.tex:564
Steady state ~500–1,000×
Stabilised ratio after early training.
paper main.tex:564
The takeaway: compute_task_grad_norm running across an entire training run produces a SIGNAL, not a single number. GABA's job (in §18.2) is to smooth that signal with EMA so the resulting λ∗ does not oscillate.
The Same Pattern In Other Fields
Field
Per-step measurement
Aggregation
Used to control
Predictive maintenance (this paper)
||grad task_loss / shared_params||_2
Sum of squared per-parameter L2
Multi-task weight λ_i
Federated learning (FedAvg)
||client_update||_2
Server-side L2 over flattened deltas
Inverse-norm aggregation, byzantine robustness
Gradient clipping (every modern trainer)
||grad combined_loss||_2
Sum of squared per-parameter L2
Scale gradient if norm > threshold
Adam / RMSProp
Per-parameter g_i^2
Element-wise EMA
Per-parameter learning rate
Reinforcement learning (TRPO)
Fisher-info-vector products via grad
L2 over policy parameters
Trust-region step size
Continual learning (EWC / SI)
||grad task_t loss / params||_2
Sum of squared per-task gradients
Per-parameter regularisation strength
Audio mastering (LUFS)
Per-band loudness
Weighted L2 across frequency bands
Per-track gain
In every row, a control mechanism reads a per-step L2 norm and feeds it back into the next step's decision. GABA's contribution is the closed form plugged into stage 2; the measurement (this section) is a cross-disciplinary pattern.
Pitfalls In Per-Step Norm Computation
Pitfall 1: Forgetting retain_graph=True. First call works; second call raises RuntimeError: Trying to backward through the graph a second time. Fix: pass retain_graph=True on every per-task gradient call, AND the same flag on the final combined_loss.backward() if you call it after.
Pitfall 2: Including head parameters in shared_params. The visualizer above shows what happens: head parameters contribute large per-task gradients (because the head is dedicated to that task) that don't reflect the shared-backbone dynamic. Always pass get_shared_params(model) not list(model.parameters()).
Pitfall 3: Using create_graph=True when you don't need second-order autograd. GABA does NOT need it — its closed form is not differentiated through. Setting create_graph=True roughly doubles memory for nothing. The paper's utility hard-codes create_graph=False for this reason. (GradNorm DOES need True — that's a §17.4 difference.)
Pitfall 4: Computing gi on the FULL combined loss instead of the per-task loss.∥∇(∑iλiLi)∥ is NOT ∑iλi∥∇Li∥ (triangle inequality, not equality). GABA needs the PER-TASK norms separately to apply its inverse rule. The utility takes one task's loss at a time for exactly this reason.
Pitfall 5: Mutating p.grad accidentally. If you use loss.backward() instead of torch.autograd.grad, gradients accumulate into p.grad and contaminate the eventual weight update. The paper's utility uses the FUNCTIONAL autograd.grad precisely so the GABA measurement does not touch .grad at all.
Takeaway
GABA's sensor is gi=∥∇θsLi∥2 on the shared backbone. Computed every step; drives the closed-form weight rule.
Shared parameters are selected by name.get_shared_params(model) excludes the heads via substring match on named_parameters(). Including heads contaminates the imbalance reading.
The L2 norm is the right choice. Rotation-invariant, differentiable, standard across PyTorch's clipping / Adam infrastructure.
Aggregate via sum-of-squares.∥g∥2=∑p∥gp∥22 gives the same answer as concat-then-norm with dramatically less memory.
One forward pass serves K backward calls.retain_graph=True on every per-task autograd.grad plus the final combined_loss.backward().
create_graph=False is what makes GABA cheap. No second-order autograd, no double memory. This is the operational gap vs GradNorm.
The paper's 480-parameter toy backbone reproduces the 500× imbalance. The ratio is structural — a property of the loss families and target range, not architecture-specific.