Chapter 21
12 min read
Section 85 of 121

The GRACE Loss Equation

Combining GABA + Weighted MSE

An Equation You Can Read In One Breath

A pharmacist filling a prescription does two things at once. They weigh each active ingredient against the others — more antibiotic, less antihistamine, an exact ratio — and they shape how much of the dose hits each tissue via the delivery vehicle. A capsule with the same total mass can be tuned to release fast or slow, in the stomach or the small intestine. The two adjustments live on different axes, and a working prescription engages both.

GRACE is a prescription. The outer ratio is GABA, deciding how the gradient budget is split between the RUL and health tasks. The inner shaping is the failure-biased weighted MSE, deciding which samples inside the RUL task carry the heaviest squared error. This section writes that prescription as a single equation, walks every symbol, traces the four-stage controller pipeline, computes a worked example end to end, and shows the same equation in production PyTorch.

The headline equation. LGRACE(t)  =  λrul(t)[1N ⁣j=1Nw(yj)(y^jyj)2]  +  λhealth(t)LCE.\mathcal{L}_{\text{GRACE}}(t) \;=\; \lambda^*_{\text{rul}}(t)\,\Big[\tfrac{1}{N}\!\sum_{j=1}^{N} w(y_j)\,(\hat{y}_j-y_j)^2\Big] \;+\; \lambda^*_{\text{health}}(t)\,\mathcal{L}_{\text{CE}}.

Anatomy: Six Symbols, Two Axes

Six symbols carry every piece of GRACE's machinery. Three are per-task (outer), three are per-sample (inner):

SymbolAxisLives inRole
λrul(t), λhealth(t)\lambda^*_{\text{rul}}(t),\ \lambda^*_{\text{health}}(t)Outer (per-task)grace/core/gaba.pyEMA-smoothed, floored, renormalised GABA weights
grul(t), ghealth(t)g_{\text{rul}}(t),\ g_{\text{health}}(t)Outer (per-task)grace/core/gradient_utils.pyPer-task L2 gradient norms on shared backbone
LCE\mathcal{L}_{\text{CE}}Outer (per-task)torch.nn.functional.cross_entropyHealth-classifier loss; computed once per forward pass
w(yj)w(y_j)Inner (per-sample)grace/core/weighted_mse.pyFailure-biased rampw(y)=1+clip ⁣(1y125, 0, 1)w(y) = 1 + \mathrm{clip}\!\left(1 - \dfrac{y}{125},\ 0,\ 1\right)
(y^jyj)2(\hat{y}_j - y_j)^2Inner (per-sample)moderate_weighted_mse_lossSquared residual per RUL prediction
1Nj=1N\dfrac{1}{N}\,\sum_{j=1}^{N}Inner (per-sample)Tensor.mean()Sample average over the mini-batch of size N

Read the equation top to bottom: the outer factor λi(t)\lambda^*_i(t) is a step-varying scalar that depends on this batch's gradient ratio; the inner factor (1/N)jw(yj)(y^jyj)2(1/N) \sum_j w(y_j)\,(\hat y_j - y_j)^2 is a per-batch scalar that depends only on the predictions and targets. The two factors multiply and then sum across the two tasks. That is the entire model.

Inner Expansion: The Weighted MSE

Expand the inner term explicitly. For one mini-batch of size NN:

1Nj=1Nw(yj)(y^jyj)2  =  1Nj=1N[1+clip ⁣(1yj125, 0, 1)](y^jyj)2.\frac{1}{N}\sum_{j=1}^{N} w(y_j)\,(\hat{y}_j - y_j)^2 \;=\; \frac{1}{N}\sum_{j=1}^{N} \Big[1 + \mathrm{clip}\!\left(1 - \tfrac{y_j}{125},\ 0,\ 1\right)\Big]\,(\hat{y}_j - y_j)^2.

The bracketed factor is the per-sample weight. It rewrites cleanly as a piecewise function:

Regionyj rangey_j\ \text{range}w(yj) valuew(y_j)\ \text{value}Effect on the squared residual
Failureyj=0y_j = 02.002.00Doubled
Critical0<yj600 < y_j \leq 60linear 1.522.00\text{linear } 1.52 \to 2.00Up-weighted up to2×2\times
Mid-life60<yj12560 < y_j \leq 125linear 1.001.52\text{linear } 1.00 \to 1.52Mildly up-weighted
Healthy (capped)yj125y_j \geq 1251.001.00Standard MSE — no boost

The piecewise structure has a physical reading: a failure prediction error is twice as costly as the same error during early-life operation. The clip at y125y \geq 125 matches the paper's piecewise-linear RUL target — for engines that have not started visibly degrading yet, every cycle looks the same to the labels, so the loss treats every cycle the same too. There is no asymmetry to encode at the healthy end.

Outer Expansion: The Four-Stage GABA Pipeline

The outer factor is not a closed-form expression in the usual sense; it is the output of a four-stage controller. Stage by stage, with K=2K=2 tasks:

Stage 1: Closed form

λiraw(t)  =  jigj(t)(K1)jgj(t),gi(t)  =  θsLi(t)2.\lambda^{\text{raw}}_i(t) \;=\; \dfrac{\sum_{j\neq i} g_j(t)}{(K{-}1)\,\sum_{j} g_j(t)}, \qquad g_i(t) \;=\; \bigl\| \nabla_{\theta_s}\mathcal{L}_i(t)\bigr\|_2.

Inverse-ratio: the task with the smaller gradient gets the larger raw weight. For K=2K=2 this collapses to λrulraw=ghealth/(grul+ghealth)\lambda^{\text{raw}}_{\text{rul}} = g_{\text{health}}/(g_{\text{rul}}+g_{\text{health}}).

Stage 2: Exponential moving average

λˉi(t)  =  βλˉi(t1)  +  (1β)λiraw(t),λˉi(0)=1/K.\bar\lambda_i(t) \;=\; \beta\,\bar\lambda_i(t{-}1) \;+\; (1-\beta)\,\lambda^{\text{raw}}_i(t), \qquad \bar\lambda_i(0) = 1/K.

Smooths out per-batch noise in λiraw(t)\lambda^{\text{raw}}_i(t). Paper default β=0.99\beta = 0.99 — an effective memory of 1/(1β)1001/(1-\beta) \approx 100 steps. The starting value 1/K1/K is uniform.

Stage 3: Per-element floor

λ~i(t)  =  max(λˉi(t), ε),ε=0.05.\tilde\lambda_i(t) \;=\; \max\bigl(\bar\lambda_i(t),\ \varepsilon\bigr), \qquad \varepsilon = 0.05.

Prevents either task from being driven to zero. Without this clamp, after enough EMA iterations the over-gradient task (RUL on C-MAPSS) would collapse to a numerically negligible weight and the joint optimisation would degenerate to single-task health training.

Stage 4: Renormalisation

λi(t)  =  λ~i(t)k=1Kλ~k(t).\lambda^*_i(t) \;=\; \dfrac{\tilde\lambda_i(t)}{\sum_{k=1}^{K} \tilde\lambda_k(t)}.

Restores the sum-to-one property after the floor breaks it. The final λi(t)\lambda^*_i(t) is a proper probability distribution over the K tasks at every step.

The full chain in one breath. giλirawλˉiλ~iλig_i \to \lambda^{\text{raw}}_i \to \bar\lambda_i \to \tilde\lambda_i \to \lambda^*_i. Each arrow is a deterministic function of its predecessor. The composed map is what the GABA controller in grace/core/gaba.py:88 computes per training step.

What .backward() Actually Computes

The whole point of writing the equation is to differentiate it. The gradient with respect to a backbone parameter θs\theta_s is

θsLGRACE(t)  =  λrul(t)θsLrulWMSE(t)  +  λhealth(t)θsLCE(t).\nabla_{\theta_s}\,\mathcal{L}_{\text{GRACE}}(t) \;=\; \lambda^*_{\text{rul}}(t)\,\nabla_{\theta_s}\,\mathcal{L}^{\text{WMSE}}_{\text{rul}}(t) \;+\; \lambda^*_{\text{health}}(t)\,\nabla_{\theta_s}\,\mathcal{L}_{\text{CE}}(t).

Two structural facts to internalise:

  • The lambdas λi(t)\lambda^*_i(t) are constants from autograd's point of view. The implementation calls .detach() before the multiply (gaba.py:131), so PyTorch never tries to differentiate through the gradient-norm computation. Forgetting this turns GABA into a meta-gradient method (closer to GradNorm) with ~2× memory and a different convergence law.
  • The inner per-sample weights w(yj)w(y_j) depend only on yjy_j, not on y^j\hat y_j, so they are also constants for autograd. The gradient simplifies to (2/N)jw(yj)(y^jyj)θsy^j(2/N)\sum_j w(y_j)\,(\hat y_j - y_j)\,\nabla_{\theta_s}\hat y_j — the standard MSE backward, multiplied per-sample by w(yj)w(y_j).

Interactive: Deconstructing The Equation

