Chapter 22
15 min read
Section 87 of 121

Putting It All Together

GRACE Training Pipeline

An Assembly Line For Gradient Descent

Walk into a Toyota plant in Aichi Prefecture and you find an assembly line of about 80 stations. Each station does one thing — bolt the windshield, attach the door, route the wiring harness — and the car moves between them on a fixed cadence. The brilliance of the line is not in any one station; it is in the precise ordering and the cadence at which work moves. Skip a station, swap two, or break the cadence and the car comes out wrong.

GRACE training is the same. Every step, fourteen operations execute in a fixed order — some firing on every mini-batch (forward, GABA controller, backward, AdamW), some firing once per epoch (warmup, weight-decay schedule, evaluator, scheduler, early stopping, checkpointer). This section is a guided tour of the line: what each station does, where it lives in the paper repo, and how the whole thing produces the published 7.72 RMSE on FD002 in roughly 30 minutes on a single GPU.

The headline. GRACE's training loop is the paper's UnifiedTrainer.fit(), 291 lines of Python at grace/training/trainer.py. It composes ten reusable callbacks — LR warmup, AdamW, ReduceLROnPlateau, EMA, gradient clipping, early stopping, checkpointer, gradient logger, plus the GABA controller and the WMSE inner loss — in a strict per-batch / per-epoch sequence.

Anatomy: 14 Stages, Two Cadences

Every stage of the pipeline runs at one of two cadences:

  • Per-batch — fires inside the inner loop of _train_epoch(), once for each mini-batch. Eight stages: data fetch, forward pass, inner losses, OUTER GABA controller, backward, gradient clip, AdamW step, EMA shadow update.
  • Per-epoch — fires inside fit(), once at the end of a full pass over the training set. Six stages: linear LR warmup (only the first 10 epochs), GRACE adaptive-weight-decay schedule, evaluator (with EMA-shadow weights), ReduceLROnPlateau scheduler step, early stopping check, best-checkpoint save.

The two cadences interact. Per-batch stages mutate the model; per-epoch stages decide whether to stop, whether to halve the learning rate, whether to save a checkpoint — based on what the per-batch stages have produced.

Why this split matters. Confusing the two cadences is the most common source of bugs in MTL training loops. Calling scheduler.step() per batch instead of per epoch decays the learning rate ~70× faster than intended on FD002. Updating the EMA shadow only at end-of-epoch instead of every batch loses ~0.3 RMSE.

Interactive: The Pipeline End-To-End

Click any stage to read what it computes, where it lives in the paper repo (grace/training/trainer.py et al.), and the production hyperparameter values. The two columns separate the per-batch cadence (top) from the per-epoch cadence (bottom).

Loading training pipeline diagram…
Read the diagram top to bottom. The per-batch cadence runs many times before any per-epoch stage fires. By the end of an epoch, the model has typically seen 70–110 batches; only then does the evaluator pull EMA shadow weights, compute rmse_last on the test set, and feed that scalar into ReduceLROnPlateau, EarlyStopping, and Checkpointer.

One Mini-Batch Step, Walked Through

A single mini-batch executes the eight per-batch stages in order. Below is what happens to a 4-sample batch with yrul=(20,80,5,100)y_{\text{rul}} = (20, 80, 5, 100) and a partly-trained model.

StageOperationInputOutput
1. Forwardmodel(seq)seq (4, 30, 14)y_pred (4,), hp_logits (4, 3)
2a. Inner WMSEmoderate_weighted_mse_lossy_pred, y_rulL_rul = 45.21 (scalar)
2b. Inner CEF.cross_entropyhp_logits, hp_targetL_health = 0.8999 (scalar)
3. GABA controllergaba(L_rul, L_health, shared_params)g_rul=41.5, g_h=0.082λ* = (0.0477, 0.9523), L = 3.013
4. Backwardloss.backward()L_total scalarθ.grad populated for ~1.7M params
5. Clipclip_grad_norm_(.., max_norm=1.0)global ‖∇‖ ≈ 2.4rescaled to ‖∇‖ = 1.0
6. AdamW stepoptimizer.step()lr=1e-3, β=(0.9, 0.999), wd=5e-5 (post-100)θ ← θ - lr·m̂/(√v̂+ε) - lr·wd·θ
7. EMA shadowema.update(model)α=0.999, current θshadow ← 0.999·shadow + 0.001·θ

The per-batch sequence is what the paper's _train_epoch() loops over. Every tuple of values in that table comes from the paper's production code; the Python script below reproduces stages 2–7 in pure NumPy for a single scalar parameter.

One Epoch Boundary, Walked Through

End of epoch ee: the per-batch loop has finished. Now the per-epoch operations run, in this order:

  • Warmup gate. If e<We < W (default W=10W=10), set η=η0(0.1+0.9e/W)\eta = \eta_0 \cdot (0.1 + 0.9\,e/W). Otherwise leave η\eta alone — the scheduler owns it from here.
  • Adaptive weight decay (GRACE-only). If e>100e > 100, set wd=wd00.5\text{wd} = \text{wd}_0 \cdot 0.5; if e>200e > 200, wd=wd00.1\text{wd} = \text{wd}_0 \cdot 0.1. Late training benefits from less regularisation as the loss landscape flattens.
  • EMA-aware evaluation. ema.apply_shadow(model) swaps the live weights for shadow weights; the evaluator runs the full test loop;ema.restore(model) swaps them back. The reported rmse_last, NASA score, health accuracy — all from the shadow.
  • Scheduler step. scheduler.step(rmse_last). If rmse_last hasn't improved for 30 epochs, halve the LR. Gated by eWe \geq W so the scheduler doesn't fight warmup.
  • Track best. If rmse_last < best_rmse, update best_rmse, save the model state and the EMA shadow, persist a checkpoint.
  • Early stopping. If no improvement for P=80P=80 consecutive epochs, restore best weights and break out of the epoch loop.
  • Gradient log flush. grad_logger.log_epoch() records the GABA stats —grul,ghealth,λiraw,λig_{\text{rul}}, g_{\text{health}}, \lambda^{\text{raw}}_i, \lambda^*_i — for the per-dataset CSV that chapter 21 §3 reads from.

Python: A Minimal fit() Loop From Scratch

Every per-batch stage of the production pipeline, by hand. One scalar parameter, four samples, no PyTorch — just enough algebra to see what every callback does.

One training step, every callback, in NumPy
🐍grace_minimal_step.py
1docstring

Names the contract: a complete one-step training loop in NumPy. Every callback the paper&apos;s UnifiedTrainer wires up (warmup, GABA, clip, AdamW, EMA) appears here as a few lines of explicit math.

8import numpy as np

NumPy is Python&apos;s numerical-array library. It provides ndarray (N-dimensional array) plus the math operations on it. Every algebraic step in this script — element-wise +, -, *, **, plus np.clip / np.maximum / np.sqrt / .sum / .mean — runs as optimised C under the hood, not slow Python loops.

EXECUTION STATE
numpy = Library for numerical computing — ndarray (homogeneous N-D array), broadcasting rules, linear algebra, RNG, and a large math library. Pre-installed in every Anaconda / Colab environment.
as np = Alias. Universal Python convention so you write np.array(), np.maximum(), etc. instead of numpy.array() everywhere.
→ why NumPy here? = GRACE&apos;s production code uses PyTorch tensors. Switching to NumPy strips away autograd and GPU plumbing so every callback (warmup, GABA, clip, AdamW, EMA) is exposed as a few lines of explicit math.
10rng = np.random.default_rng(42)

Modern NumPy random API. Pinned seed so the random feature x is reproducible across runs.

EXECUTION STATE
📚 np.random.default_rng(seed) = Generator object. Methods .normal, .uniform, .integers all share state.
14y_rul = np.array([20., 80., 5., 100.], dtype=float)

Four ground-truth RUL targets. Two near-failure (5, 20), one mid-life (80), one early-life (100). The mix lets w(y) range across [1.20, 1.96] — exercising every branch of the failure-bias ramp.

