Chapter 20
15 min read
Section 80 of 121

GABA + Standard MSE Training Pipeline

Training GABA & Results

Hook: Two Steering Wheels, One Driver

Imagine an autonomous car with two controllers riding the same chassis: a lane-keeping module that nudges steering, and a cruise-control module that nudges throttle. Each module computes its own correction and pushes it onto the same actuators. If one module’s corrections are a thousand times larger than the other’s, the small one gets drowned out — the car holds its lane perfectly but fails to brake. The fix is not to make either controller smarter; it is to put a balancer between them that listens to how much each is pushing and re-weights their outputs so neither dominates.

That is exactly what this section’s pipeline does for the dual-task RUL model. Two heads — RUL regression and health classification — share one backbone. Their gradients differ by ~1000× in magnitude on C-MAPSS FD002. GABA is the balancer. Standard MSE is the deliberately plain RUL loss we keep so that every gain we report is attributable to the balancer alone, not to a fancy loss surface. The whole training pipeline is < 100 lines of Python; this section dissects it stage by stage.

What you will be able to do after this section: read the paper’s GABATrainer class top to bottom, point to the eight stages of the inner loop, name what each stage produces, and predict the failure mode if any one stage is omitted (including the actual bug that produced 6 corrupted experiments in the paper’s norm-ablation study).

What ”GABA + Standard MSE” Means Mathematically

For a batch of BB samples, the per-step optimisation target is the single scalar

L=λrulLrul+λhLhL = \lambda^*_{rul} \cdot L_{rul} + \lambda^*_{h} \cdot L_{h}

where the two task losses are Lrul=1Bi=1B(y^iyi)2L_{rul} = \tfrac{1}{B}\sum_{i=1}^{B} (\hat{y}_i - y_i)^2 (standard mean-squared error — no per-sample weighting) and Lh=1Bi=1Blogpi,yiL_{h} = -\tfrac{1}{B}\sum_{i=1}^{B} \log p_{i, y_i} (multi-class cross-entropy on three health states), and the weights (λrul,λh)(\lambda^*_{rul},\, \lambda^*_{h}) come from the GABA closed form on the per-task gradient norms grul,ghg_{rul},\, g_{h} measured on the shared backbone (NOT the heads):

gk=θsharedLk2,λk=jgjgk(K1)jgjg_k = \| \nabla_{\theta_{shared}} L_k \|_2,\quad \lambda_k = \frac{\sum_{j} g_j - g_k}{(K-1) \sum_{j} g_j}

followed by EMA smoothing with β=0.99\beta = 0.99, a floor at λmin=0.05\lambda_{\min} = 0.05, and renormalisation onto the simplex. For the first 100100 steps the gate forces uniform λ=(0.5,0.5)\lambda = (0.5, 0.5) while batch statistics settle.

That is the entire mathematical content of GABA + Standard MSE. Everything else in the next 90 lines of code is plumbing: how to schedule learning rate, when to evaluate, how to keep an EMA shadow of the model, when to halve the lr, when to early-stop. The eight stages below show how this math becomes one concrete training step.

Anatomy of One Training Step

Inside GABATrainer._train_epoch each batch runs through eight stages in strict order. Stages 1–3 produce raw losses. Stage 4 is the only stage that distinguishes GABA from a fixed-weight baseline. Stages 5–8 are standard PyTorch boilerplate — but each has a specific role and a specific failure mode.

#StageCode (one line)What it produces
1Forward passrul_pred, health_logits = model(seq)(B, 1) and (B, 3) tensors
2RUL loss (standard MSE)rul_loss = MSELoss()(rul_pred, rul_tgt)scalar ≈ 75 (at init)
3Health loss (CE)health_loss = CrossEntropy()(health_logits, health_tgt)scalar ≈ 0.37
4GABA combineloss = gaba_loss(L_rul, L_h, shared_params=...)scalar (with grad_fn)
5Backwardloss.backward().grad on every parameter
6Gradient clippingclip_grad_norm_(params, 1.0)clipped .grad in place
7AdamW stepoptimizer.step()params updated in place
8EMA shadow updateema.update(model)ema.shadow updated
Think of stages 1–4 as the ”measurement” phase (read the world, decide what to change) and stages 5–8 as the ”commit” phase (write the change back to parameters). The boundary between phases is the single scalar that pops out of stage 4.

Interactive: One Step, Eight Stages

Click any stage below to see its inputs, outputs, math, and the exact line in fix_gaba_norm_ablation.py. Press play auto-cycle to watch the pipeline step through automatically. Drag the lower slider to scrub through a synthesised 500-step training trace and watch (Lrul,Lh,λrul,λh)(L_{rul},\, L_{h},\, \lambda_{rul},\, \lambda_{h}) evolve. The amber band on the trace is the warmup period (uniform 0.5/0.5); the green band is the active adaptive period.

Loading training-step flow…

Notice three things while scrubbing: (a) during warmup λrul=0.5\lambda_{rul} = 0.5 exactly — no measurement is taken; (b) immediately after step 100 the EMA starts absorbing closed-form measurements at 1% per step, so the weights drift smoothly rather than jumping; (c) the long-run weight settles near λrul0.485\lambda_{rul} \approx 0.485, not at the closed-form target 0.0008\approx 0.0008, because EMA inertia + the floor at 0.05 + the renormalisation collectively bound how far the weights can drift.

Python From Scratch: A Minimal Trainer

Before reading the paper’s PyTorch code, walk through the same eight-stage pipeline in pure NumPy with the autograd substituted by hand-supplied gradient norms. This makes the structure visible without any framework magic. Click any line below to see its execution state.

Minimal GABA + standard MSE training loop (NumPy)
🐍gaba_minimal_trainer.py
1Module docstring (lines 1–7)

Sets the goal: a NumPy-only walkthrough of the trainer's inner loop so the structure is visible without PyTorch's autograd. We hand-feed the gradient norms (g_rul, g_h) because we cannot compute them without autograd; everything else — warmup gate, closed-form, EMA, floor, renormalise, weighted sum — is identical to the paper code at fix_gaba_norm_ablation.py:196-218.

EXECUTION STATE
what this file demonstrates = 1) the K=2 closed-form GABA combiner, 2) standard MSE (no sample weighting), 3) numerically-stable cross-entropy, 4) the per-step pipeline that ties them together.
9import numpy as np

NumPy is the only import. It provides the ndarray type and every math operation used in this file: np.array, np.maximum, np.exp, np.log, np.mean, np.arange.

EXECUTION STATE
📚 numpy = Numerical computing library for Python. Provides ndarray (N-dimensional array stored contiguously in memory), broadcasting rules, vectorised ufuncs, and a comprehensive math API. Every operation here runs as compiled C — no Python loop overhead, no autograd, no GPU.
as np = Aliases the numpy module to 'np' so calls become np.array(), np.exp(), etc. Universal Python convention since ~2010.
12class GABALossNumPy:

A minimal K=2 GABA combiner that mirrors the paper class line for line. We covered the full version in §18.5; this is the same logic with warmup_steps=2 so we exit warmup quickly in the demo. Plain Python class — does NOT subclass nn.Module because we have no learnable parameters and no autograd graph to register buffers into.

13Class docstring: "K=2 GABA combiner. Same logic as paper class, no autograd."

K=2 means we balance exactly two tasks (RUL regression + health classification). The paper class supports arbitrary K; here we hard-code K=2 for clarity. The 'no autograd' note means we skip the torch.autograd.grad call that the paper uses to compute g_rul / g_h — instead the test harness hand-supplies them so the rest of the pipeline can be visualised in plain NumPy.

15def __init__(self, beta=0.99, warmup_steps=2, min_weight=0.05):

Constructor. Stores three hyperparameters and initialises two state buffers (self.ema, self.t). Default values match paper defaults except warmup_steps (paper uses 100; demo uses 2 to exit warmup within 3 steps).

EXECUTION STATE
⬇ input: self = Newly-allocated GABALossNumPy instance — empty until we attach attributes inside __init__.
⬇ input: beta = 0.99 = EMA decay coefficient. Each active step blends 99% old weights + 1% new measurement. Time constant 1/(1−β) = 100 steps. Big enough to absorb per-step noise, small enough to track real gradient-norm shifts as training progresses.
⬇ input: warmup_steps = 2 = Steps with uniform 0.5/0.5 weights before the adaptive logic kicks in. Paper uses 100; demo uses 2 so step 3 of the demo is the first 'active' step. Lets the network get past the noisy initial epoch where gradient norms haven't stabilised.
⬇ input: min_weight = 0.05 = Floor for the post-EMA clamp. Guarantees no task ever gets fully ignored — even if the closed form drives a task's weight toward 0, the floor re-injects 5% so its loss still backprops.
⬆ returns = None — Python __init__ always returns None and mutates self in place.
16self.beta = beta