Click any symbol pill below the rendered equation to see its definition, code reference, and current numeric value. The three sliders drive the OUTER controller live: vary the gradient ratio, the EMA β\beta, and the number of EMA steps to watch λi\lambda^*_i evolve from (0.5,0.5)(0.5, 0.5) at t=0t=0 toward the floor-clamped converged value at large tt.

Loading GRACE equation explorer…
Try this. Set log10(grul/ghealth)=2.84\log_{10}(g_{\text{rul}}/g_{\text{health}}) = 2.84 (the C-MAPSS measured value), β=0.99\beta = 0.99, and slide steps from 0 → 1000. You will see λrul\lambda^*_{\text{rul}} drift down from 0.5 toward the floor at 0.0477; the floor activates around step 400 and the renormalisation kicks in. That trajectory is what the paper's gradient_logger.py records during real training.

Python: One Training Step, Symbol By Symbol

Walk every line. The script reproduces the four-stage GABA pipeline by hand, computes the failure-biased MSE explicitly, and prints every intermediate value — raw, smoothed, floored, renormalised — so the equation can be checked numerically before we trust it to PyTorch.

One GRACE step, every symbol, in NumPy
🐍grace_loss_equation_demo.py
1docstring

States the contract: walk the equation symbol by symbol, in NumPy. No autograd, no PyTorch — just the algebra so the next file&apos;s tensor version is unambiguous.

3import numpy as np

Numerical-array library. Provides np.array, np.clip, np.maximum, np.round used throughout.

EXECUTION STATE
📚 numpy = ndarray + vectorised arithmetic + reductions. Standard alias np. Every numeric value below is an np.ndarray (dense, contiguous, C-backed).
→ why np? = Two-letter alias is the universal Python convention since 2006. Lets us write np.clip(...) instead of numpy.clip(...) — shorter, instantly recognised by every reader.
6# ---------- Mini-batch ----------

Section divider. The next four lines set up the toy mini-batch (y_true, y_pred, N, L_health) that walks every symbol of the GRACE equation.

EXECUTION STATE
Section role = Inputs to the equation. Below this comment: ground-truth RUL, predicted RUL, batch size, and a frozen L_health value from the same forward pass.
7y_true = np.array([10, 30, 60, 90, 110, 5, 15, 80], ...)

Eight ground-truth RUL values. Two near-failure (5, 10), two early-life (90, 110), four mid-life. The mix lets the failure-bias weight w(y) actually do something.

EXECUTION STATE
y_true (8,) = [ 10., 30., 60., 90., 110., 5., 15., 80.]
→ why this mix? = The C-MAPSS RUL distribution is right-skewed: most samples are early-life, a minority are near-failure. Eight handcrafted values reproduce that asymmetry on a tractable scale.
8y_pred = np.array([15, 26, 68, 87, 112, -7, 24, 74], ...)

Predictions chosen so each near-failure sample (j=5, j=6) has a notably bigger residual than the early-life ones — the failure-bias weight will amplify those.

EXECUTION STATE
y_pred (8,) = [ 15., 26., 68., 87., 112., -7., 24., 74.]
y_pred − y_true = [ 5., -4., 8., -3., 2., -12., 9., -6.]
9N = len(y_true)

Sample count. Used as the denominator of the per-sample mean inside the WMSE.

EXECUTION STATE
📚 len(ndarray) = Returns the size of the first axis. For a 1-D array of length 8 → 8.
N = 8
10L_health = 0.6069

Cross-entropy on the 3-class health head, recorded from the same forward pass. Held constant in this section so the algebra is visible.

EXECUTION STATE
L_health = 0.6069. Typical value for a partly-trained 3-class softmax (~ ln 3 / 1.8).
→ why frozen? = Pinning L_health lets us isolate the OUTER (lambda) and INNER (w·r²) machinery without the noise of also re-deriving the cross-entropy. Real training computes this from F.cross_entropy(hp_logits, hp_target).
13# ---------- Inner axis: weighted MSE ----------

Section divider. Below this comment lives the INNER half of GRACE — the failure-biased weighted MSE. Matches the Inner row in the §anatomy table.

EXECUTION STATE
What this section computes = L_rul^WMSE = (1/N) Σ_j w(y_j) · (ŷ_j - y_j)². Three lines: define w(y), apply it, average.
14def w_failure(y, max_rul=125.0):

The inner-axis weight function. Mathematically: w(y) = 1 + clip(1 − y/R_max, 0, 1). At y=0 → w=2 (failure: doubled penalty). At y ≥ R_max → w=1 (healthy: vanilla MSE). Linear ramp in between. This is the &lsquo;moderate&rsquo; variant from grace/core/weighted_mse.py — the paper also defines &lsquo;mild&rsquo; (slope 0.5) and &lsquo;steep&rsquo; (slope 2.0) variants.

EXECUTION STATE
⬇ input: y = Scalar or ndarray of RUL targets. In this script y_true is the (8,) batch.
→ y_true (8,) — what gets passed in = [ 10., 30., 60., 90., 110., 5., 15., 80.]
→ y purpose = Each y_j is the remaining-useful-life label for sample j. Lower y → closer to failure → bigger weight w(y) wanted.
⬇ input: max_rul = 125.0 = RUL cap. Matches the paper&apos;s piecewise-linear RUL convention: any engine with y > 125 cycles to failure is treated as &lsquo;fully healthy&rsquo; for both labels and weights. Default keyword arg — caller may override.
→ why 125? = C-MAPSS sub-datasets show that RUL labels above ~125 cycles are noisy and indistinguishable from a sensor-quality standpoint. The paper&apos;s ablation locks this constant at 125 across FD001–FD004.
⬆ returns = ndarray same shape as y. Per-sample weight in [1, 2]. For y_true above: [1.92, 1.76, 1.52, 1.28, 1.12, 1.96, 1.88, 1.36].
15docstring: """Per-sample weight w(y) = 1 + clip(1 - y/max_rul, 0, 1)."""

Single-line docstring. Tools like help(w_failure), Sphinx, and IDEs surface this. States the closed-form expression so a reader can match the code to the equation in §inner without scrolling.

16return 1.0 + np.clip(1.0 - y / max_rul, 0.0, 1.0)

Linear ramp from 2 (y=0, failure) to 1 (y ≥ max_rul). The clip prevents w < 1 when y > max_rul (which would PENALISE healthy-engine errors).

EXECUTION STATE
y / max_rul = [0.080, 0.240, 0.480, 0.720, 0.880, 0.040, 0.120, 0.640]
1 − y/max_rul = [0.920, 0.760, 0.520, 0.280, 0.120, 0.960, 0.880, 0.360]
📚 np.clip(arr, lo, hi) = Element-wise: arr_i ↦ max(lo, min(hi, arr_i)). Pins to [0, 1] so w stays in [1, 2].
⬆ return: w(y) (8,) = [1.920, 1.760, 1.520, 1.280, 1.120, 1.960, 1.880, 1.360]
19w = w_failure(y_true)

Apply the weight function to the whole batch. Vectorised — one NumPy call, no Python loop.

EXECUTION STATE
w (8,) = [1.920, 1.760, 1.520, 1.280, 1.120, 1.960, 1.880, 1.360]
→ reading = Sample 5 (y=5) gets weight 1.96 — almost double. Sample 4 (y=110) gets weight 1.12 — basically unweighted.
20residual_sq = (y_pred - y_true) ** 2

Element-wise squared residuals. Same as the standard MSE inner sum, before the weights.

EXECUTION STATE
y_pred - y_true = [5., -4., 8., -3., 2., -12., 9., -6.]
📚 ** 2 = Element-wise power-of-two via NumPy broadcasting. Equivalent to np.square.
residual_sq (8,) = [25., 16., 64., 9., 4., 144., 81., 36.]
21L_rul_w = (w * residual_sq).mean()

The inner half of GRACE. Element-wise multiply w · r², then average. Equivalent to the formal expression (1/N) Σ_j w(y_j)·(ŷ_j - y_j)².

EXECUTION STATE
w * residual_sq = [ 48.000, 28.160, 97.280, 11.520, 4.480, 282.240, 152.280, 48.960]
→ contribution = Sample 5 (j=5) contributes 282.24 — that is 42% of the weighted sum 672.92, even though it is only 1/8 of the batch (12.5%). The inner axis has shifted attention toward the near-failure sample.
📚 .mean() = Sum / N. Here 672.92 / 8 = 84.115.
L_rul_w = 84.1150
24# ---------- Outer axis: GABA closed form, EMA, floor, renorm ----------

Section divider. Below this comment lives the OUTER half of GRACE — the four-stage GABA controller. Each stage is one line + one comment marker inside gaba_step.

EXECUTION STATE
Pipeline (4 stages) = g_i → λ_raw → λ̄ (EMA) → λ̃ (floor) → λ* (renorm)
→ input = Per-task gradient norms g_rul, g_health (computed once per training step on the shared backbone).
→ output = λ* — a probability vector over K=2 tasks that multiplies each task loss.
25def gaba_step(g_rul, g_health, prev, beta=0.99, eps=0.05):

The OUTER half of GRACE. Encapsulates the entire four-stage GABA pipeline (closed form → EMA → floor → renormalise) in pure NumPy. Returns the four intermediate vectors so the equation explorer above can highlight every stage. Production code (grace/core/gaba.py) returns only the final λ*.

