Walk into a recording studio and look at the console. Every channel has its own fader — vocal, kick drum, bass, guitar — and every channel also has its own equaliser. The faders ask which sources matter right now; the equalisers ask which frequencies inside each source matter. The engineer can re-balance the mix without touching the EQ, and re-EQ a single instrument without touching the faders. The two controls live on different axes.
GRACE is built on the same separation. GABA is the fader bank: it decides, every training step, how much of the gradient budget each task gets. Failure-biased weighted MSE is the EQ: it decides, within the RUL task, which samples carry the heaviest squared error. The contribution of this chapter — and this section in particular — is to show that these two knobs commute. You can adapt the per-task weighting and shape the per-sample loss at the same time, with no interference, and the resulting algorithm is the model that wins NASA-score on multi-condition C-MAPSS.
The headline. Two orthogonal axes — outer λi(t) per task and inner w(yj) per sample — produce a 2×2 grid of methods. Cell A is the plain baseline, cell B is AMNL-style, cell C is GABA + standard MSE, cell D is GRACE.
Two Independent Axes Of A Multi-Task Loss
Every multi-task loss in this book has the same skeleton:
Two indices, two roles. The outer index i ranges over tasks — RUL regression and health classification, in our case. The inner index j ranges over samples in the mini-batch. Methods in the literature differ purely in which of these indices they touch:
Axis
What it weights
Examples
GRACE choice
Outer (per-task)
How much each task contributes to the combined gradient
Fixed (0.5/0.5), Uncertainty (Kendall et al.), GradNorm (Chen et al.), DWA (Liu et al.), GABA (this book)
GABA:λi∗(t)=EMA(gi(t)+gj(t)gj(t))
Inner (per-sample)
How much each sample inside a task contributes to that task's loss
Standard MSE (uniform), Asymmetric, Focal, Quantile, Failure-biased weighted MSE
Failure-biased:w(yj)=1+clip(1−125yj,0,1)
Because the two axes touch different indices, the combined formula factorises cleanly:
Inside the brackets the outer factor depends on i only and the inner factor on j only. The gradient with respect to a backbone parameter θs therefore decomposes into one outer-modulated sum per task — replacing GABA does not require rewriting the inner weights, and replacing weighted MSE does not require rewriting the outer controller.
Why this matters in practice. Every published MTL paper picks one axis to attack. AMNL freezes the outer axis at 0.5/0.5 and shapes the inner axis. GradNorm and DWA shape the outer axis and leave the inner axis at uniform MSE. GRACE is the first method on this problem that engages both simultaneously. The question of this chapter is whether the two corrections compose cleanly — numerically, on real data, the answer is yes for multi-condition datasets and a careful almost for single-condition FD003 (section 21.3).
The Composition: Outer × Inner
Substitute the GRACE choices into the general form. The outer factor becomes the GABA closed form derived in section 17.3,
and the inner factor becomes the failure-biased ramp,
w(yj)=1+clip(1−Rmaxyj,0,1),Rmax=125.
Plugging both into the skeleton, the GRACE loss for K=2 tasks (RUL + health) is
LGRACE(t)=λrul∗(t)weighted MSE on RULN1j=1∑Nw(yj)(y^j−yj)2+λhealth∗(t)cross-entropy on healthLCE.
Compare it term by term with the four sibling methods:
Cell
Outer (per-task)
Inner (per-sample RUL)
Method
A
λrul=λhealth=0.5
w(y)=1
Baseline (0.5/0.5)
B
λrul=λhealth=0.5
w(y)=1+clip(1−125y,0,1)
AMNL (0.5/0.5 + WMSE)
C
λrul∗=EMA(gr+ghgh)
w(y)=1
GABA + standard MSE
D
λrul∗=EMA(gr+ghgh)
w(y)=1+clip(1−125y,0,1)
GRACE
Interactive: The Two-Axes Grid
Click any cell below to inspect its outer and inner formulas independently, and to read the FD002 average performance from the paper's 5-seed runs at h=256. The two axis-highlight toggles isolate orthogonality empirically: with same loss-shape (row) on, the highlighted cells share the inner axis and only the outer changes — the entire NASA column-difference must come from the per-task weighting.
Loading separation-of-concerns visualizer…
What to notice. Cell B (AMNL) wins RMSE (6.74) but loses NASA (356) — one axis pulls hard on accuracy, ignores balance. Cell C (GABA) wins NASA (224.2) but doesn't exploit the per-sample asymmetry. Cell D (GRACE) sits at the Pareto sweet spot: NASA 223.4, the lowest of all four cells, with RMSE only 0.36 cycles worse than AMNL.
Python: Computing All Four Cells From Scratch
Read the code line by line. Click any line to see its execution trace on the left, including the actual yj — w(yj) — λi intermediate values and the final scalar produced by each cell. The script reuses the gradient-norm numbers from chapter 18 §1 so the OUTER lambdas are exact, not invented.
Reconstructing the 2×2 grid in NumPy
🐍grace_two_axes_demo.py
Explanation(32)
Code(56)
1Module docstring — names the 2×2 grid
Names the 2×2 grid we will reconstruct numerically. The outer axis flips between FIXED equal weighting (0.5/0.5) and GABA's adaptive closed form. The inner axis flips between uniform per-sample weighting and the failure-biased ramp w(y). The Cartesian product of the two axes generates exactly four cells — and Cell D (adaptive × failure-biased) is GRACE.
→ why a docstring? = Module docstrings are the FIRST string literal at the top of a Python file. Tools like help(), Sphinx, and IDEs surface them as documentation. Functionally they are no-ops — Python evaluates the literal and discards it.
→ cell map = Cell A = fixed × uniform (Baseline MTL)
Cell B = fixed × weighted (AMNL)
Cell C = GABA × uniform (GABA + std MSE)
Cell D = GABA × weighted (GRACE)
7import numpy as np
NumPy gives us vectorised arithmetic over y_pred and y_true. All four loss shapes below are one-line array operations rather than Python loops — broadcasting, element-wise power, and .mean() reductions all run as optimised C under the hood.
EXECUTION STATE
📚 numpy = Numerical-array library. Provides ndarray, np.clip, np.array, broadcasting, element-wise ops, and reductions (.mean, .sum) used throughout the demo. Every operation here would be a slow Python for-loop without it.
as np = Standard alias by community convention. Lets us write np.array(), np.clip() instead of numpy.array(). Almost universal in scientific Python code.
→ why we need it here = Computing w * (y_pred - y_true) ** 2 requires element-wise broadcasting across 8 samples. Pure-Python lists can't do this without an explicit for-loop.
Eight ground-truth RUL values spanning the full range: two near-failure (5, 10), two early-life (90, 110), four in between. The mix lets the inner-axis weight w(y) actually do something.
EXECUTION STATE
📚 np.array(list, dtype=float) = Builds a 1-D ndarray. dtype=float forces float64 so subsequent w(y)=1+clip(...) doesn't silently promote integers.
Predictions chosen so the residual y_pred-y_true has different signs and magnitudes per sample. Two near-failure samples are off by -12 and +9 — that is exactly where the inner-axis weight will amplify the penalty.
The bottom-row INNER axis: every sample contributes the same weight to the average. This is the loss shape used in cell A (Baseline) and cell C (GABA + standard MSE). Mathematically: L = (1/N) Σⱼ (ŷⱼ − yⱼ)².
→ y_pred purpose = What the network predicts for each engine's remaining useful life. Element j is the model's RUL forecast for engine window j. Note sample 5 = -7 (unphysical negative RUL — the linear head doesn't know about non-negativity).
→ y_true purpose = Cycles-until-failure for each engine. Mix is deliberate: samples 0,5,6 are near failure (≤15), samples 3,4 are healthy (≥90). The mix lets the inner-axis weight w(y) actually do something.
→ no max_rul here = Standard MSE has no per-sample shape, so it doesn't need the RUL cap. The function is dataset-agnostic.
⬆ returns = Python float — the unweighted mean of squared errors. Single scalar that the autograd graph (in PyTorch) or the optimiser (here, just a print) consumes.
16Function docstring — "Plain MSE. Every sample weighted equally."
Inner-function docstring. PEP 257 convention: first triple-quoted string after a def becomes function.__doc__ and is what help() shows. This one names the loss shape (plain MSE) and the key invariant (uniform weighting) so readers don't have to derive it from the body.
EXECUTION STATE
→ why uniform weighting matters here = Contrast with rul_loss_weighted() at line 20, where each sample is multiplied by w(y). The docstring's job is to surface that distinction without forcing the reader to diff the two function bodies.
17return ((y_pred - y_true) ** 2).mean()
Three vectorised operations chained in a single line: subtract, square, average. Equivalent to (1/N) · Σⱼ (ŷⱼ − yⱼ)². NumPy fuses no operations here, but each step runs as one C loop, so even on 8 samples this is dramatically faster than a Python for-loop.
→ broadcasting rule = Two ndarrays of the same shape: NumPy subtracts element-wise. Output shape = input shape = (8,). No copy of either array — the result is freshly allocated.
📚 ** 2 (element-wise power) = NumPy operator: applies pow(x, 2) to every element. Equivalent to np.square(x) but written inline. e.g. (-12)**2 = 144, (-3)**2 = 9.
📚 .mean() — ndarray reduction = ndarray method: returns sum(self) / size(self). With no axis arg, reduces over ALL elements. Here: (25+16+64+9+4+144+81+36) / 8 = 379 / 8 = 47.375. Returns a 0-D ndarray (Python treats it as a float).
→ arg axis (not used here) = If we passed axis=0 it would mean per-column on a 2-D array. With a 1-D vector, axis=0 reduces along the only axis — same as no axis.
⬆ return: L_rul^MSE = 47.3750 — the value that lines 26, 43, 45 will consume.
The top-row INNER axis: each sample is multiplied by w(y) = 1 + clip(1 − y/max_rul, 0, 1). Samples near failure (y → 0) get weight 2; samples beyond max_rul get weight 1. Line-for-line equivalent to grace/core/weighted_mse.py:moderate_weighted_mse_loss — used in cells B (AMNL) and D (GRACE).
→ y_pred role = Same predictions as in rul_loss_standard — only the loss SHAPE changes between functions, not the inputs. That is the orthogonality story for the inner axis.
→ y_true role here = Used TWICE inside the body: once for the residual (y_pred - y_true) and once for the weight w(y_true). The weight depends ONLY on the target, never on the prediction — so dw/dy_pred = 0.
⬇ input: max_rul = 125.0 = RUL cap (default value). Matches the paper's piecewise-linear RUL target where everything ≥ 125 cycles is treated as 'fully healthy'. Samples with y_true ≥ 125 → weight = 1 (no extra emphasis).
→ why max_rul = 125? = Empirical choice from the C-MAPSS paper. With y measured in cycles and engines lasting ~200-400 cycles, 125 is the boundary where degradation typically becomes detectable. Hard-coding 125.0 as a kwarg default lets callers override it for other datasets.
⬆ returns = Python float — weighted mean of squared errors with per-sample weights w(yⱼ) ∈ [1, 2]. Strictly ≥ standard MSE because every weight ≥ 1.
Inner-function docstring. Pins the implementation to a real source file in the GRACE codebase so readers can cross-check the closed-form weight w(y) against the paper's production code. The reproducibility contract is: same inputs, same outputs, line-for-line.
EXECUTION STATE
→ traceability = The paper file grace/core/weighted_mse.py:20 holds moderate_weighted_mse_loss — the PyTorch version. Body identical to lines 22-23 below, modulo torch.clamp ↔ np.clip.
Builds the per-sample weight vector. Linear ramp: 2 at y=0, 1 at y>=max_rul. The clip prevents negative weights when y > max_rul (which would actually amplify healthy-engine errors).
📚 np.clip(x, lo, hi) = Element-wise: x_i = max(lo, min(hi, x_i)). Here clip(..., 0.0, 1.0) keeps the ramp inside [0, 1] so weights stay in [1, 2].
→ why clip? = If y_true > max_rul (some samples go up to 200+ before the piecewise cap), 1 - y/max_rul becomes negative → weight < 1, accidentally PENALISING healthy-engine samples. clip pins it at 0 so the weight is exactly 1 there.
→ contribution check = Sample 5 (y=5, w=1.96) contributes 282.24, which is 42% of the weighted sum 672.92. Under standard MSE the same sample contributes only 144/379 = 38%. The shift is the ‘loss-shape’ effect.
📚 .mean() = Sum 672.92 / 8 = 84.115.
⬆ return: L_rul^WMSE = 84.1150
26L_rul_std = rul_loss_standard(y_pred, y_true)
Compute one-half of the inner axis: the unshaped RUL loss for cells A and C.
EXECUTION STATE
L_rul_std = Float = 47.3750.
27L_rul_w = rul_loss_weighted(y_pred, y_true)
Other half of the inner axis: shaped RUL loss for cells B and D. Bigger than L_rul_std by definition because every weight is >= 1.
EXECUTION STATE
L_rul_w = Float = 84.1150.
28L_health = 0.6069 # cross-entropy from the same forward pass
Hard-coded scalar standing in for the health-classification cross-entropy from the same mini-batch's forward pass. Held constant across all four cells because this section isolates the RUL leg of the OUTER × INNER composition; varying L_health here would conflate axes.
EXECUTION STATE
L_health = 0.6069 — Python float, used by every compose() call below.
→ typical magnitude = ln(3) ≈ 1.099 (random 3-class output) divided by ≈1.8 — consistent with a partly-trained 3-class head. Lower values mean the classifier is more confident on the right class.
→ why constant = If L_health changed between cells the differences L_B−L_A, L_C−L_A, L_D−L_A would mix RUL-axis effects with health-axis noise and the orthogonality demonstration would no longer be clean.
→ comment '# cross-entropy from the same forward pass' = Tells the reader L_health was computed from F.cross_entropy(hp_logits, hp_target) — same minibatch, same backbone activations as the RUL loss. Same forward pass = single autograd graph in the PyTorch sibling.
32lam_fixed_rul, lam_fixed_h = 0.5, 0.5
First half of the OUTER axis: equal-weight MTL. Used in cells A and B.
EXECUTION STATE
lam_fixed_rul = 0.5 — half the gradient budget goes to RUL.
lam_fixed_h = 0.5 — half goes to health.
→ known weakness = On C-MAPSS the gradient ratio is ~500x in favour of RUL, so 0.5/0.5 gives the health head almost no real influence. See chapter 18 §1 for the empirical figure.
33g_rul, g_health = 26.4016, 0.037833
Per-task gradient norms on the shared backbone for this batch. These are the exact numbers reproduced in chapter 18 §1 — we copy them so this section's lambdas match that one's.
EXECUTION STATE
g_rul = 26.4016 — L2 norm of dL_rul/dtheta_shared.
g_health = 0.037833 — L2 norm of dL_health/dtheta_shared.
ratio = g_rul / g_health = 698x — the structural imbalance GABA fixes.
34lam_gaba_rul = g_health / (g_rul + g_health)
Second half of the OUTER axis: GABA closed form for K=2. The task with the SMALLER gradient norm gets the LARGER weight — inverse-ratio balancing.
→ reading = RUL gradient is 698x bigger, so GABA gives RUL only 0.14% of the loss weight. The optimisation step now becomes ~equally pushed by both task gradients.
35lam_gaba_h = g_rul / (g_rul + g_health)
Other side of the GABA formula — the K=2 closed form gives the OTHER task's gradient norm (in the numerator) divided by the total. The two lambdas always sum to 1.
→ reading = Even though L_health is numerically tiny (0.6069 vs 47.375 for L_rul), GABA gives it 99.86% of the loss weight. After multiplying, the per-task contributions to ∂L/∂θ are roughly equal — that is the whole point.
The composition rule of the entire chapter, in three symbols. Note that compose() does NOT care which loss shape produced L_rul or which weighting strategy produced lam_rul — that is exactly the orthogonality the section is about.
EXECUTION STATE
⬇ input: lam_rul = Per-task weight on RUL. Comes from EITHER axis of OUTER.
⬇ input: lam_h = Per-task weight on health. Sums to 1 with lam_rul (after EMA & floor).
⬇ input: L_rul = Scalar RUL loss. Comes from EITHER axis of INNER (standard or weighted MSE).
⬇ input: L_h = Scalar health loss. Cross-entropy here.
⬆ returns = Scalar combined loss for one training step.
40return lam_rul * L_rul + lam_h * L_h
Weighted sum. The two lambdas live OUTSIDE the per-sample mean, the two w(y_j) live INSIDE — the algebra confirms the dimensions are independent.
EXECUTION STATE
→ orthogonality = Mathematically: L = sum_i lam_i * (1/N) sum_j w_j(y_j) e_ij^2 = (1/N) sum_i sum_j lam_i w_j(y_j) e_ij^2. The factorisation lam_i * w_j(y_j) shows i and j operate on disjoint indices.
Cell C: GABA + standard MSE. Now the OUTER axis flipped.
EXECUTION STATE
Step 1: 0.001431 * 47.3750 = 0.06779
Step 2: 0.998569 * 0.6069 = 0.60603
L_C = 0.06779 + 0.60603 = 0.6738
→ vs cell A = Outer axis flipped only. RUL contribution falls 23.6875 → 0.0678 because the gradient-balanced weight is tiny. The huge raw loss is no longer dominating optimisation — that is the point of GABA.
Show the inner-axis contrast: the failure-biased loss is 1.78x bigger than the standard one because the worst-residual sample (y=5, residual=-12) now counts at weight 1.96.
Print the gradient-balanced pair. Easy reference for what cells C and D share. Note the format spec :.6f — six decimals — is needed because lam_gaba_rul ≈ 0.001431 would round to 0.0014 at four decimals and lose precision.
EXECUTION STATE
📚 f-string with :.6f = Python f-string format: variable formatted as fixed-point with 6 decimals. e.g. f"{0.001431:.6f}" → '0.001431'. Different from :.4f used elsewhere because lam_gaba_rul has small magnitude.
Output = lambda_GABA =(0.001431, 0.998569)
52print() # blank line separator
Bare print() with no arguments emits a single newline. It separates the parameter dump above (loss values + lambdas) from the four-cell results below, so the terminal output has visual hierarchy.
EXECUTION STATE
📚 print() with no args = Python builtin: writes the value of `end` (default '\n') to stdout. Equivalent to print('', end='\n') or sys.stdout.write('\n').
→ why split sections? = Without this blank line, the output would be a single 7-line block. The visual gap signals to the reader: 'parameters above, results below'. Trivial cost, big readability gain when running scripts in a terminal.
53print(f"Cell A (fixed x standard) = {L_A:.4f}")
Plain MTL combined loss. The baseline.
EXECUTION STATE
Output = Cell A (fixed x standard) = 23.9910
54print(f"Cell B (fixed x weighted) = {L_B:.4f}")
AMNL-style cell. Bigger than A purely because of the inner-axis weight.
56print(f"Cell D (GRACE = GABA x WMSE)= {L_D:.4f}")
GRACE cell. Both axes engaged. The shape of this scalar matters less than the gradient it produces — see PyTorch demo next.
EXECUTION STATE
Final output =
L_rul^MSE = 47.3750 L_rul^WMSE = 84.1150
lambda_fixed=(0.5000, 0.5000)
lambda_GABA =(0.001431, 0.998569)
Cell A (fixed x standard) = 23.9910
Cell B (fixed x weighted) = 42.3610
Cell C (GABA x standard) = 0.6738
Cell D (GRACE = GABA x WMSE)= 0.7264
24 lines without explanation
1"""GRACE separation of concerns: the four cells of the 2x2 grid.
23Outer axis = per-task weighting { fixed, GABA-adaptive }
4Inner axis = per-sample RUL shape { 1, w(y) failure-biased }
5"""67import numpy as np
89# Toy mini-batch of 8 samples10y_true = np.array([10,30,60,90,110,5,15,80], dtype=float)11y_pred = np.array([15,26,68,87,112,-7,24,74], dtype=float)121314# ---------- Inner axis: two RUL loss shapes ----------15defrul_loss_standard(y_pred, y_true):16"""Plain MSE. Every sample weighted equally."""17return((y_pred - y_true)**2).mean()181920defrul_loss_weighted(y_pred, y_true, max_rul=125.0):21"""Failure-biased MSE (paper: grace/core/weighted_mse.py)."""22 w =1.0+ np.clip(1.0- y_true / max_rul,0.0,1.0)23return(w *(y_pred - y_true)**2).mean()242526L_rul_std = rul_loss_standard(y_pred, y_true)27L_rul_w = rul_loss_weighted(y_pred, y_true)28L_health =0.6069# cross-entropy from the same forward pass293031# ---------- Outer axis: two task-weighting strategies ----------32lam_fixed_rul, lam_fixed_h =0.5,0.533g_rul, g_health =26.4016,0.037833# see chapter 18 numbers34lam_gaba_rul = g_health /(g_rul + g_health)35lam_gaba_h = g_rul /(g_rul + g_health)363738# ---------- Compose the four cells of the grid ----------39defcompose(lam_rul, lam_h, L_rul, L_h):40return lam_rul * L_rul + lam_h * L_h
414243L_A = compose(lam_fixed_rul, lam_fixed_h, L_rul_std, L_health)# Baseline44L_B = compose(lam_fixed_rul, lam_fixed_h, L_rul_w, L_health)# AMNL45L_C = compose(lam_gaba_rul, lam_gaba_h, L_rul_std, L_health)# GABA46L_D = compose(lam_gaba_rul, lam_gaba_h, L_rul_w, L_health)# GRACE474849print(f"L_rul^MSE = {L_rul_std:.4f} L_rul^WMSE = {L_rul_w:.4f}")50print(f"lambda_fixed=({lam_fixed_rul:.4f}, {lam_fixed_h:.4f})")51print(f"lambda_GABA =({lam_gaba_rul:.6f}, {lam_gaba_h:.6f})")52print()53print(f"Cell A (fixed x standard) = {L_A:.4f}")54print(f"Cell B (fixed x weighted) = {L_B:.4f}")55print(f"Cell C (GABA x standard) = {L_C:.4f}")56print(f"Cell D (GRACE = GABA x WMSE)= {L_D:.4f}")
Sanity check on the diagonal. Cells A and D differ in both axes; cells A→B and A→C differ in one axis each. The algebraic identity (LD−LC)−(LB−LA)=(λrulGABA−0.5)(LrulWMSE−LrulMSE) is what ‘orthogonal’ means in numbers: the joint effect equals the product of the two individual effects. Plug the printed values in and verify.
PyTorch: The Paper's Composition In Six Lines
Same algebra, but now imported from grace/core/ directly — moderate_weighted_mse_loss for the inner axis and compute_task_grad_norm for the outer axis. No re-implementation. The four cells condense to four single-line compositions on lines 54–57.
The same composition with PyTorch autograd
🐍grace_two_axes_pytorch.py
Explanation(37)
Code(62)
1Module docstring — contract for the PyTorch sibling
States the contract: this script imports the paper's production functions verbatim and reproduces the four cells. Same algebra as the NumPy demo, but now every loss is a 0-dim torch.Tensor that .backward() can flow through. The reproducibility claim is that on identical inputs the four cell values match the NumPy version to 4 decimal places.
EXECUTION STATE
→ why two demos? = NumPy version makes the algebra visible and the values printable; PyTorch version proves the same four cells exist inside a real autograd graph and can drive optimiser steps.
→ 'same forward' = All four cells reuse the SAME forward pass — y_pred, hp_logits, L_rul_*, L_health, g_*. This is the GABA pre-condition: the gradient norms must come from the same computational graph as the loss being weighted.
8import torch
Core PyTorch. Provides the Tensor class, the autograd engine, the device abstraction (CPU/GPU), random number generators, and almost every primitive used by neural-network code.
EXECUTION STATE
📚 torch = Library namespace. Concretely we use: torch.tensor (build a Tensor from data), torch.randn (sample N(0,1)), torch.manual_seed (pin RNG), torch.relu (in the model), and Tensor methods .item() (→ Python float) and .detach() (strip autograd).
→ Tensor vs ndarray = torch.Tensor is like np.ndarray but (a) tracks gradients via grad_fn, (b) lives on a device (CPU/GPU), (c) supports autograd. For our purposes, you can read 'Tensor' as 'ndarray that knows how to compute its own derivative'.
9import torch.nn as nn
Layer primitives. nn provides STATEFUL building blocks — modules that own learnable parameters. We use nn.Module (the base class for our TinyDualHead) and nn.Linear (each of the three projections).
EXECUTION STATE
📚 torch.nn = Submodule of torch. Holds Module, Parameter, Linear, Conv2d, LSTM, etc. The 'nn' alias is universal in PyTorch code so we write nn.Linear instead of torch.nn.Linear.
📚 nn.Module — base class = Inheriting from nn.Module gives you (1) automatic parameter registration when you do self.x = nn.Parameter(...) or self.x = nn.Linear(...), (2) .parameters() / .named_parameters() iterators, (3) .to(device) recursive movement, (4) .train()/.eval() mode switching, (5) state_dict serialisation.
📚 nn.Linear(in, out, bias=True) = A learnable affine layer. Stores weight tensor W of shape (out, in) and (optional) bias b of shape (out,). Forward: y = x @ W.T + b. Used three times below for backbone, RUL head, health head.
10import torch.nn.functional as F
Stateless functional API. F.* functions take inputs explicitly (no hidden state) — contrast with nn.* which OWNS parameters. We use F.mse_loss and F.cross_entropy because they have no parameters of their own.
EXECUTION STATE
📚 torch.nn.functional = Submodule. Aliased as F by convention. Contains the same operations as nn.* but without state — e.g. F.linear(x, W, b) vs nn.Linear(in, out). Use nn.* when the operation owns weights, F.* when it doesn't.
📚 F.mse_loss(input, target, reduction='mean') = Pure function: returns ((input - target)**2).mean() by default. With reduction='sum' it skips the mean; with 'none' it returns per-sample errors. Equivalent to NumPy demo line 17 when reduction='mean'. Returns a 0-dim tensor.
📚 F.cross_entropy(logits, target, reduction='mean') = Combined log-softmax + NLL in one call. input shape (B, C) — raw logits, no need to apply softmax first. target shape (B,) — int64 class indices. Returns a 0-dim tensor. The 'fused' kernel is more numerically stable than computing softmax then NLL.
→ why not nn.MSELoss? = nn.MSELoss is a Module wrapper around F.mse_loss with no parameters. Either works; using F here keeps the code lean — no extra object to construct.
Paper code. Computes ||dL/dtheta_shared||_2 without writing to .grad. Used for the OUTER axis (GABA closed form).
EXECUTION STATE
📚 compute_task_grad_norm(loss, shared_params, retain_graph=True) = torch.autograd.grad with create_graph=False, sum-of-squared per-parameter gradients, then sqrt. Same algorithm walked through in chapter 18 §1.
15torch.manual_seed(0)
Pin the global PyTorch CPU PRNG so the random batch (line 36) and the random weight initialisations (line 31) are identical on every run. Without this the per-task gradient norms drift and the four cells stop being directly comparable across runs.
EXECUTION STATE
📚 torch.manual_seed(seed: int) = Sets the seed for torch's CPU random generator AND the default CUDA generator if a GPU is visible. After this, every subsequent torch.randn / torch.rand / nn.init call produces a deterministic sequence.
⬇ arg: 0 = Any int works; 0 is just a convention. Different seeds give different randomness — change it only when you want a different sample of the random distribution.
→ not the same as numpy seed = torch's RNG is independent of np.random. Setting torch seed does NOT affect np.random.randn(). For full determinism in mixed code you'd also need np.random.seed(0).
→ why seed for GABA? = The OUTER axis (lambda_GABA) depends on g_rul / g_health which depends on init weights. Without seeding, every run reports different lambdas and the four-cell ALGEBRA can't be checked across runs.
19class TinyDualHead(nn.Module):
Minimum viable dual-task model: one shared backbone + two heads. Mirrors grace/models/dual_task_model.py at toy scale. The class definition itself just declares the type — no allocation happens until line 31's TinyDualHead() call.
EXECUTION STATE
📚 class C(BaseClass): = Standard Python class definition. Any class that wants to participate in PyTorch's nn.Module machinery (parameter registration, .train()/.eval(), state_dict) must inherit from nn.Module.
📚 nn.Module — what inheritance gives us = Three things we use here: (1) self.x = nn.Linear(...) automatically registers x's params for the optimiser, (2) model(x) syntax — Module overrides __call__ to invoke forward() with hooks, (3) named_parameters() recursion through children.
→ why 'TinyDualHead'? = The 'Dual' refers to two output heads (RUL regression + health classification). The 'Tiny' is to keep this demo running in milliseconds. Real grace.models.dual_task_model.GraceModel has a Bi-LSTM backbone with ~200k params.
20def __init__(self):
Constructor — runs once when TinyDualHead() is invoked. Its only job is to allocate the three sub-modules and register them under self.* so nn.Module's metaclass machinery wires them up.
EXECUTION STATE
⬇ input: self = The freshly-allocated TinyDualHead instance. Python passes it implicitly when you write TinyDualHead(). At entry it has no attributes yet — they get set on lines 21-24.
→ no other args? = Most production models take config kwargs (hidden_dim, n_layers, dropout). For this toy demo all dimensions are hard-coded so __init__ stays parameter-free.
21super().__init__()
Required nn.Module bookkeeping. Calls nn.Module.__init__(self) which initialises a handful of internal dicts (_parameters, _modules, _buffers) that the metaclass uses to track child modules. Skipping this line leaves the module unable to register parameters and you'll get cryptic 'cannot assign module before Module.__init__()' errors.
EXECUTION STATE
📚 super() = Python builtin: returns a proxy object that delegates method calls to the parent class (here, nn.Module). super().__init__() is equivalent to nn.Module.__init__(self) but doesn't hard-code the parent name — survives refactors.
→ what __init__ sets up = Sets self._parameters = OrderedDict(), self._buffers = OrderedDict(), self._modules = OrderedDict(), self.training = True. After this, self.x = nn.Linear(...) goes into self._modules['x'] via Module.__setattr__.
22self.backbone = nn.Linear(4, 6)
Shared trunk. Reads from any of the 4 input features and emits a 6-dimensional latent that BOTH heads consume.
EXECUTION STATE
📚 nn.Linear(in, out) = Stores W (out, in) and b (out). Forward: y = x @ W.T + b. Total params here = 6*4 + 6 = 30.
→ shared by design = The OUTER axis (per-task weighting) operates on gradients THROUGH this module. Both heads write back into self.backbone's parameters during .backward().
23self.rul_head = nn.Linear(6, 1)
Regression head: 6 → 1 RUL prediction. Reads from the 6-D backbone latent, projects to a single scalar per sample. NOT shared — only the RUL loss gradient flows here.
EXECUTION STATE
📚 nn.Linear(6, 1) = Affine layer with W shape (1, 6) — 1 output × 6 inputs = 6 weight params, plus 1 bias. Forward: y = x @ W.T + b → scalar per sample.
⬇ arg 1: in_features = 6 = Must match the backbone's out_features. Mismatched dims would raise a runtime shape error inside forward().
⬇ arg 2: out_features = 1 = One scalar RUL prediction per sample. After self.rul_head(feat) the shape is (8, 1); .squeeze(-1) on line 28 collapses it to (8,) for F.mse_loss.
→ name 'rul_head' = The substring 'head' is what line 32 uses to EXCLUDE this from the shared-parameter list — same convention as compute_task_grad_norm requires. If you renamed this 'rul_output' the filter on line 32 would silently include rul_output's params in `shared` and GABA's gradient norms would be wrong.
24self.health_head = nn.Linear(6, 3)
Classification head: 6 → 3 logits (Normal / Degrading / Critical). Like rul_head but emits 3 values per sample (one logit per class). Cross-entropy then does its own softmax internally.
EXECUTION STATE
📚 nn.Linear(6, 3) = Affine layer with W shape (3, 6) → 18 weight params + 3 bias = 21 total. Forward: y = x @ W.T + b → 3 logits per sample.
⬇ arg 1: in_features = 6 = Same 6-D shared backbone latent feeds BOTH heads. This is the architectural reason GABA's gradient norms compose on the SAME shared parameters.
⬇ arg 2: out_features = 3 = Three classes: Normal / Degrading / Critical. The output (8, 3) tensor is raw logits — F.cross_entropy applies softmax internally for numerical stability.
→ name 'health_head' = Substring 'head' again — same exclusion rule. Both heads filtered out of `shared` so only the backbone params drive GABA's gradient-norm comparison.
26def forward(self, x):
Single forward returns a tuple — RUL prediction + health logits — both produced from the SAME backbone activation.
EXECUTION STATE
⬇ input: x (8, 4) = 8 random samples of dimension 4.
⬆ returns = Tuple (Tensor (8,), Tensor (8, 3)).
27feat = torch.relu(self.backbone(x))
Linear → ReLU. Adds a non-linearity so the shared trunk can't be folded back into a linear map.
Apply each head to the SAME feat tensor. The .squeeze(-1) collapses (8, 1) → (8,) so F.mse_loss can broadcast against y_true.
EXECUTION STATE
📚 .squeeze(dim) = Drops a size-1 dimension. .squeeze(-1) drops the trailing 1 in (8, 1) → (8,).
→ why one feat? = If we ran the backbone twice, the two losses would not share gradients on the SAME forward graph and GABA's g_i numbers would be undefined.
31model = TinyDualHead()
Instantiate. Default initialisation produces small-magnitude weights — typical training start.
EXECUTION STATE
model = TinyDualHead with 30 + 7 + 21 = 58 trainable parameters.
32shared = [p for n, p in model.named_parameters() if "head" not in n]
Filter by name: keep only parameters whose path does NOT contain ‘head’. Matches the policy of grace/core/gradient_utils.py:get_shared_params.
EXECUTION STATE
📚 .named_parameters() = nn.Module method. Yields (str, nn.Parameter) for every parameter, recursive.
Health labels in {0, 1, 2} — class indices, NOT one-hot vectors. F.cross_entropy expects integer class indices and applies its own internal one-hot. dtype is auto-inferred to int64 (torch.long) because every list element is a Python int.
EXECUTION STATE
📚 torch.tensor(data) = Builds a Tensor from a Python list/tuple/scalar. dtype is auto-inferred: all-int input → torch.int64, mixed/float input → torch.float32. To force a dtype use torch.tensor(data, dtype=torch.long).
⬇ arg: [0, 1, 2, 0, 2, 0, 0, 2] = 8 class labels, one per sample, matching y_true positionally. Sample 0 ('healthy', RUL=10... wait, RUL=10 means near-failure) is labelled class 0 — note labels and RUL aren't strictly aligned in this toy demo.
hp_target (8,) = [0, 1, 2, 0, 2, 0, 0, 2] — torch.int64. Class distribution: four 0s, one 1, three 2s.
→ why int64? = F.cross_entropy raises if target dtype isn't int64 (or for some variants float for label smoothing). Using torch.tensor with all-int data gives int64 automatically — no conversion needed.
40y_pred, hp_logits = model(x)
Run the forward pass ONCE. Build the full autograd graph from leaf parameters → feat → both heads → both losses.
EXECUTION STATE
y_pred = Tensor (8,) — the RUL predictions for this batch.
hp_logits = Tensor (8, 3) — unnormalised class scores.
41L_rul_std = F.mse_loss(y_pred, y_true)
Inner-axis: standard MSE. PyTorch equivalent of the NumPy rul_loss_standard.
EXECUTION STATE
📚 F.mse_loss(input, target, reduction='mean') = Scalar reduction by default. Returns a 0-dim tensor whose .backward() distributes 1/N over each squared residual.
L_rul_std = 0-dim tensor. Numerically depends on the random init; on this seed it's a few thousand because y_pred is small but y_true ranges up to 110.
Inner-axis: failure-biased MSE. Calls the paper's implementation directly — no replication. 0-dim tensor with the same autograd hooks.
EXECUTION STATE
📚 moderate_weighted_mse_loss(pred, target, max_rul) = From grace/core/weighted_mse.py:20. Body: pred_flat=pred.view(-1); target_flat=target.view(-1); w = 1.0 + torch.clamp(1.0 - target_flat / max_rul, 0, 1.0); return (w * (pred_flat - target_flat) ** 2).mean().
→ autograd shape = The weights w are computed from y_true ONLY — they don't require grad — so dL_rul_w/dy_pred = 2 w (y_pred - y_true) / N. The shape is what changes; the autograd graph is identical to MSE.
OUTER axis, half 1: gradient norm of L_rul on the shared backbone. retain_graph=True is critical — it keeps the autograd graph alive so the next call can use the SAME forward pass.
EXECUTION STATE
g_rul = 0-dim tensor. Real value depends on init; what matters here is that it is computed against EXACTLY the same shared params as g_health.
→ why retain_graph=True = Without it, the graph is freed after this single backward and the next compute_task_grad_norm raises ‘Trying to backward through the graph a second time’.
OUTER axis, half 2: gradient norm of L_health. Same forward pass — that is the pivotal property GABA depends on.
EXECUTION STATE
g_health = 0-dim tensor. The ratio g_rul/g_health is what GABA inverts.
49S = g_rul + g_health
K=2 normaliser used in the §17.3 closed form.
EXECUTION STATE
S = 0-dim tensor = g_rul + g_health.
50lam_gaba_rul = (g_health / S).detach()
Closed form: the small-gradient task gets the big weight. .detach() severs the computation graph because lambda is a SCALAR FACTOR — not a learnable input. Letting autograd flow through it would propagate gradients into the gradient-norm calculation, which is meta-learning territory and wildly more expensive.
EXECUTION STATE
📚 .detach() = Returns a tensor sharing storage but stripped of grad_fn. Treated as a constant by autograd.
→ why detach? = We want d(lam * L_rul)/d(theta) = lam * dL_rul/dtheta, NOT lam * dL_rul/dtheta + L_rul * dlam/dtheta. The second term would couple the GABA controller into the optimiser's update — a different algorithm (closer to GradNorm).
51lam_gaba_h = (g_rul / S).detach()
Other half. With GABA on this seed lam_gaba_h is essentially 1 because L_rul dominates the gradient norm by orders of magnitude.
55L_A = 0.5 * L_rul_std + 0.5 * L_health
Cell A. Plain MTL. Both lambdas are Python scalars — torch promotes them to 0-dim tensors automatically and the autograd graph stays intact through L_rul_std + L_health.
EXECUTION STATE
L_A = 0-dim tensor. Numerically matches the NumPy demo cell A to machine precision when y_pred is fed identical values.
56L_B = 0.5 * L_rul_w + 0.5 * L_health
Cell B. Inner axis flipped: weighted MSE replaces standard MSE. Lambdas unchanged.
→ caveat = The exact numbers depend on torch.manual_seed(0) and the RUL head's init. The four-cell ALGEBRA is what reproduces — every cell is a different OUTER x INNER product on the SAME forward pass.
25 lines without explanation
1"""GRACE composition with the paper's production helpers.
23Imports the EXACT functions from grace/core/weighted_mse.py and
4grace/core/gaba.py, runs one training step, and prints the same
5four cells as the NumPy demo. Numbers match to 4 decimals.
6"""78import torch
9import torch.nn as nn
10import torch.nn.functional as F
1112from grace.core.weighted_mse import moderate_weighted_mse_loss
13from grace.core.gradient_utils import compute_task_grad_norm
1415torch.manual_seed(0)161718# ---------- Tiny dual-task model (8 samples, 4 features) ----------19classTinyDualHead(nn.Module):20def__init__(self):21super().__init__()22 self.backbone = nn.Linear(4,6)23 self.rul_head = nn.Linear(6,1)24 self.health_head = nn.Linear(6,3)2526defforward(self, x):27 feat = torch.relu(self.backbone(x))28return self.rul_head(feat).squeeze(-1), self.health_head(feat)293031model = TinyDualHead()32shared =[p for n, p in model.named_parameters()if"head"notin n]333435# ---------- One forward pass ----------36x = torch.randn(8,4)37y_true = torch.tensor([10.,30.,60.,90.,110.,5.,15.,80.])38hp_target = torch.tensor([0,1,2,0,2,0,0,2])3940y_pred, hp_logits = model(x)41L_rul_std = F.mse_loss(y_pred, y_true)42L_rul_w = moderate_weighted_mse_loss(y_pred, y_true, max_rul=125.0)43L_health = F.cross_entropy(hp_logits, hp_target)444546# ---------- GABA closed form on the SAME forward ----------47g_rul = compute_task_grad_norm(L_rul_std, shared, retain_graph=True)48g_health = compute_task_grad_norm(L_health, shared, retain_graph=True)49S = g_rul + g_health
50lam_gaba_rul =(g_health / S).detach()51lam_gaba_h =(g_rul / S).detach()525354# ---------- Compose: outer x inner ----------55L_A =0.5* L_rul_std +0.5* L_health
56L_B =0.5* L_rul_w +0.5* L_health
57L_C = lam_gaba_rul * L_rul_std + lam_gaba_h * L_health
58L_D = lam_gaba_rul * L_rul_w + lam_gaba_h * L_health
5960print(f"L_A {L_A.item():.4f} L_B {L_B.item():.4f}")61print(f"L_C {L_C.item():.4f} L_D {L_D.item():.4f}")62print(f"lambda_GABA = ({lam_gaba_rul.item():.6f}, {lam_gaba_h.item():.6f})")
Detach the lambdas. Line 49–50 calls .detach() on the GABA weights. Without it, .backward() would propagate through the gradient-norm computation, turning GABA into a meta-gradient method with ∼2× the memory and a different algorithm entirely. The closed form must enter the optimisation as a constant.
Real Measurements On FD002
Toy losses on 8 samples make the algebra visible. Production training on 17,631 multi-condition windows (FD002, 5 seeds, 500 epochs each) confirms that the two axes really do compose without interference. The rows below come straight from data_analysis/cmapss_h256_complete_140.csv; the difference columns isolate each axis.
Method
Outer
Inner
FD002 RMSE
FD002 NASA
Δ vs Baseline
Baseline
Fixed
MSE
7.37
224.5
—
AMNL (B = +inner only)
Fixed
WMSE
6.74
356.0
RMSE −0.63, NASA +131.5
GABA (C = +outer only)
GABA
MSE
7.53
224.2
RMSE +0.16, NASA −0.3
GRACE (D = both axes)
GABA
WMSE
7.72
223.4
RMSE +0.35, NASA −1.1
Two structural observations. First, on the safety-critical NASA score, the OUTER axis (rows C and D) dominates — both adaptive methods sit near 224, while the fixed-outer methods (rows A and B) either match Baseline or blow up to 356. Second, on the accuracy RMSE, the INNER axis dominates — AMNL drops 0.63 cycles vs Baseline by shaping alone. GRACE inherits the OUTER's NASA gain and accepts a small RMSE cost — a deliberate Pareto choice the paper documents in chapter 23.
The Same Decomposition In Other Fields
The OUTER × INNER pattern is not unique to RUL. Any time a learner has multiple competing objectives and non-uniform sample importance, the same separation applies:
Domain
Outer axis (per-task / per-objective)
Inner axis (per-sample)
Self-driving perception
Detection vs. depth vs. lane segmentation losses, balanced per scene
Higher weight on pedestrians and night-time frames vs. empty highway
Medical imaging (cancer detection)
Pixel-wise segmentation loss vs. patient-level classification loss
Higher weight on biopsy-confirmed positives near the malignant boundary
Recommender systems
Click-through rate head vs. dwell-time head vs. revenue head
Higher weight on cold-start users where each impression is rarer signal
Speech recognition
CTC alignment loss vs. attention decoder cross-entropy
Higher weight on rare words and disfluent speech segments
Climate downscaling
Temperature, precipitation, wind-speed targets
Higher weight on extreme events (heatwaves, hurricanes) which are minority samples
In every row the OUTER controller answers ‘which objective matters more right now?’ and the INNER controller answers ‘which examples inside that objective deserve emphasis?’. The recipe in this section — pick one method per axis and compose — is the way to bring published gains in either column into a single training run.
Pitfalls When Composing Adaptation And Loss Shape
Pitfall 1: Letting the inner weight feed back into the outer norm
When you swap standard MSE for weighted MSE the per-sample weights increase the magnitude of ∇θsLrul roughly by the average weight (here wˉ≈1.5). GABA SEES this and tightens λrul further. If you forget that the OUTER axis is reading a SHAPED gradient and accidentally compare its lambdas to the standard-MSE run, you will misread the controller. Always evaluate λi∗ trajectories on the actual composed loss, not on the unshaped one.
Pitfall 2: Forgetting to detach the GABA lambdas
The PyTorch demo's line 49–50 calls .detach() on both lambdas. Skipping it — or, worse, setting create_graph=True in compute_task_grad_norm — produces a different algorithm in which the gradient-norm controller is itself differentiated through. That is GradNorm, not GABA. The wall-clock cost roughly doubles and the convergence behaviour changes.
Pitfall 3: Treating cell B (AMNL) as a strict subset of cell D (GRACE)
Cell D inherits the per-sample shape of cell B, but it does NOT inherit cell B's RMSE. The OUTER axis spends some of the accuracy budget buying NASA-score reductions. When stakeholders ask ‘why is GRACE's RMSE worse than AMNL's?’ the answer is in this section: the orthogonality is real, but each axis moves a different metric. Section 23.2 makes this Pareto trade quantitative.
Pitfall 4: Single-condition datasets
On FD003 (single condition, two faults) the gradient ratio is smaller, GABA's correction is smaller, and the OUTER axis has less to do. Adding the INNER axis on top can over-emphasise rare near-failure samples. Section 21.3 walks through the FD003 case where GRACE underperforms its own siblings — the orthogonality story is sound, but the magnitude of each axis's benefit depends on the dataset.
Takeaway
Multi-task losses have two independent control axes: an outer per-task weight λi(t) and an inner per-sample weight w(yj).
GABA owns the outer axis (gradient-magnitude balance). Failure-biased MSE owns the inner axis (loss shape). They live on different indices and therefore compose without interference.
The 2×2 grid {Fixed, GABA} × {MSE, WMSE} enumerates four real methods. Cell A is the baseline, cell B is AMNL, cell C is GABA + standard MSE, cell D is GRACE.
On FD002 the OUTER axis recovers NASA-score (224 vs 356) and the INNER axis recovers RMSE (6.74 vs 7.37). GRACE keeps the OUTER's NASA gain at a small RMSE cost — a deliberate Pareto choice.
The composition fits in six lines of PyTorch: import the paper's helpers, do one forward pass, compute two gradient norms, detach the lambdas, multiply.