Store the EMA coefficient as an instance attribute so step() can read it on every call. Plain attribute assignment — no validation, no copying.

EXECUTION STATE
self.beta (after) = 0.99
17self.warmup_steps = warmup_steps

Store the warmup gate threshold. Used inside step() at line 24 to decide whether to take the warmup branch or the adaptive branch.

EXECUTION STATE
self.warmup_steps (after) = 2
18self.min_weight = min_weight

Store the floor used by np.maximum() inside step() at line 31. 0.05 means each task is guaranteed at least a 5% slice of the combined loss after the clamp+renormalise step.

EXECUTION STATE
self.min_weight (after) = 0.05
19self.ema = np.array([0.5, 0.5])

Persistent EMA-smoothed weights. Initialised at uniform 1/K so the very first active-step EMA update starts from the same place the warmup branch was producing — no jump in the loss curve when we exit warmup.

EXECUTION STATE
📚 np.array(list) = Constructor: builds a 1-D ndarray from a Python list. Different from np.asarray (which avoids the copy if input is already ndarray) and np.zeros/np.ones (which initialise from a fill value). Inferred dtype here is float64.
→ why uniform init? = Setting ema = [0.5, 0.5] means the first active step blends 99% of [0.5, 0.5] with 1% of the closed-form raw weights — a smooth handoff from warmup, not a discontinuity.
self.ema (init) = ndarray (2,) = [0.5, 0.5]
20self.t = 0

Plain int counter. Incremented at the top of every step() call. Used by the warmup gate at line 24. Starts at 0 so after the first increment t=1, which (with warmup_steps=2) is still warmup.

EXECUTION STATE
self.t (init) = 0
22def step(self, l_rul, l_h, g_rul, g_h):

One full GABA per-step update. Takes the two task losses + the two gradient norms and returns the combined scalar plus the weight vector used. This is the function the trainer calls once per batch in the inner loop (replaces the autograd-version GABALoss.forward() in the paper code).

EXECUTION STATE
⬇ input: self = Class instance. Provides self.beta (0.99), self.warmup_steps (2), self.min_weight (0.05), self.ema (the (2,) buffer), self.t (the step counter).
⬇ input: l_rul = Scalar Python float. RUL regression loss for this batch. In the demo this comes from mse() on line 39.
→ l_rul example values = step 1 → 75.00, step 2 → 59.25, step 3 → 31.75 (RUL predictions improve each step)
⬇ input: l_h = Scalar Python float. Health-classification loss for this batch (cross-entropy on 3 classes). Same value across steps in the demo because the logits are kept identical.
→ l_h example value = 0.3676 across all three demo steps
⬇ input: g_rul = Scalar Python float. L2 norm of grad(L_rul) on the shared backbone parameters. Hand-supplied here; in the paper code computed by torch.autograd.grad(L_rul, shared_params, retain_graph=True) and then .norm() over each tensor.
⬇ input: g_h = Scalar Python float. L2 norm of grad(L_h) on the shared backbone. Same shared_params list as g_rul — only the loss differs.
→ realistic ratio = g_rul ≈ 250–280, g_h ≈ 0.20–0.22 → ratio ~1250×. This is the chapter's whole motivation: the regression gradient dominates the classification gradient by 3 orders of magnitude.
⬆ returns = Tuple (total: float, w: ndarray (2,)). 'total' is the scalar to backprop through; 'w' is the lambda* vector used (kept for logging / monitoring).
23self.t += 1

Increment FIRST so the gate at line 24 compares the post-increment value. After the first call self.t = 1; after the second call self.t = 2; etc. Equivalent to self.t = self.t + 1.

EXECUTION STATE
self.t (call 1, after) = 1
self.t (call 2, after) = 2
self.t (call 3, after) = 3
24if self.t <= self.warmup_steps:

Warmup gate. Inclusive ≤ so step 2 is still warmup (t=2 ≤ warmup_steps=2 → True), step 3 is the first active step (t=3 ≤ 2 → False).

EXECUTION STATE
<= vs < = Inclusive comparison. With warmup_steps=2 we want exactly 2 warmup calls, then the third call enters the adaptive branch. Using < instead of <= would only give us 1 warmup call.
evaluation table = call 1: 1 ≤ 2 → True (warmup) call 2: 2 ≤ 2 → True (warmup) call 3: 3 ≤ 2 → False (active)
25w = np.array([0.5, 0.5])

Warmup branch. Uniform 0.5/0.5 — identical to a fixed-weight baseline. Crucially, the EMA buffer (self.ema) is NOT updated during warmup, so when we eventually exit warmup the EMA still holds its [0.5, 0.5] init value.

EXECUTION STATE
w (warmup) = [0.5, 0.5]
→ why fresh np.array each call? = Cheap allocation; avoids any chance of an outside caller mutating the warmup constant. Could equivalently use a class-level constant, but a per-call literal keeps the branch self-contained.
26else: (adaptive / active branch)

Entered when t > warmup_steps (step 3 of the demo). Runs the full GABA pipeline: pack gradient norms → closed form → EMA smoothing → floor → renormalise. All of lines 27–32 belong to this branch.

EXECUTION STATE
branch invariants = self.ema still equals its previous value (NOT updated during warmup). self.t is now ≥ warmup_steps + 1.
27g = np.array([g_rul, g_h])

Pack the two scalar gradient norms into a (2,) ndarray so the rest of the pipeline can use vectorised math (subtraction, division, element-wise max).

EXECUTION STATE
📚 np.array([g_rul, g_h]) = Build a 1-D ndarray of dtype float64 by default. Inferred from the two Python floats. Shape (2,).
g (step 3 demo) = [280.0, 0.22]
→ ratio = 280 / 0.22 ≈ 1273× — RUL gradient dominates the health gradient by 3 orders of magnitude. Without GABA, the health head would be effectively ignored.
28tot = g.sum() + 1e-12

Sum of the two gradient norms plus a tiny epsilon to avoid divide-by-zero on the (rare) all-zero step. Used as the denominator in the closed-form raw-weight calculation on the next line.

EXECUTION STATE
📚 ndarray.sum() = Reduction method on ndarray — collapses all elements to a single scalar (Python float here). For a (2,) array, sum() returns g[0] + g[1]. Equivalent to np.sum(g).
⬇ arg: g = [280.0, 0.22] from line 27
g.sum() (step 3) = 280.0 + 0.22 = 280.22
+ 1e-12 = Numerical guard. Without it, an all-zero gradient (stationary point) would divide by zero on the next line. 1e-12 is small enough to be invisible in normal training (loss of precision ~10⁻¹⁵) but bulletproof against the edge case.
tot (step 3) = 280.22 + 1e-12 ≈ 280.22
29raw = (tot - g) / tot

K=2 GABA closed form. NumPy broadcasts the scalar tot against the (2,) vector g to produce a (2,) result. Smaller g_i → larger numerator → larger weight (notice the SWAP — GABA up-weights the task with the SMALLER gradient).

EXECUTION STATE
📚 broadcasting = NumPy automatically lifts a scalar to match the shape of an ndarray. Here tot (scalar) − g ((2,)) yields a (2,) result with element-wise subtraction. Same for the / division.
tot − g (step 3) = 280.22 − [280.0, 0.22] = [0.22, 280.0]
→ notice the SWAP = g was [280.0, 0.22]; tot − g is [0.22, 280.0] — the small value is now in slot 0 and the large value in slot 1. This is what makes GABA up-weight the small-gradient task.
raw (step 3) = [0.22, 280.0] / 280.22 = [0.000785, 0.999215]
→ why up-weight the small task? = The dominant-gradient task already drives learning hard; the small-gradient task needs amplification to be heard. After GABA the SMALL-gradient task contributes ~99.9% of the combined loss surface — exactly the opposite of what an unweighted sum would do.
30self.ema = self.beta * self.ema + (1 - self.beta) * raw

Convex combination: 99% of the previous EMA + 1% of the new raw weights. Smooths per-step noise so a single noisy batch can't swing the weights by much. First active step: ema = 0.99·[0.5, 0.5] + 0.01·[0.000785, 0.999215].