EXECUTION STATE
⬇ input: g_rul = L2 norm of the RUL loss gradient on the shared backbone: ||∂L_rul/∂θ_shared||₂. In this script: 26.4016.
→ why backbone-only? = GABA balances tasks on the parameters they SHARE. Head-only parameters trivially have a one-to-one mapping to a single task and cannot conflict.
⬇ input: g_health = Same for the health loss: ||∂L_health/∂θ_shared||₂. In this script: 0.037833. Roughly 698x smaller than g_rul — the structural imbalance GABA fixes.
⬇ input: prev = Tuple (λ̄_rul(t-1), λ̄_health(t-1)). EMA state carried over from the previous training step. Initialised to (0.5, 0.5) at t=0. In this script: (0.5, 0.5) — first step.
→ why a tuple? = Stays cheap to pass between training iterations. Real PyTorch code uses a non-persistent buffer (self.ema_weights) on the GABALoss module — same idea, autograd-safe.
⬇ input: beta=0.99 = EMA smoothing coefficient. β=1 → frozen at prev (no learning). β=0 → no smoothing (echoes raw). Paper default 0.99 = effective memory of 1/(1−β) ≈ 100 steps. Default keyword arg.
⬇ input: eps=0.05 = Per-task floor. Prevents either task from being completely silenced after the EMA collapses one weight to ~0. Paper default = 5%. Default keyword arg.
⬆ returns = Tuple (normed, raw, smoothed, floored). Each is a NumPy array of shape (2,) for K=2 tasks. Caller can unpack all four to inspect the pipeline; production code unpacks only `normed`.
26docstring: """One step of the four-stage GABA pipeline. Returns (lam_rul, lam_h)."""

Reminds the reader that each call advances the controller by ONE step. To simulate convergence, call gaba_step in a loop, threading the previous return as the next `prev`.

27# Stage 1: closed form (raw)

Marker for Stage 1 of the four-stage pipeline. The closed form is a function of the gradient norms only — no state, no smoothing, no floor. It produces λ_raw, the &lsquo;ideal&rsquo; per-step weight vector that perfectly inverts the gradient imbalance.

EXECUTION STATE
Stage 1 output = λ_raw — pure inverse-ratio of gradient norms. Sums to 1 by construction.
→ formula (K=2) = λ_raw_rul = g_health / (g_rul + g_health) λ_raw_health = g_rul / (g_rul + g_health) — note the SWAP: each task&apos;s weight uses the OTHER task&apos;s gradient norm.
→ next two lines do this = Line 28 computes S = denominator. Line 29 computes the swapped numerators.
28S = g_rul + g_health

Stage 1a: total gradient norm. The denominator of the K=2 closed form.

EXECUTION STATE
S = 26.4016 + 0.037833 = 26.439433
29raw = np.array([g_health / S, g_rul / S])

Stage 1b: closed form. Inverse-ratio. The task with the SMALLER gradient gets the LARGER weight. Note the swap — element 0 is rul but uses g_health.

EXECUTION STATE
raw[0] = g_health / S = 0.037833 / 26.439433 = 0.001431
raw[1] = g_rul / S = 26.4016 / 26.439433 = 0.998569
→ reason for swap = If RUL gradients are 698x bigger, RUL is already over-represented in the combined gradient. To balance, give RUL only 0.14% of the loss budget.
30# Stage 2: EMA smoothing

Marker for Stage 2. The closed form is too noisy to use directly — gradient norms can swing 10× between mini-batches. An exponential moving average filters out the high-frequency noise.

EXECUTION STATE
Stage 2 output = λ̄ — EMA-smoothed weights. Inherits the sum-to-one property from λ_raw and prev.
→ formula = λ̄(t) = β · λ̄(t−1) + (1−β) · λ_raw(t)
→ effective window = 1/(1−β). For β=0.99: ~100 steps. For β=0.9: ~10 steps. For β=0.5: ~2 steps.
31smoothed = beta * np.array(prev) + (1.0 - beta) * raw

Stage 2: EMA. New state = β · old state + (1-β) · new measurement. With β=0.99, only 1% of the new raw value enters per step — rapid mood swings get filtered out.

EXECUTION STATE
📚 np.array(prev) = Convert tuple (0.5, 0.5) → ndarray for vectorised arithmetic.
beta * np.array(prev) = 0.99 * [0.5, 0.5] = [0.495, 0.495]
(1 - beta) * raw = 0.01 * [0.001431, 0.998569] = [0.0000143, 0.0099857]
smoothed = [0.495014, 0.504986]
→ reading = After ONE step, the smoothed value is barely different from the (0.5, 0.5) starting point. This is the &lsquo;why warmup matters&rsquo; phenomenon: GABA needs ~100 steps before its output reflects the gradient-imbalance signal.
32# Stage 3: floor

Marker for Stage 3. After many EMA steps with a strong gradient imbalance, one weight can drift toward 0 — silencing that task entirely. The floor clamps each weight at ε so no task can be fully ignored.

EXECUTION STATE
Stage 3 output = λ̃ — element-wise max(λ̄, ε). MAY break the sum-to-one property → Stage 4 fixes that.
→ formula = λ̃_i = max(λ̄_i, ε)
→ when active? = Only when λ̄_i &lt; ε. On step 1 with prev=(0.5, 0.5), neither weight is anywhere near ε=0.05 — floor is inactive. After hundreds of EMA steps with the C-MAPSS imbalance, λ̄_rul drops below 0.05 and the floor activates.
33floored = np.maximum(smoothed, eps)

Stage 3: per-element floor. If a smoothed weight has dropped below ε=0.05, clamp it up. Prevents either task from being completely silenced after many EMA steps in extreme regimes.

EXECUTION STATE
📚 np.maximum(a, b) = Element-wise max of two arrays (or array + scalar). Different from np.max which is a reduction.
smoothed > eps? = [0.495 > 0.05 → no clamp, 0.505 > 0.05 → no clamp]
floored = [0.495014, 0.504986] (unchanged on this step)
→ when does clamp fire? = Run the explorer with EMA steps n=500: smoothed_rul drops to ~0.005 — below the floor — and gets clamped to 0.05. The renormalisation then pushes lam_rul down to ~0.048.
34# Stage 4: renormalise so weights sum to 1

Marker for Stage 4. The floor in Stage 3 may have made the weights sum to slightly more than 1.0. Renormalisation restores the simplex constraint so λ* is a proper probability distribution over the K tasks.

EXECUTION STATE
Stage 4 output = λ* — final weights used to multiply the loss. Guaranteed to sum to exactly 1.0.
→ formula = λ*_i = λ̃_i / Σ_k λ̃_k
→ why bother? = Without renorm, summing weighted losses with Σλ ≠ 1 silently changes the effective learning rate. Keeping Σλ = 1 means the GRACE loss has a stable magnitude no matter how the floor activated.
35normed = floored / floored.sum()

Stage 4: renormalise so the two weights sum to exactly 1. Required because the floor breaks the sum-to-one property of the EMA-smoothed pair.

EXECUTION STATE
📚 .sum() = ndarray reduction. Adds all elements: 0.495014 + 0.504986 = 1.000000.
normed = [0.495014, 0.504986]
normed.sum() check = 1.000000 — exactly. Guaranteed by construction.
36return normed, raw, smoothed, floored

Return all four intermediates so the caller can print or visualise the full pipeline. Production code returns only `normed`.

EXECUTION STATE
⬆ tuple of 4 vectors = (λ*, raw, smoothed, floored). Each shape (2,) for K=2 tasks.
→ why all four? = The interactive viz above and the §table below both display every stage. Returning only λ* would force a re-run to inspect Stage 1 / 2 / 3.
39# Real measurements from the same forward pass (chapter 18 §1)

Annotation: the next two lines (g_rul, g_health) come from real C-MAPSS training metrics, not toy values. Reproducible — same model, same seed, same forward pass as chapter 18 §1.

EXECUTION STATE
Why this matters = The 698x ratio between g_rul and g_health is not a contrived demo number — it is the structural imbalance of multi-task learning on C-MAPSS, the exact problem GABA was designed to fix.
40g_rul, g_health = 26.4016, 0.037833

Per-task gradient norms reproduced from chapter 18 §1 — same model, same seed, same forward pass. Hard-coded here so the equation walk-through is reproducible without running PyTorch.

EXECUTION STATE
g_rul = 26.4016 — paper-canonical L2 norm of dL_rul/dtheta_shared.
g_health = 0.037833 — same for L_health.
ratio = 26.4016 / 0.037833 = 698x — the structural imbalance GABA fixes.
41prev_lam = (0.5, 0.5)

EMA initial state. Paper convention: the controller starts uniform and drifts toward the gradient-balanced answer. This matches gaba.py:54 register_buffer(&lsquo;ema_weights&rsquo;, torch.ones(K)/K).

