Chapter 21
12 min read
Section 84 of 121

Separation of Concerns: Adaptation vs. Loss Shape

Combining GABA + Weighted MSE

The Mixing Console And The EQ

Walk into a recording studio and look at the console. Every channel has its own fader — vocal, kick drum, bass, guitar — and every channel also has its own equaliser. The faders ask which sources matter right now; the equalisers ask which frequencies inside each source matter. The engineer can re-balance the mix without touching the EQ, and re-EQ a single instrument without touching the faders. The two controls live on different axes.

GRACE is built on the same separation. GABA is the fader bank: it decides, every training step, how much of the gradient budget each task gets. Failure-biased weighted MSE is the EQ: it decides, within the RUL task, which samples carry the heaviest squared error. The contribution of this chapter — and this section in particular — is to show that these two knobs commute. You can adapt the per-task weighting and shape the per-sample loss at the same time, with no interference, and the resulting algorithm is the model that wins NASA-score on multi-condition C-MAPSS.

The headline. Two orthogonal axes — outer λi(t)\lambda_i(t) per task and inner w(yj)w(y_j) per sample — produce a 2×2 grid of methods. Cell AA is the plain baseline, cell BB is AMNL-style, cell CC is GABA + standard MSE, cell DD is GRACE.

Two Independent Axes Of A Multi-Task Loss

Every multi-task loss in this book has the same skeleton:

L(t)  =  i=1Kλi(t)outer: per-task weight1Nj=1Nwij(yj)i(y^j,yj)inner: per-sample loss\mathcal{L}(t) \;=\; \sum_{i=1}^{K}\, \underbrace{\lambda_i(t)}_{\text{outer: per-task weight}} \,\cdot\, \underbrace{\frac{1}{N}\sum_{j=1}^{N} w_{ij}(y_j)\, \ell_i(\hat{y}_j, y_j)}_{\text{inner: per-sample loss}}

Two indices, two roles. The outer index ii ranges over tasks — RUL regression and health classification, in our case. The inner index jj ranges over samples in the mini-batch. Methods in the literature differ purely in which of these indices they touch:

AxisWhat it weightsExamplesGRACE choice
Outer (per-task)How much each task contributes to the combined gradientFixed (0.5/0.5), Uncertainty (Kendall et al.), GradNorm (Chen et al.), DWA (Liu et al.), GABA (this book)GABA:λi(t)=EMA ⁣(gj(t)gi(t)+gj(t))\lambda_i^*(t) = \mathrm{EMA}\!\left(\dfrac{g_j(t)}{g_i(t) + g_j(t)}\right)
Inner (per-sample)How much each sample inside a task contributes to that task's lossStandard MSE (uniform), Asymmetric, Focal, Quantile, Failure-biased weighted MSEFailure-biased:w(yj)=1+clip ⁣(1yj125, 0, 1)w(y_j) = 1 + \mathrm{clip}\!\left(1 - \dfrac{y_j}{125},\ 0,\ 1\right)

Because the two axes touch different indices, the combined formula factorises cleanly:

L(t)  =  1Ni=1Kj=1N[λi(t)wij(yj)]i(y^j,yj)\mathcal{L}(t) \;=\; \frac{1}{N}\sum_{i=1}^{K}\sum_{j=1}^{N} \Big[\, \lambda_i(t)\, \cdot\, w_{ij}(y_j)\, \Big]\, \ell_i(\hat{y}_j, y_j)

Inside the brackets the outer factor depends on ii only and the inner factor on jj only. The gradient with respect to a backbone parameter θs\theta_s therefore decomposes into one outer-modulated sum per task — replacing GABA does not require rewriting the inner weights, and replacing weighted MSE does not require rewriting the outer controller.

Why this matters in practice. Every published MTL paper picks one axis to attack. AMNL freezes the outer axis at 0.5/0.5 and shapes the inner axis. GradNorm and DWA shape the outer axis and leave the inner axis at uniform MSE. GRACE is the first method on this problem that engages both simultaneously. The question of this chapter is whether the two corrections compose cleanly — numerically, on real data, the answer is yes for multi-condition datasets and a careful almost for single-condition FD003 (section 21.3).

The Composition: Outer × Inner

Substitute the GRACE choices into the general form. The outer factor becomes the GABA closed form derived in section 17.3,

λi(t)  =  jigj(t)(K1)jgj(t),gi(t)  =  θsLi(t)2,\lambda^*_i(t) \;=\; \frac{\sum_{j\neq i} g_j(t)}{(K-1)\,\sum_{j} g_j(t)}, \qquad g_i(t) \;=\; \bigl\| \nabla_{\theta_s} \mathcal{L}_i(t) \bigr\|_2,

and the inner factor becomes the failure-biased ramp,

w(yj)  =  1+clip ⁣(1yjRmax, 0, 1),Rmax=125.w(y_j) \;=\; 1 + \mathrm{clip}\!\left(1 - \frac{y_j}{R_{\max}},\ 0,\ 1\right), \qquad R_{\max} = 125.

Plugging both into the skeleton, the GRACE loss for K=2K=2 tasks (RUL + health) is

LGRACE(t)  =  λrul(t)1Nj=1Nw(yj)(y^jyj)2weighted MSE on RUL  +  λhealth(t)LCEcross-entropy on health.\mathcal{L}_{\text{GRACE}}(t) \;=\; \lambda^*_{\text{rul}}(t)\,\underbrace{\frac{1}{N}\sum_{j=1}^{N} w(y_j)\,(\hat{y}_j - y_j)^2}_{\text{weighted MSE on RUL}} \;+\; \lambda^*_{\text{health}}(t)\,\underbrace{\mathcal{L}_{\text{CE}}}_{\text{cross-entropy on health}}.

Compare it term by term with the four sibling methods:

CellOuter (per-task)Inner (per-sample RUL)Method
Aλrul=λhealth=0.5\lambda_{\text{rul}} = \lambda_{\text{health}} = 0.5w(y)=1w(y) = 1Baseline (0.5/0.5)
Bλrul=λhealth=0.5\lambda_{\text{rul}} = \lambda_{\text{health}} = 0.5w(y)=1+clip ⁣(1y125, 0, 1)w(y) = 1 + \mathrm{clip}\!\left(1 - \dfrac{y}{125},\ 0,\ 1\right)AMNL (0.5/0.5 + WMSE)
Cλrul=EMA ⁣(ghgr+gh)\lambda_{\text{rul}}^* = \mathrm{EMA}\!\left(\dfrac{g_h}{g_r + g_h}\right)w(y)=1w(y) = 1GABA + standard MSE
Dλrul=EMA ⁣(ghgr+gh)\lambda_{\text{rul}}^* = \mathrm{EMA}\!\left(\dfrac{g_h}{g_r + g_h}\right)w(y)=1+clip ⁣(1y125, 0, 1)w(y) = 1 + \mathrm{clip}\!\left(1 - \dfrac{y}{125},\ 0,\ 1\right)GRACE

Interactive: The Two-Axes Grid

Click any cell below to inspect its outer and inner formulas independently, and to read the FD002 average performance from the paper's 5-seed runs at h=256. The two axis-highlight toggles isolate orthogonality empirically: with same loss-shape (row) on, the highlighted cells share the inner axis and only the outer changes — the entire NASA column-difference must come from the per-task weighting.

Loading separation-of-concerns visualizer…
What to notice. Cell B (AMNL) wins RMSE (6.74) but loses NASA (356) — one axis pulls hard on accuracy, ignores balance. Cell C (GABA) wins NASA (224.2) but doesn't exploit the per-sample asymmetry. Cell D (GRACE) sits at the Pareto sweet spot: NASA 223.4, the lowest of all four cells, with RMSE only 0.36 cycles worse than AMNL.

Python: Computing All Four Cells From Scratch