EXECUTION STATE
📚 EMA formula = ema_new = β·ema_old + (1−β)·measurement_new. Equivalent to a low-pass filter with time constant 1/(1−β) ≈ 100 steps for β=0.99. Pure scalar math broadcast over the (2,) vector.
self.ema (before, step 3) = [0.5, 0.5] (untouched during warmup)
0.99 · ema_prev = 0.99 · [0.5, 0.5] = [0.495, 0.495]
0.01 · raw = 0.01 · [0.000785, 0.999215] = [0.00000785, 0.00999215]
self.ema (after, step 3) = [0.495008, 0.504992]
→ barely off uniform = After ONE active step the EMA has moved less than 0.5% off uniform — by design. Over a real 500-step training run with sustained 1000× imbalance, ema settles around [0.485, 0.515].
31clamped = np.maximum(self.ema, self.min_weight)

Element-wise floor at 0.05. For freshly-out-of-warmup ema ≈ [0.495, 0.505], the floor is a no-op (both above 0.05). Becomes active only if the EMA decays one component below 0.05, which can happen on extremely long runs with persistent imbalance.

EXECUTION STATE
📚 np.maximum(a, b) = Element-wise max. Returns an ndarray where result[i] = max(a[i], b[i]). When b is a scalar (here 0.05), it broadcasts. NOT the same as np.max which REDUCES an array to a single scalar.
⬇ arg 1: self.ema = [0.495008, 0.504992] from line 30
⬇ arg 2: self.min_weight = 0.05 (Python float, broadcast)
→ element-wise eval = max(0.495008, 0.05) = 0.495008 ✓ max(0.504992, 0.05) = 0.504992 ✓
clamped (step 3) = [0.495008, 0.504992] (no clamping triggered)
→ why a floor at all? = Insurance. If the closed form drives one weight near 0 for many steps in a row, the EMA will eventually drift below 0.05. Without the floor, that task could be effectively dropped from training. With the floor, every task always contributes ≥ 5% post-renormalisation.
32w = clamped / clamped.sum()

Renormalise to the simplex (sum = 1). After clamp the sum can drift above the original simplex sum (if a value was floored UP); this restores the simplex constraint so weights are a proper probability distribution.

EXECUTION STATE
📚 division by reduction = clamped is (2,); clamped.sum() is a scalar. NumPy broadcasts the scalar over the vector for element-wise division.
clamped.sum() (step 3) = 0.495008 + 0.504992 = 1.0 (no-op renormalisation here)
w (step 3, active) = [0.495008, 0.504992] / 1.0 = [0.495008, 0.504992] — sum = 1.0 ✓
→ when does this matter? = On long runs where the floor at line 31 is actually triggered. e.g., if clamped became [0.05, 0.97] (sum 1.02 > 1), w would be [0.0490, 0.9510] (sum 1.0).
33total = float(w[0] * l_rul + w[1] * l_h)

Weighted sum → single scalar. This is the value that PyTorch would call .backward() on. Conversion to Python float for clean printing and to make sure we hand back a scalar (not a 0-D ndarray).

EXECUTION STATE
📚 float(ndarray_scalar) = Forces conversion of a 0-D ndarray to a Python float. Avoids the subtle bug of 'total' being a 0-D ndarray that prints differently and may carry surprising dtype.
step 1 (warmup) calculation = 0.5 · 75.00 + 0.5 · 0.3676 = 37.5000 + 0.1838 = 37.6838
step 2 (warmup) calculation = 0.5 · 59.25 + 0.5 · 0.3676 = 29.6250 + 0.1838 = 29.8088
step 3 (active) calculation = 0.495008 · 31.75 + 0.504992 · 0.3676 = 15.7165 + 0.1856 = 15.9021
34return total, w

Return both the scalar (for backward) and the weight vector (for logging / monitoring / TensorBoard). Python implicitly packs into a tuple.

EXECUTION STATE
⬆ return: total = Python float — the GABA-combined scalar to backprop through.
⬆ return: w = ndarray (2,) — the lambda* used this step. Logged so we can plot weight evolution alongside losses.
37def mse(pred, tgt):

Standard mean-squared error. The "standard" in "GABA + standard MSE": no per-sample weighting, no Huber, no NASA-aware kernel. Plain MSE — that is what makes the chapter's headline result attributable to GABA alone (not to a fancy loss surface).

EXECUTION STATE
⬇ input: pred = ndarray (B,). Predicted RUL values. In the live model these are clamped to [0, 125] cycles.
⬇ input: tgt = ndarray (B,). Ground-truth RUL values, also clamped to 125 (piecewise-linear convention).
⬆ returns = Python float — scalar mean-squared error.
39return float(np.mean((pred - tgt) ** 2))

Element-wise diff, square, then mean. Wrapped in float() for clean printing. Three vectorised NumPy operations replace the textbook Python loop.