EXECUTION STATE
prev_lam (tuple, len 2) = (0.5, 0.5) — uniform over K=2 tasks. Equivalent to torch.ones(2)/2.
→ why uniform? = Before the first step, the controller has no information about the gradient ratio. Uniform is the maximum-entropy starting point — the EMA then walks it toward the inverse-ratio answer over ~100 steps.
→ why a Python tuple? = The NumPy demo passes state as a plain tuple. Real PyTorch code uses `register_buffer` so the state moves with the module (.to(device), .state_dict()) and survives checkpoint round-trips.
42lam, raw, sm, fl = gaba_step(g_rul, g_health, prev_lam)

Run one full GABA step.

EXECUTION STATE
lam (final λ*) = [0.495014, 0.504986]
raw = [0.001431, 0.998569]
sm (after EMA) = [0.495014, 0.504986]
fl (after floor) = [0.495014, 0.504986]
→ caveat = After ONE step the EMA has barely moved off (0.5, 0.5). The published GRACE controllers run for hundreds of steps — see the slider in the Interactive section above.
45# ---------- Compose: outer * inner ----------

Section divider. The next line is the entire equation in one expression. Inner axis (L_rul_w, L_health) and outer axis (lam) meet here.

EXECUTION STATE
What composition means = Element-wise: λ*_rul · L_rul^WMSE + λ*_health · L_CE. Two scalar multiplies, one scalar add, one scalar output.
46L_GRACE = lam[0] * L_rul_w + lam[1] * L_health

The whole equation, in one line. λ*_rul · L_WMSE + λ*_health · L_CE. After this many EMA steps, the lambdas are still close to (0.5, 0.5), so this looks similar to the AMNL cell from section 21·1.

EXECUTION STATE
lam[0] * L_rul_w = 0.495014 * 84.1150 = 41.6381
lam[1] * L_health = 0.504986 * 0.6069 = 0.3065
L_GRACE = 41.9446
→ vs converged = After thousands of steps, lam[0] would settle at ~0.048 (floor active) and L_GRACE ≈ 4.59. The interactive viz lets you scrub between these two regimes.
49# ---------- Inspect every symbol ----------

Section divider. Below this comment we print every intermediate value so the equation can be checked numerically before trusting it to PyTorch. Each print line names one symbol from the headline equation.

EXECUTION STATE
Why print every symbol? = Reading the formula and reading the values must agree. Printing y_true → r² → w → w·r² → L_rul_w then g_rul → raw → EMA → floored → λ* → L_GRACE walks the equation top-to-bottom.
50print(f"y_true = {y_true.astype(int).tolist()}")

Pretty-print y_true.

EXECUTION STATE
Output = y_true = [10, 30, 60, 90, 110, 5, 15, 80]
51print(f"y_pred = {y_pred.astype(int).tolist()}")

Pretty-print y_pred.

EXECUTION STATE
Output = y_pred = [15, 26, 68, 87, 112, -7, 24, 74]
52print(f"residual^2 = {residual_sq.astype(int).tolist()}")

Squared residuals.

EXECUTION STATE
Output = residual^2 = [25, 16, 64, 9, 4, 144, 81, 36]
53print(f"w(y) = {np.round(w, 3).tolist()}")

The inner per-sample weights.

EXECUTION STATE
Output = w(y) = [1.92, 1.76, 1.52, 1.28, 1.12, 1.96, 1.88, 1.36]
54print(f"w * r^2 = {np.round(w * residual_sq, 2).tolist()}")

Per-sample weighted squared residuals — the elements that go into L_rul_w.

EXECUTION STATE
Output = w * r^2 = [48.0, 28.16, 97.28, 11.52, 4.48, 282.24, 152.28, 48.96]
55print(f"L_rul^WMSE = {L_rul_w:.4f}")

Final value of the inner half.

EXECUTION STATE
Output = L_rul^WMSE = 84.1150
56print() — blank line

Empty print emits a blank line. Visually separates the inner-axis output above from the outer-axis output below — a common Python idiom for sectioning printed reports.

EXECUTION STATE
📚 print() with no args = Writes only the end character (default '\n'). Equivalent to print('').
57print(f"g_rul = {g_rul:.4f} g_health = {g_health:.6f}")

Per-task gradient norms — input to the OUTER axis.

EXECUTION STATE
Output = g_rul = 26.4016 g_health = 0.037833
58print(f"raw lambdas = ({raw[0]:.6f}, {raw[1]:.6f})")

Stage 1: closed form output. Pure inverse-ratio.

EXECUTION STATE
Output = raw lambdas = (0.001431, 0.998569)
59print(f"EMA lambdas = ({sm[0]:.6f}, {sm[1]:.6f}) (β=0.99, prev=0.5)")

Stage 2: smoothing. After one step, barely different from prev.

EXECUTION STATE
Output = EMA lambdas = (0.495014, 0.504986) (β=0.99, prev=0.5)
60print(f"floored = ({fl[0]:.6f}, {fl[1]:.6f}) (ε=0.05)")

Stage 3: floor. Inactive on this step because both EMA lambdas are above 0.05.

EXECUTION STATE
Output = floored = (0.495014, 0.504986) (ε=0.05)
61print(f"final λ* = ({lam[0]:.6f}, {lam[1]:.6f}) (renormalised)")

Stage 4: renormalise. Trivial here — already sums to 1.

EXECUTION STATE
Output = final λ* = (0.495014, 0.504986) (renormalised)
62print() — blank line

Empty print emits a blank line. Separates the four-stage pipeline output from the final composed L_GRACE line below.

EXECUTION STATE
📚 print() = No args ⇒ writes a single newline. Used here purely for layout.
63print(f"L_GRACE = {lam[0]:.4f}*{L_rul_w:.4f} + {lam[1]:.4f}*{L_health:.4f}")

Show the composition formula with substituted numbers.

EXECUTION STATE
Output = L_GRACE = 0.4950*84.1150 + 0.5050*0.6069
64print(f" = {L_GRACE:.4f}")

Final scalar.

EXECUTION STATE
Final output =
y_true       = [10, 30, 60, 90, 110, 5, 15, 80]
y_pred       = [15, 26, 68, 87, 112, -7, 24, 74]
residual^2   = [25, 16, 64, 9, 4, 144, 81, 36]
w(y)         = [1.92, 1.76, 1.52, 1.28, 1.12, 1.96, 1.88, 1.36]
w * r^2      = [48.0, 28.16, 97.28, 11.52, 4.48, 282.24, 152.28, 48.96]
L_rul^WMSE   = 84.1150

g_rul        = 26.4016    g_health = 0.037833
raw lambdas  = (0.001431, 0.998569)
EMA lambdas  = (0.495014, 0.504986)    (β=0.99, prev=0.5)
floored      = (0.495014, 0.504986)    (ε=0.05)
final λ*     = (0.495014, 0.504986)    (renormalised)

L_GRACE      = 0.4950*84.1150 + 0.5050*0.6069
             = 41.9446
15 lines without explanation
1"""GRACE loss equation — every symbol, every value, in NumPy."""
2
3import numpy as np
4
5
6# ---------- Mini-batch ----------
7y_true = np.array([10, 30, 60, 90, 110,  5, 15, 80], dtype=float)
8y_pred = np.array([15, 26, 68, 87, 112, -7, 24, 74], dtype=float)
9N      = len(y_true)
10L_health = 0.6069                    # cross-entropy from same forward pass
11
12
13# ---------- Inner axis: weighted MSE ----------
14def w_failure(y, max_rul=125.0):
15    """Per-sample weight w(y) = 1 + clip(1 - y/max_rul, 0, 1)."""
16    return 1.0 + np.clip(1.0 - y / max_rul, 0.0, 1.0)
17
18
19w           = w_failure(y_true)
20residual_sq = (y_pred - y_true) ** 2
21L_rul_w     = (w * residual_sq).mean()
22
23
24# ---------- Outer axis: GABA closed form, EMA, floor, renorm ----------
25def gaba_step(g_rul, g_health, prev, beta=0.99, eps=0.05):
26    """One step of the four-stage GABA pipeline. Returns (lam_rul, lam_h)."""
27    # Stage 1: closed form (raw)
28    S = g_rul + g_health
29    raw = np.array([g_health / S, g_rul / S])
30    # Stage 2: EMA smoothing
31    smoothed = beta * np.array(prev) + (1.0 - beta) * raw
32    # Stage 3: floor
33    floored = np.maximum(smoothed, eps)
34    # Stage 4: renormalise so weights sum to 1
35    normed = floored / floored.sum()
36    return normed, raw, smoothed, floored
37
38
39# Real measurements from the same forward pass (chapter 18 §1)
40g_rul, g_health = 26.4016, 0.037833
41prev_lam        = (0.5, 0.5)
42lam, raw, sm, fl = gaba_step(g_rul, g_health, prev_lam)
43
44
45# ---------- Compose: outer * inner ----------
46L_GRACE = lam[0] * L_rul_w + lam[1] * L_health
47
48
49# ---------- Inspect every symbol ----------
50print(f"y_true       = {y_true.astype(int).tolist()}")
51print(f"y_pred       = {y_pred.astype(int).tolist()}")
52print(f"residual^2   = {residual_sq.astype(int).tolist()}")
53print(f"w(y)         = {np.round(w, 3).tolist()}")
54print(f"w * r^2      = {np.round(w * residual_sq, 2).tolist()}")
55print(f"L_rul^WMSE   = {L_rul_w:.4f}")
56print()
57print(f"g_rul        = {g_rul:.4f}    g_health = {g_health:.6f}")
58print(f"raw lambdas  = ({raw[0]:.6f}, {raw[1]:.6f})")
59print(f"EMA lambdas  = ({sm[0]:.6f}, {sm[1]:.6f})    (β=0.99, prev=0.5)")
60print(f"floored      = ({fl[0]:.6f}, {fl[1]:.6f})    (ε=0.05)")
61print(f"final λ*     = ({lam[0]:.6f}, {lam[1]:.6f})    (renormalised)")
62print()
63print(f"L_GRACE      = {lam[0]:.4f}*{L_rul_w:.4f} + {lam[1]:.4f}*{L_health:.4f}")
64print(f"             = {L_GRACE:.4f}")