EXECUTION STATE
📚 np.array(object, dtype) = NumPy constructor: build an ndarray from any Python iterable. Stores values in a contiguous C buffer (not a Python list), enabling vectorised math.
⬇ arg 1: [20., 80., 5., 100.] = Plain Python list of floats. The trailing dot (20. vs 20) marks them as float literals — keeps everyone in float-land.
⬇ arg 2: dtype=float = Force float64 (8 bytes per element). Without this, NumPy could infer int64 if all values were whole numbers, breaking later divides like y/125.
→ example dtype mismatch = np.array([1, 2, 3]) → dtype=int64; np.array([1, 2, 3]) / 2 → array([0.5, 1.0, 1.5]) BUT np.array([1, 2, 3], dtype=int) // 2 → array([0, 1, 1]). Floats avoid this.
⬆ result: y_rul (4,) — shape and values = [ 20., 80., 5., 100.] shape=(4,), dtype=float64
15hp_target = np.array([2, 1, 2, 0])

Health labels matching y_rul thresholds: y<30→2 (critical), y<70→1 (degrading), else 0 (healthy). Held in this demo as a static input — the toy model only updates theta_rul.

EXECUTION STATE
hp_target (4,) = [2, 1, 2, 0]
16x = rng.normal(0, 1, size=(4,))

Random per-sample features. One scalar per sample — a stand-in for what the CNN-BiLSTM-Attention backbone would produce on real C-MAPSS sensor windows.

EXECUTION STATE
📚 rng.normal(loc, scale, size) = NumPy Generator method: draw samples from N(loc, scale²). Cleaner replacement for legacy np.random.normal — uses the modern PCG64 bit generator and respects the seeded `rng` from line 10.
⬇ arg 1: loc = 0 = Mean of the Gaussian. We center features at 0 so the toy linear model y_pred = θ·x·100 + 50 has predictions centred at 50.
⬇ arg 2: scale = 1 = Standard deviation. σ=1 → ~68% of features land in [-1, 1], ~95% in [-2, 2].
⬇ arg 3: size = (4,) = Output shape. (4,) is a 1-D array of 4 elements (a vector). For a matrix you would pass (rows, cols), e.g. size=(256, 14).
→ why a tuple? = size accepts an int or a tuple of ints. (4,) is the 1-element tuple syntax — the trailing comma distinguishes it from the integer (4) wrapped in parens.
⬆ result: x (4,) (illustrative) = [+0.30, -1.04, +0.75, +0.94] — exact values depend on seed 42.
20theta_rul = 0.5

Single learnable scalar. Stand-in for the 1.7M-parameter backbone. We track its update path through the pipeline.

21theta_h = 0.5

Health-head weight. Defined for symmetry but not updated in this demo (we hold L_health constant).

22m_rul, v_rul = 0.0, 0.0

AdamW state: first and second moment estimates. Zero-initialised, then bias-corrected via the (1 - β^t) divisor on every step.

EXECUTION STATE
m_rul = First moment (gradient mean). 0 at step 0.
v_rul = Second moment (squared-gradient mean). 0 at step 0.
23shadow_rul = theta_rul

EMA shadow buffer. Initialised to the same value as theta_rul (ExponentialMovingAverage.__init__ in callbacks.py:62).

24step_count = 0

AdamW timestep counter. Used in the bias-correction divisor (1 - β^t).

26base_lr = 1e-3

Target learning rate after warmup. Paper default for C-MAPSS.

27warmup_epochs = 10

Number of epochs over which to linearly ramp the learning rate from 0.1·base_lr to base_lr.

28beta1, beta2 = 0.9, 0.999

AdamW momentum coefficients. β₁=0.9 weights the first moment (mean), β₂=0.999 weights the second moment (uncentered variance).

EXECUTION STATE
β₁ = 0.9 = Effective memory of ~10 steps for the gradient mean.
β₂ = 0.999 = Effective memory of ~1000 steps for the squared gradient.
29eps_adam = 1e-8

Numerical safety floor in AdamW. Added to sqrt(v_hat) before division to avoid 0/0 when a parameter has never seen a gradient.

30weight_decay = 1e-4

AdamW weight decay coefficient. Decoupled from the gradient — applied as a separate -lr·wd·θ shrink. GRACE schedules this down at epoch 100/200.

31ema_alpha = 0.999

EMA decay for the shadow weights. Effective memory of 1000 steps. Eval at end-of-epoch swaps in shadow_rul.

34def lr_at(epoch) → float

Returns the learning rate for a given epoch. Implements GRACE&apos;s linear warmup callback — mirrors the production grace/training/callbacks.py:LRWarmup.get_lr (~5 lines of Python). Called once per epoch by UnifiedTrainer.fit before the per-batch loop runs.

EXECUTION STATE
⬇ input: epoch = Integer epoch index, starting at 0. Trainer maintains it in trainer.py and increments at end-of-epoch.
→ epoch purpose = Decides whether we are inside the warmup window (epoch &lt; W) or past it. The whole branch logic on line 36 hinges on this scalar.
→ closes over = Reads two module-level globals: warmup_epochs (10) and base_lr (1e-3). Kept as globals so the demo stays a single flat script.
⬆ returns float = The learning rate for this epoch. 1e-4 at epoch 0, 5.5e-4 at epoch 5, 1e-3 from epoch 10 onward.
36if epoch < warmup_epochs:

Branch on warmup phase. Outside warmup, return the constant base_lr.

37return base_lr * (0.1 + 0.9 * epoch / warmup_epochs)

Linear ramp. epoch=0 → 0.1·base_lr; epoch=W → base_lr.

EXECUTION STATE
epoch=0 = lr = 1e-3 * (0.1 + 0.0) = 1e-4 — gentle start, no large initial step.
epoch=5 = lr = 1e-3 * (0.1 + 0.45) = 5.5e-4.
epoch=10 = lr = 1e-3 * (0.1 + 0.9) = 1e-3 — full base learning rate.
38return base_lr

Post-warmup: scheduler takes over (ReduceLROnPlateau).

41def w_failure(y, max_rul=125.0):

Failure-biased per-sample weight. Same closed form as core/weighted_mse.py:moderate_weighted_mse_loss.

42return 1.0 + np.clip(1.0 - y / max_rul, 0.0, 1.0)

Failure-bias ramp: weight is 2.0 at y=0 (failure imminent) and decays linearly to 1.0 at y ≥ 125 (healthy). Forces the optimiser to spend more error budget on near-failure samples — exactly where wrong predictions cost a real plant.

EXECUTION STATE
📚 np.clip(a, a_min, a_max) = NumPy function: element-wise saturate every value in `a` to the closed interval [a_min, a_max]. Anything below a_min becomes a_min; anything above a_max becomes a_max; values in range pass through.
⬇ arg 1: 1.0 - y / max_rul = The unclipped ramp. y/125 = fraction of life used. 1 - that = fraction remaining. Goes from 1.0 at y=0 down to 0 at y=125 down to NEGATIVE for y>125.
⬇ arg 2: a_min = 0.0 = Lower clip. Stops the ramp from going negative for super-healthy samples (y > 125 in raw RUL). Without it, w could fall BELOW 1.
⬇ arg 3: a_max = 1.0 = Upper clip. Caps the ramp at +1 — only effective if y were ever NEGATIVE (it shouldn&apos;t be). Belt-and-braces.
→ np.clip vs np.maximum vs np.minimum = np.clip(a, lo, hi) ≡ np.minimum(np.maximum(a, lo), hi). Single function, two-sided saturation. np.maximum is one-sided (just a floor).
→ trace y = 5 = 1 + clip(1 - 5/125, 0, 1) = 1 + clip(0.96, 0, 1) = 1 + 0.96 = 1.96
→ trace y = 20 = 1 + clip(1 - 20/125, 0, 1) = 1 + clip(0.84, 0, 1) = 1 + 0.84 = 1.84
→ trace y = 80 = 1 + clip(1 - 80/125, 0, 1) = 1 + clip(0.36, 0, 1) = 1 + 0.36 = 1.36
→ trace y = 100 = 1 + clip(1 - 100/125, 0, 1) = 1 + clip(0.20, 0, 1) = 1 + 0.20 = 1.20
→ trace y = 200 (above cap) = 1 + clip(1 - 200/125, 0, 1) = 1 + clip(-0.6, 0, 1) = 1 + 0 = 1.00 — the floor saved us.
⬆ return: w (4,) for this batch = [1.84, 1.36, 1.96, 1.20]
45def gaba_lambda(g_rul, g_health, eps_floor=0.05) → np.ndarray