EXECUTION STATE
📚 (pred - tgt) = Element-wise subtraction. Both ndarray (B,) → result ndarray (B,). No copy of pred or tgt.
📚 ** 2 = Element-wise square (Python's exponentiation operator, overridden by NumPy). Equivalent to np.square(...).
📚 np.mean(arr) = Reduction: sum(arr) / len(arr). For ndarray (B,), returns a scalar. axis defaults to None (full reduction).
── Step 1 trace ── =
pred = [95., 60., 30., 110.]
tgt = [100., 65., 25., 125.]
pred − tgt = [−5., −5., 5., −15.]
(pred − tgt)² = [25., 25., 25., 225.]
MSE = (25 + 25 + 25 + 225) / 4 = 300 / 4 = 75.00
── Step 2 trace ── =
pred = [96., 61., 31., 112.]
(pred − tgt)² = [16., 16., 36., 169.]
MSE = (16 + 16 + 36 + 169) / 4 = 237 / 4 = 59.25
── Step 3 trace ── =
pred = [97., 62., 28., 115.]
(pred − tgt)² = [9., 9., 9., 100.]
MSE = (9 + 9 + 9 + 100) / 4 = 127 / 4 = 31.75
42def cross_entropy(logits, tgt):

Numerically-stable multi-class cross-entropy. Used for the 3-class health task (Normal / Early / Critical). Equivalent to torch.nn.functional.cross_entropy without the autograd graph.

EXECUTION STATE
⬇ input: logits = ndarray (B, C). Pre-softmax scores from the health head. C = 3 classes here. Values can be any real number — the function applies softmax internally.
→ demo logits (4×3) =
[[ 2.0, -1.0,  0.5],
 [ 0.5,  1.5, -0.5],
 [-0.5,  0.5,  2.0],
 [ 1.0,  0.5, -1.0]]
⬇ input: tgt = ndarray (B,) of int class indices in {0, 1, 2}. Derived from RUL via 30-cycle / 80-cycle thresholds in the live data pipeline.
→ demo tgt = [0, 1, 2, 0] (sample 0 is class 0, sample 1 is class 1, ...)
⬆ returns = Python float — mean negative log-likelihood across the batch.
44z = logits - logits.max(axis=1, keepdims=True)

Subtract per-row max for numerical stability before exp. Mathematically a no-op (softmax is shift-invariant: softmax(x) = softmax(x − c)); numerically prevents overflow when logits are large.

EXECUTION STATE
📚 ndarray.max(axis, keepdims) = Reduction along an axis. Returns the max value(s). With keepdims=True the reduced axis stays as size-1 instead of disappearing — this matters for broadcasting.
⬇ arg: axis=1 = Reduce along the class dimension. For each row (sample), find the max across the 3 classes. axis=0 would reduce across samples; axis=None would reduce to a single scalar.
⬇ arg: keepdims=True = Result shape (B, 1) instead of (B,). Without it, broadcasting logits(B,C) − rowmax(B,) would fail (dimensions don't align). With it, broadcast lifts (B,1) to (B,C) by repeating the column.
logits.max(axis=1, keepdims=True) =
[[2.0],
 [1.5],
 [2.0],
 [1.0]]  shape (4, 1)
z (after subtraction, demo) =
[[ 0.0, -3.0, -1.5],
 [-1.0,  0.0, -2.0],
 [-2.5, -1.5,  0.0],
 [ 0.0, -0.5, -2.0]]
→ invariant = Every row of z now has max = 0. exp(z) is therefore in (0, 1] — never overflows.
45log_p = z - np.log(np.exp(z).sum(axis=1, keepdims=True))

Log-softmax via the log-sum-exp identity: log p_i = z_i − log Σ_j exp(z_j). One numerically-stable line replacing softmax-then-log.

EXECUTION STATE
📚 np.exp(z) = Element-wise exponential. exp(0)=1, exp(−1)≈0.368, exp(−3)≈0.050. Returns ndarray with same shape as z.
📚 np.log(x) = Element-wise natural log (base e). log(1)=0, log(e)=1. Used here on the row sums.
📚 .sum(axis=1, keepdims=True) = Same axis/keepdims pattern as line 44 — sum across classes, keep result shape (B, 1) for broadcasting against z (B, C).
── Row 0: logits=[2, −1, 0.5], target=0 ── =
z[0] = [0.0, −3.0, −1.5]
exp(z[0]) = [1.0000, 0.0498, 0.2231]
Σ exp(z[0]) = 1.2729
log Σ = 0.2413
log_p[0] = [−0.2413, −3.2413, −1.7413]
── Row 1: logits=[0.5, 1.5, −0.5], target=1 ── =
log_p[1] = [−1.4076, −0.4076, −2.4076]
── Row 2: logits=[−0.5, 0.5, 2.0], target=2 ── =
log_p[2] = [−2.7664, −1.7664, −0.2664]
── Row 3: logits=[1.0, 0.5, −1.0], target=0 ── =
log_p[3] = [−0.5550, −1.0550, −2.5550]
46return float(-np.mean(log_p[np.arange(len(tgt)), tgt]))

Fancy-index pulls log_p[i, tgt[i]] for each i — the log-probability of the TRUE class for each sample. Negate to get NLL, mean to reduce, float() to coerce to Python scalar.

EXECUTION STATE
📚 np.arange(n) = Build a 1-D ndarray [0, 1, ..., n-1]. Used here as the row-index column for fancy indexing.
📚 fancy indexing = When two arrays of equal length are passed as indices, NumPy pairs them element-wise. log_p[[0,1,2,3], [0,1,2,0]] picks log_p[0,0], log_p[1,1], log_p[2,2], log_p[3,0]. Returns a 1-D array of length 4.
📚 -np.mean(...) = Sum the four log-probabilities, divide by 4 (batch size), negate. Standard NLL definition.
len(tgt) = 4 — batch size
np.arange(len(tgt)) = [0, 1, 2, 3]
tgt = [0, 1, 2, 0]
values pulled (log_p[i, tgt[i]]) = [−0.2413, −0.4076, −0.2664, −0.5550]
sum = −1.4702
mean = −1.4702 / 4 = −0.3676
⬆ return: CE = −(−0.3676) = 0.3676
49# ---------- Five fake training steps ----------

Visual section break. The next ~20 lines instantiate the GABA combiner and build a hand-crafted trace of three batches, then run the per-step pipeline. (The comment says 'five' historically; the current trace has 3 — enough to show one warmup pair plus one active step.)

50gaba = GABALossNumPy(beta=0.99, warmup_steps=2, min_weight=0.05)

Instantiate the combiner. Calls __init__ on line 15 with all-default values shown explicitly. After this line: gaba.ema = [0.5, 0.5], gaba.t = 0.

EXECUTION STATE
after construction = gaba.beta = 0.99, gaba.warmup_steps = 2, gaba.min_weight = 0.05, gaba.ema = [0.5, 0.5], gaba.t = 0
51trace = [...] (lines 51–62)

Hand-built sequence of 3 batches with mildly improving predictions and a 1250×→1273× gradient imbalance — the realistic scale of g_rul vs g_health on C-MAPSS FD002 (paper Table 4). Each tuple has six fields.

EXECUTION STATE
tuple format = (rul_pred, rul_tgt, health_logits, health_tgt, g_rul, g_h)
step 1 row = rul_pred=[95, 60, 30, 110]; same logits; g=(250.0, 0.20) — ratio 1250×
step 2 row = rul_pred=[96, 61, 31, 112]; predictions nudge closer; g=(260.0, 0.21)
step 3 row = rul_pred=[97, 62, 28, 115]; closer still; g=(280.0, 0.22) — first ACTIVE step (post-warmup)
→ why same logits across steps? = We are isolating the GABA pipeline. Keeping logits constant means health_loss is constant across steps (0.3676), so any movement in 'total' comes from rul_loss or from the GABA weights — not from the health head.
64for step, (rp, rt, hl, ht, g_r, g_h) in enumerate(trace, start=1):

Iterate the trace with a 1-indexed counter. enumerate(iter, start=1) yields (1, item₀), (2, item₁), (3, item₂). Tuple-unpacking on the inner tuple gives us six named values per iteration.

LOOP TRACE · 3 iterations
step 1 (warmup, t→1)
rul_loss = MSE([95,60,30,110], [100,65,25,125]) = 75.00
health_loss = CE(logits, [0,1,2,0]) = 0.3676
branch = warmup (t=1 ≤ 2)
w = [0.5, 0.5]
self.ema (unchanged) = [0.5, 0.5]
total = 0.5 · 75.00 + 0.5 · 0.3676 = 37.6838
step 2 (warmup, t→2)
rul_loss = MSE([96,61,31,112], [100,65,25,125]) = 59.25
health_loss = 0.3676
branch = warmup (t=2 ≤ 2)
w = [0.5, 0.5]
self.ema (still unchanged) = [0.5, 0.5]
total = 0.5 · 59.25 + 0.5 · 0.3676 = 29.8088
step 3 (FIRST active, t→3)
rul_loss = MSE([97,62,28,115], [100,65,25,125]) = 31.75
health_loss = 0.3676
branch = active (t=3 > 2)
g (raw norms) = [280.0, 0.22]
tot = 280.22
raw closed form = [0.000785, 0.999215]
self.ema (1st update) = 0.99·[0.5,0.5] + 0.01·[0.000785, 0.999215] = [0.495008, 0.504992]
clamp at 0.05 = no-op (both > 0.05)
renormalise = sum already 1.0 → unchanged
w (active) = [0.495008, 0.504992]
total = 0.495008 · 31.75 + 0.504992 · 0.3676 = 15.9021
65rul_loss = mse(rp, rt)

Stage 2 of the per-step pipeline: standard MSE on the RUL predictions. Calls mse() defined on line 37.

EXECUTION STATE
rul_loss per step = step 1 → 75.00 step 2 → 59.25 step 3 → 31.75
66health_loss = cross_entropy(hl, ht)

Stage 3: cross-entropy on the health logits. Calls cross_entropy() defined on line 42.

EXECUTION STATE
health_loss per step = 0.3676 every step (logits and tgt are intentionally identical across the trace so we can isolate GABA's behaviour)
67total, w = gaba.step(rul_loss, health_loss, g_r, g_h)

Stage 4: GABA combine. Calls GABALossNumPy.step() which (a) increments t, (b) checks warmup, (c) on active branch computes raw closed-form weights from gradient norms, (d) updates the EMA, (e) floors at min_weight, (f) renormalises, (g) returns the weighted scalar plus the weight vector.

EXECUTION STATE
(total, w) per step = step 1 → (37.6838, [0.5, 0.5]) step 2 → (29.8088, [0.5, 0.5]) step 3 → (15.9021, [0.495008, 0.504992])
68print formatted row (lines 68–69)

Aligned f-string spread over two physical lines (Python concatenates adjacent string literals). Watch the w pair: [0.5, 0.5] for steps 1–2 (warmup), then drift past 0.5 in step 3 as the EMA absorbs its first measurement.

EXECUTION STATE
📚 f-string format specs = :6.2f → 6-char wide, 2 decimal places (right-aligned, padded) :.4f → 4 decimal places :8.4f → 8-char wide, 4 decimals
Captured stdout =
step 1 | rul= 75.00  health=0.3676  w=(0.5000,0.5000)  total= 37.6838
step 2 | rul= 59.25  health=0.3676  w=(0.5000,0.5000)  total= 29.8088
step 3 | rul= 31.75  health=0.3676  w=(0.4950,0.5050)  total= 15.9021
71print(f" final ema = {gaba.ema}")

After 1 active step the EMA has barely moved off uniform — by design (β=0.99 absorbs only 1% of each new measurement). Over a real 500-step training run with sustained 1000× imbalance the final EMA settles around [0.485, 0.515].

EXECUTION STATE
📚 = Newline escape inside the f-string — prints a blank line before 'final ema = ...' to visually separate the loop output from the summary.
Final captured stdout = final ema = [0.49500785 0.50499215]
→ why so close to uniform? = Only ONE active step has elapsed. EMA inertia (β=0.99) means it would take ~100 active steps for the EMA to absorb half of any sustained signal. The full 500-step run shows a much larger drift.
33 lines without explanation
1"""Minimal GABA + standard MSE trainer in pure NumPy.
2
3A simulated training loop showing the per-step pipeline of
4GABATrainer._train_epoch (paper_ieee_tii/experiments/
5fix_gaba_norm_ablation.py:196-218) without PyTorch autograd.
6We hand-supply the gradient norms; the structure is identical.
7"""
8
9import numpy as np
10
11
12class GABALossNumPy:
13    """K=2 GABA combiner. Same logic as paper class, no autograd."""
14
15    def __init__(self, beta=0.99, warmup_steps=2, min_weight=0.05):
16        self.beta = beta
17        self.warmup_steps = warmup_steps
18        self.min_weight = min_weight
19        self.ema = np.array([0.5, 0.5])
20        self.t = 0
21
22    def step(self, l_rul, l_h, g_rul, g_h):
23        self.t += 1
24        if self.t <= self.warmup_steps:
25            w = np.array([0.5, 0.5])
26        else:
27            g = np.array([g_rul, g_h])
28            tot = g.sum() + 1e-12
29            raw = (tot - g) / tot
30            self.ema = self.beta * self.ema + (1 - self.beta) * raw
31            clamped = np.maximum(self.ema, self.min_weight)
32            w = clamped / clamped.sum()
33        total = float(w[0] * l_rul + w[1] * l_h)
34        return total, w
35
36
37def mse(pred, tgt):
38    """Standard MSE — no sample weighting, no Huber, no kernel."""
39    return float(np.mean((pred - tgt) ** 2))
40
41
42def cross_entropy(logits, tgt):
43    """Multi-class CE. logits (B, C); tgt (B,) of int class indices."""
44    z = logits - logits.max(axis=1, keepdims=True)
45    log_p = z - np.log(np.exp(z).sum(axis=1, keepdims=True))
46    return float(-np.mean(log_p[np.arange(len(tgt)), tgt]))
47
48
49# ---------- Five fake training steps ----------
50gaba = GABALossNumPy(beta=0.99, warmup_steps=2, min_weight=0.05)
51trace = [
52    # (rul_pred,            rul_tgt,            health_logits,     health_tgt, g_rul, g_h)
53    (np.array([95., 60., 30., 110.]), np.array([100., 65., 25., 125.]),
54     np.array([[2., -1., .5], [.5, 1.5, -.5], [-.5, .5, 2.], [1., .5, -1.]]),
55     np.array([0, 1, 2, 0]), 250.0, 0.20),
56    (np.array([96., 61., 31., 112.]), np.array([100., 65., 25., 125.]),
57     np.array([[2., -1., .5], [.5, 1.5, -.5], [-.5, .5, 2.], [1., .5, -1.]]),
58     np.array([0, 1, 2, 0]), 260.0, 0.21),
59    (np.array([97., 62., 28., 115.]), np.array([100., 65., 25., 125.]),
60     np.array([[2., -1., .5], [.5, 1.5, -.5], [-.5, .5, 2.], [1., .5, -1.]]),
61     np.array([0, 1, 2, 0]), 280.0, 0.22),
62]
63
64for step, (rp, rt, hl, ht, g_r, g_h) in enumerate(trace, start=1):
65    rul_loss = mse(rp, rt)
66    health_loss = cross_entropy(hl, ht)
67    total, w = gaba.step(rul_loss, health_loss, g_r, g_h)
68    print(f"step {step} | rul={rul_loss:6.2f}  health={health_loss:.4f}  "
69          f"w=({w[0]:.4f},{w[1]:.4f})  total={total:8.4f}")
70
71print(f"\nfinal ema = {gaba.ema}")

The output is three rows: the first two are warmup (weights pinned at 0.5/0.5); the third is the first active step (weights barely off uniform — that is the EMA at work). Run it 500 times with a steady 1000× imbalance and the weights settle around (0.485,0.515)(0.485, 0.515) — exactly what the slider in the visualisation above shows.

PyTorch: The Paper’s GABATrainer Verbatim

Below is the actual constructor and inner-loop method from paper_ieee_tii/experiments/fix_gaba_norm_ablation.py (lines 125–218). Every line is annotated. The CRITICAL FIX comment is the real bug: an earlier version of the norm-ablation script omitted the shared_params= kwarg, which made GABALoss silently fall back to fixed 0.5/0.5 weights and corrupted six experiments before it was caught. The fix file’s docstring documents this in detail.

GABATrainer — paper code, lines 125-218
🐍fix_gaba_norm_ablation.py
1Module docstring (lines 1–6)

Names the source file and the contribution: this trainer is the inner loop of GABA + standard MSE. The CRITICAL FIX line is the actual paper bug — earlier code paths called gaba_loss(rul, health) WITHOUT shared_params, and GABALoss silently fell back to fixed 0.5/0.5 weights, making GABA experimentally indistinguishable from a plain baseline. fix_gaba_norm_ablation.py exists solely to re-run the corrupted experiments with the kwarg in place.

8import copy

Standard-library copy module. Used elsewhere in the trainer (not shown in this excerpt) to deep-copy the best model state dict on improvement so later in-place updates don't corrupt the snapshot.

EXECUTION STATE
📚 copy.deepcopy(obj) = Recursively duplicates every Python object inside obj. For a state_dict this duplicates every tensor — vital because PyTorch's optimiser.step() mutates parameters in place; a shallow copy would alias the live weights.
9import torch

PyTorch core. Provides torch.Tensor, autograd, device handling (cuda / cpu / mps), torch.no_grad context, and the entry point to the rest of the framework.

EXECUTION STATE
📚 torch = Tensor library with automatic differentiation. Tensors are GPU-capable ndarrays that record operations into a dynamic graph; calling .backward() walks the graph in reverse to populate .grad on every leaf.
10import torch.nn as nn

Neural-network building blocks. Provides nn.Module (base class with parameter registration), and pre-built layers like nn.Linear, nn.Conv1d, nn.LSTM, nn.MultiheadAttention, nn.LayerNorm, plus loss criteria nn.MSELoss and nn.CrossEntropyLoss used here.

EXECUTION STATE
📚 nn.MSELoss = Mean-squared-error loss as an nn.Module. forward(pred, tgt) → scalar. Default reduction='mean'. Equivalent to mean((pred − tgt)²) but properly registered for torch.compile and DDP.
📚 nn.CrossEntropyLoss = Combines log-softmax + NLL in one numerically-stable call. forward(logits, tgt_int) → scalar. Inputs: logits (B, C) raw scores; targets (B,) int class indices. Internally does the same log-sum-exp trick we wrote by hand in the NumPy code.
11import torch.optim as optim

Optimisation algorithms (SGD, Adam, AdamW, RMSprop) and learning-rate schedulers (under optim.lr_scheduler). The trainer uses AdamW + ReduceLROnPlateau.

EXECUTION STATE
📚 optim.AdamW = Adam with DECOUPLED weight decay (Loshchilov & Hutter 2019). The decay term is applied to parameters directly rather than added to the gradient — interacts more cleanly with adaptive learning rates than classic L2 regularisation.
📚 optim.lr_scheduler.ReduceLROnPlateau = Watches a metric (test-RMSE here). After 'patience' epochs without improvement, multiplies lr by 'factor'. Continues until min_lr. Adaptive — no schedule to hand-tune.
14class GABATrainer:

Plain Python class — NOT nn.Module. Trainers don't need to register parameters or be moved to device; they only orchestrate the model, optimiser, criteria, and EMA tracker. Holds references; does not own them.

15def __init__(self, model, rul_criterion, health_criterion, gaba_loss, optimizer, scheduler, device, config): (lines 15–16)

Constructor. Six PyTorch objects + a device + a config dict. Notice gaba_loss is a separate object (not buried inside the optimiser) so its 2 buffers (ema_weights, step_count) can be checkpointed alongside the model.

EXECUTION STATE
⬇ self = Newly-allocated GABATrainer instance — empty until we attach attributes inside __init__.
⬇ model = GABAModel (DualTaskEnhancedModel + get_shared_params). Has both rul_head and health_head sitting on a shared backbone (Conv1d → BN → BiLSTM → MultiheadAttention → fc trunk).
⬇ rul_criterion = nn.MSELoss() — the "standard" MSE. No per-sample weighting; that's the chapter's headline simplification — every gain is attributable to GABA, not to a fancy loss.
⬇ health_criterion = nn.CrossEntropyLoss() — combines log-softmax + NLL in one numerically-stable call. Eats raw logits (B, 3) and int targets (B,).
⬇ gaba_loss = GABALoss(beta=0.99, warmup_steps=100, min_weight=0.05). The K=2 adaptive combiner from §18.5 — the PyTorch version of the GABALossNumPy class above.
⬇ optimizer = AdamW(all_params, lr=1e-3, weight_decay=1e-4). all_params = model.parameters() + gaba_loss.parameters() so the EMA buffer travels with the optimiser state for clean checkpointing.
⬇ scheduler = ReduceLROnPlateau(mode="min", factor=0.5, patience=30, min_lr=5e-6). Halves lr after 30 epochs without test-RMSE improvement.
⬇ device = torch.device("cuda" if torch.cuda.is_available() else "cpu"). Single device handle reused everywhere to avoid CPU↔GPU shuffles.
⬇ config = Dict with epochs, patience, grad_clip, warmup_epochs, lr, ema_decay. Optional — falls back to defaults if None.
⬆ returns = None — Python __init__ always returns None.
17self.model = model.to(device)

Move every parameter and buffer of the model onto the target device once. .to() returns self, so subsequent calls on the model do not need .to(device) again.

EXECUTION STATE
📚 nn.Module.to(device) = Recursively moves every Parameter and registered Buffer of the module (and submodules) to device. In-place for nn.Module (returns self). Note: tensors created later (e.g., a freshly allocated mask) need their own .to(device) call.
18self.rul_criterion = rul_criterion

Keep nn.MSELoss() — applied in stage 2 of the inner loop (line 45). Plain reference; the criterion has no learnable parameters so we don't need to .to(device) it.

19self.health_criterion = health_criterion

Keep nn.CrossEntropyLoss() — applied in stage 3 of the inner loop (line 46). Also no learnable parameters; same reasoning.

20self.gaba_loss = gaba_loss.to(device)

Move GABALoss buffers (ema_weights, step_count) to GPU. Necessary so the autograd graph stays on one device — mixing CPU and GPU tensors inside backward() crashes.

EXECUTION STATE
→ why .to(device) here but not for criteria? = MSELoss / CrossEntropyLoss compute on whatever device their inputs live on (no internal state). GABALoss carries persistent buffers (the EMA), so those buffers must live on the same device as the model.
21self.optimizer = optimizer

Store the AdamW instance. AdamW combines Adam's adaptive per-parameter learning rates with decoupled weight decay (Loshchilov & Hutter 2019).

EXECUTION STATE
📚 AdamW vs Adam = Both use first/second moment estimates (m, v) for adaptive lr. Difference: Adam folds weight decay into the gradient (effectively L2 regularisation that gets warped by the adaptive scale). AdamW applies decay directly to parameters — decoupled — equivalent to plain L2 in the no-momentum case but more well-behaved with adaptive lr.
22self.scheduler = scheduler

Store the ReduceLROnPlateau scheduler. Monitors test-RMSE (registered in fit()), waits patience=30 epochs without improvement, then multiplies lr by factor=0.5. Continues halving until min_lr=5e-6.

23self.device = device

Cache the device handle for cheap access inside the per-batch loop. Reading a Python attribute is faster than calling torch.device() every iteration.

24cfg = config or {}

Defensive default. If the caller passed config=None (or any falsy value), substitute an empty dict so the .get() calls below all hit their defaults instead of raising AttributeError.

EXECUTION STATE
📚 truthy 'or' fallback = Python's 'or' short-circuits and returns the first truthy operand. None or {} → {}; {} or {} → {}; {'epochs': 500} or {} → {'epochs': 500}.
25self.epochs = cfg.get("epochs", 500)

Maximum epochs. Paper uses 500 with early stopping (typically converges in 80–150 epochs).

EXECUTION STATE
📚 dict.get(key, default) = Safe lookup: returns default if key absent instead of raising KeyError. Idiomatic for config dicts.
26self.patience = cfg.get("patience", 80)

Early-stopping patience. If test-RMSE does not improve for 80 consecutive epochs, training halts. Paper FD002/FD004 typically stop around epoch 130–180.

27self.grad_clip = cfg.get("grad_clip", 1.0)

Global L2 gradient norm cap. Set to 1.0 — the standard value for RNN-based RUL models. Disabled (skipped) if set to 0.

28self.warmup_epochs = cfg.get("warmup_epochs", 10)

Linear lr-warmup duration. First 10 epochs ramp lr from 0.1·base to 1.0·base. Distinct from GABA's 100-step warmup, which runs at the BATCH level inside _train_epoch.

EXECUTION STATE
→ two warmups, two scales = lr-warmup = 10 EPOCHS (epoch-level, helps optimiser converge). GABA warmup = 100 BATCHES (step-level, lets gradient norms stabilise). Independent; each protects something different.
29self.base_lr = cfg.get("lr", 1e-3)

Target learning rate after warmup ends. Used inside fit() to set the per-epoch lr_scale for the linear warmup ramp.

30self.ema = ExponentialMovingAverage(model, decay=cfg.get("ema_decay", 0.999))

Polyak averaging of MODEL parameters (different from the GABA loss-weight EMA). Maintains a shadow copy that is averaged in at decay=0.999 every step. The shadow — not the live model — is what the evaluator scores between epochs.

EXECUTION STATE
decay = 0.999 = Each step: shadow = 0.999·shadow + 0.001·live. Time constant 1/(1−d) = 1000 steps. Smooths the noise that GABA-induced weight changes inject into the live trajectory.
→ two EMAs, do not confuse = EMA #1: gaba_loss.ema_weights (decay 0.99) — smooths the LOSS WEIGHTS λ. EMA #2: self.ema (decay 0.999) — smooths the MODEL PARAMETERS θ. Both are convex combinations; they operate on entirely different state.
31self.evaluator = Evaluator(device)

Wraps the test-set forward pass + metric computation (RMSE, NASA score, last-cycle metrics, health accuracy). Held by the trainer so fit() can call it between epochs.

33def _train_epoch(self, loader):

ONE epoch of training. This is the function that runs the 8-stage inner-loop pipeline once per batch. Called from fit() inside the epoch loop. Underscore-prefixed = conventional 'protected' — internal API.

EXECUTION STATE
⬇ self = Trainer instance — provides self.model, self.optimizer, self.gaba_loss, self.device, etc.
⬇ loader = DataLoader over MTLDatasetWrapper. Yields tuples (seq, rul_tgt, health_tgt, uid) of batch_size=256.
⬆ returns = None (in this excerpt). Real implementation returns aggregate per-epoch metrics.
34self.model.train()

Switch the model to TRAINING mode. Activates dropout layers and BatchNorm running-mean updates. Forgetting this is the #1 silent training bug in PyTorch — the model will use eval-mode dropout (i.e., no dropout) and frozen BN stats.

EXECUTION STATE
📚 nn.Module.train(mode=True) = Recursively sets self.training=True on every submodule. Counterpart: .eval() flips it to False. Some layers (Dropout, BatchNorm, LayerNorm in some configs) check self.training to decide their behaviour.
35shared_params = list(self.model.get_shared_params())

Materialise the shared-backbone parameter list ONCE per epoch. Excludes rul_head and health_head — only parameters that BOTH tasks share. The gradient norm GABA cares about is the conflict on the SHARED trunk; per-head gradients have no balance problem.

EXECUTION STATE
📚 list(generator) = Materialise a generator into a list. Required because get_shared_params() is a generator and we need to iterate it many times (every batch). Calling list() once per epoch amortises the materialisation cost.
shared_params content = Every conv1d, BN, BiLSTM, MultiheadAttention, layer-norm, fc1, fc2, fc3 parameter on the shared trunk. NO rul_head or health_head parameters.
→ why exclude heads? = Each head sees only its own task's gradient — there is no 'balance' problem there. Balance is only meaningful where the two task gradients MEET, which is exactly the shared trunk.
37for batch in loader:

Iterate batches. PyTorch's DataLoader handles parallel CPU loading + pin_memory transfer to GPU under the hood. Each batch is a 4-tuple of pre-collated tensors.

38seq, rul_tgt, health_tgt, uid = batch

Tuple-unpack the 4-tuple yielded by MTLDatasetWrapper. uid (unit id) is unused during training — it's there only for per-unit evaluation in the test set.

EXECUTION STATE
seq = torch.Tensor (B=256, L=30, F=17). Sliding windows of normalised sensor readings.
rul_tgt = torch.Tensor (B,) of clipped RUL in [0, 125].
health_tgt = torch.Tensor (B,) of int class indices in {0, 1, 2}.
uid = torch.Tensor (B,) — engine id; unused inside this method.
39seq = seq.to(self.device)

Move the sequence tensor to GPU. With pin_memory=True in the DataLoader and non_blocking=True (omitted here for clarity) this can be asynchronous — the next batch's H2D transfer overlaps with the current batch's compute.

EXECUTION STATE
📚 Tensor.to(device) = Returns a tensor on the target device. If already there, returns self (no-op). Asynchronous when src is pinned memory + non_blocking=True.
40rul_tgt = rul_tgt.to(self.device).view(-1, 1)

Move target to GPU AND reshape (B,) → (B, 1). The model's rul_pred has shape (B, 1); MSELoss requires matching shapes or it broadcasts (which silently changes the loss value).

EXECUTION STATE
📚 .view(-1, 1) = Reshape, keeping memory layout. The -1 means 'infer this dim from the total element count'. (B,) → (B, 1) — same data, new shape header. Cheap (no copy) when the source is contiguous.
→ matching shapes matter = If rul_pred is (B, 1) and rul_tgt is (B,), MSELoss broadcasts to (B, B) and silently averages over B² elements instead of B. Reshape avoids the bug.
41health_tgt = health_tgt.to(self.device)

Move the class-index target to GPU. CrossEntropyLoss expects (B,) of int — no reshape needed.

43self.optimizer.zero_grad()

Clear .grad on every parameter from the previous step. PyTorch ACCUMULATES gradients into .grad by default (a feature for grad-accumulation patterns); for plain training we MUST zero them every step or gradients from previous batches will leak in.

EXECUTION STATE
📚 optimizer.zero_grad(set_to_none=True) = Walks the optimiser's param_groups and sets each .grad to None (default in PyTorch ≥ 1.7) or zero. set_to_none=True is slightly faster — backward() will allocate fresh storage on first write rather than zero-then-add.
44rul_pred, health_logits = self.model(seq) ← STAGE 1 (forward)

STAGE 1 — forward pass. Single call hits the shared backbone once and returns BOTH heads' outputs simultaneously. The autograd graph from this point on records both heads' computation history; calling .backward() on either loss later will propagate through the trunk for free.

EXECUTION STATE
📚 model(seq) calls __call__ = nn.Module overrides __call__ to invoke forward() with hooks (forward pre/post hooks, JIT tracing, etc.). Always call model(x) — never model.forward(x) directly — or you bypass the hooks.
rul_pred = torch.Tensor (B, 1) — predicted RUL clamped ≥ 0.
health_logits = torch.Tensor (B, 3) — pre-softmax class scores.
45rul_loss = self.rul_criterion(rul_pred, rul_tgt) ← STAGE 2

STAGE 2 — standard MSE. Returns a scalar tensor with grad_fn pointing back through the RUL head + the shared backbone. Plain MSE — no per-sample weighting — is what makes the chapter's headline result attributable to GABA alone.

EXECUTION STATE
📚 nn.MSELoss.__call__ = Computes mean((pred − tgt)²). Returns a 0-D tensor with grad_fn=MseLossBackward. The grad_fn pointer is what enables .backward() to walk the graph in reverse.
rul_loss type = torch.Tensor (0-D) with grad_fn — NOT a Python float.
46health_loss = self.health_criterion(health_logits, health_tgt) ← STAGE 3

STAGE 3 — cross-entropy. Returns a scalar tensor with grad_fn pointing back through the health head + the shared backbone. The two losses share an autograd subgraph (the trunk).

EXECUTION STATE
📚 nn.CrossEntropyLoss.__call__ = Internally: log_softmax(logits, dim=1) → gather(target) → mean(−). Numerically stable. Targets must be int (LongTensor). Logits are raw — DO NOT pass softmax(logits) yourself or you'll double-softmax.
48# CRITICAL FIX: pass shared_params so GABA can compute gradient norms

This comment is the actual paper's scar tissue. An earlier version of run_normalization_ablation.py omitted shared_params on the next line — GABALoss then took the warmup branch forever (because shared_params was None) and silently produced uniform 0.5/0.5 weights. Six experiments (FD002 ×5 seeds + FD004 seed 42) were corrupted before this was caught and re-run.

49loss = self.gaba_loss(rul_loss, health_loss, shared_params=shared_params) ← STAGE 4

STAGE 4 — GABA combine. GABALoss.forward computes per-task gradient norms via torch.autograd.grad(rul_loss, shared_params, retain_graph=True), runs the closed-form/EMA/floor pipeline, then returns w_rul·rul_loss + w_h·health_loss. The returned tensor has a grad_fn that, when .backward() is called, propagates through BOTH original losses with the GABA-chosen weights baked in.

EXECUTION STATE
📚 torch.autograd.grad(loss, params, retain_graph=True) = Computes ∂loss/∂params explicitly (returns a tuple of tensors, NOT populating .grad). retain_graph=True keeps the graph alive so we can do it again for the OTHER loss without re-running forward. This is the secret sauce that lets GABA measure both gradient norms before backward.
⬇ kwarg: shared_params = List of every shared-backbone parameter from line 35. Passing it activates the active branch of GABALoss.forward (line 86 of fix_gaba_norm_ablation.py); omitting it falls back to uniform 0.5/0.5 — the 6-experiment bug.
⬆ loss type = torch.Tensor (0-D) with grad_fn pointing back through BOTH rul_loss and health_loss with the GABA weights baked in.
50loss.backward() ← STAGE 5

STAGE 5 — backward. Walk the autograd graph from the single combined scalar back through GABA → both heads → shared backbone. Populates .grad on every learnable parameter. ONE backward() call covers BOTH tasks because GABA produced a single combined scalar.

EXECUTION STATE
📚 Tensor.backward() = PyTorch's reverse-mode autodiff entry point. Walks the dynamic graph, applies chain rule, accumulates ∂loss/∂θ into θ.grad on each leaf tensor. Default: assumes the tensor is a scalar (0-D); for non-scalar use .backward(grad_tensor).
→ after this call = Every parameter p in model.parameters() has p.grad populated (a tensor of the same shape as p). Optimiser.step() will read these next.
52if self.grad_clip > 0:

Skip clipping entirely if disabled (grad_clip=0). Standard pattern — lets you ablate clipping by setting one config field instead of editing the trainer.

53torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) ← STAGE 6

STAGE 6 — global L2 norm cap. Computes total_norm = sqrt(Σ_p ‖g_p‖²) over every parameter, then if total_norm > max_norm scales every g_p by max_norm / total_norm. In place.

EXECUTION STATE
📚 clip_grad_norm_(params, max_norm) = PyTorch utility for global L2 norm clipping. Trailing underscore = in-place modifier convention. Returns the pre-clip total norm (often logged for monitoring).
⬇ arg 1: self.model.parameters() = Generator over every learnable parameter of the model. Including the heads — clipping is global, after GABA has done its job.
⬇ arg 2: self.grad_clip = 1.0 = Cap. With cap=1.0 the optimiser step is at most 1·lr in parameter space — bounded, can't blow up after a noisy batch.
54self.optimizer.step() ← STAGE 7

STAGE 7 — AdamW step. Reads .grad on every parameter, updates the (m, v) moment estimates, applies the decoupled weight-decay term, then mutates parameters in place. After this call the model has new weights.

EXECUTION STATE
📚 optimizer.step() = Walks param_groups; for each parameter applies the optimiser-specific update rule. AdamW: m ← β₁·m + (1−β₁)·g; v ← β₂·v + (1−β₂)·g²; θ ← θ − lr·(m̂/√v̂ + λ·θ). The λ·θ term is the decoupled weight decay.
55self.ema.update(self.model) ← STAGE 8

STAGE 8 — EMA shadow update. For each parameter: shadow = 0.999·shadow + 0.001·live. The live model continues training; the shadow is the smoothed checkpoint that the evaluator scores between epochs.

EXECUTION STATE
📚 ExponentialMovingAverage.update(model) = Iterates model.parameters() in lockstep with the internal shadow dict. shadow[name] ← decay·shadow[name] + (1−decay)·live. No autograd — these are plain tensor operations.
ema.shadow vs ema.live = Two parameter dicts. Optimiser steps live; ema.update averages live INTO shadow; the evaluator temporarily swaps shadow into the model for scoring, then restores live.
→ why bother? = Polyak averaging knocks roughly 0.5 RMSE off the test score on FD002/FD004 — a quiet but consistent contributor to almost every reported number in this book.
14 lines without explanation
1"""GABATrainer — paper_ieee_tii/experiments/fix_gaba_norm_ablation.py:125-218
2
3The inner-loop pipeline (8 stages) of GABA + standard MSE training.
4Includes the CRITICAL fix that distinguishes correct GABA from a
5silent-fallback baseline: shared_params must be passed to gaba_loss().
6"""
7
8import copy
9import torch
10import torch.nn as nn
11import torch.optim as optim
12
13
14class GABATrainer:
15    def __init__(self, model, rul_criterion, health_criterion, gaba_loss,
16                 optimizer, scheduler, device, config):
17        self.model = model.to(device)
18        self.rul_criterion = rul_criterion
19        self.health_criterion = health_criterion
20        self.gaba_loss = gaba_loss.to(device)
21        self.optimizer = optimizer
22        self.scheduler = scheduler
23        self.device = device
24        cfg = config or {}
25        self.epochs = cfg.get("epochs", 500)
26        self.patience = cfg.get("patience", 80)
27        self.grad_clip = cfg.get("grad_clip", 1.0)
28        self.warmup_epochs = cfg.get("warmup_epochs", 10)
29        self.base_lr = cfg.get("lr", 1e-3)
30        self.ema = ExponentialMovingAverage(model, decay=cfg.get("ema_decay", 0.999))
31        self.evaluator = Evaluator(device)
32
33    def _train_epoch(self, loader):
34        self.model.train()
35        shared_params = list(self.model.get_shared_params())
36
37        for batch in loader:
38            seq, rul_tgt, health_tgt, uid = batch
39            seq = seq.to(self.device)
40            rul_tgt = rul_tgt.to(self.device).view(-1, 1)
41            health_tgt = health_tgt.to(self.device)
42
43            self.optimizer.zero_grad()
44            rul_pred, health_logits = self.model(seq)
45            rul_loss = self.rul_criterion(rul_pred, rul_tgt)
46            health_loss = self.health_criterion(health_logits, health_tgt)
47
48            # CRITICAL FIX: pass shared_params so GABA can compute gradient norms
49            loss = self.gaba_loss(rul_loss, health_loss, shared_params=shared_params)
50            loss.backward()
51
52            if self.grad_clip > 0:
53                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
54            self.optimizer.step()
55            self.ema.update(self.model)
The grand total of code that distinguishes GABA + standard MSE training from a plain single-task baseline is exactly two lines: line 34 (gather shared_params) and line 48 (call the GABA combiner with that list). Everything else is what any modern PyTorch trainer already does.

Hyperparameters That Stay Fixed

For every dataset (FD001/FD002/FD003/FD004) and every random seed reported in the paper, the trainer uses the configuration below. The only thing that changes between runs is the random seed.

HyperparameterValueSourceWhat it controls
batch_size256DataLoaderSequences per gradient step
sequence_length30CMAPSSDatasetSliding-window history per sample
epochs (max)500configUpper bound; early stop usually triggers earlier
patience80configEpochs without test-RMSE improvement before stop
lr (base)1e-3AdamWTarget lr after warmup
weight_decay1e-4AdamWDecoupled L2 regularisation strength
warmup_epochs10configEpochs of linear lr ramp 0.1× → 1.0× base
schedulerReduceLROnPlateau(0.5, 30, 5e-6)torch.optim.lr_schedulerHalve lr after 30 stalled epochs
grad_clip1.0configGlobal L2 norm cap on gradients
ema_decay0.999configPolyak averaging time constant on weights
GABA β0.99GABALossEMA on task weights (time const ≈ 100 steps)
GABA warmup_steps100GABALossBatches of fixed 0.5/0.5 before adaptive
GABA min_weight0.05GABALossFloor on each task’s smoothed weight
Do not retune these per dataset. The headline result of this chapter is that the same hyperparameters give state-of-the-art NASA score on FD002/FD004 without any per-dataset adjustment — that is what makes the ”single λ that converges within 10 epochs” claim non-trivial.

The Same Pattern In Other Domains

The eight-stage structure above is not specific to RUL. Anywhere two predictive tasks share a backbone and one task’s gradient dwarfs the other’s, the same trainer scaffolding applies — only the loss functions and the data loader change.

  • Adaptive chemotherapy dosing. Predict (a) tumour-size trajectory (regression, large gradient — large absolute targets) and (b) toxicity-event class (3-class classification, small gradient). GABA + MSE keeps the toxicity head from being drowned out by the regression task. Same trainer; different criteria.
  • Self-driving stack. Joint depth regression + semantic segmentation on a shared CNN trunk. Depth has a wide pixel-value range (large MSE gradient); seg has per-pixel CE (small gradient). Identical pipeline.
  • Algorithmic trading risk models. Predict (a) absolute return regression and (b) regime label classification. Returns can swing 10×–1000× larger than label-prob deltas. The same eight-stage trainer scaffolding works directly.
  • Climate downscaling. Joint precipitation regression + extreme-event classification on a shared U-Net backbone. The classification gradient vanishes relative to MSE on continuous mm/day output without a balancer like GABA.

In every case, the only thing that changes is what flows in (data loader) and what each head outputs. The eight stages — forward, two losses, GABA combine, backward, clip, step, EMA — are domain-agnostic.

Pitfalls (One Was The Actual Paper Bug)

Pitfall 1 — forgetting shared_params (the real bug). In run_normalization_ablation.py, line 185 called self.mtl_loss(rul_loss, health_loss) without the kwarg. GABALoss saw shared_params=None, took the warmup branch forever, and silently produced uniform 0.5/0.5 weights. Six experiments (FD002 ×5 seeds + FD004 seed 42) were corrupted before this was caught. fix_gaba_norm_ablation.py exists solely to re-run them with the kwarg in place.
Pitfall 2 — including head parameters in shared_params. If you accidentally include rul_head or health_head in the shared list, the gradient norms get inflated by per-task content that has nothing to do with the shared trunk and the balancer behaviour collapses. GABAModel.get_shared_params explicitly skips any name starting with rul_head or health_head for this reason.
Pitfall 3 — calling .train() inside the eval block. The trainer pattern is train epoch → ema.apply_shadow → evaluator.evaluate → ema.restore. If you forget the apply_shadow / restore dance, you evaluate the noisy live weights and report ~0.5 RMSE worse than the paper. EMA with decay=0.999 has been a quiet contributor to almost every reported number in this book.
Pitfall 4 — registering GABA buffers but not its parameters into the optimiser. Line 264 of the paper code uses all_params = list(model.parameters()) + list(gaba_loss.parameters()). GABALoss has no learnable parameters in this implementation, but its buffers (ema_weights, step_count) must travel with the optimiser’s state for clean checkpointing. If you only optimise model.parameters() you cannot resume training mid-run reproducibly.

Takeaway

One Sentence

GABA + Standard MSE training is a plain eight-stage PyTorch trainer where exactly one line — the GABA combine call with shared_params — distinguishes it from a fixed-weight baseline; everything else is forward, two losses, backward, clip, step, EMA.

What To Remember

  • The pipeline is exactly 8 stages per batch. Stages 1–3 measure; stage 4 balances; stages 5–8 commit.
  • Standard MSE (no sample weighting) is a deliberate choice: it makes the chapter’s headline gain attributable to the balancer, not to the loss surface.
  • The two warmups operate on different scales: lr-warmup is 10 epochs at the epoch level; GABA warmup is 100 batches at the step level. They are independent and they protect different things.
  • The single critical glue line is self.gaba_loss(rul_loss, health_loss, shared_params=shared_params). Forget the kwarg → silent fallback to baseline behaviour → corrupted experiments.
  • The next section watches the weights (λrul,λh)(\lambda_{rul}, \lambda_{h}) converge over 200 epochs of a real C-MAPSS run.
Loading comments...