PyTorch: The Same Equation In Production Code

Now the equation as the paper actually runs it. GABALoss is an nn.Module — a stateful controller that keeps the EMA buffer between steps and exposes inspection helpers. The whole composition is a single call on line 49: L_GRACE = gaba(L_rul_w, L_health, shared_params=shared).

The same equation in production PyTorch
🐍grace_loss_equation_pytorch.py
1docstring

Same algebra as the NumPy demo, but every value lives in a torch.Tensor that autograd can backprop through. Imports paper code verbatim — no replication.

3import torch

Core PyTorch. Tensors, autograd, the device abstraction.

EXECUTION STATE
📚 torch = Provides torch.Tensor (the differentiable equivalent of np.ndarray), torch.manual_seed, torch.randn, autograd graph machinery, and the device abstraction (CPU/CUDA/MPS).
→ key difference vs NumPy = Every torch.Tensor remembers how it was created (the autograd graph). Calling .backward() on a scalar walks that graph in reverse to fill in .grad on every leaf parameter.
4import torch.nn as nn

Module/parameter machinery. nn.Module, nn.Linear here.

EXECUTION STATE
📚 torch.nn = Stateful neural-net building blocks. nn.Module (the base class), nn.Linear, nn.Conv*, nn.LSTM, etc. Each Module owns nn.Parameter tensors that autograd tracks.
→ why aliased as `nn`? = Two-letter alias is the universal PyTorch convention. Lets us write nn.Linear(4, 6) instead of torch.nn.Linear(4, 6).
5import torch.nn.functional as F

Stateless functional API. F.cross_entropy here.

EXECUTION STATE
📚 torch.nn.functional = Pure functions: same operations as nn.* but no parameters of their own. F.cross_entropy, F.mse_loss, F.relu, F.softmax. Used when there is nothing to register as a Module.
📚 F.cross_entropy(logits, target) = Combined log-softmax + NLL. Inputs: logits (B, C), target int64 (B,).
→ why F not nn? = L_health is computed once per forward — there are no learnable parameters in cross-entropy itself. The functional form avoids creating a useless empty Module.
7from grace.core.gaba import GABALoss

Paper&apos;s GABA controller. nn.Module subclass. forward() takes the two task losses + shared params, runs the four-stage pipeline, and returns the combined loss.

EXECUTION STATE
📚 GABALoss(beta, warmup_steps, min_weight, n_tasks) = From grace/core/gaba.py:30. Stores ema_weights and step_count as buffers. Implements the closed form + EMA + floor + renorm.
8from grace.core.weighted_mse import moderate_weighted_mse_loss

Paper&apos;s failure-biased MSE. Pure function — no state.

EXECUTION STATE
📚 moderate_weighted_mse_loss(pred, target, max_rul) = From grace/core/weighted_mse.py:20. Body: w = 1.0 + clip(1.0 - target/max_rul, 0, 1.0); return (w * (pred-target)**2).mean().
10torch.manual_seed(0)

Pin the PRNG so the model init and the random batch are reproducible.

EXECUTION STATE
📚 torch.manual_seed(seed) = Sets the seed for the global CPU RNG. Every subsequent torch.randn / torch.rand call produces the same sequence. Does NOT affect Python's random or NumPy's RNGs — set those separately if needed.
⬇ arg: 0 = The seed value. Any int works. Convention is 0 or 42 in pedagogical examples.
→ what it pins = (1) nn.Linear init (Kaiming-style draws from torch.empty().uniform_), (2) the torch.randn(8, 4) batch on line 31. Without this seed the printed numbers would change every run.
13# ---------- Tiny dual-task model ----------

Section divider. Below this comment we define a minimal nn.Module that mirrors the structure of grace/models/dual_task_model.py: one shared backbone, two task heads.

EXECUTION STATE
Why a toy model? = The point of this script is the LOSS equation, not the architecture. A 4→6→{1,3} two-head MLP is the smallest model that still has a real shared backbone — exactly what GABA needs to balance.
14class TinyDualHead(nn.Module):

Toy multi-task model. One shared backbone + RUL head + health head. Mirrors grace/models/dual_task_model.py at minimal scale.

EXECUTION STATE
📚 class X(nn.Module) = Standard PyTorch model recipe. Subclass nn.Module → register sub-modules in __init__ → define forward(). PyTorch then auto-tracks parameters, supports .to(device), .state_dict(), DDP wrapping, etc.
Architecture = Inputs (B, 4) → Linear(4, 6) → ReLU → split into RUL Linear(6, 1) and Health Linear(6, 3).
→ what makes it dual-task? = TWO heads consume the SAME backbone activation `feat`. Both task losses backprop through the same backbone parameters — that shared coupling is precisely what creates the gradient-conflict problem GABA was built to solve.
15def __init__(self):

Constructor. Allocates the three sub-modules so PyTorch can register their parameters.

EXECUTION STATE
⬇ input: self = The TinyDualHead instance being constructed. nn.Module's __setattr__ hook intercepts assignments like `self.backbone = nn.Linear(...)` and registers the sub-module.
→ why no other args? = All sizes are hard-coded for the demo. A real implementation would take (in_features, hidden, n_classes) arguments; here we keep it minimal.
⬆ returns = None — Python convention for __init__. The constructed object is returned implicitly by the class call `TinyDualHead()`.
16super().__init__()

Required nn.Module bookkeeping.

EXECUTION STATE
📚 super() = Returns a proxy object that delegates method calls to the parent class (nn.Module). Modern Python 3 form — no need for super(TinyDualHead, self).
→ what nn.Module.__init__ does = Initialises internal dicts: _parameters (learnable tensors), _buffers (non-learnable state), _modules (sub-Modules), training flag, hooks. Forgetting this line breaks the __setattr__ hook below — assignments would NOT register as sub-modules.
⬆ returns = None.
17self.backbone = nn.Linear(4, 6)

Shared trunk. 4 → 6 dims. Feeds BOTH heads — the technical reason the two losses share gradients on the same parameters.

EXECUTION STATE
📚 nn.Linear(in, out) = Stores W (out, in), b (out). Forward: y = x @ W.T + b.
18self.rul_head = nn.Linear(6, 1)

Regression head: 6 → 1.

EXECUTION STATE
→ name 'rul_head' = The substring &lsquo;head&rsquo; is what get_shared_params() filters on. Renaming to e.g. &lsquo;rul_branch&rsquo; would silently include this in the shared list and break GABA.
19self.health_head = nn.Linear(6, 3)

Classification head: 6 → 3.

EXECUTION STATE
📚 nn.Linear(6, 3) = Stores W (3, 6) + b (3). Forward: hp_logits = feat @ W.T + b. Output: 3 raw class scores per sample (NOT probabilities — softmax happens inside F.cross_entropy).
→ 3 classes are? = Healthy / Degrading / Failing — the standard piecewise-linear health-state classification used in chapter 19.
→ name 'health_head' = Substring 'head' is what get_shared_params() filters out. Both task heads share that suffix so neither leaks into the shared-params list.
21def forward(self, x):

Single forward returns (y_pred, hp_logits) — both produced from the SAME backbone activation.

EXECUTION STATE
⬇ input: self = TinyDualHead instance — gives access to self.backbone, self.rul_head, self.health_head.
⬇ input: x = Tensor (B, 4). The mini-batch — B samples, each with 4 input features.
→ x in this script = torch.randn(8, 4) — a (8, 4) tensor of standard-normal noise. Real C-MAPSS would use 24 features (sensors + operational settings) but the loss equation is dimension-agnostic.
⬆ returns = Tuple (y_pred (8,), hp_logits (8, 3)). y_pred is the RUL regression output; hp_logits is the 3-class classification output.
→ why one forward? = GABA needs BOTH losses computed against the SAME backbone parameters. A single shared forward guarantees that.
22feat = torch.relu(self.backbone(x))

Shared latent. Every gradient computation downstream traces back through this tensor.