GABA&apos;s closed-form weight allocator for K=2 tasks. Compressed for the demo: we skip the EMA-smoothing stage that the production controller runs in grace/core/gaba_loss.py — i.e. assume the EMA has fully converged so EMA(g) ≈ g. Returns the two outer-axis weights λ_rul, λ_health that turn (L_rul, L_health) into a single scalar L_total.

EXECUTION STATE
⬇ input: g_rul = Per-task gradient L2 norm of L_rul w.r.t. the SHARED backbone parameters. Real production value at epoch 50 on FD002 ≈ 41.5.
→ g_rul purpose = Tells GABA how loud the RUL task&apos;s gradient signal is. Bigger g_rul ⇒ smaller λ_rul (downweight a loud task).
⬇ input: g_health = Same statistic for L_health. ≈ 0.082 at the same checkpoint — about 506× smaller than g_rul.
⬇ input: eps_floor = 0.05 (default) = Floor on every λ_i. Prevents collapse — no task can be weighted at less than 5% even if its gradient dominates. Paper default.
→ why 0.05? = Empirically: 0.01 lets the small-gradient task starve; 0.10 over-weights it. 0.05 is the validated sweet spot from chapter 21 §3.
⬆ returns: (2,) ndarray = Final λ* after floor + renorm. Sums to exactly 1.0 — a proper convex combination ready for L_total = λ[0]·L_rul + λ[1]·L_health.
47S = g_rul + g_health

K=2 normaliser.

48raw = np.array([g_health / S, g_rul / S])

GABA&apos;s inverse-ratio rule: the task with the BIGGER per-task gradient norm gets the SMALLER weight. Intuition — that task is already getting a loud gradient signal, so dial it down to let the quieter task be heard.

EXECUTION STATE
📚 np.array(list) = Build a 2-element ndarray from the Python list of two scalars. Equivalent to np.array([0.001972, 0.998028]).
→ element 0: g_health / S = 0.082 / 41.582 = 0.001972 — λ_rul (small, because g_rul dominates).
→ element 1: g_rul / S = 41.5 / 41.582 = 0.998028 — λ_health (large, because g_health is tiny).
→ why the swap? = Index 0 of the result is for RUL; we put g_health/S there because RUL&apos;s OWN gradient is huge. The closed form is λ_i ∝ 1/g_i, scaled to sum to 1.
raw (2,) = [0.001972, 0.998028]
49flo = np.maximum(raw, eps_floor)

Per-element floor. Prevents either λ from collapsing to ~0 — a task with weight near 0 receives no gradient and effectively stops training.

EXECUTION STATE
📚 np.maximum(a, b) = Element-wise binary max. For each index i, returns max(a[i], b[i]). Broadcasts a scalar `b` against an array `a`. Distinct from np.max which is a REDUCTION (returns a single value).
⬇ arg 1: raw (2,) = [0.001972, 0.998028]
⬇ arg 2: eps_floor = 0.05 = Scalar. Broadcasts to compare against every element of raw.
→ np.maximum vs np.max = np.maximum([1, 5], [3, 2]) → [3, 5] (element-wise pair) np.max([1, 5, 3, 2]) → 5 (single reduction)
→ element-wise trace = max(0.001972, 0.05) = 0.05 ← floor kicks in max(0.998028, 0.05) = 0.998028 ← passes through
⬆ result: flo (2,) = [0.05, 0.998028]
50return flo / flo.sum()

Renormalise so the two λ* sum to exactly 1.0 after the floor. Without this step, applying the floor breaks the simplex constraint (sum ≠ 1) and L_total would no longer be a convex combination of the per-task losses.

EXECUTION STATE
📚 ndarray.sum(axis=None) = ndarray method: REDUCE the array to a single scalar by summing all elements. With axis= it sums along that axis. axis=None (default) sums everything.
/ (division operator) = Element-wise division. Scalar broadcasts: each element of flo gets divided by the same flo.sum() scalar.
→ trace = flo.sum() = 0.05 + 0.998028 = 1.048028 flo[0] / flo.sum() = 0.05 / 1.048028 = 0.04771 flo[1] / flo.sum() = 0.998028 / 1.048028 = 0.95229
→ verify simplex = 0.04771 + 0.95229 = 1.00000 ✓ — proper convex combination weights.
⬆ return: λ* (2,) = [0.04771, 0.95229]
54epoch = 50

Pretend we are at epoch 50. Past warmup, so lr=1e-3.

55lr = lr_at(epoch)

Look up the warmup-aware learning rate.

EXECUTION STATE
lr = 1e-3 — past warmup, full rate.
58y_pred = theta_rul * x * 100.0 + 50.0

Toy forward pass. Stand-in for the production CNN-BiLSTM-Attention backbone + RUL head. The +50 keeps predictions in the RUL-target range.

EXECUTION STATE
y_pred (4,) (illustrative) = Depends on x; on seed 42 ≈ [65., 0., 87.5, 97.0]
61w = w_failure(y_rul)

Per-sample failure-biased weights for THIS batch.

EXECUTION STATE
w (4,) = [1.84, 1.36, 1.96, 1.20]
62residual_sq = (y_pred - y_rul) ** 2

Element-wise squared residuals. The classic MSE inner term — squaring makes errors strictly positive and quadratically penalises large mistakes.

EXECUTION STATE
y_pred - y_rul = Element-wise subtraction. NumPy broadcasts pair-by-pair: [y_pred[0]-y_rul[0], y_pred[1]-y_rul[1], ...].
** (power operator) = Element-wise exponentiation. For ndarray: a ** 2 squares EACH element independently. Equivalent to np.square(a) or a * a.
→ trace (illustrative — random x) = Random x → y_pred ≈ [65., 0., 87.5, 97.0] residual = y_pred - y_rul ≈ [+45., -80., +82.5, -3.] residual_sq ≈ [2025., 6400., 6806.25, 9.]
→ demo override (used in printed output) = We override y_pred = [25., 75., 12., 98.] so the algebra below is clean integers. residual = [+5., -5., +7., -2.] residual_sq = [25., 25., 49., 4.]
residual_sq (4,) — final = [25., 25., 49., 4.] (using demo override)
63L_rul = (w * residual_sq).mean()

Failure-biased MSE on the RUL task — the INNER axis of GRACE. Multiplies each squared residual by its failure-bias weight before averaging across the batch.

EXECUTION STATE
📚 ndarray.mean(axis=None) = ndarray method: arithmetic mean. With axis=None (default) reduces every element to a single scalar. Equivalent to .sum() / .size.
* (multiply operator) = Element-wise product (NOT matrix multiply — that&apos;s @). Pairs each w[i] with each residual_sq[i].
→ step 1: w * residual_sq = [1.84·25, 1.36·25, 1.96·49, 1.20·4] = [46.0, 34.0, 96.04, 4.8]
→ step 2: .mean() = .sum() / N = (46.0 + 34.0 + 96.04 + 4.8) / 4 = 180.84 / 4 = 45.21
→ why mean not sum? = Mean keeps the loss scale invariant to batch size — doubling the batch does not double the loss, so LR does not need rescaling.
L_rul (scalar) = 45.2100
64L_health = 0.8999

Cross-entropy from the same forward pass on the health head. Held constant in this demo because we are tracking only theta_rul.

67g_rul, g_h = 41.5, 0.082

Per-task gradient L2 norms on the shared backbone. Real production values for a partly-trained C-MAPSS model.

EXECUTION STATE
ratio = 41.5 / 0.082 ≈ 506× — typical FD002 value, see chapter 21 §3.
68lam = gaba_lambda(g_rul, g_h)

Run the (compressed) GABA controller.

EXECUTION STATE
raw lambdas = [0.001972, 0.998028]
after floor = [0.05, 0.998028]
after renorm = [0.0477, 0.9523]
69L_total = lam[0] * L_rul + lam[1] * L_health

Composed GRACE loss. The OUTER axis sums the two task losses with the GABA-derived weights.