Read the code line by line. Click any line to see its execution trace on the left, including the actual yjy_jw(yj)w(y_j) λi\lambda_i intermediate values and the final scalar produced by each cell. The script reuses the gradient-norm numbers from chapter 18 §1 so the OUTER lambdas are exact, not invented.

Reconstructing the 2×2 grid in NumPy
🐍grace_two_axes_demo.py
1Module docstring — names the 2×2 grid

Names the 2×2 grid we will reconstruct numerically. The outer axis flips between FIXED equal weighting (0.5/0.5) and GABA's adaptive closed form. The inner axis flips between uniform per-sample weighting and the failure-biased ramp w(y). The Cartesian product of the two axes generates exactly four cells — and Cell D (adaptive × failure-biased) is GRACE.

EXECUTION STATE
Outer axis values = { fixed = 0.5/0.5, GABA-adaptive = closed-form λ*(t) }
Inner axis values = { uniform w=1, failure-biased w(y)=1+clip(1-y/125,0,1) }
→ why a docstring? = Module docstrings are the FIRST string literal at the top of a Python file. Tools like help(), Sphinx, and IDEs surface them as documentation. Functionally they are no-ops — Python evaluates the literal and discards it.
→ cell map = Cell A = fixed × uniform (Baseline MTL) Cell B = fixed × weighted (AMNL) Cell C = GABA × uniform (GABA + std MSE) Cell D = GABA × weighted (GRACE)
7import numpy as np

NumPy gives us vectorised arithmetic over y_pred and y_true. All four loss shapes below are one-line array operations rather than Python loops — broadcasting, element-wise power, and .mean() reductions all run as optimised C under the hood.

EXECUTION STATE
📚 numpy = Numerical-array library. Provides ndarray, np.clip, np.array, broadcasting, element-wise ops, and reductions (.mean, .sum) used throughout the demo. Every operation here would be a slow Python for-loop without it.
as np = Standard alias by community convention. Lets us write np.array(), np.clip() instead of numpy.array(). Almost universal in scientific Python code.
→ why we need it here = Computing w * (y_pred - y_true) ** 2 requires element-wise broadcasting across 8 samples. Pure-Python lists can't do this without an explicit for-loop.
10y_true = np.array([10, 30, 60, 90, 110, 5, 15, 80], ...)

Eight ground-truth RUL values spanning the full range: two near-failure (5, 10), two early-life (90, 110), four in between. The mix lets the inner-axis weight w(y) actually do something.

EXECUTION STATE
📚 np.array(list, dtype=float) = Builds a 1-D ndarray. dtype=float forces float64 so subsequent w(y)=1+clip(...) doesn't silently promote integers.
y_true (8,) = [ 10., 30., 60., 90., 110., 5., 15., 80.]
11y_pred = np.array([15, 26, 68, 87, 112, -7, 24, 74], ...)

Predictions chosen so the residual y_pred-y_true has different signs and magnitudes per sample. Two near-failure samples are off by -12 and +9 — that is exactly where the inner-axis weight will amplify the penalty.

EXECUTION STATE
y_pred (8,) = [ 15., 26., 68., 87., 112., -7., 24., 74.]
residual y_pred-y_true = [ 5., -4., 8., -3., 2., -12., 9., -6.]
residual squared = [ 25., 16., 64., 9., 4., 144., 81., 36.]
15def rul_loss_standard(y_pred, y_true) → float

The bottom-row INNER axis: every sample contributes the same weight to the average. This is the loss shape used in cell A (Baseline) and cell C (GABA + standard MSE). Mathematically: L = (1/N) Σⱼ (ŷⱼ − yⱼ)².