EXECUTION STATE
📚 self.backbone(x) = Forward pass through nn.Linear(4, 6): x @ W.T + b. Output shape: (8, 6).
📚 torch.relu = Element-wise max(0, x). Standard backbone non-linearity. Differentiable everywhere except x=0 (subgradient = 0 there). Replaces sigmoid/tanh in modern nets to avoid vanishing gradients.
feat = Tensor (8, 6). Both heads consume this. The autograd graph below `feat` is what compute_task_grad_norm walks.
→ why a single shared `feat`? = Because BOTH heads use the same `feat`, computing dL_rul/dW_backbone and dL_health/dW_backbone uses the same forward graph. That is what makes the gradients on shared parameters comparable — and what makes GABA possible.
23return self.rul_head(feat).squeeze(-1), self.health_head(feat)

Apply both heads to feat. .squeeze(-1) collapses the trailing 1 in (8, 1) → (8,) so F.mse-style functions can broadcast against y_true.

EXECUTION STATE
📚 self.rul_head(feat) = nn.Linear(6, 1) forward. Output shape: (8, 1) — a column vector of scalar RUL predictions.
📚 .squeeze(-1) = Removes the LAST dimension if it has size 1. (8, 1) → (8,). Without this, F.mse_loss(pred (8,1), target (8,)) would broadcast to (8, 8) and compute the wrong loss.
→ why squeeze last only? = We only want to drop the regression-head 1, not any other size-1 dim that might appear (e.g. batch=1 in inference). .squeeze() with no arg would drop ALL size-1 dims and could silently break batching.
📚 self.health_head(feat) = nn.Linear(6, 3) forward. Output shape: (8, 3) — three logits per sample. NOT squeezed because all three dims carry meaning.
⬆ returns: (y_pred (8,), hp_logits (8, 3)) = Tuple — Python lets us return multiple values without a struct. Unpacked as `y_pred, hp_logits = model(x)` on line 35.
25def get_shared_params(self):

Helper: backbone-only parameter list. GABA needs these for compute_task_grad_norm. Filtering by name substring is robust to wrappers (EMA, DDP) prefixing module paths.

EXECUTION STATE
⬇ input: self = Module instance — gives access to .named_parameters() which walks every (name, Parameter) pair recursively.
⬆ returns = List[nn.Parameter] — all parameters of `backbone` (its weight + bias). Excludes both heads. Length 2 for this model.
→ why a method, not a buffer? = Re-running this every call is cheap (two pointer lookups) and immune to model surgery — if a future user adds another shared layer, this returns it automatically.
26return [p for n, p in self.named_parameters() if 'head' not in n]

List comprehension. Walks every (name, param) pair, keeps those whose name doesn&apos;t contain &lsquo;head&rsquo;.

EXECUTION STATE
📚 .named_parameters() = nn.Module method. Recursive yield of (str, nn.Parameter).
→ output (this model) = [backbone.weight (6, 4), backbone.bias (6,)]
29# ---------- One mini-batch forward pass ----------

Section divider. Below this comment we materialise the inputs (x, y_true, hp_target) and run the model once to produce predictions.

EXECUTION STATE
Why a mini-batch? = Per-task gradient norms are batch averages. Without a full batch the GABA closed form is undefined. 8 samples is the minimum that still shows the failure-bias weight doing something.
30model = TinyDualHead()

Instantiate. Default (Kaiming-style) init — small-magnitude weights, typical training start.

EXECUTION STATE
📚 TinyDualHead() = Calls __init__ on a freshly allocated instance. Triggers the three nn.Linear constructions inside. PyTorch then auto-registers their parameters via the __setattr__ hook.
model parameters (after init) = backbone.weight (6, 4), backbone.bias (6,), rul_head.weight (1, 6), rul_head.bias (1,), health_head.weight (3, 6), health_head.bias (3,) — 6 tensors, ~73 learnable scalars total.
→ init scheme = nn.Linear's default is Kaiming uniform on the weight (U(-√k, √k) where k = 1/in_features) and U(-√k, √k) on the bias. With manual_seed(0) above, the exact draws are reproducible.
31x = torch.randn(8, 4)

Random batch.

EXECUTION STATE
📚 torch.randn(*size) = Draws from N(0, 1) — standard normal. *size is a positional shape spec. Returns a fresh leaf tensor with requires_grad=False (inputs do not need gradients).
⬇ arg: 8 = Batch size. Number of samples in this mini-batch.
⬇ arg: 4 = Feature dimension. Matches backbone's in_features=4. Mismatched dims would raise on the first matmul inside model(x).
x = Tensor (8, 4) of standard-normal numbers. Determined by the seed on line 10.
32y_true = torch.tensor([10., 30., 60., 90., 110., 5., 15., 80.])

Same RUL targets as the NumPy demo so the inner WMSE numbers match (modulo random y_pred from the toy model).

EXECUTION STATE
📚 torch.tensor([...]) = Builds a tensor from a Python list. Dtype is inferred — the trailing dots make these float32 floats (otherwise int64 would be picked).
y_true (8,) = [10., 30., 60., 90., 110., 5., 15., 80.] — same mix of failure-region and healthy-region samples as the NumPy demo on line 7.
→ why match the NumPy values? = Identical y_true means w(y_true) and the §inner-axis worked example are reproduced unchanged in the PyTorch path. Only y_pred differs (random model init).
33hp_target = torch.tensor([0, 1, 2, 0, 2, 0, 0, 2])

Health labels for the 3-class CE.

EXECUTION STATE
hp_target (8,) int64 = [0, 1, 2, 0, 2, 0, 0, 2]. Class indices for the 3 health states.
📚 dtype int64 by default = Integer literals (no trailing dot) yield torch.int64. F.cross_entropy REQUIRES int64 targets — a float dtype here would raise RuntimeError.
→ labelling rule = 0 = healthy, 1 = degrading, 2 = failing. Roughly correlated with y_true: very low RUL → class 0 (failed/replaced), mid RUL → 1, high RUL → 2 in this toy script. Real C-MAPSS uses different thresholds per sub-dataset.
35y_pred, hp_logits = model(x)

ONE forward pass. The autograd graph extending from (backbone.W, backbone.b) → feat → both heads → both losses now exists in memory.

EXECUTION STATE
📚 model(x) = nn.Module's __call__ runs registered hooks then dispatches to .forward(x). Always use model(x), never model.forward(x), so hooks fire correctly.
y_pred (8,) = Predicted RUL — one float per sample. Random for an untrained model; replaced by real predictions after training.
hp_logits (8, 3) = Raw class scores — 3 floats per sample. Negative or positive, not normalised. F.cross_entropy below applies log-softmax internally.
→ autograd graph after this line = Both outputs carry .grad_fn pointers back through their head's Linear, the ReLU, the backbone Linear, all the way to the input x. Only the parameter leaves (backbone.W, backbone.b, head.W, head.b) have requires_grad=True — those are where gradients accumulate.
38# ---------- Inner axis: failure-biased MSE on RUL ----------

Section divider. The next two lines compute both task losses (L_rul_w, L_health) — these are the INPUTS to the OUTER axis (GABA controller) below.

EXECUTION STATE
What this section produces = Two 0-dim tensors L_rul_w and L_health, both with live autograd graphs back to the shared backbone parameters.
39L_rul_w = moderate_weighted_mse_loss(y_pred, y_true, max_rul=125.0)

Inner axis. Calls paper code; returns a 0-dim tensor with autograd hooks.

EXECUTION STATE
📚 moderate_weighted_mse_loss(pred, target, max_rul) = From grace/core/weighted_mse.py. Body: w = 1.0 + (1.0 - target/max_rul).clamp(0, 1); return (w * (pred - target).pow(2)).mean(). Pure tensor ops — autograd-safe.
⬇ arg: y_pred (8,) = Predictions from the forward pass. Has requires_grad=True via the head, so gradient flows back through here on .backward().
⬇ arg: y_true (8,) = Ground-truth RUL. Has requires_grad=False — backward does NOT flow through targets (correct — labels are not learnable).
⬇ arg: max_rul=125.0 = Same RUL cap as the NumPy demo. Keyword arg — explicit at the call site for readability.
L_rul_w = 0-dim tensor. Backward through this gives 2·w(y_j)·(y_pred_j - y_true_j)/N per output element.
40L_health = F.cross_entropy(hp_logits, hp_target)

Health cross-entropy. Same forward as L_rul_w — gradients on the shared backbone are coupled through `feat`.

EXECUTION STATE
📚 F.cross_entropy(input, target, reduction='mean') = Combined log_softmax + nll_loss. Inputs: input (B, C) raw logits, target (B,) int64 class indices. Returns scalar mean cross-entropy. Also accepts class weights, label smoothing, and ignore_index.
⬇ arg: hp_logits (8, 3) = Raw class scores from the health head. NOT softmax-normalised — passing softmaxed values would double-apply log-softmax and silently produce garbage gradients.
⬇ arg: hp_target (8,) = True class indices. Must be int64. Each value in [0, C-1] — F.cross_entropy will raise on out-of-range targets.
L_health = 0-dim tensor. Typical magnitude for a barely-trained 3-class softmax: ~ln(3) ≈ 1.10. Backward gives (softmax(hp_logits) - one_hot(hp_target)) / N.
→ why same forward? = Because L_rul_w and L_health both trace back through `feat`, the gradients on backbone parameters from the two tasks ARE comparable — a precondition for GABA to balance them.
43# ---------- Outer axis: GABA controller ----------