EXECUTION STATE
lam[0] * L_rul = 0.0477 * 45.21 = 2.1565
lam[1] * L_health = 0.9523 * 0.8999 = 0.8569
L_total = 2.1565 + 0.8569 = 3.0134
74g_theta = lam[0] * (2.0 / 4) * (w * (y_pred - y_rul) * 100.0 * x).sum()

Closed-form gradient of L_total w.r.t. theta_rul. Chain rule: dL_total/dtheta_rul = λ_rul · dL_rul/dtheta_rul, since L_health does not depend on theta_rul.

EXECUTION STATE
λ_rul = lam[0] = 0.0477 = Outer-axis weight on the RUL task.
dL_rul/dtheta = (2/N)·Σ w·r·100·x = From d/dtheta ((1/N)·Σ w·(100·θ·x + 50 - y)²) = (2/N)·Σ w·(100·θ·x + 50 - y)·100·x.
g_theta (illustrative) = On seed 42 ≈ +2.4 BEFORE clip.
77g_norm = abs(g_theta)

Scalar gradient norm. In production, clip_grad_norm_ uses the L2 norm across ALL parameters at once.

78if g_norm > 1.0:

Trigger clipping if the global norm exceeds the threshold.

79g_theta = g_theta / g_norm

Rescale so |g_theta| = 1.0. Direction preserved; magnitude clipped.

EXECUTION STATE
g_theta post-clip = Sign of original g_theta, magnitude exactly 1.0.
→ in production = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0). Same algebra, applied to the concatenated parameter gradient vector.
82step_count += 1

Advance the AdamW timestep. Needed for bias correction below.

83m_rul = beta1 * m_rul + (1 - beta1) * g_theta

AdamW first moment EMA.

EXECUTION STATE
m_rul (after step 1) = 0.9·0 + 0.1·g_theta = 0.1·g_theta
84v_rul = beta2 * v_rul + (1 - beta2) * g_theta ** 2

AdamW second moment EMA. Tracks the running uncentered variance of the gradient.

EXECUTION STATE
v_rul (after step 1) = 0.999·0 + 0.001·g_theta² = 0.001·g_theta²
85m_hat = m_rul / (1 - beta1 ** step_count)

Bias correction. Needed because m_rul starts at 0 and biases small in early steps. After many steps, (1 - β₁^t) → 1 and the correction vanishes.

EXECUTION STATE
1 - β₁^1 = 0.1 = Correction factor at step 1: divide by 0.1, multiplying m_rul by 10.
1 - β₁^100 ≈ 1.0 = Correction factor at step 100: nearly identity.
86v_hat = v_rul / (1 - beta2 ** step_count)

Bias-corrected second moment.

87update = lr * m_hat / (np.sqrt(v_hat) + eps_adam) + lr * weight_decay * theta_rul

The full AdamW step in one line. Two contributions are added: (1) the bias-corrected, variance-normalised gradient step, and (2) the DECOUPLED weight-decay shrink. The decoupling — applying wd as a SEPARATE additive term, not as part of the gradient — is what differentiates AdamW from plain Adam-with-L2.

EXECUTION STATE
📚 np.sqrt(x) = NumPy element-wise square root. For scalars returns √x. For ndarrays returns sqrt of each element. Used here to convert v_hat (uncentered variance estimate) into a standard-deviation-like denominator.
→ np.sqrt examples = np.sqrt(4) → 2.0 np.sqrt(np.array([1, 4, 9])) → [1., 2., 3.] np.sqrt(0) → 0.0 (then ε saves us from /0)
Term 1: lr * m_hat / (np.sqrt(v_hat) + eps_adam) = Adaptive step. m_hat is the bias-corrected mean gradient; sqrt(v_hat) is the bias-corrected RMS gradient. Their ratio is dimensionless ≈ ±1, so the effective per-parameter step is ≈ lr regardless of the gradient&apos;s scale.
→ eps_adam = 1e-8 = Numerical floor inside the denominator. If v_hat = 0 (parameter has never seen a gradient), sqrt(v_hat)+ε = 1e-8 — finite, no divide-by-zero.
Term 2: lr * weight_decay * theta_rul = Decoupled weight-decay shrink. Pulls θ toward 0 INDEPENDENTLY of the gradient. Equivalent to applying θ ← (1 - lr·wd)·θ before the gradient step.
→ AdamW vs Adam+L2 = Adam+L2 adds wd·θ to the GRADIENT, then divides by sqrt(v_hat) — so the effective decay ends up scaled by 1/√v_hat (uneven across parameters). AdamW skips the divide for the wd term, keeping decay uniform.
→ numerical trace (step 1, g_theta = +1.0 post-clip) = m_hat = 0.1·1.0 / (1 - 0.9¹) = 0.1 / 0.1 = 1.0 v_hat = 0.001·1.0 / (1 - 0.999¹) = 0.001 / 0.001 = 1.0 Term 1 = 1e-3 · 1.0 / (sqrt(1.0) + 1e-8) ≈ 1.0e-3 Term 2 = 1e-3 · 1e-4 · 0.5 = 5.0e-8 update ≈ 1.00005e-3 ≈ 1.001e-3
update (illustrative) = Step 1, post-clip: ≈ +1.001e-3
88theta_rul = theta_rul - update

Apply the AdamW update.

EXECUTION STATE
theta_rul (illustrative) = 0.5 - 1.001e-3 ≈ 0.499
91shadow_rul = ema_alpha * shadow_rul + (1 - ema_alpha) * theta_rul

EMA shadow update. Same formula as the GABA EMA but a different decay (0.999 here, 0.99 there). Decay closer to 1 = longer memory, smoother shadow.

EXECUTION STATE
shadow_rul (illustrative) = 0.999·0.5 + 0.001·0.499 ≈ 0.499999 — barely moved.
→ eval semantics = End-of-epoch evaluation calls .apply_shadow (swap weights → shadow), runs the test loop, then .restore (swap back). The shadow gives smoother test metrics on partly-converged models.
94print(f"epoch={epoch}, lr={lr:.4e}")

Header line.

95print(f"y_pred = {y_pred.round(2).tolist()}")

Predictions.

96print(f"w(y) = {w.round(2).tolist()}")

Per-sample weights.

97print(f"L_rul = {L_rul:.4f} L_health = {L_health:.4f}")

Inner-axis losses.

98print(f"lambda* = ({lam[0]:.4f}, {lam[1]:.4f})")

OUTER-axis weights.

99print(f"L_total = {L_total:.4f}")

Composed GRACE loss.

100print(f"g_theta = {g_theta:+.6f} (post-clip)")

Final gradient on theta_rul after clipping.

101print(f"theta_rul = {theta_rul:.6f} (post-AdamW)")

Updated parameter.

102print(f"shadow_rul = {shadow_rul:.6f} (post-EMA)")

EMA shadow used for end-of-epoch evaluation.