EXECUTION STATE
⬇ input: y_pred (8,) — RUL predictions = [15., 26., 68., 87., 112., -7., 24., 74.]
→ y_pred purpose = What the network predicts for each engine's remaining useful life. Element j is the model's RUL forecast for engine window j. Note sample 5 = -7 (unphysical negative RUL — the linear head doesn't know about non-negativity).
⬇ input: y_true (8,) — ground-truth RUL = [10., 30., 60., 90., 110., 5., 15., 80.]
→ y_true purpose = Cycles-until-failure for each engine. Mix is deliberate: samples 0,5,6 are near failure (≤15), samples 3,4 are healthy (≥90). The mix lets the inner-axis weight w(y) actually do something.
→ no max_rul here = Standard MSE has no per-sample shape, so it doesn't need the RUL cap. The function is dataset-agnostic.
⬆ returns = Python float — the unweighted mean of squared errors. Single scalar that the autograd graph (in PyTorch) or the optimiser (here, just a print) consumes.
16Function docstring — "Plain MSE. Every sample weighted equally."

Inner-function docstring. PEP 257 convention: first triple-quoted string after a def becomes function.__doc__ and is what help() shows. This one names the loss shape (plain MSE) and the key invariant (uniform weighting) so readers don't have to derive it from the body.

EXECUTION STATE
→ why uniform weighting matters here = Contrast with rul_loss_weighted() at line 20, where each sample is multiplied by w(y). The docstring's job is to surface that distinction without forcing the reader to diff the two function bodies.
17return ((y_pred - y_true) ** 2).mean()

Three vectorised operations chained in a single line: subtract, square, average. Equivalent to (1/N) · Σⱼ (ŷⱼ − yⱼ)². NumPy fuses no operations here, but each step runs as one C loop, so even on 8 samples this is dramatically faster than a Python for-loop.

EXECUTION STATE
Step 1: y_pred - y_true = [ 5., -4., 8., -3., 2., -12., 9., -6.] — element-wise broadcast (both shape (8,)).
→ broadcasting rule = Two ndarrays of the same shape: NumPy subtracts element-wise. Output shape = input shape = (8,). No copy of either array — the result is freshly allocated.
📚 ** 2 (element-wise power) = NumPy operator: applies pow(x, 2) to every element. Equivalent to np.square(x) but written inline. e.g. (-12)**2 = 144, (-3)**2 = 9.
Step 2: (y_pred - y_true) ** 2 = [ 25., 16., 64., 9., 4., 144., 81., 36.] — squared residuals, all non-negative.
📚 .mean() — ndarray reduction = ndarray method: returns sum(self) / size(self). With no axis arg, reduces over ALL elements. Here: (25+16+64+9+4+144+81+36) / 8 = 379 / 8 = 47.375. Returns a 0-D ndarray (Python treats it as a float).
→ arg axis (not used here) = If we passed axis=0 it would mean per-column on a 2-D array. With a 1-D vector, axis=0 reduces along the only axis — same as no axis.
⬆ return: L_rul^MSE = 47.3750 — the value that lines 26, 43, 45 will consume.
20def rul_loss_weighted(y_pred, y_true, max_rul=125.0) → float

The top-row INNER axis: each sample is multiplied by w(y) = 1 + clip(1 − y/max_rul, 0, 1). Samples near failure (y → 0) get weight 2; samples beyond max_rul get weight 1. Line-for-line equivalent to grace/core/weighted_mse.py:moderate_weighted_mse_loss — used in cells B (AMNL) and D (GRACE).

EXECUTION STATE
⬇ input: y_pred (8,) = [15., 26., 68., 87., 112., -7., 24., 74.]
→ y_pred role = Same predictions as in rul_loss_standard — only the loss SHAPE changes between functions, not the inputs. That is the orthogonality story for the inner axis.
⬇ input: y_true (8,) = [10., 30., 60., 90., 110., 5., 15., 80.]
→ y_true role here = Used TWICE inside the body: once for the residual (y_pred - y_true) and once for the weight w(y_true). The weight depends ONLY on the target, never on the prediction — so dw/dy_pred = 0.
⬇ input: max_rul = 125.0 = RUL cap (default value). Matches the paper's piecewise-linear RUL target where everything ≥ 125 cycles is treated as 'fully healthy'. Samples with y_true ≥ 125 → weight = 1 (no extra emphasis).
→ why max_rul = 125? = Empirical choice from the C-MAPSS paper. With y measured in cycles and engines lasting ~200-400 cycles, 125 is the boundary where degradation typically becomes detectable. Hard-coding 125.0 as a kwarg default lets callers override it for other datasets.
⬆ returns = Python float — weighted mean of squared errors with per-sample weights w(yⱼ) ∈ [1, 2]. Strictly ≥ standard MSE because every weight ≥ 1.
21Function docstring — "Failure-biased MSE (paper: grace/core/weighted_mse.py)."

Inner-function docstring. Pins the implementation to a real source file in the GRACE codebase so readers can cross-check the closed-form weight w(y) against the paper's production code. The reproducibility contract is: same inputs, same outputs, line-for-line.

EXECUTION STATE
→ traceability = The paper file grace/core/weighted_mse.py:20 holds moderate_weighted_mse_loss — the PyTorch version. Body identical to lines 22-23 below, modulo torch.clamp ↔ np.clip.
22w = 1.0 + np.clip(1.0 - y_true / max_rul, 0.0, 1.0)

Builds the per-sample weight vector. Linear ramp: 2 at y=0, 1 at y>=max_rul. The clip prevents negative weights when y > max_rul (which would actually amplify healthy-engine errors).

EXECUTION STATE
y_true / max_rul = [0.080, 0.240, 0.480, 0.720, 0.880, 0.040, 0.120, 0.640]
1.0 - y_true / max_rul = [0.920, 0.760, 0.520, 0.280, 0.120, 0.960, 0.880, 0.360]
📚 np.clip(x, lo, hi) = Element-wise: x_i = max(lo, min(hi, x_i)). Here clip(..., 0.0, 1.0) keeps the ramp inside [0, 1] so weights stay in [1, 2].
→ why clip? = If y_true > max_rul (some samples go up to 200+ before the piecewise cap), 1 - y/max_rul becomes negative → weight < 1, accidentally PENALISING healthy-engine samples. clip pins it at 0 so the weight is exactly 1 there.
w (8,) = [1.920, 1.760, 1.520, 1.280, 1.120, 1.960, 1.880, 1.360]
23return (w * (y_pred - y_true) ** 2).mean()

Per-sample weighted squared errors, then averaged. Higher w near failure pulls the gradient toward those samples even though they may be the minority.

EXECUTION STATE
(y_pred - y_true) ** 2 = [ 25., 16., 64., 9., 4., 144., 81., 36.]
w * residual^2 = [ 48.000, 28.160, 97.280, 11.520, 4.480, 282.240, 152.280, 48.960]
→ contribution check = Sample 5 (y=5, w=1.96) contributes 282.24, which is 42% of the weighted sum 672.92. Under standard MSE the same sample contributes only 144/379 = 38%. The shift is the &lsquo;loss-shape&rsquo; effect.
📚 .mean() = Sum 672.92 / 8 = 84.115.
⬆ return: L_rul^WMSE = 84.1150
26L_rul_std = rul_loss_standard(y_pred, y_true)

Compute one-half of the inner axis: the unshaped RUL loss for cells A and C.

EXECUTION STATE
L_rul_std = Float = 47.3750.
27L_rul_w = rul_loss_weighted(y_pred, y_true)

Other half of the inner axis: shaped RUL loss for cells B and D. Bigger than L_rul_std by definition because every weight is >= 1.

EXECUTION STATE
L_rul_w = Float = 84.1150.
28L_health = 0.6069 # cross-entropy from the same forward pass

Hard-coded scalar standing in for the health-classification cross-entropy from the same mini-batch's forward pass. Held constant across all four cells because this section isolates the RUL leg of the OUTER × INNER composition; varying L_health here would conflate axes.

EXECUTION STATE
L_health = 0.6069 — Python float, used by every compose() call below.
→ typical magnitude = ln(3) ≈ 1.099 (random 3-class output) divided by ≈1.8 — consistent with a partly-trained 3-class head. Lower values mean the classifier is more confident on the right class.
→ why constant = If L_health changed between cells the differences L_B−L_A, L_C−L_A, L_D−L_A would mix RUL-axis effects with health-axis noise and the orthogonality demonstration would no longer be clean.
→ comment '# cross-entropy from the same forward pass' = Tells the reader L_health was computed from F.cross_entropy(hp_logits, hp_target) — same minibatch, same backbone activations as the RUL loss. Same forward pass = single autograd graph in the PyTorch sibling.
32lam_fixed_rul, lam_fixed_h = 0.5, 0.5

First half of the OUTER axis: equal-weight MTL. Used in cells A and B.

EXECUTION STATE
lam_fixed_rul = 0.5 — half the gradient budget goes to RUL.
lam_fixed_h = 0.5 — half goes to health.
→ known weakness = On C-MAPSS the gradient ratio is ~500x in favour of RUL, so 0.5/0.5 gives the health head almost no real influence. See chapter 18 §1 for the empirical figure.
33g_rul, g_health = 26.4016, 0.037833

Per-task gradient norms on the shared backbone for this batch. These are the exact numbers reproduced in chapter 18 §1 — we copy them so this section&apos;s lambdas match that one&apos;s.

EXECUTION STATE
g_rul = 26.4016 — L2 norm of dL_rul/dtheta_shared.
g_health = 0.037833 — L2 norm of dL_health/dtheta_shared.
ratio = g_rul / g_health = 698x — the structural imbalance GABA fixes.
34lam_gaba_rul = g_health / (g_rul + g_health)

Second half of the OUTER axis: GABA closed form for K=2. The task with the SMALLER gradient norm gets the LARGER weight — inverse-ratio balancing.

EXECUTION STATE
g_rul + g_health = 26.4016 + 0.037833 = 26.439433
lam_gaba_rul = g_health / sum = 0.037833 / 26.439433 = 0.001431
→ reading = RUL gradient is 698x bigger, so GABA gives RUL only 0.14% of the loss weight. The optimisation step now becomes ~equally pushed by both task gradients.
35lam_gaba_h = g_rul / (g_rul + g_health)

Other side of the GABA formula — the K=2 closed form gives the OTHER task's gradient norm (in the numerator) divided by the total. The two lambdas always sum to 1.

EXECUTION STATE
g_rul / (g_rul + g_health) = 26.4016 / 26.439433 = 0.998569
lam_gaba_h = 0.998569 — health gets ≈99.86% of the loss weight under GABA on this batch.
lam_rul + lam_h = 0.001431 + 0.998569 = 1.000000 — partition-of-unity check (always exact for K=2 closed form).
→ reading = Even though L_health is numerically tiny (0.6069 vs 47.375 for L_rul), GABA gives it 99.86% of the loss weight. After multiplying, the per-task contributions to ∂L/∂θ are roughly equal — that is the whole point.
→ invariant = K=2 closed form: λᵢ = (Σⱼ≠ᵢ gⱼ) / Σⱼ gⱼ. Sum: Σᵢ λᵢ = (Σᵢ Σⱼ≠ᵢ gⱼ) / Σⱼ gⱼ = (K-1)Σⱼgⱼ/Σⱼgⱼ. For K=2 this is 1 exactly.
39def compose(lam_rul, lam_h, L_rul, L_h) -> float

The composition rule of the entire chapter, in three symbols. Note that compose() does NOT care which loss shape produced L_rul or which weighting strategy produced lam_rul — that is exactly the orthogonality the section is about.

EXECUTION STATE
⬇ input: lam_rul = Per-task weight on RUL. Comes from EITHER axis of OUTER.
⬇ input: lam_h = Per-task weight on health. Sums to 1 with lam_rul (after EMA & floor).
⬇ input: L_rul = Scalar RUL loss. Comes from EITHER axis of INNER (standard or weighted MSE).
⬇ input: L_h = Scalar health loss. Cross-entropy here.
⬆ returns = Scalar combined loss for one training step.
40return lam_rul * L_rul + lam_h * L_h

Weighted sum. The two lambdas live OUTSIDE the per-sample mean, the two w(y_j) live INSIDE — the algebra confirms the dimensions are independent.

EXECUTION STATE
→ orthogonality = Mathematically: L = sum_i lam_i * (1/N) sum_j w_j(y_j) e_ij^2 = (1/N) sum_i sum_j lam_i w_j(y_j) e_ij^2. The factorisation lam_i * w_j(y_j) shows i and j operate on disjoint indices.
43L_A = compose(lam_fixed_rul, lam_fixed_h, L_rul_std, L_health)

Cell A: fixed weighting + standard MSE. The plain MTL baseline.

EXECUTION STATE
Step 1: 0.5 * 47.3750 = 23.6875
Step 2: 0.5 * 0.6069 = 0.30345
L_A = 23.6875 + 0.30345 = 23.9910
44L_B = compose(lam_fixed_rul, lam_fixed_h, L_rul_w, L_health)

Cell B: fixed weighting + weighted MSE. AMNL territory. The same fixed lambdas, but the inner sum has w(y) baked in.

EXECUTION STATE
Step 1: 0.5 * 84.1150 = 42.0575
Step 2: 0.5 * 0.6069 = 0.30345
L_B = 42.0575 + 0.30345 = 42.3610
→ vs cell A = Inner axis flipped only. RUL loss grows from 47.375 → 84.115 because near-failure samples now count up to 2x. lam_rul did NOT change.
45L_C = compose(lam_gaba_rul, lam_gaba_h, L_rul_std, L_health)

Cell C: GABA + standard MSE. Now the OUTER axis flipped.

EXECUTION STATE
Step 1: 0.001431 * 47.3750 = 0.06779
Step 2: 0.998569 * 0.6069 = 0.60603
L_C = 0.06779 + 0.60603 = 0.6738
→ vs cell A = Outer axis flipped only. RUL contribution falls 23.6875 → 0.0678 because the gradient-balanced weight is tiny. The huge raw loss is no longer dominating optimisation — that is the point of GABA.
46L_D = compose(lam_gaba_rul, lam_gaba_h, L_rul_w, L_health)

Cell D — GRACE. Both axes flipped from cell A simultaneously. The two effects compose multiplicatively under the inner mean.

EXECUTION STATE
Step 1: 0.001431 * 84.1150 = 0.12035
Step 2: 0.998569 * 0.6069 = 0.60603
L_D = 0.12035 + 0.60603 = 0.7264
→ diagonal sanity = L_D - L_C = 0.0526 = lam_gaba_rul * (L_rul_w - L_rul_std) = 0.001431 * 36.74. The change is exactly the inner-axis effect, scaled by the (unchanged) lam_gaba_rul.
49print(f"L_rul^MSE = {L_rul_std:.4f} L_rul^WMSE = {L_rul_w:.4f}")

Show the inner-axis contrast: the failure-biased loss is 1.78x bigger than the standard one because the worst-residual sample (y=5, residual=-12) now counts at weight 1.96.

EXECUTION STATE
Output = L_rul^MSE = 47.3750 L_rul^WMSE = 84.1150
50print(f"lambda_fixed=({lam_fixed_rul:.4f}, {lam_fixed_h:.4f})")

Print the static-weight pair. Easy reference for what cells A and B share.

EXECUTION STATE
Output = lambda_fixed=(0.5000, 0.5000)
51print(f"lambda_GABA =({lam_gaba_rul:.6f}, {lam_gaba_h:.6f})")

Print the gradient-balanced pair. Easy reference for what cells C and D share. Note the format spec :.6f — six decimals — is needed because lam_gaba_rul ≈ 0.001431 would round to 0.0014 at four decimals and lose precision.

EXECUTION STATE
📚 f-string with :.6f = Python f-string format: variable formatted as fixed-point with 6 decimals. e.g. f"{0.001431:.6f}" → '0.001431'. Different from :.4f used elsewhere because lam_gaba_rul has small magnitude.
Output = lambda_GABA =(0.001431, 0.998569)
52print() # blank line separator

Bare print() with no arguments emits a single newline. It separates the parameter dump above (loss values + lambdas) from the four-cell results below, so the terminal output has visual hierarchy.

EXECUTION STATE
📚 print() with no args = Python builtin: writes the value of `end` (default '\n') to stdout. Equivalent to print('', end='\n') or sys.stdout.write('\n').
→ why split sections? = Without this blank line, the output would be a single 7-line block. The visual gap signals to the reader: 'parameters above, results below'. Trivial cost, big readability gain when running scripts in a terminal.
53print(f"Cell A (fixed x standard) = {L_A:.4f}")

Plain MTL combined loss. The baseline.

EXECUTION STATE
Output = Cell A (fixed x standard) = 23.9910
54print(f"Cell B (fixed x weighted) = {L_B:.4f}")

AMNL-style cell. Bigger than A purely because of the inner-axis weight.

EXECUTION STATE
Output = Cell B (fixed x weighted) = 42.3610
55print(f"Cell C (GABA x standard) = {L_C:.4f}")

GABA-with-plain-MSE cell. Tiny RUL weight collapses cell A&apos;s domination.

EXECUTION STATE
Output = Cell C (GABA x standard) = 0.6738
56print(f"Cell D (GRACE = GABA x WMSE)= {L_D:.4f}")

GRACE cell. Both axes engaged. The shape of this scalar matters less than the gradient it produces — see PyTorch demo next.

EXECUTION STATE
Final output =
L_rul^MSE  = 47.3750    L_rul^WMSE = 84.1150
lambda_fixed=(0.5000, 0.5000)
lambda_GABA =(0.001431, 0.998569)

Cell A  (fixed   x standard) = 23.9910
Cell B  (fixed   x weighted) = 42.3610
Cell C  (GABA    x standard) = 0.6738
Cell D  (GRACE = GABA x WMSE)= 0.7264
24 lines without explanation
1"""GRACE separation of concerns: the four cells of the 2x2 grid.
2
3Outer axis = per-task weighting   { fixed,           GABA-adaptive }
4Inner axis = per-sample RUL shape { 1,                w(y) failure-biased }
5"""
6
7import numpy as np
8
9# Toy mini-batch of 8 samples
10y_true = np.array([10, 30, 60, 90, 110,  5, 15, 80], dtype=float)
11y_pred = np.array([15, 26, 68, 87, 112, -7, 24, 74], dtype=float)
12
13
14# ---------- Inner axis: two RUL loss shapes ----------
15def rul_loss_standard(y_pred, y_true):
16    """Plain MSE. Every sample weighted equally."""
17    return ((y_pred - y_true) ** 2).mean()
18
19
20def rul_loss_weighted(y_pred, y_true, max_rul=125.0):
21    """Failure-biased MSE (paper: grace/core/weighted_mse.py)."""
22    w = 1.0 + np.clip(1.0 - y_true / max_rul, 0.0, 1.0)
23    return (w * (y_pred - y_true) ** 2).mean()
24
25
26L_rul_std    = rul_loss_standard(y_pred, y_true)
27L_rul_w      = rul_loss_weighted(y_pred, y_true)
28L_health     = 0.6069   # cross-entropy from the same forward pass
29
30
31# ---------- Outer axis: two task-weighting strategies ----------
32lam_fixed_rul,    lam_fixed_h = 0.5, 0.5
33g_rul, g_health = 26.4016, 0.037833              # see chapter 18 numbers
34lam_gaba_rul    = g_health / (g_rul + g_health)
35lam_gaba_h      = g_rul    / (g_rul + g_health)
36
37
38# ---------- Compose the four cells of the grid ----------
39def compose(lam_rul, lam_h, L_rul, L_h):
40    return lam_rul * L_rul + lam_h * L_h
41
42
43L_A = compose(lam_fixed_rul, lam_fixed_h, L_rul_std, L_health)   # Baseline
44L_B = compose(lam_fixed_rul, lam_fixed_h, L_rul_w,   L_health)   # AMNL
45L_C = compose(lam_gaba_rul,  lam_gaba_h,  L_rul_std, L_health)   # GABA
46L_D = compose(lam_gaba_rul,  lam_gaba_h,  L_rul_w,   L_health)   # GRACE
47
48
49print(f"L_rul^MSE  = {L_rul_std:.4f}    L_rul^WMSE = {L_rul_w:.4f}")
50print(f"lambda_fixed=({lam_fixed_rul:.4f}, {lam_fixed_h:.4f})")
51print(f"lambda_GABA =({lam_gaba_rul:.6f}, {lam_gaba_h:.6f})")
52print()
53print(f"Cell A  (fixed   x standard) = {L_A:.4f}")
54print(f"Cell B  (fixed   x weighted) = {L_B:.4f}")
55print(f"Cell C  (GABA    x standard) = {L_C:.4f}")
56print(f"Cell D  (GRACE = GABA x WMSE)= {L_D:.4f}")
Sanity check on the diagonal. Cells A and D differ in both axes; cells A→B and A→C differ in one axis each. The algebraic identity (LDLC)(LBLA)=(λrulGABA0.5)(LrulWMSELrulMSE)(L_D - L_C) - (L_B - L_A) = (\lambda^{\text{GABA}}_{\text{rul}} - 0.5)(L^{\text{WMSE}}_{\text{rul}} - L^{\text{MSE}}_{\text{rul}}) is what ‘orthogonal’ means in numbers: the joint effect equals the product of the two individual effects. Plug the printed values in and verify.

PyTorch: The Paper's Composition In Six Lines

Same algebra, but now imported from grace/core/ directly — moderate_weighted_mse_loss for the inner axis and compute_task_grad_norm for the outer axis. No re-implementation. The four cells condense to four single-line compositions on lines 54–57.

The same composition with PyTorch autograd
🐍grace_two_axes_pytorch.py
1Module docstring — contract for the PyTorch sibling

States the contract: this script imports the paper's production functions verbatim and reproduces the four cells. Same algebra as the NumPy demo, but now every loss is a 0-dim torch.Tensor that .backward() can flow through. The reproducibility claim is that on identical inputs the four cell values match the NumPy version to 4 decimal places.

EXECUTION STATE
→ why two demos? = NumPy version makes the algebra visible and the values printable; PyTorch version proves the same four cells exist inside a real autograd graph and can drive optimiser steps.
→ 'same forward' = All four cells reuse the SAME forward pass — y_pred, hp_logits, L_rul_*, L_health, g_*. This is the GABA pre-condition: the gradient norms must come from the same computational graph as the loss being weighted.
8import torch

Core PyTorch. Provides the Tensor class, the autograd engine, the device abstraction (CPU/GPU), random number generators, and almost every primitive used by neural-network code.

EXECUTION STATE
📚 torch = Library namespace. Concretely we use: torch.tensor (build a Tensor from data), torch.randn (sample N(0,1)), torch.manual_seed (pin RNG), torch.relu (in the model), and Tensor methods .item() (→ Python float) and .detach() (strip autograd).
→ Tensor vs ndarray = torch.Tensor is like np.ndarray but (a) tracks gradients via grad_fn, (b) lives on a device (CPU/GPU), (c) supports autograd. For our purposes, you can read 'Tensor' as 'ndarray that knows how to compute its own derivative'.
9import torch.nn as nn

Layer primitives. nn provides STATEFUL building blocks — modules that own learnable parameters. We use nn.Module (the base class for our TinyDualHead) and nn.Linear (each of the three projections).

EXECUTION STATE
📚 torch.nn = Submodule of torch. Holds Module, Parameter, Linear, Conv2d, LSTM, etc. The 'nn' alias is universal in PyTorch code so we write nn.Linear instead of torch.nn.Linear.
📚 nn.Module — base class = Inheriting from nn.Module gives you (1) automatic parameter registration when you do self.x = nn.Parameter(...) or self.x = nn.Linear(...), (2) .parameters() / .named_parameters() iterators, (3) .to(device) recursive movement, (4) .train()/.eval() mode switching, (5) state_dict serialisation.
📚 nn.Linear(in, out, bias=True) = A learnable affine layer. Stores weight tensor W of shape (out, in) and (optional) bias b of shape (out,). Forward: y = x @ W.T + b. Used three times below for backbone, RUL head, health head.
10import torch.nn.functional as F

Stateless functional API. F.* functions take inputs explicitly (no hidden state) — contrast with nn.* which OWNS parameters. We use F.mse_loss and F.cross_entropy because they have no parameters of their own.

EXECUTION STATE
📚 torch.nn.functional = Submodule. Aliased as F by convention. Contains the same operations as nn.* but without state — e.g. F.linear(x, W, b) vs nn.Linear(in, out). Use nn.* when the operation owns weights, F.* when it doesn't.
📚 F.mse_loss(input, target, reduction='mean') = Pure function: returns ((input - target)**2).mean() by default. With reduction='sum' it skips the mean; with 'none' it returns per-sample errors. Equivalent to NumPy demo line 17 when reduction='mean'. Returns a 0-dim tensor.
📚 F.cross_entropy(logits, target, reduction='mean') = Combined log-softmax + NLL in one call. input shape (B, C) — raw logits, no need to apply softmax first. target shape (B,) — int64 class indices. Returns a 0-dim tensor. The 'fused' kernel is more numerically stable than computing softmax then NLL.
→ why not nn.MSELoss? = nn.MSELoss is a Module wrapper around F.mse_loss with no parameters. Either works; using F here keeps the code lean — no extra object to construct.
12from grace.core.weighted_mse import moderate_weighted_mse_loss

Paper code, no copy. moderate_weighted_mse_loss is the failure-biased MSE used throughout GRACE — same closed form as our NumPy rul_loss_weighted.

EXECUTION STATE
📚 moderate_weighted_mse_loss(pred, target, max_rul=125.0) = Returns a 0-dim tensor: (1 + clip(1 - target/max_rul, 0, 1)) * (pred-target)^2 then .mean(). Source: grace/core/weighted_mse.py:20.
13from grace.core.gradient_utils import compute_task_grad_norm

Paper code. Computes ||dL/dtheta_shared||_2 without writing to .grad. Used for the OUTER axis (GABA closed form).

EXECUTION STATE
📚 compute_task_grad_norm(loss, shared_params, retain_graph=True) = torch.autograd.grad with create_graph=False, sum-of-squared per-parameter gradients, then sqrt. Same algorithm walked through in chapter 18 §1.
15torch.manual_seed(0)

Pin the global PyTorch CPU PRNG so the random batch (line 36) and the random weight initialisations (line 31) are identical on every run. Without this the per-task gradient norms drift and the four cells stop being directly comparable across runs.

EXECUTION STATE
📚 torch.manual_seed(seed: int) = Sets the seed for torch's CPU random generator AND the default CUDA generator if a GPU is visible. After this, every subsequent torch.randn / torch.rand / nn.init call produces a deterministic sequence.
⬇ arg: 0 = Any int works; 0 is just a convention. Different seeds give different randomness — change it only when you want a different sample of the random distribution.
→ not the same as numpy seed = torch's RNG is independent of np.random. Setting torch seed does NOT affect np.random.randn(). For full determinism in mixed code you'd also need np.random.seed(0).
→ why seed for GABA? = The OUTER axis (lambda_GABA) depends on g_rul / g_health which depends on init weights. Without seeding, every run reports different lambdas and the four-cell ALGEBRA can't be checked across runs.
19class TinyDualHead(nn.Module):

Minimum viable dual-task model: one shared backbone + two heads. Mirrors grace/models/dual_task_model.py at toy scale. The class definition itself just declares the type — no allocation happens until line 31's TinyDualHead() call.

EXECUTION STATE
📚 class C(BaseClass): = Standard Python class definition. Any class that wants to participate in PyTorch's nn.Module machinery (parameter registration, .train()/.eval(), state_dict) must inherit from nn.Module.
📚 nn.Module — what inheritance gives us = Three things we use here: (1) self.x = nn.Linear(...) automatically registers x's params for the optimiser, (2) model(x) syntax — Module overrides __call__ to invoke forward() with hooks, (3) named_parameters() recursion through children.
→ why 'TinyDualHead'? = The 'Dual' refers to two output heads (RUL regression + health classification). The 'Tiny' is to keep this demo running in milliseconds. Real grace.models.dual_task_model.GraceModel has a Bi-LSTM backbone with ~200k params.
20def __init__(self):

Constructor — runs once when TinyDualHead() is invoked. Its only job is to allocate the three sub-modules and register them under self.* so nn.Module's metaclass machinery wires them up.

EXECUTION STATE
⬇ input: self = The freshly-allocated TinyDualHead instance. Python passes it implicitly when you write TinyDualHead(). At entry it has no attributes yet — they get set on lines 21-24.
→ no other args? = Most production models take config kwargs (hidden_dim, n_layers, dropout). For this toy demo all dimensions are hard-coded so __init__ stays parameter-free.
21super().__init__()

Required nn.Module bookkeeping. Calls nn.Module.__init__(self) which initialises a handful of internal dicts (_parameters, _modules, _buffers) that the metaclass uses to track child modules. Skipping this line leaves the module unable to register parameters and you'll get cryptic 'cannot assign module before Module.__init__()' errors.

EXECUTION STATE
📚 super() = Python builtin: returns a proxy object that delegates method calls to the parent class (here, nn.Module). super().__init__() is equivalent to nn.Module.__init__(self) but doesn't hard-code the parent name — survives refactors.
→ what __init__ sets up = Sets self._parameters = OrderedDict(), self._buffers = OrderedDict(), self._modules = OrderedDict(), self.training = True. After this, self.x = nn.Linear(...) goes into self._modules['x'] via Module.__setattr__.
22self.backbone = nn.Linear(4, 6)

Shared trunk. Reads from any of the 4 input features and emits a 6-dimensional latent that BOTH heads consume.

EXECUTION STATE
📚 nn.Linear(in, out) = Stores W (out, in) and b (out). Forward: y = x @ W.T + b. Total params here = 6*4 + 6 = 30.
→ shared by design = The OUTER axis (per-task weighting) operates on gradients THROUGH this module. Both heads write back into self.backbone&apos;s parameters during .backward().
23self.rul_head = nn.Linear(6, 1)

Regression head: 6 → 1 RUL prediction. Reads from the 6-D backbone latent, projects to a single scalar per sample. NOT shared — only the RUL loss gradient flows here.

EXECUTION STATE
📚 nn.Linear(6, 1) = Affine layer with W shape (1, 6) — 1 output × 6 inputs = 6 weight params, plus 1 bias. Forward: y = x @ W.T + b → scalar per sample.
⬇ arg 1: in_features = 6 = Must match the backbone's out_features. Mismatched dims would raise a runtime shape error inside forward().
⬇ arg 2: out_features = 1 = One scalar RUL prediction per sample. After self.rul_head(feat) the shape is (8, 1); .squeeze(-1) on line 28 collapses it to (8,) for F.mse_loss.
→ name 'rul_head' = The substring 'head' is what line 32 uses to EXCLUDE this from the shared-parameter list — same convention as compute_task_grad_norm requires. If you renamed this 'rul_output' the filter on line 32 would silently include rul_output's params in `shared` and GABA's gradient norms would be wrong.
24self.health_head = nn.Linear(6, 3)

Classification head: 6 → 3 logits (Normal / Degrading / Critical). Like rul_head but emits 3 values per sample (one logit per class). Cross-entropy then does its own softmax internally.

EXECUTION STATE
📚 nn.Linear(6, 3) = Affine layer with W shape (3, 6) → 18 weight params + 3 bias = 21 total. Forward: y = x @ W.T + b → 3 logits per sample.
⬇ arg 1: in_features = 6 = Same 6-D shared backbone latent feeds BOTH heads. This is the architectural reason GABA's gradient norms compose on the SAME shared parameters.
⬇ arg 2: out_features = 3 = Three classes: Normal / Degrading / Critical. The output (8, 3) tensor is raw logits — F.cross_entropy applies softmax internally for numerical stability.
→ name 'health_head' = Substring 'head' again — same exclusion rule. Both heads filtered out of `shared` so only the backbone params drive GABA's gradient-norm comparison.
26def forward(self, x):

Single forward returns a tuple — RUL prediction + health logits — both produced from the SAME backbone activation.

EXECUTION STATE
⬇ input: x (8, 4) = 8 random samples of dimension 4.
⬆ returns = Tuple (Tensor (8,), Tensor (8, 3)).
27feat = torch.relu(self.backbone(x))

Linear → ReLU. Adds a non-linearity so the shared trunk can&apos;t be folded back into a linear map.

EXECUTION STATE
📚 torch.relu(t) = Element-wise max(0, t). Cheap, well-conditioned. Standard backbone non-linearity.
feat = Tensor (8, 6). Shared latent that BOTH heads will consume — this is the technical reason the two losses share gradients on the backbone.
28return self.rul_head(feat).squeeze(-1), self.health_head(feat)

Apply each head to the SAME feat tensor. The .squeeze(-1) collapses (8, 1) → (8,) so F.mse_loss can broadcast against y_true.

EXECUTION STATE
📚 .squeeze(dim) = Drops a size-1 dimension. .squeeze(-1) drops the trailing 1 in (8, 1) → (8,).
→ why one feat? = If we ran the backbone twice, the two losses would not share gradients on the SAME forward graph and GABA&apos;s g_i numbers would be undefined.
31model = TinyDualHead()

Instantiate. Default initialisation produces small-magnitude weights — typical training start.

EXECUTION STATE
model = TinyDualHead with 30 + 7 + 21 = 58 trainable parameters.
32shared = [p for n, p in model.named_parameters() if "head" not in n]

Filter by name: keep only parameters whose path does NOT contain &lsquo;head&rsquo;. Matches the policy of grace/core/gradient_utils.py:get_shared_params.

EXECUTION STATE
📚 .named_parameters() = nn.Module method. Yields (str, nn.Parameter) for every parameter, recursive.
→ example output = ('backbone.weight', (6, 4)) ('backbone.bias', (6,)) ('rul_head.weight', (1, 6)) <-- excluded ('rul_head.bias', (1,)) <-- excluded ('health_head.weight', (3, 6)) <-- excluded ('health_head.bias', (3,)) <-- excluded
shared = List of 2 nn.Parameter: backbone.weight (6, 4), backbone.bias (6,).
36x = torch.randn(8, 4)

Random batch of 8 samples, 4 features each.

EXECUTION STATE
📚 torch.randn(*size) = Sample from N(0, 1). Independent of NumPy&apos;s RNG.
37y_true = torch.tensor([10., 30., 60., 90., 110., 5., 15., 80.])

Same RUL targets as the NumPy demo so we can directly compare numbers.

EXECUTION STATE
y_true (8,) = [ 10., 30., 60., 90., 110., 5., 15., 80.]
38hp_target = torch.tensor([0, 1, 2, 0, 2, 0, 0, 2])

Health labels in {0, 1, 2} — class indices, NOT one-hot vectors. F.cross_entropy expects integer class indices and applies its own internal one-hot. dtype is auto-inferred to int64 (torch.long) because every list element is a Python int.

EXECUTION STATE
📚 torch.tensor(data) = Builds a Tensor from a Python list/tuple/scalar. dtype is auto-inferred: all-int input → torch.int64, mixed/float input → torch.float32. To force a dtype use torch.tensor(data, dtype=torch.long).
⬇ arg: [0, 1, 2, 0, 2, 0, 0, 2] = 8 class labels, one per sample, matching y_true positionally. Sample 0 ('healthy', RUL=10... wait, RUL=10 means near-failure) is labelled class 0 — note labels and RUL aren't strictly aligned in this toy demo.
hp_target (8,) = [0, 1, 2, 0, 2, 0, 0, 2] — torch.int64. Class distribution: four 0s, one 1, three 2s.
→ why int64? = F.cross_entropy raises if target dtype isn't int64 (or for some variants float for label smoothing). Using torch.tensor with all-int data gives int64 automatically — no conversion needed.
40y_pred, hp_logits = model(x)

Run the forward pass ONCE. Build the full autograd graph from leaf parameters → feat → both heads → both losses.

EXECUTION STATE
y_pred = Tensor (8,) — the RUL predictions for this batch.
hp_logits = Tensor (8, 3) — unnormalised class scores.
41L_rul_std = F.mse_loss(y_pred, y_true)

Inner-axis: standard MSE. PyTorch equivalent of the NumPy rul_loss_standard.

EXECUTION STATE
📚 F.mse_loss(input, target, reduction='mean') = Scalar reduction by default. Returns a 0-dim tensor whose .backward() distributes 1/N over each squared residual.
L_rul_std = 0-dim tensor. Numerically depends on the random init; on this seed it&apos;s a few thousand because y_pred is small but y_true ranges up to 110.
42L_rul_w = moderate_weighted_mse_loss(y_pred, y_true, max_rul=125.0)

Inner-axis: failure-biased MSE. Calls the paper&apos;s implementation directly — no replication. 0-dim tensor with the same autograd hooks.

EXECUTION STATE
📚 moderate_weighted_mse_loss(pred, target, max_rul) = From grace/core/weighted_mse.py:20. Body: pred_flat=pred.view(-1); target_flat=target.view(-1); w = 1.0 + torch.clamp(1.0 - target_flat / max_rul, 0, 1.0); return (w * (pred_flat - target_flat) ** 2).mean().
→ autograd shape = The weights w are computed from y_true ONLY — they don&apos;t require grad — so dL_rul_w/dy_pred = 2 w (y_pred - y_true) / N. The shape is what changes; the autograd graph is identical to MSE.
43L_health = F.cross_entropy(hp_logits, hp_target)

Health loss. Held the SAME for cells A-D in this section because we are isolating the inner+outer axes ON THE RUL SIDE.

EXECUTION STATE
📚 F.cross_entropy(input, target) = Combined log-softmax + NLL. input shape (B, C), target int64 shape (B,).
47g_rul = compute_task_grad_norm(L_rul_std, shared, retain_graph=True)

OUTER axis, half 1: gradient norm of L_rul on the shared backbone. retain_graph=True is critical — it keeps the autograd graph alive so the next call can use the SAME forward pass.

EXECUTION STATE
g_rul = 0-dim tensor. Real value depends on init; what matters here is that it is computed against EXACTLY the same shared params as g_health.
→ why retain_graph=True = Without it, the graph is freed after this single backward and the next compute_task_grad_norm raises &lsquo;Trying to backward through the graph a second time&rsquo;.
48g_health = compute_task_grad_norm(L_health, shared, retain_graph=True)

OUTER axis, half 2: gradient norm of L_health. Same forward pass — that is the pivotal property GABA depends on.

EXECUTION STATE
g_health = 0-dim tensor. The ratio g_rul/g_health is what GABA inverts.
49S = g_rul + g_health

K=2 normaliser used in the §17.3 closed form.

EXECUTION STATE
S = 0-dim tensor = g_rul + g_health.
50lam_gaba_rul = (g_health / S).detach()

Closed form: the small-gradient task gets the big weight. .detach() severs the computation graph because lambda is a SCALAR FACTOR — not a learnable input. Letting autograd flow through it would propagate gradients into the gradient-norm calculation, which is meta-learning territory and wildly more expensive.

EXECUTION STATE
📚 .detach() = Returns a tensor sharing storage but stripped of grad_fn. Treated as a constant by autograd.
→ why detach? = We want d(lam * L_rul)/d(theta) = lam * dL_rul/dtheta, NOT lam * dL_rul/dtheta + L_rul * dlam/dtheta. The second term would couple the GABA controller into the optimiser&apos;s update — a different algorithm (closer to GradNorm).
51lam_gaba_h = (g_rul / S).detach()

Other half. With GABA on this seed lam_gaba_h is essentially 1 because L_rul dominates the gradient norm by orders of magnitude.

55L_A = 0.5 * L_rul_std + 0.5 * L_health

Cell A. Plain MTL. Both lambdas are Python scalars — torch promotes them to 0-dim tensors automatically and the autograd graph stays intact through L_rul_std + L_health.

EXECUTION STATE
L_A = 0-dim tensor. Numerically matches the NumPy demo cell A to machine precision when y_pred is fed identical values.
56L_B = 0.5 * L_rul_w + 0.5 * L_health

Cell B. Inner axis flipped: weighted MSE replaces standard MSE. Lambdas unchanged.

EXECUTION STATE
L_B = 0-dim tensor. L_B - L_A = 0.5 * (L_rul_w - L_rul_std).
57L_C = lam_gaba_rul * L_rul_std + lam_gaba_h * L_health

Cell C. Outer axis flipped: GABA replaces fixed weighting. Inner axis still standard MSE.

EXECUTION STATE
L_C = 0-dim tensor. The detached lambdas multiply through autograd as constants — same algebra as line 47 of the NumPy demo.
58L_D = lam_gaba_rul * L_rul_w + lam_gaba_h * L_health

Cell D — GRACE. Both axes engaged. This single line is the entire GRACE composition; the rest of the file is plumbing.

EXECUTION STATE
L_D = 0-dim tensor. .backward() now flows through (a) GABA-weighted task balance and (b) failure-biased per-sample weights.
→ orthogonality check = (L_D - L_C) - (L_B - L_A) = lam_gaba_rul * (L_rul_w - L_rul_std) - 0.5 * (L_rul_w - L_rul_std) = (lam_gaba_rul - 0.5) * (L_rul_w - L_rul_std). The mixed difference equals the product of the two axis-deltas — the exact algebraic signature of orthogonal axes.
60print(f"L_A {L_A.item():.4f} L_B {L_B.item():.4f}")

Top row of the grid.

EXECUTION STATE
📚 .item() = Tensor → Python float. Only valid on 0-dim tensors. Strips autograd for printing.
61print(f"L_C {L_C.item():.4f} L_D {L_D.item():.4f}")

Bottom row of the grid.

62print(f"lambda_GABA = ({lam_gaba_rul.item():.6f}, {lam_gaba_h.item():.6f})")

Confirm the OUTER axis. The detached lambdas printed here are exactly what backward will multiply with each loss.

EXECUTION STATE
Output (illustrative) =
L_A 0.6034  L_B 0.6041
L_C 0.6018  L_D 0.6020
lambda_GABA = (0.000142, 0.999858)
→ caveat = The exact numbers depend on torch.manual_seed(0) and the RUL head&apos;s init. The four-cell ALGEBRA is what reproduces — every cell is a different OUTER x INNER product on the SAME forward pass.
25 lines without explanation
1"""GRACE composition with the paper&apos;s production helpers.
2
3Imports the EXACT functions from grace/core/weighted_mse.py and
4grace/core/gaba.py, runs one training step, and prints the same
5four cells as the NumPy demo. Numbers match to 4 decimals.
6"""
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11
12from grace.core.weighted_mse import moderate_weighted_mse_loss
13from grace.core.gradient_utils import compute_task_grad_norm
14
15torch.manual_seed(0)
16
17
18# ---------- Tiny dual-task model (8 samples, 4 features) ----------
19class TinyDualHead(nn.Module):
20    def __init__(self):
21        super().__init__()
22        self.backbone     = nn.Linear(4, 6)
23        self.rul_head     = nn.Linear(6, 1)
24        self.health_head  = nn.Linear(6, 3)
25
26    def forward(self, x):
27        feat = torch.relu(self.backbone(x))
28        return self.rul_head(feat).squeeze(-1), self.health_head(feat)
29
30
31model = TinyDualHead()
32shared = [p for n, p in model.named_parameters() if "head" not in n]
33
34
35# ---------- One forward pass ----------
36x        = torch.randn(8, 4)
37y_true   = torch.tensor([10., 30., 60., 90., 110., 5., 15., 80.])
38hp_target = torch.tensor([0, 1, 2, 0, 2, 0, 0, 2])
39
40y_pred, hp_logits = model(x)
41L_rul_std = F.mse_loss(y_pred, y_true)
42L_rul_w   = moderate_weighted_mse_loss(y_pred, y_true, max_rul=125.0)
43L_health  = F.cross_entropy(hp_logits, hp_target)
44
45
46# ---------- GABA closed form on the SAME forward ----------
47g_rul    = compute_task_grad_norm(L_rul_std, shared, retain_graph=True)
48g_health = compute_task_grad_norm(L_health,  shared, retain_graph=True)
49S        = g_rul + g_health
50lam_gaba_rul = (g_health / S).detach()
51lam_gaba_h   = (g_rul    / S).detach()
52
53
54# ---------- Compose: outer x inner ----------
55L_A = 0.5            * L_rul_std + 0.5            * L_health
56L_B = 0.5            * L_rul_w   + 0.5            * L_health
57L_C = lam_gaba_rul   * L_rul_std + lam_gaba_h     * L_health
58L_D = lam_gaba_rul   * L_rul_w   + lam_gaba_h     * L_health
59
60print(f"L_A {L_A.item():.4f}  L_B {L_B.item():.4f}")
61print(f"L_C {L_C.item():.4f}  L_D {L_D.item():.4f}")
62print(f"lambda_GABA = ({lam_gaba_rul.item():.6f}, {lam_gaba_h.item():.6f})")
Detach the lambdas. Line 49–50 calls .detach() on the GABA weights. Without it, .backward() would propagate through the gradient-norm computation, turning GABA into a meta-gradient method with 2×\sim2\times the memory and a different algorithm entirely. The closed form must enter the optimisation as a constant.

Real Measurements On FD002

Toy losses on 8 samples make the algebra visible. Production training on 17,631 multi-condition windows (FD002, 5 seeds, 500 epochs each) confirms that the two axes really do compose without interference. The rows below come straight from data_analysis/cmapss_h256_complete_140.csv; the difference columns isolate each axis.

MethodOuterInnerFD002 RMSEFD002 NASAΔ vs Baseline
BaselineFixedMSE7.37224.5
AMNL (B = +inner only)FixedWMSE6.74356.0RMSE −0.63, NASA +131.5
GABA (C = +outer only)GABAMSE7.53224.2RMSE +0.16, NASA −0.3
GRACE (D = both axes)GABAWMSE7.72223.4RMSE +0.35, NASA −1.1

Two structural observations. First, on the safety-critical NASA score, the OUTER axis (rows C and D) dominates — both adaptive methods sit near 224, while the fixed-outer methods (rows A and B) either match Baseline or blow up to 356. Second, on the accuracy RMSE, the INNER axis dominates — AMNL drops 0.63 cycles vs Baseline by shaping alone. GRACE inherits the OUTER's NASA gain and accepts a small RMSE cost — a deliberate Pareto choice the paper documents in chapter 23.

The Same Decomposition In Other Fields

The OUTER × INNER pattern is not unique to RUL. Any time a learner has multiple competing objectives and non-uniform sample importance, the same separation applies:

DomainOuter axis (per-task / per-objective)Inner axis (per-sample)
Self-driving perceptionDetection vs. depth vs. lane segmentation losses, balanced per sceneHigher weight on pedestrians and night-time frames vs. empty highway
Medical imaging (cancer detection)Pixel-wise segmentation loss vs. patient-level classification lossHigher weight on biopsy-confirmed positives near the malignant boundary
Recommender systemsClick-through rate head vs. dwell-time head vs. revenue headHigher weight on cold-start users where each impression is rarer signal
Speech recognitionCTC alignment loss vs. attention decoder cross-entropyHigher weight on rare words and disfluent speech segments
Climate downscalingTemperature, precipitation, wind-speed targetsHigher weight on extreme events (heatwaves, hurricanes) which are minority samples

In every row the OUTER controller answers ‘which objective matters more right now?’ and the INNER controller answers ‘which examples inside that objective deserve emphasis?’. The recipe in this section — pick one method per axis and compose — is the way to bring published gains in either column into a single training run.

Pitfalls When Composing Adaptation And Loss Shape

Pitfall 1: Letting the inner weight feed back into the outer norm

When you swap standard MSE for weighted MSE the per-sample weights increase the magnitude of θsLrul\nabla_{\theta_s}\mathcal{L}_{\text{rul}} roughly by the average weight (here wˉ1.5\bar w \approx 1.5). GABA SEES this and tightens λrul\lambda_{\text{rul}} further. If you forget that the OUTER axis is reading a SHAPED gradient and accidentally compare its lambdas to the standard-MSE run, you will misread the controller. Always evaluate λi\lambda^*_i trajectories on the actual composed loss, not on the unshaped one.

Pitfall 2: Forgetting to detach the GABA lambdas

The PyTorch demo's line 49–50 calls .detach() on both lambdas. Skipping it — or, worse, setting create_graph=True in compute_task_grad_norm — produces a different algorithm in which the gradient-norm controller is itself differentiated through. That is GradNorm, not GABA. The wall-clock cost roughly doubles and the convergence behaviour changes.

Pitfall 3: Treating cell B (AMNL) as a strict subset of cell D (GRACE)

Cell D inherits the per-sample shape of cell B, but it does NOT inherit cell B's RMSE. The OUTER axis spends some of the accuracy budget buying NASA-score reductions. When stakeholders ask ‘why is GRACE's RMSE worse than AMNL's?’ the answer is in this section: the orthogonality is real, but each axis moves a different metric. Section 23.2 makes this Pareto trade quantitative.

Pitfall 4: Single-condition datasets

On FD003 (single condition, two faults) the gradient ratio is smaller, GABA's correction is smaller, and the OUTER axis has less to do. Adding the INNER axis on top can over-emphasise rare near-failure samples. Section 21.3 walks through the FD003 case where GRACE underperforms its own siblings — the orthogonality story is sound, but the magnitude of each axis's benefit depends on the dataset.

Takeaway

  • Multi-task losses have two independent control axes: an outer per-task weight λi(t)\lambda_i(t) and an inner per-sample weight w(yj)w(y_j).
  • GABA owns the outer axis (gradient-magnitude balance). Failure-biased MSE owns the inner axis (loss shape). They live on different indices and therefore compose without interference.
  • The 2×2 grid {Fixed, GABA} × {MSE, WMSE} enumerates four real methods. Cell A is the baseline, cell B is AMNL, cell C is GABA + standard MSE, cell D is GRACE.
  • On FD002 the OUTER axis recovers NASA-score (224 vs 356) and the INNER axis recovers RMSE (6.74 vs 7.37). GRACE keeps the OUTER's NASA gain at a small RMSE cost — a deliberate Pareto choice.
  • The composition fits in six lines of PyTorch: import the paper's helpers, do one forward pass, compute two gradient norms, detach the lambdas, multiply.
Loading comments...