Section divider. The next four lines instantiate the GABA controller and run one full step (closed form → EMA → floor → renorm → combine).

EXECUTION STATE
What this section produces = L_GRACE = the composed loss. One scalar, ready for .backward(). The controller's internal state (ema_weights, step_count) is updated in-place as a side effect.
44gaba = GABALoss(beta=0.99, warmup_steps=0, min_weight=0.05, n_tasks=2)

Instantiate the OUTER controller. warmup_steps=0 disables the early-step uniform-weight phase so this single step does the full pipeline. min_weight=0.05 = paper default.

EXECUTION STATE
→ ema_weights buffer = Initialised to ones(2)/2 = (0.5, 0.5). Same starting point as the NumPy `prev_lam`.
→ step_count buffer = Initialised to 0. Incremented on every forward.
45shared = model.get_shared_params()

Backbone-only parameter list. Passed to GABA so it can compute g_i = ||dL_i/dtheta_shared||_2 internally.

EXECUTION STATE
shared (list) = [backbone.weight (6, 4), backbone.bias (6,)] — 2 nn.Parameter tensors. Both have requires_grad=True.
→ why pass to GABA? = GABA needs to compute ||dL_rul/dtheta_shared|| and ||dL_health/dtheta_shared|| separately. It calls torch.autograd.grad(L_i, shared, retain_graph=True) for each task, then norms the result.
→ why backbone-only? = Head parameters belong to a single task by construction — no balancing question to answer. Including them would dilute the gradient-imbalance signal.
47# One step of GABA: the controller computes per-task g_i internally,

First half of a two-line comment block describing what the next line does.

EXECUTION STATE
Comment role = Multi-line // explanation directly above the call site. Common Python style — keeps the why-this-call near the call itself, easier than scrolling to the function definition.
48# applies EMA + floor + renorm, and returns the COMBINED loss.

Second half of the two-line explanation. Reminder that this single function call hides the entire four-stage pipeline.

EXECUTION STATE
What gaba(...) hides = Closed form (Stage 1) → EMA on the buffer (Stage 2) → clamp(min=0.05) (Stage 3) → divide by sum (Stage 4) → weighted-sum the two losses (Stage 5: combine).
49L_GRACE = gaba(L_rul_w, L_health, shared_params=shared)

ONE call. GABA&apos;s forward (gaba.py:64) computes both per-task gradient norms, applies the closed form, EMA-smooths, floors, renormalises, and returns a properly-detached weighted sum λ*_rul · L_rul_w + λ*_health · L_health.

EXECUTION STATE
→ step 1 (closed form) = Inside gaba.forward_k: compute_task_grad_norm with retain_graph=True, twice. Then raw_weights = (S - g_i) / ((K-1) S).
→ step 2 (EMA) = ema_w = beta * ema_w + (1 - beta) * raw_weights. Then ema_weights buffer is updated in-place (.detach()).
→ step 3 (floor) = weights = ema_w.clamp(min=self.min_weight). Inactive on step 1 because both ema_w &gt; 0.05.
→ step 4 (renorm) = weights = weights / weights.sum().
→ step 5 (combine) = Return Σ_i weights[i] * losses[i]. The weights enter as DETACHED scalars — autograd treats them as constants.
L_GRACE = 0-dim tensor. .backward() flows through ONLY the loss values, not the gradient-norm computation.
52# ---------- Inspect the controller state ----------

Section divider. Below this comment we pull the inspection helpers off the GABA module — the per-task gradient norms, the raw lambdas, and the final lambda*. Useful for logging and debugging.

EXECUTION STATE
Why inspect? = Real training logs the four-stage pipeline to TensorBoard so a human can see when the floor activates, when the EMA converges, and whether the gradient ratio is stable.
53weights = gaba.get_weights()

Inspection helper. Returns a dict {&lsquo;rul_weight&rsquo;: float, &lsquo;health_weight&rsquo;: float} from the EMA buffer. Useful for logging — for actual training, the controller already used these internally.

EXECUTION STATE
📚 gaba.get_weights() = Reads the ema_weights buffer, returns Python floats (not tensors) so the dict is JSON-serialisable for loggers like W&B/TensorBoard.
weights (dict) = {'rul_weight': 0.4950, 'health_weight': 0.5050} after one step. → ~(0.0477, 0.9523) after convergence with floor active.
54grad_stats = gaba.get_gradient_stats()

Returns g_rul, g_health, raw_weight_rul, raw_weight_health from the LAST forward call (cached on the module). Lets you log the four-stage pipeline without re-computing.

EXECUTION STATE
📚 gaba.get_gradient_stats() = Reads cached values stamped during the last gaba.forward() call. Free to call (no autograd, no extra forward).
grad_stats (dict) = {'grad_norm_rul': 26.4016, 'grad_norm_health': 0.0378, 'raw_weight_rul': 0.0014, 'raw_weight_health': 0.9986} (illustrative — depends on init).
→ why cache? = compute_task_grad_norm is expensive (it calls torch.autograd.grad). Caching the per-step result means logging is free.
56print(f"L_rul^WMSE = {L_rul_w.item():.4f}")

Inner-axis value.

EXECUTION STATE
📚 .item() = Tensor → Python float. Only works on 0-dim tensors. Detaches from autograd automatically — calling .item() on a graph-tracked tensor is safe.
Output (illustrative) = L_rul^WMSE = ~50000 (magnitude depends on random init; ALGEBRA matches the NumPy demo)
57print(f"L_health = {L_health.item():.4f}")

Health-task value.

EXECUTION STATE
Output (illustrative) = L_health = ~1.10 (≈ ln 3 for an untrained 3-class softmax with random logits)
58print(f"g_rul = {grad_stats[grad_norm_rul]:.4f}")

Outer-axis sensor — measured per-task gradient norm.

EXECUTION STATE
📚 dict[key] = Plain Python subscript on the dict returned by get_gradient_stats(). Different syntax from grad_stats.grad_norm_rul (attribute) — dict has bracket access only.
Output (illustrative) = g_rul = ~30000 (huge — RUL labels are 5-110, MSE squares the error)
59print(f"g_health = {grad_stats[grad_norm_health]:.4f}")

Same for health.

EXECUTION STATE
Output (illustrative) = g_health = ~0.5 (small — cross-entropy gradients are bounded by O(1))
→ ratio = g_rul / g_health ≈ 60000 — the structural imbalance GABA exists to fix.
60print(f"raw_λ_rul = {grad_stats[raw_weight_rul]:.6f}")

Stage 1 output — pre-EMA closed form.

EXECUTION STATE
Output (illustrative) = raw_λ_rul = ~0.000017 (≈ g_health / (g_rul + g_health) — RUL gets vanishingly small raw weight because its gradient is already huge)
61print(f"final λ* = (...)")

Stage 4 output — what backward will multiply with each loss.

EXECUTION STATE
Output (illustrative) =
L_rul^WMSE = ~50000
L_health   = ~1.10
g_rul      = ~30000
g_health   = ~0.5
raw_λ_rul  = ~0.000017
final λ*   = (0.4950, 0.5050)
L_GRACE    = ~24750
→ caveat = Numbers depend on the random init of model.backbone. The four-stage ALGEBRA is what reproduces; the magnitudes track the seed.
62print(f"L_GRACE = {L_GRACE.item():.4f}")

The whole equation, one scalar. .backward() on this would update every parameter in the model with GRACE&apos;s composed gradient.