EXECUTION STATE
Final output (illustrative) =
epoch=50, lr=1.0000e-03
y_pred       = [25.0, 75.0, 12.0, 98.0]
w(y)         = [1.84, 1.36, 1.96, 1.2]
L_rul        = 45.2100    L_health = 0.8999
lambda*      = (0.0477, 0.9523)
L_total      = 3.0134
g_theta      = +1.000000    (post-clip)
theta_rul    = 0.499000    (post-AdamW)
shadow_rul   = 0.499999    (post-EMA)
43 lines without explanation
1"""Minimal GRACE training loop, by hand. One epoch, four samples.
2
3Stitches together every stage of UnifiedTrainer.fit() — warmup,
4weighted MSE, GABA, gradient clip, AdamW, EMA — using only NumPy so
5the algebra is fully visible.
6"""
7
8import numpy as np
9
10rng = np.random.default_rng(42)
11
12
13# ---------- Mini-batch (4 samples, 1 feature each) ----------
14y_rul     = np.array([20., 80.,  5., 100.], dtype=float)
15hp_target = np.array([2,    1,   2,   0])
16x         = rng.normal(0, 1, size=(4,))
17
18
19# ---------- Two scalar &lsquo;parameters&rsquo;: the toy model ----------
20theta_rul = 0.5         # RUL head weight
21theta_h   = 0.5         # Health head weight (unused in update for brevity)
22m_rul, v_rul   = 0.0, 0.0   # AdamW first/second moments
23shadow_rul     = theta_rul   # EMA shadow
24step_count     = 0
25
26base_lr        = 1e-3
27warmup_epochs  = 10
28beta1, beta2   = 0.9, 0.999
29eps_adam       = 1e-8
30weight_decay   = 1e-4
31ema_alpha      = 0.999
32
33
34def lr_at(epoch):
35    """Linear warmup: 0.1*lr_0 -> lr_0 over warmup_epochs."""
36    if epoch < warmup_epochs:
37        return base_lr * (0.1 + 0.9 * epoch / warmup_epochs)
38    return base_lr
39
40
41def w_failure(y, max_rul=125.0):
42    return 1.0 + np.clip(1.0 - y / max_rul, 0.0, 1.0)
43
44
45def gaba_lambda(g_rul, g_health, eps_floor=0.05):
46    """K=2 closed form + floor + renorm. Skip EMA for the demo (assume converged)."""
47    S   = g_rul + g_health
48    raw = np.array([g_health / S, g_rul / S])
49    flo = np.maximum(raw, eps_floor)
50    return flo / flo.sum()
51
52
53# ---------- ONE training step ----------
54epoch = 50
55lr    = lr_at(epoch)
56
57# 1. Forward: toy linear model y_pred = theta_rul * x * 100 + 50
58y_pred = theta_rul * x * 100.0 + 50.0
59
60# 2. Inner losses
61w           = w_failure(y_rul)
62residual_sq = (y_pred - y_rul) ** 2
63L_rul       = (w * residual_sq).mean()
64L_health    = 0.8999                                  # held constant for the demo
65
66# 3. Outer: GABA closed form (use measured ratio from chapter 18 §1)
67g_rul, g_h = 41.5, 0.082
68lam        = gaba_lambda(g_rul, g_h)
69L_total    = lam[0] * L_rul + lam[1] * L_health
70
71# 4. Backward: scalar grad of L_total w.r.t. theta_rul
72#    dL_total/dtheta_rul = lam[0] * d/dtheta_rul ((1/N) sum w*(theta*100x - y_rul)^2)
73#                        = lam[0] * (2/N) * sum( w * (y_pred - y_rul) * 100 * x )
74g_theta    = lam[0] * (2.0 / 4) * (w * (y_pred - y_rul) * 100.0 * x).sum()
75
76# 5. Gradient clip (clip if |g| > 1.0)
77g_norm     = abs(g_theta)
78if g_norm > 1.0:
79    g_theta = g_theta / g_norm
80
81# 6. AdamW update
82step_count += 1
83m_rul = beta1 * m_rul + (1 - beta1) * g_theta
84v_rul = beta2 * v_rul + (1 - beta2) * g_theta ** 2
85m_hat = m_rul / (1 - beta1 ** step_count)
86v_hat = v_rul / (1 - beta2 ** step_count)
87update      = lr * m_hat / (np.sqrt(v_hat) + eps_adam) + lr * weight_decay * theta_rul
88theta_rul   = theta_rul - update
89
90# 7. EMA shadow
91shadow_rul  = ema_alpha * shadow_rul + (1 - ema_alpha) * theta_rul
92
93
94print(f"epoch={epoch}, lr={lr:.4e}")
95print(f"y_pred       = {y_pred.round(2).tolist()}")
96print(f"w(y)         = {w.round(2).tolist()}")
97print(f"L_rul        = {L_rul:.4f}    L_health = {L_health:.4f}")
98print(f"lambda*      = ({lam[0]:.4f}, {lam[1]:.4f})")
99print(f"L_total      = {L_total:.4f}")
100print(f"g_theta      = {g_theta:+.6f}    (post-clip)")
101print(f"theta_rul    = {theta_rul:.6f}    (post-AdamW)")
102print(f"shadow_rul   = {shadow_rul:.6f}    (post-EMA)")
What the toy reveals. AdamW's update is two contributions: the bias-corrected gradient step and the decoupled weight-decay shrink. EMA's update barely moves the shadow on a single step (0.5 → 0.499999) but compounds: after 1000 steps the shadow has fully tracked the live weights with a ~1000-step lag, smoothing out late-training noise. Gradient clipping rescales the global vector, preserving its direction.

PyTorch: The Paper's UnifiedTrainer.fit()

The same pipeline, but now driven through UnifiedTrainer. Roughly 30 lines of caller code wires up every callback to the right cadence. The actual fit loop lives in grace/training/trainer.py; what you see below is the surface that the paper's experiments/phase1_cmapss.py calls into.

Production GRACE training in 30 lines
🐍grace_train_pipeline.py
1docstring

Names the contract. The train_grace function is what every Phase-1 experiment calls — see grace/experiments/phase1_cmapss.py:run_single_experiment for the production caller.

9import torch

Core PyTorch.

10import torch.nn as nn

nn.CrossEntropyLoss is the only direct nn use here; everything else lives in paper modules.

11import torch.optim as optim

AdamW + ReduceLROnPlateau scheduler.

EXECUTION STATE
📚 torch.optim = PyTorch optimisers and learning-rate schedulers. AdamW = Adam with decoupled weight decay (Loshchilov &amp; Hutter, 2019).
12from torch.utils.data import DataLoader

Standard PyTorch mini-batch iterator. Shuffles, batches, and collates the underlying Dataset.

EXECUTION STATE
📚 DataLoader(dataset, batch_size, shuffle, num_workers) = Returns an iterable of batches. Each batch is a 4-tuple (seq, rul, health, uid) when wrapped via MTLDatasetWrapper.
14from grace.core.loss_registry import get_loss

Factory that returns any of the 9 MTL loss variants by name. We ask for &lsquo;gaba&rsquo;.

EXECUTION STATE
📚 get_loss(name, **kwargs) = Source: grace/core/loss_registry.py:39. Returns an nn.Module pre-configured. Available names: gaba, fixed_050, fixed_075, amnl_fixed, amnl_v7, uncertainty, gradnorm, dwa, pcgrad, cagrad.
15from grace.core.weighted_mse import moderate_weighted_mse_loss

Failure-biased MSE — the inner axis. Pure function, no state.

16from grace.models.backbone import UnifiedBackbone

Shared CNN-BiLSTM-Attention backbone. Source: grace/models/backbone.py:UnifiedBackbone.

EXECUTION STATE
→ architecture = Conv1d → BiLSTM → MultiHeadAttention → FC. Returns the shared latent that BOTH heads consume.
17from grace.models.dual_task_model import DualTaskModel

Wraps the backbone with two heads (RUL regression, health classification).

18from grace.models.model_configs import get_model_config

Named architecture configs: &lsquo;cmapss&rsquo;, &lsquo;ncmapss_20feat&rsquo;, &lsquo;ncmapss_32feat&rsquo;. Each returns a ModelConfig dataclass with input_size, hidden_size, etc.

19from grace.training.trainer import UnifiedTrainer

The orchestrator that calls every callback in the right order. Source: grace/training/trainer.py:UnifiedTrainer.

EXECUTION STATE
📚 UnifiedTrainer(model, mtl_loss, rul_criterion, health_criterion, optimizer, scheduler, device, config) = .fit(train_loader, test_loader) runs the full 500-epoch loop with all callbacks. Returns {history, best_metrics, best_epoch, best_rmse}.
20from grace.training.seed_utils import set_seed

Pin the PRNGs (torch + numpy + python random + cudnn deterministic). Reproducibility helper.

23def train_grace(train_ds, test_ds, dataset_name='FD002', seed=42):

One-call training surface. Inputs: two CMAPSSDataset instances. Output: best-metrics dict + history.

EXECUTION STATE
⬇ input: train_ds = CMAPSSDataset wrapped via MTLDatasetWrapper. Yields (seq, rul, health, uid).
⬇ input: test_ds = Same shape as train_ds, evaluation set.
⬇ input: dataset_name = Affects per-condition normalisation defaults (FD002/FD004 turn it on).
⬇ input: seed = 42 = Single seed run. Production averages 5 seeds: [42, 123, 456, 789, 1024].
⬆ returns = Dict with keys 'history' (per-epoch lists), 'best_metrics' (final eval), 'best_epoch', 'best_rmse'.
25set_seed(seed)

