Walk into a Toyota plant in Aichi Prefecture and you find an assembly line of about 80 stations. Each station does one thing — bolt the windshield, attach the door, route the wiring harness — and the car moves between them on a fixed cadence. The brilliance of the line is not in any one station; it is in the precise ordering and the cadence at which work moves. Skip a station, swap two, or break the cadence and the car comes out wrong.
GRACE training is the same. Every step, fourteen operations execute in a fixed order — some firing on every mini-batch (forward, GABA controller, backward, AdamW), some firing once per epoch (warmup, weight-decay schedule, evaluator, scheduler, early stopping, checkpointer). This section is a guided tour of the line: what each station does, where it lives in the paper repo, and how the whole thing produces the published 7.72 RMSE on FD002 in roughly 30 minutes on a single GPU.
The headline. GRACE's training loop is the paper's UnifiedTrainer.fit(), 291 lines of Python at grace/training/trainer.py. It composes ten reusable callbacks — LR warmup, AdamW, ReduceLROnPlateau, EMA, gradient clipping, early stopping, checkpointer, gradient logger, plus the GABA controller and the WMSE inner loss — in a strict per-batch / per-epoch sequence.
Anatomy: 14 Stages, Two Cadences
Every stage of the pipeline runs at one of two cadences:
Per-batch — fires inside the inner loop of _train_epoch(), once for each mini-batch. Eight stages: data fetch, forward pass, inner losses, OUTER GABA controller, backward, gradient clip, AdamW step, EMA shadow update.
Per-epoch — fires inside fit(), once at the end of a full pass over the training set. Six stages: linear LR warmup (only the first 10 epochs), GRACE adaptive-weight-decay schedule, evaluator (with EMA-shadow weights), ReduceLROnPlateau scheduler step, early stopping check, best-checkpoint save.
The two cadences interact. Per-batch stages mutate the model; per-epoch stages decide whether to stop, whether to halve the learning rate, whether to save a checkpoint — based on what the per-batch stages have produced.
Why this split matters. Confusing the two cadences is the most common source of bugs in MTL training loops. Calling scheduler.step() per batch instead of per epoch decays the learning rate ~70× faster than intended on FD002. Updating the EMA shadow only at end-of-epoch instead of every batch loses ~0.3 RMSE.
Interactive: The Pipeline End-To-End
Click any stage to read what it computes, where it lives in the paper repo (grace/training/trainer.py et al.), and the production hyperparameter values. The two columns separate the per-batch cadence (top) from the per-epoch cadence (bottom).
Loading training pipeline diagram…
Read the diagram top to bottom. The per-batch cadence runs many times before any per-epoch stage fires. By the end of an epoch, the model has typically seen 70–110 batches; only then does the evaluator pull EMA shadow weights, compute rmse_last on the test set, and feed that scalar into ReduceLROnPlateau, EarlyStopping, and Checkpointer.
One Mini-Batch Step, Walked Through
A single mini-batch executes the eight per-batch stages in order. Below is what happens to a 4-sample batch with yrul=(20,80,5,100) and a partly-trained model.
Stage
Operation
Input
Output
1. Forward
model(seq)
seq (4, 30, 14)
y_pred (4,), hp_logits (4, 3)
2a. Inner WMSE
moderate_weighted_mse_loss
y_pred, y_rul
L_rul = 45.21 (scalar)
2b. Inner CE
F.cross_entropy
hp_logits, hp_target
L_health = 0.8999 (scalar)
3. GABA controller
gaba(L_rul, L_health, shared_params)
g_rul=41.5, g_h=0.082
λ* = (0.0477, 0.9523), L = 3.013
4. Backward
loss.backward()
L_total scalar
θ.grad populated for ~1.7M params
5. Clip
clip_grad_norm_(.., max_norm=1.0)
global ‖∇‖ ≈ 2.4
rescaled to ‖∇‖ = 1.0
6. AdamW step
optimizer.step()
lr=1e-3, β=(0.9, 0.999), wd=5e-5 (post-100)
θ ← θ - lr·m̂/(√v̂+ε) - lr·wd·θ
7. EMA shadow
ema.update(model)
α=0.999, current θ
shadow ← 0.999·shadow + 0.001·θ
The per-batch sequence is what the paper's _train_epoch() loops over. Every tuple of values in that table comes from the paper's production code; the Python script below reproduces stages 2–7 in pure NumPy for a single scalar parameter.
One Epoch Boundary, Walked Through
End of epoch e: the per-batch loop has finished. Now the per-epoch operations run, in this order:
Warmup gate. If e<W (default W=10), set η=η0⋅(0.1+0.9e/W). Otherwise leave η alone — the scheduler owns it from here.
Adaptive weight decay (GRACE-only). If e>100, set wd=wd0⋅0.5; if e>200, wd=wd0⋅0.1. Late training benefits from less regularisation as the loss landscape flattens.
EMA-aware evaluation.ema.apply_shadow(model) swaps the live weights for shadow weights; the evaluator runs the full test loop;ema.restore(model) swaps them back. The reported rmse_last, NASA score, health accuracy — all from the shadow.
Scheduler step.scheduler.step(rmse_last). If rmse_last hasn't improved for 30 epochs, halve the LR. Gated by e≥W so the scheduler doesn't fight warmup.
Track best. If rmse_last < best_rmse, update best_rmse, save the model state and the EMA shadow, persist a checkpoint.
Early stopping. If no improvement for P=80 consecutive epochs, restore best weights and break out of the epoch loop.
Gradient log flush.grad_logger.log_epoch() records the GABA stats —grul,ghealth,λiraw,λi∗ — for the per-dataset CSV that chapter 21 §3 reads from.
Python: A Minimal fit() Loop From Scratch
Every per-batch stage of the production pipeline, by hand. One scalar parameter, four samples, no PyTorch — just enough algebra to see what every callback does.
One training step, every callback, in NumPy
🐍grace_minimal_step.py
Explanation(59)
Code(102)
1docstring
Names the contract: a complete one-step training loop in NumPy. Every callback the paper's UnifiedTrainer wires up (warmup, GABA, clip, AdamW, EMA) appears here as a few lines of explicit math.
8import numpy as np
NumPy is Python's numerical-array library. It provides ndarray (N-dimensional array) plus the math operations on it. Every algebraic step in this script — element-wise +, -, *, **, plus np.clip / np.maximum / np.sqrt / .sum / .mean — runs as optimised C under the hood, not slow Python loops.
EXECUTION STATE
numpy = Library for numerical computing — ndarray (homogeneous N-D array), broadcasting rules, linear algebra, RNG, and a large math library. Pre-installed in every Anaconda / Colab environment.
as np = Alias. Universal Python convention so you write np.array(), np.maximum(), etc. instead of numpy.array() everywhere.
→ why NumPy here? = GRACE's production code uses PyTorch tensors. Switching to NumPy strips away autograd and GPU plumbing so every callback (warmup, GABA, clip, AdamW, EMA) is exposed as a few lines of explicit math.
10rng = np.random.default_rng(42)
Modern NumPy random API. Pinned seed so the random feature x is reproducible across runs.
Four ground-truth RUL targets. Two near-failure (5, 20), one mid-life (80), one early-life (100). The mix lets w(y) range across [1.20, 1.96] — exercising every branch of the failure-bias ramp.
EXECUTION STATE
📚 np.array(object, dtype) = NumPy constructor: build an ndarray from any Python iterable. Stores values in a contiguous C buffer (not a Python list), enabling vectorised math.
⬇ arg 1: [20., 80., 5., 100.] = Plain Python list of floats. The trailing dot (20. vs 20) marks them as float literals — keeps everyone in float-land.
⬇ arg 2: dtype=float = Force float64 (8 bytes per element). Without this, NumPy could infer int64 if all values were whole numbers, breaking later divides like y/125.
Health labels matching y_rul thresholds: y<30→2 (critical), y<70→1 (degrading), else 0 (healthy). Held in this demo as a static input — the toy model only updates theta_rul.
EXECUTION STATE
hp_target (4,) = [2, 1, 2, 0]
16x = rng.normal(0, 1, size=(4,))
Random per-sample features. One scalar per sample — a stand-in for what the CNN-BiLSTM-Attention backbone would produce on real C-MAPSS sensor windows.
EXECUTION STATE
📚 rng.normal(loc, scale, size) = NumPy Generator method: draw samples from N(loc, scale²). Cleaner replacement for legacy np.random.normal — uses the modern PCG64 bit generator and respects the seeded `rng` from line 10.
⬇ arg 1: loc = 0 = Mean of the Gaussian. We center features at 0 so the toy linear model y_pred = θ·x·100 + 50 has predictions centred at 50.
⬇ arg 2: scale = 1 = Standard deviation. σ=1 → ~68% of features land in [-1, 1], ~95% in [-2, 2].
⬇ arg 3: size = (4,) = Output shape. (4,) is a 1-D array of 4 elements (a vector). For a matrix you would pass (rows, cols), e.g. size=(256, 14).
→ why a tuple? = size accepts an int or a tuple of ints. (4,) is the 1-element tuple syntax — the trailing comma distinguishes it from the integer (4) wrapped in parens.
⬆ result: x (4,) (illustrative) = [+0.30, -1.04, +0.75, +0.94] — exact values depend on seed 42.
20theta_rul = 0.5
Single learnable scalar. Stand-in for the 1.7M-parameter backbone. We track its update path through the pipeline.
21theta_h = 0.5
Health-head weight. Defined for symmetry but not updated in this demo (we hold L_health constant).
22m_rul, v_rul = 0.0, 0.0
AdamW state: first and second moment estimates. Zero-initialised, then bias-corrected via the (1 - β^t) divisor on every step.
EXECUTION STATE
m_rul = First moment (gradient mean). 0 at step 0.
v_rul = Second moment (squared-gradient mean). 0 at step 0.
23shadow_rul = theta_rul
EMA shadow buffer. Initialised to the same value as theta_rul (ExponentialMovingAverage.__init__ in callbacks.py:62).
24step_count = 0
AdamW timestep counter. Used in the bias-correction divisor (1 - β^t).
26base_lr = 1e-3
Target learning rate after warmup. Paper default for C-MAPSS.
27warmup_epochs = 10
Number of epochs over which to linearly ramp the learning rate from 0.1·base_lr to base_lr.
28beta1, beta2 = 0.9, 0.999
AdamW momentum coefficients. β₁=0.9 weights the first moment (mean), β₂=0.999 weights the second moment (uncentered variance).
EXECUTION STATE
β₁ = 0.9 = Effective memory of ~10 steps for the gradient mean.
β₂ = 0.999 = Effective memory of ~1000 steps for the squared gradient.
29eps_adam = 1e-8
Numerical safety floor in AdamW. Added to sqrt(v_hat) before division to avoid 0/0 when a parameter has never seen a gradient.
30weight_decay = 1e-4
AdamW weight decay coefficient. Decoupled from the gradient — applied as a separate -lr·wd·θ shrink. GRACE schedules this down at epoch 100/200.
31ema_alpha = 0.999
EMA decay for the shadow weights. Effective memory of 1000 steps. Eval at end-of-epoch swaps in shadow_rul.
34def lr_at(epoch) → float
Returns the learning rate for a given epoch. Implements GRACE's linear warmup callback — mirrors the production grace/training/callbacks.py:LRWarmup.get_lr (~5 lines of Python). Called once per epoch by UnifiedTrainer.fit before the per-batch loop runs.
EXECUTION STATE
⬇ input: epoch = Integer epoch index, starting at 0. Trainer maintains it in trainer.py and increments at end-of-epoch.
→ epoch purpose = Decides whether we are inside the warmup window (epoch < W) or past it. The whole branch logic on line 36 hinges on this scalar.
→ closes over = Reads two module-level globals: warmup_epochs (10) and base_lr (1e-3). Kept as globals so the demo stays a single flat script.
⬆ returns float = The learning rate for this epoch. 1e-4 at epoch 0, 5.5e-4 at epoch 5, 1e-3 from epoch 10 onward.
36if epoch < warmup_epochs:
Branch on warmup phase. Outside warmup, return the constant base_lr.
Linear ramp. epoch=0 → 0.1·base_lr; epoch=W → base_lr.
EXECUTION STATE
epoch=0 = lr = 1e-3 * (0.1 + 0.0) = 1e-4 — gentle start, no large initial step.
epoch=5 = lr = 1e-3 * (0.1 + 0.45) = 5.5e-4.
epoch=10 = lr = 1e-3 * (0.1 + 0.9) = 1e-3 — full base learning rate.
38return base_lr
Post-warmup: scheduler takes over (ReduceLROnPlateau).
41def w_failure(y, max_rul=125.0):
Failure-biased per-sample weight. Same closed form as core/weighted_mse.py:moderate_weighted_mse_loss.
42return 1.0 + np.clip(1.0 - y / max_rul, 0.0, 1.0)
Failure-bias ramp: weight is 2.0 at y=0 (failure imminent) and decays linearly to 1.0 at y ≥ 125 (healthy). Forces the optimiser to spend more error budget on near-failure samples — exactly where wrong predictions cost a real plant.
EXECUTION STATE
📚 np.clip(a, a_min, a_max) = NumPy function: element-wise saturate every value in `a` to the closed interval [a_min, a_max]. Anything below a_min becomes a_min; anything above a_max becomes a_max; values in range pass through.
⬇ arg 1: 1.0 - y / max_rul = The unclipped ramp. y/125 = fraction of life used. 1 - that = fraction remaining. Goes from 1.0 at y=0 down to 0 at y=125 down to NEGATIVE for y>125.
⬇ arg 2: a_min = 0.0 = Lower clip. Stops the ramp from going negative for super-healthy samples (y > 125 in raw RUL). Without it, w could fall BELOW 1.
⬇ arg 3: a_max = 1.0 = Upper clip. Caps the ramp at +1 — only effective if y were ever NEGATIVE (it shouldn't be). Belt-and-braces.
→ np.clip vs np.maximum vs np.minimum = np.clip(a, lo, hi) ≡ np.minimum(np.maximum(a, lo), hi). Single function, two-sided saturation. np.maximum is one-sided (just a floor).
GABA's closed-form weight allocator for K=2 tasks. Compressed for the demo: we skip the EMA-smoothing stage that the production controller runs in grace/core/gaba_loss.py — i.e. assume the EMA has fully converged so EMA(g) ≈ g. Returns the two outer-axis weights λ_rul, λ_health that turn (L_rul, L_health) into a single scalar L_total.
EXECUTION STATE
⬇ input: g_rul = Per-task gradient L2 norm of L_rul w.r.t. the SHARED backbone parameters. Real production value at epoch 50 on FD002 ≈ 41.5.
→ g_rul purpose = Tells GABA how loud the RUL task's gradient signal is. Bigger g_rul ⇒ smaller λ_rul (downweight a loud task).
⬇ input: g_health = Same statistic for L_health. ≈ 0.082 at the same checkpoint — about 506× smaller than g_rul.
⬇ input: eps_floor = 0.05 (default) = Floor on every λ_i. Prevents collapse — no task can be weighted at less than 5% even if its gradient dominates. Paper default.
→ why 0.05? = Empirically: 0.01 lets the small-gradient task starve; 0.10 over-weights it. 0.05 is the validated sweet spot from chapter 21 §3.
⬆ returns: (2,) ndarray = Final λ* after floor + renorm. Sums to exactly 1.0 — a proper convex combination ready for L_total = λ[0]·L_rul + λ[1]·L_health.
47S = g_rul + g_health
K=2 normaliser.
48raw = np.array([g_health / S, g_rul / S])
GABA's inverse-ratio rule: the task with the BIGGER per-task gradient norm gets the SMALLER weight. Intuition — that task is already getting a loud gradient signal, so dial it down to let the quieter task be heard.
EXECUTION STATE
📚 np.array(list) = Build a 2-element ndarray from the Python list of two scalars. Equivalent to np.array([0.001972, 0.998028]).
→ element 0: g_health / S = 0.082 / 41.582 = 0.001972 — λ_rul (small, because g_rul dominates).
→ element 1: g_rul / S = 41.5 / 41.582 = 0.998028 — λ_health (large, because g_health is tiny).
→ why the swap? = Index 0 of the result is for RUL; we put g_health/S there because RUL's OWN gradient is huge. The closed form is λ_i ∝ 1/g_i, scaled to sum to 1.
raw (2,) = [0.001972, 0.998028]
49flo = np.maximum(raw, eps_floor)
Per-element floor. Prevents either λ from collapsing to ~0 — a task with weight near 0 receives no gradient and effectively stops training.
EXECUTION STATE
📚 np.maximum(a, b) = Element-wise binary max. For each index i, returns max(a[i], b[i]). Broadcasts a scalar `b` against an array `a`. Distinct from np.max which is a REDUCTION (returns a single value).
⬇ arg 1: raw (2,) = [0.001972, 0.998028]
⬇ arg 2: eps_floor = 0.05 = Scalar. Broadcasts to compare against every element of raw.
→ element-wise trace = max(0.001972, 0.05) = 0.05 ← floor kicks in
max(0.998028, 0.05) = 0.998028 ← passes through
⬆ result: flo (2,) = [0.05, 0.998028]
50return flo / flo.sum()
Renormalise so the two λ* sum to exactly 1.0 after the floor. Without this step, applying the floor breaks the simplex constraint (sum ≠ 1) and L_total would no longer be a convex combination of the per-task losses.
EXECUTION STATE
📚 ndarray.sum(axis=None) = ndarray method: REDUCE the array to a single scalar by summing all elements. With axis= it sums along that axis. axis=None (default) sums everything.
/ (division operator) = Element-wise division. Scalar broadcasts: each element of flo gets divided by the same flo.sum() scalar.
Failure-biased MSE on the RUL task — the INNER axis of GRACE. Multiplies each squared residual by its failure-bias weight before averaging across the batch.
EXECUTION STATE
📚 ndarray.mean(axis=None) = ndarray method: arithmetic mean. With axis=None (default) reduces every element to a single scalar. Equivalent to .sum() / .size.
* (multiply operator) = Element-wise product (NOT matrix multiply — that's @). Pairs each w[i] with each residual_sq[i].
Closed-form gradient of L_total w.r.t. theta_rul. Chain rule: dL_total/dtheta_rul = λ_rul · dL_rul/dtheta_rul, since L_health does not depend on theta_rul.
EXECUTION STATE
λ_rul = lam[0] = 0.0477 = Outer-axis weight on the RUL task.
g_theta (illustrative) = On seed 42 ≈ +2.4 BEFORE clip.
77g_norm = abs(g_theta)
Scalar gradient norm. In production, clip_grad_norm_ uses the L2 norm across ALL parameters at once.
78if g_norm > 1.0:
Trigger clipping if the global norm exceeds the threshold.
79g_theta = g_theta / g_norm
Rescale so |g_theta| = 1.0. Direction preserved; magnitude clipped.
EXECUTION STATE
g_theta post-clip = Sign of original g_theta, magnitude exactly 1.0.
→ in production = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0). Same algebra, applied to the concatenated parameter gradient vector.
82step_count += 1
Advance the AdamW timestep. Needed for bias correction below.
87update = lr * m_hat / (np.sqrt(v_hat) + eps_adam) + lr * weight_decay * theta_rul
The full AdamW step in one line. Two contributions are added: (1) the bias-corrected, variance-normalised gradient step, and (2) the DECOUPLED weight-decay shrink. The decoupling — applying wd as a SEPARATE additive term, not as part of the gradient — is what differentiates AdamW from plain Adam-with-L2.
EXECUTION STATE
📚 np.sqrt(x) = NumPy element-wise square root. For scalars returns √x. For ndarrays returns sqrt of each element. Used here to convert v_hat (uncentered variance estimate) into a standard-deviation-like denominator.
Term 1: lr * m_hat / (np.sqrt(v_hat) + eps_adam) = Adaptive step. m_hat is the bias-corrected mean gradient; sqrt(v_hat) is the bias-corrected RMS gradient. Their ratio is dimensionless ≈ ±1, so the effective per-parameter step is ≈ lr regardless of the gradient's scale.
→ eps_adam = 1e-8 = Numerical floor inside the denominator. If v_hat = 0 (parameter has never seen a gradient), sqrt(v_hat)+ε = 1e-8 — finite, no divide-by-zero.
Term 2: lr * weight_decay * theta_rul = Decoupled weight-decay shrink. Pulls θ toward 0 INDEPENDENTLY of the gradient. Equivalent to applying θ ← (1 - lr·wd)·θ before the gradient step.
→ AdamW vs Adam+L2 = Adam+L2 adds wd·θ to the GRADIENT, then divides by sqrt(v_hat) — so the effective decay ends up scaled by 1/√v_hat (uneven across parameters). AdamW skips the divide for the wd term, keeping decay uniform.
→ eval semantics = End-of-epoch evaluation calls .apply_shadow (swap weights → shadow), runs the test loop, then .restore (swap back). The shadow gives smoother test metrics on partly-converged models.
1"""Minimal GRACE training loop, by hand. One epoch, four samples.
23Stitches together every stage of UnifiedTrainer.fit() — warmup,
4weighted MSE, GABA, gradient clip, AdamW, EMA — using only NumPy so
5the algebra is fully visible.
6"""78import numpy as np
910rng = np.random.default_rng(42)111213# ---------- Mini-batch (4 samples, 1 feature each) ----------14y_rul = np.array([20.,80.,5.,100.], dtype=float)15hp_target = np.array([2,1,2,0])16x = rng.normal(0,1, size=(4,))171819# ---------- Two scalar ‘parameters’: the toy model ----------20theta_rul =0.5# RUL head weight21theta_h =0.5# Health head weight (unused in update for brevity)22m_rul, v_rul =0.0,0.0# AdamW first/second moments23shadow_rul = theta_rul # EMA shadow24step_count =02526base_lr =1e-327warmup_epochs =1028beta1, beta2 =0.9,0.99929eps_adam =1e-830weight_decay =1e-431ema_alpha =0.999323334deflr_at(epoch):35"""Linear warmup: 0.1*lr_0 -> lr_0 over warmup_epochs."""36if epoch < warmup_epochs:37return base_lr *(0.1+0.9* epoch / warmup_epochs)38return base_lr
394041defw_failure(y, max_rul=125.0):42return1.0+ np.clip(1.0- y / max_rul,0.0,1.0)434445defgaba_lambda(g_rul, g_health, eps_floor=0.05):46"""K=2 closed form + floor + renorm. Skip EMA for the demo (assume converged)."""47 S = g_rul + g_health
48 raw = np.array([g_health / S, g_rul / S])49 flo = np.maximum(raw, eps_floor)50return flo / flo.sum()515253# ---------- ONE training step ----------54epoch =5055lr = lr_at(epoch)5657# 1. Forward: toy linear model y_pred = theta_rul * x * 100 + 5058y_pred = theta_rul * x *100.0+50.05960# 2. Inner losses61w = w_failure(y_rul)62residual_sq =(y_pred - y_rul)**263L_rul =(w * residual_sq).mean()64L_health =0.8999# held constant for the demo6566# 3. Outer: GABA closed form (use measured ratio from chapter 18 §1)67g_rul, g_h =41.5,0.08268lam = gaba_lambda(g_rul, g_h)69L_total = lam[0]* L_rul + lam[1]* L_health
7071# 4. Backward: scalar grad of L_total w.r.t. theta_rul72# dL_total/dtheta_rul = lam[0] * d/dtheta_rul ((1/N) sum w*(theta*100x - y_rul)^2)73# = lam[0] * (2/N) * sum( w * (y_pred - y_rul) * 100 * x )74g_theta = lam[0]*(2.0/4)*(w *(y_pred - y_rul)*100.0* x).sum()7576# 5. Gradient clip (clip if |g| > 1.0)77g_norm =abs(g_theta)78if g_norm >1.0:79 g_theta = g_theta / g_norm
8081# 6. AdamW update82step_count +=183m_rul = beta1 * m_rul +(1- beta1)* g_theta
84v_rul = beta2 * v_rul +(1- beta2)* g_theta **285m_hat = m_rul /(1- beta1 ** step_count)86v_hat = v_rul /(1- beta2 ** step_count)87update = lr * m_hat /(np.sqrt(v_hat)+ eps_adam)+ lr * weight_decay * theta_rul
88theta_rul = theta_rul - update
8990# 7. EMA shadow91shadow_rul = ema_alpha * shadow_rul +(1- ema_alpha)* theta_rul
929394print(f"epoch={epoch}, lr={lr:.4e}")95print(f"y_pred = {y_pred.round(2).tolist()}")96print(f"w(y) = {w.round(2).tolist()}")97print(f"L_rul = {L_rul:.4f} L_health = {L_health:.4f}")98print(f"lambda* = ({lam[0]:.4f}, {lam[1]:.4f})")99print(f"L_total = {L_total:.4f}")100print(f"g_theta = {g_theta:+.6f} (post-clip)")101print(f"theta_rul = {theta_rul:.6f} (post-AdamW)")102print(f"shadow_rul = {shadow_rul:.6f} (post-EMA)")
What the toy reveals. AdamW's update is two contributions: the bias-corrected gradient step and the decoupled weight-decay shrink. EMA's update barely moves the shadow on a single step (0.5 → 0.499999) but compounds: after 1000 steps the shadow has fully tracked the live weights with a ~1000-step lag, smoothing out late-training noise. Gradient clipping rescales the global vector, preserving its direction.
PyTorch: The Paper's UnifiedTrainer.fit()
The same pipeline, but now driven through UnifiedTrainer. Roughly 30 lines of caller code wires up every callback to the right cadence. The actual fit loop lives in grace/training/trainer.py; what you see below is the surface that the paper's experiments/phase1_cmapss.py calls into.
Production GRACE training in 30 lines
🐍grace_train_pipeline.py
Explanation(37)
Code(60)
1docstring
Names the contract. The train_grace function is what every Phase-1 experiment calls — see grace/experiments/phase1_cmapss.py:run_single_experiment for the production caller.
9import torch
Core PyTorch.
10import torch.nn as nn
nn.CrossEntropyLoss is the only direct nn use here; everything else lives in paper modules.
11import torch.optim as optim
AdamW + ReduceLROnPlateau scheduler.
EXECUTION STATE
📚 torch.optim = PyTorch optimisers and learning-rate schedulers. AdamW = Adam with decoupled weight decay (Loshchilov & Hutter, 2019).
12from torch.utils.data import DataLoader
Standard PyTorch mini-batch iterator. Shuffles, batches, and collates the underlying Dataset.
EXECUTION STATE
📚 DataLoader(dataset, batch_size, shuffle, num_workers) = Returns an iterable of batches. Each batch is a 4-tuple (seq, rul, health, uid) when wrapped via MTLDatasetWrapper.
14from grace.core.loss_registry import get_loss
Factory that returns any of the 9 MTL loss variants by name. We ask for ‘gaba’.
Named architecture configs: ‘cmapss’, ‘ncmapss_20feat’, ‘ncmapss_32feat’. Each returns a ModelConfig dataclass with input_size, hidden_size, etc.
Pin everything. Sets torch.manual_seed, np.random.seed, random.seed, and torch.backends.cudnn.deterministic = True.
26device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Pick the compute device. Use GPU if a CUDA-capable card is visible; otherwise fall back to CPU. The chosen device is later passed to model.to(device) and to every per-batch tensor.to(device).
EXECUTION STATE
📚 torch.device(string) = PyTorch class: lightweight handle representing a compute device. Stores type ('cuda' / 'cpu' / 'mps') and optional index. Used everywhere PyTorch needs to know WHERE a tensor lives.
📚 torch.cuda.is_available() = Probes the CUDA driver. Returns True iff a GPU + the matching CUDA driver are both installed AND PyTorch was built with CUDA support. Cheap to call; safe at module import.
→ ternary expression = Python conditional: VALUE_IF_TRUE if CONDITION else VALUE_IF_FALSE. Here returns 'cuda' or 'cpu' depending on the probe.
→ why it matters = Production runs use a single A100 / H100 — 12-second epochs on FD002. Pure CPU works but is ~50× slower (≈10 min per epoch) because the BiLSTM does not vectorise across time steps.
⬆ result: device = device(type='cuda', index=0) on a GPU box
device(type='cpu') on a laptop
Eval loader. shuffle=False — the evaluator needs unit-id ordering for last-cycle scoring.
31mc = get_model_config('cmapss')
Load the named architecture config. Returns a dataclass with input_size=14, hidden_size=256, cnn_channels=(32,64,32), num_attn_heads=4, fc_dims=(256,128), dropout=0.3.
EXECUTION STATE
mc.input_size = 14 — number of selected sensors after the paper's feature filter.
mc.dropout = 0.3 — applied after BiLSTM and FC layers.
32backbone = UnifiedBackbone(...)
Shared CNN-BiLSTM-Attention trunk. Roughly 1.7M parameters. Returns a (B, hidden_size) latent per window.
EXECUTION STATE
📚 UnifiedBackbone(input_size, hidden_size, cnn_channels, num_attn_heads, fc_dims, dropout, use_attention, use_residual) = From grace/models/backbone.py. Architecture: Conv1d ×3 → BiLSTM → MultiHeadAttention → FC. use_attention=True and use_residual=True are the paper defaults.
Instantiate the OUTER controller. Returns a GABALoss nn.Module with EMA buffer, step counter, and the four-stage closed form.
EXECUTION STATE
⬇ beta=0.99 = EMA smoothing — see chapter 21 §2 for the four-stage trace.
⬇ warmup_steps=100 = First 100 mini-batches use uniform 1/K weights, NOT the closed form. Stops the controller from over-reacting before the EMA has any history.
Build the optimiser parameter list. Concatenates the model's learnable tensors (weights, biases, BatchNorm γ/β) with any learnable parameters in the MTL loss. Critical for Uncertainty (Kendall et al.) which has learnable σ_rul, σ_health; harmless no-op for GABA which has none.
EXECUTION STATE
📚 nn.Module.parameters() = Generator that yields every learnable Tensor registered on the module (and its submodules, recursively). Yields nn.Parameter objects — Tensors with requires_grad=True that the optimiser owns.
📚 list(generator) = Python builtin: materialise a generator into a list. Needed here because parameters() yields lazily — list() forces full enumeration so the result can be concatenated with +.
+ (list concatenation) = Python list operator: returns a NEW list with all elements of the left followed by all of the right. Different from element-wise + on tensors.
→ why concatenate? = AdamW takes a single iterable of parameters. We hand it the union of model params + MTL-loss params so a single optimizer.step() updates both.
→ forgetting this is a silent footgun = If you write optim.AdamW(model.parameters()) without the mtl_loss part, the σ scalars in Uncertainty / GradNorm never receive a gradient step → algorithm degrades to fixed-σ → published numbers cannot be reproduced. See Pitfall 2 below.
Instantiate the AdamW optimiser with the paper's defaults. Decoupled weight decay (Loshchilov & Hutter, 2019). The trainer later schedules weight_decay down at epoch 100 / 200 via the adaptive_weight_decay flag.
EXECUTION STATE
📚 optim.AdamW(params, lr, betas=(0.9,0.999), eps=1e-8, weight_decay) = PyTorch optimiser. Per-parameter adaptive learning rate via 1st/2nd moment estimates with bias correction, plus DECOUPLED L2-style weight decay. Update rule: θ ← θ - lr·m̂/(√v̂+ε) - lr·wd·θ. The wd term is added independently of the gradient — that decoupling is what fixes Adam-with-L2's scale-mismatch bug.
⬇ arg 1: params = The list of learnable Tensors built on line 43. AdamW allocates one m and one v buffer per parameter (~3.4M floats of optimiser state for a 1.7M-param model).
⬇ arg 2: lr = 1e-3 = Base learning rate. Standard transformer-era default. Linearly warmed up from 1e-4 over the first 10 epochs by the trainer.
⬇ arg 3: weight_decay = 1e-4 = L2-style shrink coefficient applied OUTSIDE the gradient term. Halved at epoch 100, divided by 10 at epoch 200 by the adaptive schedule.
→ defaults left at PyTorch values = betas=(0.9, 0.999) — momentum coefficients for m and v.
eps=1e-8 — numerical floor under the sqrt(v) denominator.
We accept these so they don't need explicit passing.
→ AdamW vs Adam (one-line) = Adam: θ ← θ - lr·(m̂ + wd·θ)/(√v̂+ε) ← wd interacts with the variance scaling
AdamW: θ ← θ - lr·m̂/(√v̂+ε) - lr·wd·θ ← wd applied independently
⬆ result: optimizer = AdamW instance with .step(), .zero_grad(), .state_dict(). Holds m and v in optimizer.state[param].
Build the LR scheduler. ReduceLROnPlateau watches a metric (validation rmse_last) and halves the learning rate whenever the metric stops improving. The opening parenthesis here begins a multi-line constructor call; arguments continue on line 46.
EXECUTION STATE
📚 optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, min_lr) = PyTorch scheduler. Wraps the optimiser. Call scheduler.step(metric) ONCE PER EPOCH with the validation metric. The scheduler tracks the best value seen; if no improvement for `patience` calls, multiplies every param-group's lr by `factor`, floored at `min_lr`.
Constructor arguments for ReduceLROnPlateau, on the continuation line.
EXECUTION STATE
⬇ arg 1: optimizer = The AdamW instance from line 44. The scheduler mutates each param-group's 'lr' key in-place; AdamW reads that key on the next .step().
⬇ arg 2: mode='min' = Smaller-is-better. Use 'max' for accuracy, 'min' for loss/error metrics like RMSE. Determines whether the scheduler considers a metric improved.
⬇ arg 3: factor=0.5 = Multiplicative LR reduction. lr ← lr · 0.5 = halve. Standard for plateau schedulers; aggressive (0.1) is too jumpy on this problem.
⬇ arg 4: patience=30 = Epochs of no-improvement before triggering a reduction. 30 is loose enough to ride out a noisy plateau without early panic.
⬇ arg 5: min_lr=5e-6 = Floor. After many reductions, lr stops at 5e-6 — keeps SGD non-zero so EMA shadow can keep tracking and parameters don't freeze.
→ trace example = Initial lr=1e-3.
After 1st plateau: lr=5e-4
After 2nd plateau: lr=2.5e-4
... (continues halving every 30 stagnant epochs) ...
Floor: lr=5e-6
47)
Closing parenthesis of the ReduceLROnPlateau constructor. The scheduler instance is now bound to the local name `scheduler` and ready for per-epoch .step(rmse_last) calls.
49trainer = UnifiedTrainer(
Wire everything. UnifiedTrainer holds the model, the MTL loss, the per-sample RUL+health criteria, the optimiser, the scheduler, the device, and the config dict. fit() does the rest.
EXECUTION STATE
📚 UnifiedTrainer(model, mtl_loss, rul_criterion, health_criterion, optimizer, scheduler, device, config) = Source: grace/training/trainer.py:UnifiedTrainer (291 lines). Holds every component the loop needs and exposes a single .fit(train_loader, test_loader) entry point. Owns the per-batch / per-epoch cadence split.
→ rul_criterion = moderate_weighted_mse_loss = Paper's failure-biased MSE — passed by reference (it's a pure function, no state).
→ health_criterion = nn.CrossEntropyLoss() = 3-class CE on the health head — instantiated inline since it has no learnable parameters.
50model=model, mtl_loss=mtl_loss,
First two keyword arguments to UnifiedTrainer.__init__. The trainer stores references so .step() can call model(seq) and mtl_loss(L_rul, L_health, shared_params) inside the per-batch loop.
EXECUTION STATE
model=model = The DualTaskModel instance (~1.7M params). Trainer calls model.train() / model.eval() to flip dropout + BatchNorm modes.
mtl_loss=mtl_loss = The GABALoss instance. Trainer calls it once per batch with (L_rul, L_health, shared_params) → returns the scalar L_total ready for .backward().
51rul_criterion=moderate_weighted_mse_loss,
Per-sample RUL loss — passed as a callable, not an instance, because it is a pure function with no state.
Per-sample health loss. Standard PyTorch CE module — instantiated here (parentheses) because it is an nn.Module, not a function.
EXECUTION STATE
📚 nn.CrossEntropyLoss(weight=None, reduction='mean') = PyTorch loss module: combines log_softmax + NLLLoss in one numerically-stable op. Forward signature: criterion(logits (B, C), targets (B,)) → scalar. C=3 classes here (healthy / degrading / critical).
→ why instantiate ()? = nn.CrossEntropyLoss is a class. The () creates an instance whose .__call__ implements the loss. Compare to moderate_weighted_mse_loss above — that is already a function, so no ().
Hand the trainer the already-configured optimiser, scheduler, and device. Trainer calls optimizer.step() per batch and scheduler.step(metric) per epoch.
EXECUTION STATE
optimizer=optimizer = AdamW instance from line 44. Trainer calls .zero_grad() before each backward and .step() after gradient clipping.
scheduler=scheduler = ReduceLROnPlateau instance from line 45. Called once per epoch with the validation rmse_last.
device=device = torch.device. Used to .to(device) the model and every batch.
Training-loop hyperparameters bundled as a dict. Continued on lines 55-57. The trainer dereferences keys via config.get('key', default) for graceful fallback.
EXECUTION STATE
epochs=500 = Maximum budget. Early stopping usually fires earlier (epoch 130-230 in practice on FD002).
patience=80 = Early stopping patience. No-improvement epoch budget — if rmse_last has not improved for 80 epochs, stop.
grad_clip=1.0 = Threshold for torch.nn.utils.clip_grad_norm_(parameters, max_norm=1.0). Applied AFTER backward, BEFORE optimizer.step().
warmup_epochs=10 = Linear LR ramp budget — 1e-4 at epoch 0 climbs to 1e-3 at epoch 10, then ReduceLROnPlateau owns LR.
56'lr': 1e-3, 'adaptive_weight_decay': True,
Base learning rate + GRACE's adaptive weight-decay schedule (drops to 0.5× at epoch 100, 0.1× at epoch 200).
EXECUTION STATE
adaptive_weight_decay=True = GRACE-only addition over GABA. Trainer.fit() at trainer.py:131-134 reduces wd in three stages.
57'initial_weight_decay': 1e-4},
Starting weight-decay value before the schedule kicks in. The closing } here ends the config dict.
EXECUTION STATE
initial_weight_decay = 1e-4 = Baseline AdamW wd. Halved to 5e-5 at epoch 100; cut to 1e-5 at epoch 200. Late training benefits from less L2-style shrink as the loss landscape flattens.
58)
Closing parenthesis of the UnifiedTrainer(...) constructor call. The trainer instance is now bound to the local name `trainer` and ready to receive .fit().
60return trainer.fit(train_loader, test_loader)
Run the full 500-epoch loop. Returns a dict with per-epoch history (loss, rmse_last, NASA, HA, λ_rul, g_rul, ...) and the best-metrics tuple.
EXECUTION STATE
📚 trainer.fit(train_loader, test_loader) = trainer.py:106. For each epoch: warmup, weight-decay schedule, _train_epoch, EMA-eval, scheduler step, history append, early-stopping check. ~30 minutes on a single GPU for FD002.
⬇ arg 1: train_loader = DataLoader from line 28. Iterated each epoch — yields ~110 mini-batches of 256 samples on FD002.
⬇ arg 2: test_loader = DataLoader from line 29. Iterated end-of-epoch through the EMA-shadow weights, NOT shuffled.
1"""End-to-end GRACE training with the paper's UnifiedTrainer.
23Boots a DualTaskModel + GABA + WMSE + AdamW + ReduceLROnPlateau + EMA
4+ early stopping + checkpointer in roughly 30 lines. The actual fit
5loop lives in grace/training/trainer.py:UnifiedTrainer.fit (291 lines)
6— shown here is the surface that drives it.
7"""89import torch
10import torch.nn as nn
11import torch.optim as optim
12from torch.utils.data import DataLoader
1314from grace.core.loss_registry import get_loss
15from grace.core.weighted_mse import moderate_weighted_mse_loss
16from grace.models.backbone import UnifiedBackbone
17from grace.models.dual_task_model import DualTaskModel
18from grace.models.model_configs import get_model_config
19from grace.training.trainer import UnifiedTrainer
20from grace.training.seed_utils import set_seed
212223deftrain_grace(train_ds, test_ds, dataset_name="FD002", seed=42):24"""Single-seed GRACE training run. Returns best metrics."""25 set_seed(seed)26 device = torch.device("cuda"if torch.cuda.is_available()else"cpu")2728 train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)29 test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)3031 mc = get_model_config("cmapss")32 backbone = UnifiedBackbone(33 input_size=mc.input_size, hidden_size=mc.hidden_size,34 cnn_channels=mc.cnn_channels, num_attn_heads=mc.num_attn_heads,35 fc_dims=mc.fc_dims, dropout=mc.dropout,36 use_attention=True, use_residual=True,37)38 model = DualTaskModel(backbone, num_health_states=3, dropout=0.3)3940 mtl_loss = get_loss("gaba", beta=0.99, warmup_steps=100,41 min_weight=0.05, n_tasks=2)4243 params =list(model.parameters())+list(mtl_loss.parameters())44 optimizer = optim.AdamW(params, lr=1e-3, weight_decay=1e-4)45 scheduler = optim.lr_scheduler.ReduceLROnPlateau(46 optimizer, mode="min", factor=0.5, patience=30, min_lr=5e-6,47)4849 trainer = UnifiedTrainer(50 model=model, mtl_loss=mtl_loss,51 rul_criterion=moderate_weighted_mse_loss,52 health_criterion=nn.CrossEntropyLoss(),53 optimizer=optimizer, scheduler=scheduler, device=device,54 config={"epochs":500,"patience":80,"grad_clip":1.0,55"use_ema":True,"ema_decay":0.999,"warmup_epochs":10,56"lr":1e-3,"adaptive_weight_decay":True,57"initial_weight_decay":1e-4},58)5960return trainer.fit(train_loader, test_loader)
Order of construction matters. The MTL loss must be instantiated before the optimiser, because list(mtl_loss.parameters()) is concatenated into the optimiser's parameter list (line 42). Some MTL variants — Uncertainty, GradNorm — have learnable parameters of their own. Forgetting to include them in the optimiser silently disables learning of those scalars and produces a different algorithm than the paper's.
Production Hyperparameter Defaults
Every default below comes from grace/experiments/config.py:ExperimentConfig and is what produced the published 7.72 RMSE on FD002. Section 22·2 shows the search protocol that justifies each choice.
Hyperparameter
Value
Why this value
epochs (max budget)
500
Loose cap; early stopping fires at 130-230 in practice.
batch_size
256
GPU-fits; large enough to reduce per-batch λ noise.
lr (AdamW base)
1e-3
Standard transformer-era default; warmup ramps from 1e-4.
weight_decay (initial)
1e-4
Mild regularisation; halved at epoch 100, /10 at 200.
β₁, β₂ (AdamW)
0.9, 0.999
PyTorch defaults; not tuned per dataset.
grad_clip max_norm
1.0
Defends against BiLSTM gradient spikes.
EMA decay (ema_alpha)
0.999
1000-step memory; smooths late-training noise.
warmup_epochs
10
Linear ramp 1e-4 → 1e-3.
scheduler factor / patience
0.5 / 30
Halve LR after 30 stagnant epochs.
min_lr (scheduler floor)
5e-6
Keeps SGD non-zero in late training.
GABA β (EMA)
0.99
Outer-axis 100-step memory.
GABA min_weight (floor)
0.05
5% per task — prevents collapse.
GABA warmup_steps
100
Uniform weights for first 100 mini-batches.
WMSE max_rul
125
Standard piecewise-linear RUL cap on C-MAPSS.
EarlyStopping patience
80
No-improvement budget.
seeds (reproducibility)
[42, 123, 456, 789, 1024]
5-seed averages; published numbers are means.
The Same Pipeline In Other Domains
The 14-stage skeleton is not specific to RUL prediction. Every domain that does multi-task supervised learning on a shared backbone uses some subset of it:
Domain
What replaces WMSE
What replaces GABA?
What replaces NASA-score eval?
Self-driving perception
Focal loss for rare classes (pedestrians, motorbikes)
GradNorm or PCGrad — gradient surgery for conflicting tasks
mAP at IoU thresholds + tail-class recall
Speech recognition
CTC loss + attention CE
Uncertainty (Kendall) — homoscedastic σ scaling
Word error rate on rare-vocabulary subsets
Medical segmentation + classification
Dice + focal CE
DWA (Liu et al.) — loss-ratio tracking
Per-anatomy Dice + sensitivity at boundary regions
Recommender systems
Listwise rank loss + dwell-time MSE
Custom controller — task-priority schedule
NDCG@10 + revenue lift on cold-start cohorts
Robotics policy learning
SmoothL1 on torques + binary CE on success
PCGrad — orthogonal gradient projection
Success rate + safety violation rate
The skeleton stays. The choices for inner loss, outer controller, and evaluator change with the domain's loss geometry and risk profile. Building a new MTL system rarely needs new pipeline code — it needs new content in three slots.
Pitfalls When Wiring The Pipeline
Pitfall 1: scheduler.step() per batch instead of per epoch
ReduceLROnPlateau decays the LR every time it sees a metric. Calling scheduler.step(loss.item()) inside the per-batch loop — a common transcription error from OneCycleLR-style schedulers — multiplies the rate of decay by the number of batches per epoch (~70 on FD001). The LR drops below 1e-6 within the first <5 epochs and the model never leaves its initial neighbourhood. Symptom: training loss stuck at first-epoch value.
Pitfall 2: forgetting to include mtl_loss.parameters() in the optimiser
GABA itself has no learnable parameters (its lambdas are derived). Uncertainty, GradNorm, and AMNL-v7 do. If you write optim.AdamW(model.parameters(), ...) without appending mtl_loss.parameters(), the learnable-σ parameters never receive a gradient step. The algorithm degenerates to fixed-σ scaling and the published numbers cannot be reproduced.
Pitfall 3: evaluating without swapping in the EMA shadow
End-of-epoch evaluation should run on shadow weights, not live weights. Skipping apply_shadow() produces noisier rmse_last; downstream that fools both ReduceLROnPlateau (false plateau triggers) and EarlyStopping (false-positive stops). Symptom: rmse_last fluctuates ±2 cycles per epoch instead of monotonically descending.
Pitfall 4: in-place model modification during evaluation
apply_shadow() overwrites param.data in place. If you forget the matching restore() call, the next training batch sees shadow weights as live weights — the optimiser then computes gradients with respect to a smoothed copy of itself. Slow divergence; confusing loss curves. Always pair the two.
Pitfall 5: shuffle=True on the test loader
The evaluator computes last-cycle-per-unit NASA scores by selecting the final window of each engine unit. That selection assumes the loader yields windows in original temporal order. Setting shuffle=True on the test DataLoader scrambles the order; ‘last’ becomes random; the NASA score becomes meaningless. Always shuffle=False for the test loader.
Takeaway
GRACE training is a 14-stage assembly line driven by UnifiedTrainer.fit(): 8 stages per batch, 6 stages per epoch.
The OUTER controller (GABA) is one of those 14 stages — a single nn.Module call. Everything else is standard PyTorch: AdamW, ReduceLROnPlateau, EMA, gradient clipping, early stopping, checkpointer, gradient logger.
The per-batch / per-epoch split is non-negotiable. Confusing the cadences (most commonly scheduler.step() in the wrong place) corrupts the learning rate trajectory and the published numbers cannot be reproduced.
Production defaults: AdamW lr=1e-3, weight_decay=1e-4 (decayed to 1e-5 by epoch 200), grad_clip=1.0, EMA α=0.999, warmup=10 epochs, scheduler patience=30, early stopping=80, seeds = [42, 123, 456, 789, 1024].
The skeleton is domain-agnostic. Replace WMSE with focal / Dice / CTC, replace GABA with GradNorm / Uncertainty / DWA, replace NASA scoring with mAP / WER / Dice — and you have a working multi-task trainer for self-driving perception, speech, medical imaging, recommenders, or robot policies.