EXECUTION STATE
📚 L_GRACE.item() = 0-dim tensor → Python float. Equivalent to float(L_GRACE) but more idiomatic and faster (no detach).
Output (illustrative) = L_GRACE = ~24750 (≈ 0.495 · L_rul_w + 0.505 · L_health on step 1)
→ after this line = L_GRACE.backward() would populate .grad on every parameter in `model`. The optimiser step would then update them with the composed GRACE direction. The whole training loop is just: forward → gaba(...) → backward → step.
18 lines without explanation
1"""GRACE loss equation — production composition with the paper helpers."""
2
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6
7from grace.core.gaba           import GABALoss
8from grace.core.weighted_mse   import moderate_weighted_mse_loss
9
10torch.manual_seed(0)
11
12
13# ---------- Tiny dual-task model ----------
14class TinyDualHead(nn.Module):
15    def __init__(self):
16        super().__init__()
17        self.backbone    = nn.Linear(4, 6)
18        self.rul_head    = nn.Linear(6, 1)
19        self.health_head = nn.Linear(6, 3)
20
21    def forward(self, x):
22        feat = torch.relu(self.backbone(x))
23        return self.rul_head(feat).squeeze(-1), self.health_head(feat)
24
25    def get_shared_params(self):
26        return [p for n, p in self.named_parameters() if "head" not in n]
27
28
29# ---------- One mini-batch forward pass ----------
30model = TinyDualHead()
31x         = torch.randn(8, 4)
32y_true    = torch.tensor([10., 30., 60., 90., 110., 5., 15., 80.])
33hp_target = torch.tensor([0, 1, 2, 0, 2, 0, 0, 2])
34
35y_pred, hp_logits = model(x)
36
37
38# ---------- Inner axis: failure-biased MSE on RUL ----------
39L_rul_w  = moderate_weighted_mse_loss(y_pred, y_true, max_rul=125.0)
40L_health = F.cross_entropy(hp_logits, hp_target)
41
42
43# ---------- Outer axis: GABA controller ----------
44gaba = GABALoss(beta=0.99, warmup_steps=0, min_weight=0.05, n_tasks=2)
45shared = model.get_shared_params()
46
47# One step of GABA: the controller computes per-task g_i internally,
48# applies EMA + floor + renorm, and returns the COMBINED loss.
49L_GRACE = gaba(L_rul_w, L_health, shared_params=shared)
50
51
52# ---------- Inspect the controller state ----------
53weights      = gaba.get_weights()                        # final λ*
54grad_stats   = gaba.get_gradient_stats()                 # g_i + raw_λ_i
55
56print(f"L_rul^WMSE = {L_rul_w.item():.4f}")
57print(f"L_health   = {L_health.item():.4f}")
58print(f"g_rul      = {grad_stats['grad_norm_rul']:.4f}")
59print(f"g_health   = {grad_stats['grad_norm_health']:.4f}")
60print(f"raw_λ_rul  = {grad_stats['raw_weight_rul']:.6f}")
61print(f"final λ*   = ({weights['rul_weight']:.6f}, {weights['health_weight']:.6f})")
62print(f"L_GRACE    = {L_GRACE.item():.4f}")
Two non-obvious responsibilities of GABALoss.forward. First, it calls compute_task_grad_norm with retain_graph=True for every task, which means the SAME forward graph survives until the final .backward() call. Second, the EMA buffer is updated in-place on the module (line 131 of gaba.py: self.ema_weights[:K] = ema_w.detach()) so saving the controller's state is just saving its state_dict.

A Worked Numerical Example, End To End

The same 8-sample mini-batch as section 21·1, with every intermediate value at every stage. Inputs first, then the inner axis, then the four-stage outer pipeline, then the composed loss. Numbers come from the printed output of the NumPy script above.

Inner axis on this batch

jjyjy_jy^j\hat{y}_j(y^y)2(\hat{y}-y)^2w(yj)w(y_j)w(y^y)2w \cdot (\hat{y}-y)^2
01015251.9248.00
13026161.7628.16
26068641.5297.28
3908791.2811.52
411011241.124.48
557-71441.96282.24
61524811.88152.28
78074361.3648.96
=379\sum = 379=672.92\sum = 672.92

Standard MSE would be 379/8=47.375379/8 = 47.375; weighted MSE is 672.92/8=84.115672.92/8 = 84.115. The ratio 84.115/47.3751.7884.115/47.375 \approx 1.78 depends on the y-distribution of THIS batch — for a batch of all-healthy engines (every yj125y_j \geq 125) the two losses would be identical. The inner axis bites only when the batch contains failure-region samples.

Outer axis on this batch

StageOperationλrul\lambda_{\text{rul}}λhealth\lambda_{\text{health}}
1. closed formghgr+gh,grgr+gh\dfrac{g_h}{g_r + g_h},\quad \dfrac{g_r}{g_r + g_h}0.0014310.998569
2. EMA (β=0.99, prev=0.5, n=1\beta = 0.99,\ \text{prev} = 0.5,\ n = 1)0.990.5+0.01raw0.99 \cdot 0.5 + 0.01 \cdot \text{raw}0.4950140.504986
2′. EMA (n=n = \inftyconverged)raw\to \text{raw}0.0014310.998569
3. floor (ε=0.05\varepsilon = 0.05)max(EMA, ε)\max(\text{EMA},\ \varepsilon)0.0500000.998569
4. renorm÷\div \sum0.0476840.952316

The two EMA rows show the controller's two regimes. Early in training (n=1), λ\lambda^* is essentially uniform — GABA is still ‘learning’ the gradient balance. After convergence (n=∞), the floor activates, fixing λrul\lambda^*_{\text{rul}} at 0.0477 and λhealth\lambda^*_{\text{health}} at 0.9523. Production GRACE training spends most of its 500 epochs in the converged regime; the warmup_steps parameter (default 100) explicitly disables adaptation until enough EMA history exists.

Composed GRACE loss

LGRACE(t)  =  0.047784.115  +  0.95230.6069  =  4.011+0.578  =  4.589(t=).\mathcal{L}_{\text{GRACE}}(t) \;=\; 0.0477 \cdot 84.115 \;+\; 0.9523 \cdot 0.6069 \;=\; 4.011 + 0.578 \;=\; 4.589 \quad (t = \infty).

Compare this with the AMNL cell from section 21·1 (cell B): same inner WMSE, but fixed lambdas at (0.5,0.5)(0.5, 0.5) give 0.584.115+0.50.6069=42.360.5 \cdot 84.115 + 0.5 \cdot 0.6069 = 42.36. The OUTER axis shifts the optimisation target by an order of magnitude. That shift — mediated by the four-stage controller — is what produces the 224 NASA score on FD002 instead of 356.

The Same Equation In Other Fields

The shape iλi(t)1Njwjij\sum_i \lambda^*_i(t) \cdot \frac{1}{N}\sum_j w_j \ell_{ij} appears any time a learner has multiple competing objectives and non-uniform sample importance. A few canonical examples:

DomainTask indexii(outer)Sample weightwjw_j(inner)
Object detection (Faster R-CNN, YOLO)Class loss, bbox regression, objectnessUp-weight rare classes via focal loss; up-weight hard examples via OHEM
Speech recognition (RNN-T)CTC alignment, attention decoderUp-weight rare words and disfluent segments
Robot grasping (multi-modal)Force prediction, joint angles, success classificationUp-weight near-failure grasps where small errors flip success → failure
Drug-response modellingSurvival regression, cancer-subtype classificationUp-weight late-stage patients where prediction errors are most consequential
Climate downscalingTemperature, precipitation, wind targetsUp-weight extreme events (heatwaves, hurricanes) which dominate impact

In every row, replacing the outer factor with GABA and the inner weight with a domain-specific failure or rarity ramp produces a GRACE-shaped algorithm. The equation is universal; the choice of w()w(\cdot) is the domain knowledge.

Pitfalls When Reading The Equation

Pitfall 1: confusing the four lambdas

Four different lambda symbols appear in this section: λiraw\lambda^{\text{raw}}_i (closed form), λˉi\bar\lambda_i (EMA), λ~i\tilde\lambda_i (floored), λi\lambda^*_i (renormalised). Only λi\lambda^*_i multiplies the loss. The other three are intermediate. Logging the wrong one (especially λˉi\bar\lambda_i instead of λi\lambda^*_i) is the most common source of confusion when debugging GABA-style controllers — the controller's actual decision is post-floor, post-renorm.

Pitfall 2: thinkingw(yj)w(y_j)needs a gradient

w(yj)w(y_j) depends on yjy_j, the ground truth, not on y^j\hat y_j. No backward pass flows through the weight. Treating the weights as needing a gradient (e.g. by making them learnable) is a different algorithm — closer to attention-weighted MSE — and produces qualitatively different training dynamics.

Pitfall 3: forgetting the (1/N) factor

The mean over the mini-batch makes the inner term scale-invariant in batch size. If the implementation uses reduction='sum' instead of reduction='mean', the per-task gradient norms grow with batch size and the closed-form lambdas drift accordingly. The paper consistently uses mean reduction; check before changing the loss aggregation.

Pitfall 4: assuming the equation is differentiable in all symbols

It is differentiable in y^j\hat y_j and the health logits. It is not differentiable in yjy_j (one-sided clip), β\beta, or ε\varepsilon by design — these are controller hyperparameters, set per-experiment, never learned in the same loop as the model. Hyperparameter search exists in a separate outer loop (chapter 22 §2 covers the GRACE search protocol).

Takeaway

  • GRACE is one equation: LGRACE(t)=λrul(t)LrulWMSE(t)+λhealth(t)LCE\mathcal{L}_{\text{GRACE}}(t) = \lambda^*_{\text{rul}}(t) \cdot \mathcal{L}^{\text{WMSE}}_{\text{rul}}(t) + \lambda^*_{\text{health}}(t) \cdot \mathcal{L}_{\text{CE}}.
  • The inner factor is a per-sample failure ramp w(yj)=1+clip(1yj/125,0,1)w(y_j) = 1 + \mathrm{clip}(1 - y_j/125, 0, 1) times the squared residual, averaged.
  • The outer factor is the output of a four-stage controller: closed form → EMA → floor → renormalise. Only the final λi(t)\lambda^*_i(t) multiplies the loss.
  • The lambdas are detached before the multiply — autograd treats them as constants. Forgetting this turns GABA into a different algorithm (GradNorm-style meta-gradients) with double the memory.
  • On the 8-sample worked example: converged λ=(0.0477,0.9523)\lambda^* = (0.0477, 0.9523), inner WMSE = 84.12, composed LGRACE4.59\mathcal{L}_{\text{GRACE}} \approx 4.59 — an order of magnitude below the AMNL cell at the same batch.
Loading comments...