Pin everything. Sets torch.manual_seed, np.random.seed, random.seed, and torch.backends.cudnn.deterministic = True.

26device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Pick the compute device. Use GPU if a CUDA-capable card is visible; otherwise fall back to CPU. The chosen device is later passed to model.to(device) and to every per-batch tensor.to(device).

EXECUTION STATE
📚 torch.device(string) = PyTorch class: lightweight handle representing a compute device. Stores type (&apos;cuda&apos; / &apos;cpu&apos; / &apos;mps&apos;) and optional index. Used everywhere PyTorch needs to know WHERE a tensor lives.
📚 torch.cuda.is_available() = Probes the CUDA driver. Returns True iff a GPU + the matching CUDA driver are both installed AND PyTorch was built with CUDA support. Cheap to call; safe at module import.
→ ternary expression = Python conditional: VALUE_IF_TRUE if CONDITION else VALUE_IF_FALSE. Here returns &apos;cuda&apos; or &apos;cpu&apos; depending on the probe.
→ why it matters = Production runs use a single A100 / H100 — 12-second epochs on FD002. Pure CPU works but is ~50× slower (≈10 min per epoch) because the BiLSTM does not vectorise across time steps.
⬆ result: device = device(type=&apos;cuda&apos;, index=0) on a GPU box device(type=&apos;cpu&apos;) on a laptop
28train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)

Mini-batch iterator with shuffling.

EXECUTION STATE
⬇ batch_size=256 = Paper default for C-MAPSS. ~70 batches per epoch on FD001, ~110 on FD002.
⬇ shuffle=True = Reshuffle order each epoch. Critical for SGD convergence.
29test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)

Eval loader. shuffle=False — the evaluator needs unit-id ordering for last-cycle scoring.

31mc = get_model_config('cmapss')

Load the named architecture config. Returns a dataclass with input_size=14, hidden_size=256, cnn_channels=(32,64,32), num_attn_heads=4, fc_dims=(256,128), dropout=0.3.

EXECUTION STATE
mc.input_size = 14 — number of selected sensors after the paper&apos;s feature filter.
mc.hidden_size = 256 — BiLSTM hidden size. Paper&apos;s &lsquo;h=256&rsquo; setting.
mc.num_attn_heads = 4 — multi-head self-attention heads.
mc.dropout = 0.3 — applied after BiLSTM and FC layers.
32backbone = UnifiedBackbone(...)

Shared CNN-BiLSTM-Attention trunk. Roughly 1.7M parameters. Returns a (B, hidden_size) latent per window.

EXECUTION STATE
📚 UnifiedBackbone(input_size, hidden_size, cnn_channels, num_attn_heads, fc_dims, dropout, use_attention, use_residual) = From grace/models/backbone.py. Architecture: Conv1d ×3 → BiLSTM → MultiHeadAttention → FC. use_attention=True and use_residual=True are the paper defaults.
38model = DualTaskModel(backbone, num_health_states=3, dropout=0.3)

Dual-head wrapper. Adds rul_head (Linear → 1) and health_head (Linear → 3) on top of the shared backbone.

EXECUTION STATE
→ forward signature = model(seq) → (rul_pred (B, 1), hp_logits (B, 3))
→ get_shared_params() = Returns the backbone-only parameter list — what GABA differentiates against.
40mtl_loss = get_loss('gaba', beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2)

Instantiate the OUTER controller. Returns a GABALoss nn.Module with EMA buffer, step counter, and the four-stage closed form.

EXECUTION STATE
⬇ beta=0.99 = EMA smoothing — see chapter 21 §2 for the four-stage trace.
⬇ warmup_steps=100 = First 100 mini-batches use uniform 1/K weights, NOT the closed form. Stops the controller from over-reacting before the EMA has any history.
⬇ min_weight=0.05 = Floor.
⬇ n_tasks=2 = K=2: RUL and health.
43params = list(model.parameters()) + list(mtl_loss.parameters())

Build the optimiser parameter list. Concatenates the model&apos;s learnable tensors (weights, biases, BatchNorm γ/β) with any learnable parameters in the MTL loss. Critical for Uncertainty (Kendall et al.) which has learnable σ_rul, σ_health; harmless no-op for GABA which has none.

EXECUTION STATE
📚 nn.Module.parameters() = Generator that yields every learnable Tensor registered on the module (and its submodules, recursively). Yields nn.Parameter objects — Tensors with requires_grad=True that the optimiser owns.
📚 list(generator) = Python builtin: materialise a generator into a list. Needed here because parameters() yields lazily — list() forces full enumeration so the result can be concatenated with +.
+ (list concatenation) = Python list operator: returns a NEW list with all elements of the left followed by all of the right. Different from element-wise + on tensors.
→ why concatenate? = AdamW takes a single iterable of parameters. We hand it the union of model params + MTL-loss params so a single optimizer.step() updates both.
→ forgetting this is a silent footgun = If you write optim.AdamW(model.parameters()) without the mtl_loss part, the σ scalars in Uncertainty / GradNorm never receive a gradient step → algorithm degrades to fixed-σ → published numbers cannot be reproduced. See Pitfall 2 below.
⬆ result: params (list of ~50-100 nn.Parameter tensors) = [backbone.conv1.weight, backbone.conv1.bias, backbone.lstm.weight_ih_l0, ..., rul_head.weight, rul_head.bias, health_head.weight, health_head.bias] (plus mtl_loss params if any)
44optimizer = optim.AdamW(params, lr=1e-3, weight_decay=1e-4)

Instantiate the AdamW optimiser with the paper&apos;s defaults. Decoupled weight decay (Loshchilov &amp; Hutter, 2019). The trainer later schedules weight_decay down at epoch 100 / 200 via the adaptive_weight_decay flag.

EXECUTION STATE
📚 optim.AdamW(params, lr, betas=(0.9,0.999), eps=1e-8, weight_decay) = PyTorch optimiser. Per-parameter adaptive learning rate via 1st/2nd moment estimates with bias correction, plus DECOUPLED L2-style weight decay. Update rule: θ ← θ - lr·m̂/(√v̂+ε) - lr·wd·θ. The wd term is added independently of the gradient — that decoupling is what fixes Adam-with-L2&apos;s scale-mismatch bug.
⬇ arg 1: params = The list of learnable Tensors built on line 43. AdamW allocates one m and one v buffer per parameter (~3.4M floats of optimiser state for a 1.7M-param model).
⬇ arg 2: lr = 1e-3 = Base learning rate. Standard transformer-era default. Linearly warmed up from 1e-4 over the first 10 epochs by the trainer.
⬇ arg 3: weight_decay = 1e-4 = L2-style shrink coefficient applied OUTSIDE the gradient term. Halved at epoch 100, divided by 10 at epoch 200 by the adaptive schedule.
→ defaults left at PyTorch values = betas=(0.9, 0.999) — momentum coefficients for m and v. eps=1e-8 — numerical floor under the sqrt(v) denominator. We accept these so they don&apos;t need explicit passing.
→ AdamW vs Adam (one-line) = Adam: θ ← θ - lr·(m̂ + wd·θ)/(√v̂+ε) ← wd interacts with the variance scaling AdamW: θ ← θ - lr·m̂/(√v̂+ε) - lr·wd·θ ← wd applied independently
⬆ result: optimizer = AdamW instance with .step(), .zero_grad(), .state_dict(). Holds m and v in optimizer.state[param].
45scheduler = optim.lr_scheduler.ReduceLROnPlateau(

Build the LR scheduler. ReduceLROnPlateau watches a metric (validation rmse_last) and halves the learning rate whenever the metric stops improving. The opening parenthesis here begins a multi-line constructor call; arguments continue on line 46.

EXECUTION STATE
📚 optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, min_lr) = PyTorch scheduler. Wraps the optimiser. Call scheduler.step(metric) ONCE PER EPOCH with the validation metric. The scheduler tracks the best value seen; if no improvement for `patience` calls, multiplies every param-group&apos;s lr by `factor`, floored at `min_lr`.
46optimizer, mode='min', factor=0.5, patience=30, min_lr=5e-6,

Constructor arguments for ReduceLROnPlateau, on the continuation line.

EXECUTION STATE
⬇ arg 1: optimizer = The AdamW instance from line 44. The scheduler mutates each param-group&apos;s &apos;lr&apos; key in-place; AdamW reads that key on the next .step().
⬇ arg 2: mode='min' = Smaller-is-better. Use &apos;max&apos; for accuracy, &apos;min&apos; for loss/error metrics like RMSE. Determines whether the scheduler considers a metric improved.
⬇ arg 3: factor=0.5 = Multiplicative LR reduction. lr ← lr · 0.5 = halve. Standard for plateau schedulers; aggressive (0.1) is too jumpy on this problem.
⬇ arg 4: patience=30 = Epochs of no-improvement before triggering a reduction. 30 is loose enough to ride out a noisy plateau without early panic.
⬇ arg 5: min_lr=5e-6 = Floor. After many reductions, lr stops at 5e-6 — keeps SGD non-zero so EMA shadow can keep tracking and parameters don&apos;t freeze.
→ trace example = Initial lr=1e-3. After 1st plateau: lr=5e-4 After 2nd plateau: lr=2.5e-4 ... (continues halving every 30 stagnant epochs) ... Floor: lr=5e-6
47)

Closing parenthesis of the ReduceLROnPlateau constructor. The scheduler instance is now bound to the local name `scheduler` and ready for per-epoch .step(rmse_last) calls.

49trainer = UnifiedTrainer(

Wire everything. UnifiedTrainer holds the model, the MTL loss, the per-sample RUL+health criteria, the optimiser, the scheduler, the device, and the config dict. fit() does the rest.

EXECUTION STATE
📚 UnifiedTrainer(model, mtl_loss, rul_criterion, health_criterion, optimizer, scheduler, device, config) = Source: grace/training/trainer.py:UnifiedTrainer (291 lines). Holds every component the loop needs and exposes a single .fit(train_loader, test_loader) entry point. Owns the per-batch / per-epoch cadence split.
→ rul_criterion = moderate_weighted_mse_loss = Paper&apos;s failure-biased MSE — passed by reference (it&apos;s a pure function, no state).
→ health_criterion = nn.CrossEntropyLoss() = 3-class CE on the health head — instantiated inline since it has no learnable parameters.
50model=model, mtl_loss=mtl_loss,

First two keyword arguments to UnifiedTrainer.__init__. The trainer stores references so .step() can call model(seq) and mtl_loss(L_rul, L_health, shared_params) inside the per-batch loop.

EXECUTION STATE
model=model = The DualTaskModel instance (~1.7M params). Trainer calls model.train() / model.eval() to flip dropout + BatchNorm modes.
mtl_loss=mtl_loss = The GABALoss instance. Trainer calls it once per batch with (L_rul, L_health, shared_params) → returns the scalar L_total ready for .backward().
51rul_criterion=moderate_weighted_mse_loss,

Per-sample RUL loss — passed as a callable, not an instance, because it is a pure function with no state.

EXECUTION STATE
📚 moderate_weighted_mse_loss(y_pred, y_true, max_rul=125) = Source: grace/core/weighted_mse.py. Computes (1/N)·Σ w(y)·(y_pred - y_true)² where w(y) = 1 + clip(1 - y/125, 0, 1). Returns a scalar tensor. Differentiable end-to-end.
52health_criterion=nn.CrossEntropyLoss(),

Per-sample health loss. Standard PyTorch CE module — instantiated here (parentheses) because it is an nn.Module, not a function.

EXECUTION STATE
📚 nn.CrossEntropyLoss(weight=None, reduction='mean') = PyTorch loss module: combines log_softmax + NLLLoss in one numerically-stable op. Forward signature: criterion(logits (B, C), targets (B,)) → scalar. C=3 classes here (healthy / degrading / critical).
→ why instantiate ()? = nn.CrossEntropyLoss is a class. The () creates an instance whose .__call__ implements the loss. Compare to moderate_weighted_mse_loss above — that is already a function, so no ().
53optimizer=optimizer, scheduler=scheduler, device=device,

Hand the trainer the already-configured optimiser, scheduler, and device. Trainer calls optimizer.step() per batch and scheduler.step(metric) per epoch.

EXECUTION STATE
optimizer=optimizer = AdamW instance from line 44. Trainer calls .zero_grad() before each backward and .step() after gradient clipping.
scheduler=scheduler = ReduceLROnPlateau instance from line 45. Called once per epoch with the validation rmse_last.
device=device = torch.device. Used to .to(device) the model and every batch.
54config={'epochs': 500, 'patience': 80, 'grad_clip': 1.0,

Training-loop hyperparameters bundled as a dict. Continued on lines 55-57. The trainer dereferences keys via config.get('key', default) for graceful fallback.

EXECUTION STATE
epochs=500 = Maximum budget. Early stopping usually fires earlier (epoch 130-230 in practice on FD002).
patience=80 = Early stopping patience. No-improvement epoch budget — if rmse_last has not improved for 80 epochs, stop.
grad_clip=1.0 = Threshold for torch.nn.utils.clip_grad_norm_(parameters, max_norm=1.0). Applied AFTER backward, BEFORE optimizer.step().
55'use_ema': True, 'ema_decay': 0.999, 'warmup_epochs': 10,

EMA shadow + 10-epoch linear LR warmup config.

EXECUTION STATE
use_ema=True = Enable the ExponentialMovingAverage callback. Maintains a shadow copy of every parameter; swap in for evaluation.
ema_decay=0.999 = Per-step decay: shadow ← 0.999·shadow + 0.001·θ. Effective ~1000-step memory window.
warmup_epochs=10 = Linear LR ramp budget — 1e-4 at epoch 0 climbs to 1e-3 at epoch 10, then ReduceLROnPlateau owns LR.
56'lr': 1e-3, 'adaptive_weight_decay': True,

Base learning rate + GRACE&apos;s adaptive weight-decay schedule (drops to 0.5× at epoch 100, 0.1× at epoch 200).

EXECUTION STATE
adaptive_weight_decay=True = GRACE-only addition over GABA. Trainer.fit() at trainer.py:131-134 reduces wd in three stages.
57'initial_weight_decay': 1e-4},

Starting weight-decay value before the schedule kicks in. The closing } here ends the config dict.

EXECUTION STATE
initial_weight_decay = 1e-4 = Baseline AdamW wd. Halved to 5e-5 at epoch 100; cut to 1e-5 at epoch 200. Late training benefits from less L2-style shrink as the loss landscape flattens.
58)

Closing parenthesis of the UnifiedTrainer(...) constructor call. The trainer instance is now bound to the local name `trainer` and ready to receive .fit().

60return trainer.fit(train_loader, test_loader)

Run the full 500-epoch loop. Returns a dict with per-epoch history (loss, rmse_last, NASA, HA, λ_rul, g_rul, ...) and the best-metrics tuple.

EXECUTION STATE
📚 trainer.fit(train_loader, test_loader) = trainer.py:106. For each epoch: warmup, weight-decay schedule, _train_epoch, EMA-eval, scheduler step, history append, early-stopping check. ~30 minutes on a single GPU for FD002.
⬇ arg 1: train_loader = DataLoader from line 28. Iterated each epoch — yields ~110 mini-batches of 256 samples on FD002.
⬇ arg 2: test_loader = DataLoader from line 29. Iterated end-of-epoch through the EMA-shadow weights, NOT shuffled.
⬆ return value (illustrative) = { 'history': {...}, 'best_metrics': {'rmse_last': 7.72, 'nasa_score': 223.4, 'health_accuracy': 97.22, ...}, 'best_epoch': 168, 'best_rmse': 7.72 }
23 lines without explanation
1"""End-to-end GRACE training with the paper&apos;s UnifiedTrainer.
2
3Boots a DualTaskModel + GABA + WMSE + AdamW + ReduceLROnPlateau + EMA
4+ early stopping + checkpointer in roughly 30 lines. The actual fit
5loop lives in grace/training/trainer.py:UnifiedTrainer.fit (291 lines)
6— shown here is the surface that drives it.
7"""
8
9import torch
10import torch.nn as nn
11import torch.optim as optim
12from torch.utils.data import DataLoader
13
14from grace.core.loss_registry      import get_loss
15from grace.core.weighted_mse       import moderate_weighted_mse_loss
16from grace.models.backbone         import UnifiedBackbone
17from grace.models.dual_task_model  import DualTaskModel
18from grace.models.model_configs    import get_model_config
19from grace.training.trainer        import UnifiedTrainer
20from grace.training.seed_utils     import set_seed
21
22
23def train_grace(train_ds, test_ds, dataset_name="FD002", seed=42):
24    """Single-seed GRACE training run. Returns best metrics."""
25    set_seed(seed)
26    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
28    train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
29    test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False)
30
31    mc       = get_model_config("cmapss")
32    backbone = UnifiedBackbone(
33        input_size=mc.input_size, hidden_size=mc.hidden_size,
34        cnn_channels=mc.cnn_channels, num_attn_heads=mc.num_attn_heads,
35        fc_dims=mc.fc_dims, dropout=mc.dropout,
36        use_attention=True, use_residual=True,
37    )
38    model    = DualTaskModel(backbone, num_health_states=3, dropout=0.3)
39
40    mtl_loss = get_loss("gaba", beta=0.99, warmup_steps=100,
41                        min_weight=0.05, n_tasks=2)
42
43    params    = list(model.parameters()) + list(mtl_loss.parameters())
44    optimizer = optim.AdamW(params, lr=1e-3, weight_decay=1e-4)
45    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
46        optimizer, mode="min", factor=0.5, patience=30, min_lr=5e-6,
47    )
48
49    trainer = UnifiedTrainer(
50        model=model, mtl_loss=mtl_loss,
51        rul_criterion=moderate_weighted_mse_loss,
52        health_criterion=nn.CrossEntropyLoss(),
53        optimizer=optimizer, scheduler=scheduler, device=device,
54        config={"epochs": 500, "patience": 80, "grad_clip": 1.0,
55                "use_ema": True, "ema_decay": 0.999, "warmup_epochs": 10,
56                "lr": 1e-3, "adaptive_weight_decay": True,
57                "initial_weight_decay": 1e-4},
58    )
59
60    return trainer.fit(train_loader, test_loader)
Order of construction matters. The MTL loss must be instantiated before the optimiser, because list(mtl_loss.parameters()) is concatenated into the optimiser's parameter list (line 42). Some MTL variants — Uncertainty, GradNorm — have learnable parameters of their own. Forgetting to include them in the optimiser silently disables learning of those scalars and produces a different algorithm than the paper's.

Production Hyperparameter Defaults

Every default below comes from grace/experiments/config.py:ExperimentConfig and is what produced the published 7.72 RMSE on FD002. Section 22·2 shows the search protocol that justifies each choice.

HyperparameterValueWhy this value
epochs (max budget)500Loose cap; early stopping fires at 130-230 in practice.
batch_size256GPU-fits; large enough to reduce per-batch λ noise.
lr (AdamW base)1e-3Standard transformer-era default; warmup ramps from 1e-4.
weight_decay (initial)1e-4Mild regularisation; halved at epoch 100, /10 at 200.
β₁, β₂ (AdamW)0.9, 0.999PyTorch defaults; not tuned per dataset.
grad_clip max_norm1.0Defends against BiLSTM gradient spikes.
EMA decay (ema_alpha)0.9991000-step memory; smooths late-training noise.
warmup_epochs10Linear ramp 1e-4 → 1e-3.
scheduler factor / patience0.5 / 30Halve LR after 30 stagnant epochs.
min_lr (scheduler floor)5e-6Keeps SGD non-zero in late training.
GABA β (EMA)0.99Outer-axis 100-step memory.
GABA min_weight (floor)0.055% per task — prevents collapse.
GABA warmup_steps100Uniform weights for first 100 mini-batches.
WMSE max_rul125Standard piecewise-linear RUL cap on C-MAPSS.
EarlyStopping patience80No-improvement budget.
seeds (reproducibility)[42, 123, 456, 789, 1024]5-seed averages; published numbers are means.

The Same Pipeline In Other Domains

The 14-stage skeleton is not specific to RUL prediction. Every domain that does multi-task supervised learning on a shared backbone uses some subset of it:

DomainWhat replaces WMSEWhat replaces GABA?What replaces NASA-score eval?
Self-driving perceptionFocal loss for rare classes (pedestrians, motorbikes)GradNorm or PCGrad — gradient surgery for conflicting tasksmAP at IoU thresholds + tail-class recall
Speech recognitionCTC loss + attention CEUncertainty (Kendall) — homoscedastic σ scalingWord error rate on rare-vocabulary subsets
Medical segmentation + classificationDice + focal CEDWA (Liu et al.) — loss-ratio trackingPer-anatomy Dice + sensitivity at boundary regions
Recommender systemsListwise rank loss + dwell-time MSECustom controller — task-priority scheduleNDCG@10 + revenue lift on cold-start cohorts
Robotics policy learningSmoothL1 on torques + binary CE on successPCGrad — orthogonal gradient projectionSuccess rate + safety violation rate

The skeleton stays. The choices for inner loss, outer controller, and evaluator change with the domain's loss geometry and risk profile. Building a new MTL system rarely needs new pipeline code — it needs new content in three slots.

Pitfalls When Wiring The Pipeline

Pitfall 1: scheduler.step() per batch instead of per epoch

ReduceLROnPlateau decays the LR every time it sees a metric. Calling scheduler.step(loss.item()) inside the per-batch loop — a common transcription error from OneCycleLR-style schedulers — multiplies the rate of decay by the number of batches per epoch (~70 on FD001). The LR drops below 1e-6 within the first <5 epochs and the model never leaves its initial neighbourhood. Symptom: training loss stuck at first-epoch value.

Pitfall 2: forgetting to include mtl_loss.parameters() in the optimiser

GABA itself has no learnable parameters (its lambdas are derived). Uncertainty, GradNorm, and AMNL-v7 do. If you write optim.AdamW(model.parameters(), ...) without appending mtl_loss.parameters(), the learnable-σ parameters never receive a gradient step. The algorithm degenerates to fixed-σ scaling and the published numbers cannot be reproduced.

Pitfall 3: evaluating without swapping in the EMA shadow

End-of-epoch evaluation should run on shadow weights, not live weights. Skipping apply_shadow() produces noisier rmse_last; downstream that fools both ReduceLROnPlateau (false plateau triggers) and EarlyStopping (false-positive stops). Symptom: rmse_last fluctuates ±2 cycles per epoch instead of monotonically descending.

Pitfall 4: in-place model modification during evaluation

apply_shadow() overwrites param.data in place. If you forget the matching restore() call, the next training batch sees shadow weights as live weights — the optimiser then computes gradients with respect to a smoothed copy of itself. Slow divergence; confusing loss curves. Always pair the two.

Pitfall 5: shuffle=True on the test loader

The evaluator computes last-cycle-per-unit NASA scores by selecting the final window of each engine unit. That selection assumes the loader yields windows in original temporal order. Setting shuffle=True on the test DataLoader scrambles the order; ‘last’ becomes random; the NASA score becomes meaningless. Always shuffle=False for the test loader.

Takeaway

  • GRACE training is a 14-stage assembly line driven by UnifiedTrainer.fit(): 8 stages per batch, 6 stages per epoch.
  • The OUTER controller (GABA) is one of those 14 stages — a single nn.Module call. Everything else is standard PyTorch: AdamW, ReduceLROnPlateau, EMA, gradient clipping, early stopping, checkpointer, gradient logger.
  • The per-batch / per-epoch split is non-negotiable. Confusing the cadences (most commonly scheduler.step() in the wrong place) corrupts the learning rate trajectory and the published numbers cannot be reproduced.
  • Production defaults: AdamW lr=1e-3, weight_decay=1e-4 (decayed to 1e-5 by epoch 200), grad_clip=1.0, EMA α=0.999, warmup=10 epochs, scheduler patience=30, early stopping=80, seeds = [42, 123, 456, 789, 1024].
  • The skeleton is domain-agnostic. Replace WMSE with focal / Dice / CTC, replace GABA with GradNorm / Uncertainty / DWA, replace NASA scoring with mAP / WER / Dice — and you have a working multi-task trainer for self-driving perception, speech, medical imaging, recommenders, or robot policies.
Loading comments...