Chapter 22
25 min read
Section 90 of 121

Full Training Script Walkthrough

GRACE Training Pipeline

A Director's Master Shot List

Walk onto a film set during principal photography and you find the director holding a master shot list — one page per scene, every camera angle numbered, every actor blocked, every light cued. The list is the contract between intention and execution. Production runs all the way through; nothing is improvised.

The paper's training script is the same artefact for the published 7.72 RMSE on FD002. Every line is a numbered shot. To reproduce the result you do not interpret the script — you execute it. This section walks the entire production entry point run_single_experiment() from grace/experiments/phase1_cmapss.py, plus the UnifiedTrainer.fit() body and the per-batch _train_epoch() inner loop, then traces what actually happened across 190 epochs of one published GABA run.

The headline. Production training is ~150 lines of Python: 60 lines to set up dataset / model / loss / optimiser, 90 lines to drive the 500-epoch loop, save checkpoints, and persist results. Every component composes chapters 5–21; nothing new is invented in chapter 22.

The Three Phases Of A Production Run

PhaseFunctionWhat it ownsTime budget
1. Setuprun_single_experiment() — 60 linesSeed pin, datasets, MTL wrapper, DataLoaders, backbone, dual-task model, MTL loss, criterion, AdamW, ReduceLROnPlateau, output dir.~10 seconds
2. Main loopUnifiedTrainer.fit() — 130 linesEpoch counter, warmup, adaptive WD schedule, _train_epoch call, EMA-aware eval, scheduler step, best tracking, early stopping, history accumulation.~20-30 minutes
3. Persistphase1_cmapss tail — 15 linesBuild serialisable dict, json.dump to results.json, log to stderr.~1 second

Time budget on a single A100: ~30 minutes per seed × 5 seeds × 4 datasets × 7 methods = 70 hours for the full Phase 1 sweep. With a 4-GPU host running seeds in parallel, the whole sweep finishes in ~18 hours.

Interactive: A Real Run From The Paper Repo

Below is the actual per-epoch trajectory of GABA on FD002 with seed 42 — 190 epochs of real data from grace/outputs/phase1/FD002/gaba/seed_42/logs/gradient_stats.csv. Hover the chart to read λrul\lambda^*_{\text{rul}} and the gradient ratio at any epoch. Click the markers to see what the trainer did at that boundary.

Loading training timeline visualizer…
What the trajectory tells you. Epoch 0: λrul=0.164\lambda^*_{\text{rul}} = 0.164 because the GABA controller is in its 100-step warmup phase — uniform-ish weights. By epoch 5 (post-warmup), the closed form kicks in and λrul0.0006\lambda^*_{\text{rul}} \to 0.0006 because the gradient ratio measures ~2400×. The floor-renormalised steady state hovers around 0.0030.003 for the rest of training. Best epoch lands at 109; early stopping fires at 189 after 80 stagnant epochs.

Phase 1 — Setup: run_single_experiment()

The PyTorch walkthrough below is the verbatim production function. The setup is dataset-aware (per-condition normalisation auto-enabled for FD002/FD004) and method-aware (weighted MSE on for GRACE, off for GABA). The output directory is constructed deterministically so the artefact path encodes all the run's identity: output_dir/phase1/FD002/gaba/seed_42.

Two non-obvious responsibilities of the setup:

  • Test scaler reuse. scaler_params=train_ds.get_scaler_params() on the test dataset prevents data leakage. The model sees the test set normalised by the TRAIN-FIT scaler, not the test scaler.
  • MTL parameter union. params = list(model.parameters()) + list(mtl_loss.parameters()) ensures any LEARNABLE parameters in the MTL loss (Uncertainty σ's, GradNorm weights) are included in the optimiser. GABA has none; the union pattern is uniform across all 9 method variants in loss_registry.py.

Phase 2 — Main Loop: UnifiedTrainer.fit()

fit() is 130 lines, but the structural pattern is a single nested loop:

Lines (in trainer.py)What runsCadence
126for epoch in range(epochs):Outer loop (≤500 iterations)
128self.warmup.apply(self.optimizer, epoch)Per epoch — sets LR via warmup ramp
131-134if adaptive_weight_decay and epoch > 100: rescale wdPer epoch — GRACE schedule
137train_stats = self._train_epoch(train_loader, epoch)Per epoch — drives all 8 per-batch stages
140-144ema.apply_shadow(model); evaluate; ema.restore(model)Per epoch — EMA-aware eval
161-165if not warmup: scheduler.step(rmse_last)Per epoch — ReduceLROnPlateau
168-176if rmse_last < best: snapshot weights + checkpointer.savePer epoch — best tracking
210if early_stopping(rmse_last, model): breakPer epoch — patience check
179-193history[*].append(...)Per epoch — log accumulation

The order matters. Warmup sets the LR BEFORE training so the very first batch already runs at the correct rate. Adaptive WD runs AFTER warmup so the schedule's epoch-100 trigger applies to the post-warmup state. Eval runs AFTER the inner training loop with EMA shadow weights. Scheduler step runs only after warmup (otherwise the scheduler would fight the linear ramp). Early stopping runs LAST so the current epoch's improvement is counted before the patience check.

Phase 3 — Per-Epoch: _train_epoch()

Inside _train_epoch (trainer.py:243-290), the per-batch loop is the eight stages walked through in section 22.1. The shape of one iteration:

  • for batch in loader: — pull (seq, rul, health, uid).
  • self.optimizer.zero_grad() — reset accumulated gradients.
  • rul_pred, health_logits = self.model(seq) — forward pass.
  • rul_loss = self.rul_criterion(rul_pred, rul_tgt), health_loss = self.health_criterion(health_logits, health_tgt) — inner WMSE + CE.
  • shared_params = self.model.get_shared_params() — backbone-only param list.
  • loss = self.mtl_loss(rul_loss, health_loss, shared_params=shared_params) — OUTER GABA controller. Returns the combined loss.
  • loss.backward() — autograd pass.
  • torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) — gradient clip.
  • self.optimizer.step() — AdamW update.
  • if self.ema: self.ema.update(self.model) — EMA shadow update.

Eight stages × ~110 batches per epoch on FD002 = 880 ops per epoch. At 500 epochs cap that is 440,000 forward+backward iterations — the budget that produces the published 7.72 RMSE.

Python: A Minimal End-To-End Mini-fit()

Every callback the production trainer wires together, by hand, in 120 lines of NumPy. Linear regression on synthetic data — the model is trivial; the scaffolding is what matters. Replace the closed-form gradient on line 79 with loss.backward() and you have the production loop.

Production training loop, minimal NumPy version
🐍grace_minimal_fit.py
1docstring

States the contract: a fully self-contained mini training loop that exercises every callback in the paper&apos;s UnifiedTrainer.fit() — warmup, AdamW, gradient clip, EMA, ReduceLROnPlateau, early stopping. Replace the linear closed-form gradient with PyTorch autograd and you have the production loop.

11import numpy as np

Numerical-array library. Single dependency for the whole demo.

13rng = np.random.default_rng(42)

Modern NumPy RNG. seed=42 makes the synthetic data + init reproducible.

17N_TRAIN, N_TEST = 1024, 256

Sample counts. Ratio 4:1 — enough for a stable RMSE and a bit of validation headroom.

18x_train = rng.normal(0, 1, (N_TRAIN, 3))

Three features, drawn from N(0, 1).

EXECUTION STATE
📚 rng.normal(loc, scale, size) = Gaussian samples. shape (N, 3) means N rows × 3 features.
19y_train = ( 2.0 * x_train[:, 0] + 0.5 * x_train[:, 1] - x_train[:, 2] + rng.normal(0, 0.3, N_TRAIN) )

Linear regression target with σ=0.3 noise. The TRUE coefficients are [2.0, 0.5, -1.0] — what the model should recover.

EXECUTION STATE
📚 x_train[:, k] = Slice all rows, column k. Selects the k-th feature for every sample.
true coefs (target) = [2.0, 0.5, -1.0]
23x_test = rng.normal(0, 1, (N_TEST, 3))

Test features.

24y_test = 2.0 * x_test[:, 0] + 0.5 * x_test[:, 1] - x_test[:, 2]

Test targets — NOISE-FREE so the recovered RMSE measures model error, not data noise.

28theta = rng.normal(0, 0.1, size=3)

Tiny-magnitude initialisation. Standard for linear models — close to zero so AdamW&apos;s weight decay has no immediate effect.

EXECUTION STATE
theta (3,) initial = Around [0, 0, 0] but not exactly zero. Tiny random nudges break symmetry.
29m, v = np.zeros(3), np.zeros(3)

AdamW first/second moment accumulators. Zero-init; bias-corrected per step.

30shadow = theta.copy()

EMA shadow weights. Initialised to the live theta. Updated every step with α=0.999.

EXECUTION STATE
📚 .copy() = Independent ndarray with the same values. Required so shadow updates don&apos;t alias theta.
34lr0, beta1, beta2, eps = 1e-2, 0.9, 0.999, 1e-8

AdamW core hyperparameters. lr0 = base learning rate after warmup. β=(0.9, 0.999) is the literature default.

35weight_decay = 1e-4

AdamW decoupled weight-decay coefficient. Same as the paper&apos;s production default.

36ema_alpha = 0.999

EMA shadow decay. 1000-step memory.

37warmup_epochs = 5

Linear LR ramp budget. Smaller than the paper&apos;s 10 because this synthetic toy converges quickly.

38patience_es = 20

Early-stopping patience. Smaller than the paper&apos;s 80 because the toy stabilises faster.

39batch_size = 64

Mini-batch size. 1024/64 = 16 batches per epoch.

40n_epochs = 100

Maximum epoch budget.

43def lr_at(epoch):

Linear warmup schedule. Mirrors callbacks.py:LRWarmup.get_lr.

44"""Linear LR warmup."""

Docstring.

45if epoch < warmup_epochs:

Branch on warmup phase.

46return lr0 * (0.1 + 0.9 * epoch / warmup_epochs)

0.1·lr0 at epoch 0 → lr0 at epoch warmup_epochs.

EXECUTION STATE
epoch=0 → lr = 1e-2 * (0.1 + 0) = 1e-3
epoch=5 → lr = 1e-2 * (0.1 + 0.9) = 1e-2 (full rate)
46return lr0

Post-warmup: scheduler takes over.

50def predict(theta, x):

Linear forward pass: y_pred = x @ theta. The simplest possible model.

EXECUTION STATE
⬇ input: theta (3,) = Current parameters.
⬇ input: x (B, 3) = Batch features.
⬆ returns = (B,) predictions.
52return x @ theta

Matrix-vector product: x (B,3) @ theta (3,) → y (B,). Each row dot-products with the parameter vector.

EXECUTION STATE
📚 @ operator = Python matrix multiply. NumPy: x @ theta = np.dot(x, theta) for these shapes.
55def rmse(theta, x, y):

RMSE on a held-out set. Used for both validation tracking and final reporting.

56return float(np.sqrt(((predict(theta, x) - y) ** 2).mean()))

Three-step formula. Predict → squared residual → mean → sqrt. Cast to plain float.

60best_rmse = float('inf')

Best-so-far RMSE. Initialised to infinity so any real value wins.

61best_theta = theta.copy()

Snapshot of the best parameters. Updated whenever val RMSE improves.

62no_improve = 0

Counter for ReduceLROnPlateau and EarlyStopping. Reset on improvement.

63scheduler_lr = lr0

Learning rate the scheduler maintains AFTER warmup. Halved on plateaus.

64step = 0

Global step counter. Used for AdamW bias correction.

65history = []

Per-epoch log. Each tuple = (epoch, lr, val_rmse). Used for plotting and final reporting.

69for epoch in range(n_epochs):

Outer loop over epochs. Mirrors UnifiedTrainer.fit&apos;s for-loop in trainer.py:126.

70lr = lr_at(epoch) if epoch < warmup_epochs else scheduler_lr

Choose the learning rate. Warmup phase uses the linear ramp; after that, the scheduler value applies.

EXECUTION STATE
📚 conditional expression = Python ternary: <expr_if_true> if <cond> else <expr_if_false>. Compact form of an if/else assignment.
73idx = rng.permutation(N_TRAIN)

Shuffle order for THIS epoch. Equivalent to DataLoader(shuffle=True). New permutation per epoch.

EXECUTION STATE
📚 rng.permutation(n) = Returns a random permutation of integers 0..n-1.
74for start in range(0, N_TRAIN, batch_size):

Inner loop over mini-batches. 16 iterations at batch_size=64.

75b = idx[start:start + batch_size]

Slice of the shuffled indices for this batch.

76xb, yb = x_train[b], y_train[b]

Fancy indexing — pulls rows of x_train at positions b. yb is the matching subset of targets.

79residual = predict(theta, xb) - yb

Per-sample prediction error. Shape (B,).

80g = (2.0 / len(b)) * (xb.T @ residual)

Closed-form MSE gradient: dL/dθ = (2/B)·X^T·(Xθ - y). For PyTorch this would be loss.backward(); here we have the closed form because the model is linear.

EXECUTION STATE
📚 .T = Transpose. xb (B, 3) → xb.T (3, B).
📚 @ result = (3, B) @ (B,) → (3,) — gradient vector for theta.
83gn = np.linalg.norm(g)

Global gradient L2 norm. The clip threshold is applied to this scalar.

EXECUTION STATE
📚 np.linalg.norm(v) = Vector L2 norm: sqrt(sum(v_i^2)).
84if gn > 1.0:

Trigger clip if the gradient is larger than the threshold.

85g = g / gn

Rescale to unit norm. Direction preserved; magnitude exactly 1.0.

88step += 1

Advance global step counter for bias correction.

89m = beta1 * m + (1 - beta1) * g

AdamW first moment EMA.

90v = beta2 * v + (1 - beta2) * g ** 2

AdamW second moment EMA.

91m_hat = m / (1 - beta1 ** step)

Bias-correct the first moment.

92v_hat = v / (1 - beta2 ** step)

Bias-correct the second moment.

93theta -= lr * (m_hat / (np.sqrt(v_hat) + eps) + weight_decay * theta)

AdamW update. Two terms: variance-normalised step + decoupled weight-decay shrink.

96shadow = ema_alpha * shadow + (1 - ema_alpha) * theta

EMA shadow update. After many steps, shadow tracks theta with a ~1000-step lag, smoothing late-training noise.

99val_rmse = rmse(shadow, x_test, y_test)

End-of-epoch evaluation. Critical: use the SHADOW, not theta. This mirrors UnifiedTrainer.fit lines 140-144 (apply_shadow → eval → restore).

100history.append((epoch, lr, val_rmse))

Record this epoch&apos;s metrics for the final report.

103if val_rmse < best_rmse - 1e-4:

Improvement check with min_delta=1e-4. Mirrors callbacks.py:EarlyStopping line 38. The 1e-4 stops floating-point noise from looking like progress.

104best_rmse = val_rmse

Update the best-so-far RMSE.

61best_theta = theta.copy()

Snapshot the live theta as the new best. .copy() ensures it doesn&apos;t alias future updates.

62no_improve = 0

Reset patience counter.

107else:

No improvement.

108no_improve += 1

Increment patience counter.

111if no_improve > 0 and no_improve % 5 == 0:

ReduceLROnPlateau-style trigger. Every 5 stagnant epochs, halve the LR. The paper uses scheduler_patience=30 in production; 5 is enough for the toy.

112scheduler_lr = max(scheduler_lr * 0.5, 1e-6)

Halve the LR, floor at 1e-6.

114if no_improve >= patience_es:

Early-stopping trigger. After 20 stagnant epochs, break out of the loop.

114print(f"early stop at epoch {epoch}")

Notify and break.

116break

Exit the outer epoch loop.

117print(f"best epoch: {history[-1][0] - no_improve} | rmse {best_rmse:.4f}")

Final report. The best epoch is `current_epoch - no_improve` because that is the last epoch the best_rmse was updated.

118print(f"final theta: {best_theta.round(3).tolist()}")

Recovered coefficients. Should be close to the target [2.0, 0.5, -1.0].

EXECUTION STATE
Output (illustrative) =
best epoch: 32 | rmse 0.0623
final theta:  [1.999, 0.501, -0.999]
target theta: [2.0, 0.5, -1.0]
119print(f"target theta: [2.0, 0.5, -1.0]")

The truth, for visual comparison. AdamW recovers the coefficients to ~0.001 accuracy on the noise-free test.

53 lines without explanation
1"""End-to-end mini training loop. NumPy + closed-form gradient.
2
3Stitches every UnifiedTrainer.fit() callback into one self-contained
4script: data, model, GABA controller, AdamW, gradient clip, EMA,
5LR warmup, ReduceLROnPlateau, early stopping, best checkpoint.
6
7Synthetic regression task — just enough signal that GABA + AdamW
8beat plain SGD. The point is the SCAFFOLDING, not the model.
9"""
10
11import numpy as np
12
13rng = np.random.default_rng(42)
14
15
16# ---------- Synthetic data ----------
17N_TRAIN, N_TEST = 1024, 256
18x_train = rng.normal(0, 1, (N_TRAIN, 3))
19y_train = (
20    2.0 * x_train[:, 0] + 0.5 * x_train[:, 1] - x_train[:, 2]
21    + rng.normal(0, 0.3, N_TRAIN)
22)
23x_test = rng.normal(0, 1, (N_TEST, 3))
24y_test = 2.0 * x_test[:, 0] + 0.5 * x_test[:, 1] - x_test[:, 2]
25
26
27# ---------- Model: 3-feature linear ----------
28theta = rng.normal(0, 0.1, size=3)
29m, v = np.zeros(3), np.zeros(3)
30shadow = theta.copy()
31
32
33# ---------- Hyperparameters ----------
34lr0, beta1, beta2, eps = 1e-2, 0.9, 0.999, 1e-8
35weight_decay  = 1e-4
36ema_alpha     = 0.999
37warmup_epochs = 5
38patience_es   = 20
39batch_size    = 64
40n_epochs      = 100
41
42
43def lr_at(epoch):
44    """Linear LR warmup."""
45    if epoch < warmup_epochs:
46        return lr0 * (0.1 + 0.9 * epoch / warmup_epochs)
47    return lr0
48
49
50def predict(theta, x):
51    """Linear forward pass."""
52    return x @ theta
53
54
55def rmse(theta, x, y):
56    return float(np.sqrt(((predict(theta, x) - y) ** 2).mean()))
57
58
59# ---------- Training state ----------
60best_rmse = float("inf")
61best_theta = theta.copy()
62no_improve = 0
63scheduler_lr = lr0
64step = 0
65history = []
66
67
68# ---------- Training loop ----------
69for epoch in range(n_epochs):
70    lr = lr_at(epoch) if epoch < warmup_epochs else scheduler_lr
71
72    # Shuffle and iterate mini-batches
73    idx = rng.permutation(N_TRAIN)
74    for start in range(0, N_TRAIN, batch_size):
75        b = idx[start:start + batch_size]
76        xb, yb = x_train[b], y_train[b]
77
78        # Forward + backward (closed-form gradient for linear MSE)
79        residual = predict(theta, xb) - yb           # (B,)
80        g = (2.0 / len(b)) * (xb.T @ residual)       # (3,)
81
82        # Gradient clip (global norm 1.0)
83        gn = np.linalg.norm(g)
84        if gn > 1.0:
85            g = g / gn
86
87        # AdamW update
88        step += 1
89        m = beta1 * m + (1 - beta1) * g
90        v = beta2 * v + (1 - beta2) * g ** 2
91        m_hat = m / (1 - beta1 ** step)
92        v_hat = v / (1 - beta2 ** step)
93        theta -= lr * (m_hat / (np.sqrt(v_hat) + eps) + weight_decay * theta)
94
95        # EMA shadow
96        shadow = ema_alpha * shadow + (1 - ema_alpha) * theta
97
98    # Per-epoch eval (on EMA shadow)
99    val_rmse = rmse(shadow, x_test, y_test)
100    history.append((epoch, lr, val_rmse))
101
102    # Track best + early stopping
103    if val_rmse < best_rmse - 1e-4:
104        best_rmse = val_rmse
105        best_theta = theta.copy()
106        no_improve = 0
107    else:
108        no_improve += 1
109
110    # ReduceLROnPlateau (patience 5 here for the toy)
111    if no_improve > 0 and no_improve % 5 == 0:
112        scheduler_lr = max(scheduler_lr * 0.5, 1e-6)
113
114    if no_improve >= patience_es:
115        print(f"early stop at epoch {epoch}")
116        break
117
118print(f"best epoch: {history[-1][0] - no_improve} | rmse {best_rmse:.4f}")
119print(f"final theta:  {best_theta.round(3).tolist()}")
120print(f"target theta: [2.0, 0.5, -1.0]")
What the toy proves. AdamW + warmup + EMA + clip + ReduceLROnPlateau + EarlyStopping recover the true coefficients [2.0, 0.5, 1.0][2.0,\ 0.5,\ -1.0] on the synthetic regression to ~0.001 accuracy in 30–50 epochs. The same scaffolding applied to a 1.7M-parameter CNN-BiLSTM-Attention backbone produces the published GRACE result in 100–200 epochs.

PyTorch: The Paper's run_single_experiment, Verbatim

The actual production entry point, line by line. Click any line for the per-line trace. Every constructor call, every keyword argument, every config-dict key has a card that explains what it does and why this value was chosen. Every published row of the paper's 140-experiment table came out of this function.

grace/experiments/phase1_cmapss.py:run_single_experiment
🐍phase1_cmapss.py
1docstring

Names the contract: this is the paper&apos;s production single-seed entry point. Every published row of Table I (FD001-FD004 × 7 methods × 5 seeds = 140 runs) was produced by calling this function.

8import json

Standard library JSON. Used to persist the per-seed results.json artefact.

9from pathlib import Path

Modern path manipulation. Path objects support / for joining and .mkdir(parents=True). Replaces the older os.path API.

11import torch

Core PyTorch.

12import torch.nn as nn

Module/parameter machinery.

13import torch.optim as optim

AdamW + ReduceLROnPlateau scheduler.

14from torch.utils.data import DataLoader

Mini-batch iterator over the MTL-wrapped CMAPSSDataset.

16from ..experiments.config import ExperimentConfig

Centralised configuration dataclass. All hyperparameters live here — see chapter 22 §2 for the full schema.

17from ..core.loss_registry import get_loss

Factory returning any of the 9 MTL loss variants by name.

18from ..core.weighted_mse import moderate_weighted_mse_loss

Failure-biased MSE — the INNER axis from chapter 21.

19from ..models.backbone import UnifiedBackbone

CNN-BiLSTM-Attention shared trunk. ~1.7M parameters at the C-MAPSS config.

20from ..models.dual_task_model import DualTaskModel

Wraps the backbone with rul_head (Linear → 1) and health_head (Linear → 3).

21from ..models.model_configs import get_model_config

Named architecture configs: 'cmapss', 'ncmapss_20feat', 'ncmapss_32feat'.

22from ..data.cmapss_dataset import CMAPSSDataset

Sliding-window dataset over the C-MAPSS sensor matrices.

23from ..data.health_labels import rul_to_health_3class

Maps RUL → {0=healthy, 1=degrading, 2=critical} based on the 30/70 thresholds.

24from ..data.mtl_wrapper import MTLDatasetWrapper

Wraps a single-target dataset to add a per-sample health label, producing the (seq, rul, health, uid) 4-tuple.

25from ..training.trainer import UnifiedTrainer

The 14-stage assembly line walked through in section 22.1.

26from ..training.seed_utils import set_seed

Reproducibility pin from section 22.3 — five PRNG streams + cuDNN flags.

29def run_single_experiment(cfg: ExperimentConfig, seed: int):

Single-seed experiment. Inputs: a config and a seed. Output: serialisable dict with best_epoch, best_rmse, and the full metrics block.

EXECUTION STATE
⬇ input: cfg = Production-default ExperimentConfig with the dataset, mtl_method, and all training hyperparameters set.
⬇ input: seed = One of [42, 123, 456, 789, 1024] for the published runs.
⬆ returns = Dict with 'dataset', 'method', 'seed', 'best_epoch', 'best_rmse', 'metrics'.
31set_seed(seed)

Pin everything. Five PRNG streams + cuDNN deterministic. Section 22.3 details.

32device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Auto-select GPU when available; fall back to CPU.

35train_ds = CMAPSSDataset(

Construct the training dataset.

36dataset_name=cfg.dataset, data_dir=cfg.data_dir,

Which subset (FD001..FD004) and where the CSVs live.

37sequence_length=cfg.sequence_length, train=True,

30-cycle sliding window. train=True picks the train CSV split.

38random_seed=seed, per_condition_norm=cfg.per_condition_norm,

per_condition_norm=True for FD002/FD004 (multi-condition); the dataset class normalises sensors WITHIN each operating regime.

40test_ds = CMAPSSDataset(

Test dataset. train=False.

36dataset_name=cfg.dataset, data_dir=cfg.data_dir,

Same args as train_ds.

42sequence_length=cfg.sequence_length, train=False,

Same window length, test split.

43scaler_params=train_ds.get_scaler_params(),

CRITICAL: the test set must be normalised with the TRAIN scaler. Re-fitting on the test set is data leakage.

EXECUTION STATE
→ leakage risk = If the test set fits its own scaler, the model gets implicit access to test-set statistics — typical RUL inflation of 2-3 cycles.
38random_seed=seed, per_condition_norm=cfg.per_condition_norm,

Match train settings.

46train_mtl = MTLDatasetWrapper(train_ds, rul_to_health_3class)

Add the health label per sample. Wrapper.__getitem__ returns (seq, rul, health, uid) where health = rul_to_health_3class(rul).

47test_mtl = MTLDatasetWrapper(test_ds, rul_to_health_3class)

Same wrapper for test.

48train_loader = DataLoader(train_mtl, batch_size=cfg.batch_size, shuffle=True)

Batch_size=256 (default). Shuffle for SGD convergence.

49test_loader = DataLoader(test_mtl, batch_size=cfg.batch_size, shuffle=False)

shuffle=False is critical: the evaluator&apos;s last-cycle NASA scoring assumes the loader yields windows in original temporal order.

52mc = get_model_config(cfg.model_config)

Fetch the named architecture config. cfg.model_config defaults to 'cmapss' (input_size=17, hidden_size=256, 8 heads, fc_dims=(256, 64, 32), dropout=0.3).

53backbone = UnifiedBackbone(

Build the shared CNN-BiLSTM-Attention trunk. Constructor args populated from mc.

54input_size=mc.input_size, hidden_size=mc.hidden_size,

Sensor count + BiLSTM hidden width. (17, 256) for C-MAPSS.

55cnn_channels=mc.cnn_channels, num_attn_heads=mc.num_attn_heads,

(64, 128, 64) Conv1d channels and 8 attention heads.

56fc_dims=mc.fc_dims, dropout=mc.dropout,

(256, 64, 32) FC stack and 0.3 dropout (uniform across all datasets).

57use_attention=mc.use_attention, use_residual=mc.use_residual,

Both True for the production architecture. Disabling either is an ablation (chapter 27).

59model = DualTaskModel(backbone, num_health_states=mc.num_health_states,

Wrap the backbone with two heads.

EXECUTION STATE
→ forward signature = model(seq) → (rul_pred (B,), hp_logits (B, 3))
→ get_shared_params() = Returns the backbone-only param list — what GABA differentiates against.
60dropout=mc.dropout)

Per-head dropout for regularisation.

63mtl_loss = get_loss(cfg.mtl_method)

Factory returns the configured MTL loss. For GRACE, this is GABALoss(beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2).

64rul_criterion = (moderate_weighted_mse_loss if cfg.use_weighted_mse

Pick the per-sample RUL loss. GRACE uses weighted MSE; GABA uses plain MSE.

65else nn.MSELoss())

Fallback: standard PyTorch MSE.

68params = list(model.parameters()) + list(mtl_loss.parameters())

Concatenate model parameters with any LEARNABLE params in the MTL loss. GABA has none; Uncertainty has σ_rul, σ_health; GradNorm has the loss weights themselves.

69optimizer = optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)

AdamW with paper defaults. lr=1e-3, weight_decay=1e-4. The trainer schedules wd down at epochs 100/200 if cfg.adaptive_weight_decay=True.

70scheduler = optim.lr_scheduler.ReduceLROnPlateau(

Halve LR when rmse_last hasn&apos;t improved.

71optimizer, mode="min", factor=cfg.scheduler_factor,

mode='min' = smaller is better. factor=0.5 = halve.

72patience=cfg.scheduler_patience, min_lr=cfg.min_lr,

patience=30 epochs. Floor at 5e-6.

76run_dir = Path(cfg.output_dir) / "phase1" / cfg.dataset / cfg.mtl_method / f"seed_{seed}"

Per-run output directory. Path operator overload joins; f-string interpolates the seed.

EXECUTION STATE
Example = grace/outputs/phase1/FD002/gaba/seed_42
77run_dir.mkdir(parents=True, exist_ok=True)

Create the directory; create parents if needed; don&apos;t error if it already exists.

EXECUTION STATE
📚 .mkdir(parents=True, exist_ok=True) = Pathlib method. parents=True is like `mkdir -p`; exist_ok=True suppresses FileExistsError. Idempotent.
80trainer = UnifiedTrainer(

Construct the orchestrator. Section 22.1 walks the 14-stage pipeline.

81model=model, mtl_loss=mtl_loss,

Pass the model and the MTL loss controller.

82rul_criterion=rul_criterion,

Per-sample RUL loss (weighted MSE for GRACE).

83health_criterion=nn.CrossEntropyLoss(),

Standard 3-class CE on the health head.

84optimizer=optimizer, scheduler=scheduler, device=device,

Already-configured optimiser, scheduler, and device.

85config={

Training-loop configuration dict.

86"epochs": cfg.epochs, "patience": cfg.patience,

epochs=500, early-stopping patience=80.

87"grad_clip": cfg.grad_clip, "use_ema": cfg.use_ema,

grad_clip=1.0 global L2-norm threshold. use_ema=True enables shadow-weight evaluation.

88"ema_decay": cfg.ema_decay, "warmup_epochs": cfg.warmup_epochs,

ema_decay=0.999, warmup_epochs=10.

89"lr": cfg.lr,

Base lr (1e-3) — used by LRWarmup to compute the linear ramp.

90"checkpoint_dir": str(run_dir / "checkpoints"),

Where to write best-RMSE checkpoints. Per-run, per-seed.

91"log_dir": str(run_dir / "logs"),

Where to dump the per-epoch gradient_stats.csv that the timeline viz reads from.

92"adaptive_weight_decay": cfg.adaptive_weight_decay,

GRACE-only flag. Halve wd at epoch 100, /10 at 200.

93"initial_weight_decay": cfg.initial_weight_decay,

Starting wd value before the schedule kicks in. 1e-4 = AdamW default.

98results = trainer.fit(train_loader, test_loader)

Run the full 500-epoch (max) loop. Returns {history, best_metrics, best_epoch, best_rmse}. Typically 20-30 minutes on a single A100.

100serialisable = {

Build the JSON-friendly result.

101"dataset": cfg.dataset, "method": cfg.mtl_method, "seed": seed,

Provenance: which row of the published table this artifact is.

102"best_epoch": results["best_epoch"],

Which epoch the rmse_last minimum was reached. The published 109 for FD002/GABA/seed=42.

103"best_rmse": results["best_rmse"],

rmse_last at best_epoch. The 7.42 in chapter 21 §1&apos;s GABA cell on FD002 is exactly this.

104"metrics": {k: v for k, v in results["best_metrics"].items()},

Full metrics dict copy. dict comprehension keeps the original `results` untouched.

106with open(run_dir / "results.json", "w") as f:

Open the per-seed results file for writing. Path / 'results.json' joins; 'w' truncates if the file exists.

EXECUTION STATE
📚 with open(path, mode) as f: = Context manager. f.close() is called automatically on block exit, even if an exception fires.
107json.dump(serialisable, f, indent=2, default=str)

Pretty-print JSON. indent=2 for human readability. default=str converts any non-JSON-native objects to strings.

108return serialisable

Return the dict for the caller (run_phase1 in the same file aggregates 5 seeds).

EXECUTION STATE
⬆ return (illustrative) =
{ "dataset": "FD002", "method": "gaba", "seed": 42, "best_epoch": 109, "best_rmse": 7.42, "metrics": {"rmse_last": 7.42, "nasa_score": 203.96, "health_accuracy": 97.43, ...} }
33 lines without explanation
1"""Paper&apos;s production training entry point — verbatim.
2
3Source: paper_ieee_tii/grace/experiments/phase1_cmapss.py:37-156.
4Single-seed runner that produces one row of the published 140-experiment
5table.
6"""
7
8import json
9from pathlib import Path
10
11import torch
12import torch.nn as nn
13import torch.optim as optim
14from torch.utils.data import DataLoader
15
16from ..experiments.config import ExperimentConfig
17from ..core.loss_registry         import get_loss
18from ..core.weighted_mse          import moderate_weighted_mse_loss
19from ..models.backbone            import UnifiedBackbone
20from ..models.dual_task_model     import DualTaskModel
21from ..models.model_configs       import get_model_config
22from ..data.cmapss_dataset        import CMAPSSDataset
23from ..data.health_labels         import rul_to_health_3class
24from ..data.mtl_wrapper           import MTLDatasetWrapper
25from ..training.trainer           import UnifiedTrainer
26from ..training.seed_utils        import set_seed
27
28
29def run_single_experiment(cfg: ExperimentConfig, seed: int):
30    """One seed of one method on one dataset. Returns serialisable dict."""
31    set_seed(seed)
32    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
34    # ---- Data ----
35    train_ds = CMAPSSDataset(
36        dataset_name=cfg.dataset, data_dir=cfg.data_dir,
37        sequence_length=cfg.sequence_length, train=True,
38        random_seed=seed, per_condition_norm=cfg.per_condition_norm,
39    )
40    test_ds = CMAPSSDataset(
41        dataset_name=cfg.dataset, data_dir=cfg.data_dir,
42        sequence_length=cfg.sequence_length, train=False,
43        scaler_params=train_ds.get_scaler_params(),
44        random_seed=seed, per_condition_norm=cfg.per_condition_norm,
45    )
46    train_mtl = MTLDatasetWrapper(train_ds, rul_to_health_3class)
47    test_mtl  = MTLDatasetWrapper(test_ds,  rul_to_health_3class)
48    train_loader = DataLoader(train_mtl, batch_size=cfg.batch_size, shuffle=True)
49    test_loader  = DataLoader(test_mtl,  batch_size=cfg.batch_size, shuffle=False)
50
51    # ---- Model ----
52    mc = get_model_config(cfg.model_config)
53    backbone = UnifiedBackbone(
54        input_size=mc.input_size, hidden_size=mc.hidden_size,
55        cnn_channels=mc.cnn_channels, num_attn_heads=mc.num_attn_heads,
56        fc_dims=mc.fc_dims, dropout=mc.dropout,
57        use_attention=mc.use_attention, use_residual=mc.use_residual,
58    )
59    model = DualTaskModel(backbone, num_health_states=mc.num_health_states,
60                          dropout=mc.dropout)
61
62    # ---- Loss ----
63    mtl_loss = get_loss(cfg.mtl_method)
64    rul_criterion = (moderate_weighted_mse_loss if cfg.use_weighted_mse
65                     else nn.MSELoss())
66
67    # ---- Optimiser + scheduler ----
68    params    = list(model.parameters()) + list(mtl_loss.parameters())
69    optimizer = optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
70    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
71        optimizer, mode="min", factor=cfg.scheduler_factor,
72        patience=cfg.scheduler_patience, min_lr=cfg.min_lr,
73    )
74
75    # ---- Output dir ----
76    run_dir = Path(cfg.output_dir) / "phase1" / cfg.dataset / cfg.mtl_method / f"seed_{seed}"
77    run_dir.mkdir(parents=True, exist_ok=True)
78
79    # ---- Trainer ----
80    trainer = UnifiedTrainer(
81        model=model, mtl_loss=mtl_loss,
82        rul_criterion=rul_criterion,
83        health_criterion=nn.CrossEntropyLoss(),
84        optimizer=optimizer, scheduler=scheduler, device=device,
85        config={
86            "epochs": cfg.epochs, "patience": cfg.patience,
87            "grad_clip": cfg.grad_clip, "use_ema": cfg.use_ema,
88            "ema_decay": cfg.ema_decay, "warmup_epochs": cfg.warmup_epochs,
89            "lr": cfg.lr,
90            "checkpoint_dir": str(run_dir / "checkpoints"),
91            "log_dir":        str(run_dir / "logs"),
92            "adaptive_weight_decay": cfg.adaptive_weight_decay,
93            "initial_weight_decay":  cfg.initial_weight_decay,
94        },
95    )
96
97    # ---- Run + persist ----
98    results = trainer.fit(train_loader, test_loader)
99
100    serialisable = {
101        "dataset": cfg.dataset, "method": cfg.mtl_method, "seed": seed,
102        "best_epoch": results["best_epoch"],
103        "best_rmse":  results["best_rmse"],
104        "metrics":    {k: v for k, v in results["best_metrics"].items()},
105    }
106    with open(run_dir / "results.json", "w") as f:
107        json.dump(serialisable, f, indent=2, default=str)
108    return serialisable
The MTL parameter union pattern (line 68). list(mtl_loss.parameters()) appears innocuous — for GABA it returns an empty list. For Uncertainty (Kendall et al.) it returns the two learnable σ scalars; for GradNorm (Chen et al.) it returns the per-task weight parameters. Without the concatenation, those scalars would never see a gradient step and the algorithm would degenerate to fixed-σ scaling. The published tables for those baselines depend on this one line.

What Happens Across 190 Epochs

Reading the timeline above with the trainer code in mind, here is the narrative of the published GABA-on-FD002-seed-42 run:

Epoch rangeWhat's happeningObservable
0–9LR warmup + GABA warmup_steps. Model is in tiny-LR territory; GABA emits uniform 0.5/0.5 weights for the first ~110 batches.λ_rul stays near 0.16, gradient ratio still moderate (~100×) because the model is essentially un-trained.
10–30Full LR. GABA enters its closed-form regime. Gradient ratio explodes to ~2500× as the model learns coarse features and L_rul shoots up.λ_rul drops sharply to ~5×10⁻⁴. Training loss falls fastest in this window.
30–100Steady descent. EMA smoothing settles λ_rul at ~2×10⁻³. ReduceLROnPlateau hasn&apos;t fired yet — rmse_last keeps improving.λ_rul climbs slowly upward as the gradient ratio narrows from ~2000× to ~500×.
100Adaptive weight decay halves (1e-4 → 5e-5). Optimisation enters its second phase.Vertical orange dashed line in the chart.
100–109Best window. The model lands its lowest rmse_last at epoch 109 (RMSE = 7.42). Checkpointer saves model + optimiser + EMA shadow.Violet dashed line at the best epoch.
109–189Patience countdown. 80 stagnant epochs. ReduceLROnPlateau fires at ~epoch 140 (LR halved); fires again at ~epoch 170 (halved again).λ_rul keeps drifting upward (gradient ratio shrinks as the model becomes more accurate); rmse_last oscillates ±0.5 around 7.5.
189Early stopping fires. Best weights restored. Final eval recomputed with best EMA shadow.Red dashed line. Run terminates.
The GABA controller is doing its job throughout. At every epoch the closed form λraw1/grul\lambda^{\text{raw}} \propto 1/g_{\text{rul}} is being recomputed; EMA smooths the per-batch fluctuation; the floor (0.05) caps the minimum per-task weight; renormalisation restores sum-to-one. The trajectory is the controller's record — not a tuned schedule.

The Same Skeleton For Other Domains

The 150-line script generalises by replacing three things: dataset, backbone, MTL loss. Everything else — AdamW, scheduler, warmup, EMA, clip, early stopping, checkpointer — is unchanged.

DomainReplace dataset withReplace backbone withReplace MTL loss with
Self-driving 3D detectionnuScenes / Waymo Open Dataset class with multi-modal inputPointPillars or CenterPoint encoderPCGrad or GradNorm — the per-task gradients can conflict on far-vs-near objects
Speech recognitionLibriSpeech with mel-spectrogram windowsConformer encoderUncertainty (Kendall) — the CTC and attention-decoder losses have different scales
Medical imagingBraTS / MIMIC-CXR with paired image+textUNet++ or SegFormer encoderDWA (Liu et al.) — the loss-ratio history is well-suited to multi-modal medical tasks
Recommender systemMovieLens-style click + rating sequenceTransformer-based BERT4Rec encoderCustom controller — task priorities (CTR > revenue > dwell) are business-driven
Robot policy learningRobosuite multi-task bufferDiffusion policy backbonePCGrad — the per-task gradients can flip signs around the task switch

The skeleton invariant is what makes ML papers reproducible: the contribution lives in the three replaceable slots, not in the surrounding training loop. A reviewer who knows the skeleton can read any new MTL paper's training script in five minutes by spotting which slot is the contribution.

Pitfalls When Reading Or Modifying The Script

Pitfall 1: changing one line without checking the cadence

The placement of scheduler.step() AFTER warmup (line 161 in trainer.py) is non-trivial. Moving it before warmup makes the scheduler fight the warmup ramp; moving it inside _train_epoch turns ReduceLROnPlateau into a per-batch decay. Either change silently destroys the published result.

Pitfall 2: forgetting to pass the train scaler to the test set

The scaler_params=train_ds.get_scaler_params() argument (line 43) is the difference between leaked and clean evaluation. Skipping it lets the test set fit its own scaler — the model implicitly sees test-set statistics, RMSE typically drops by 2–3 cycles, and the result is not publishable.

Pitfall 3: shuffle=True on the test loader

Line 49: shuffle=False. The evaluator's per-unit last-cycle NASA score depends on the loader yielding windows in original temporal order. Setting shuffle=True makes ‘last cycle per engine’ pick a random window per engine; the NASA score becomes meaningless.

Pitfall 4: not creating output_dir before the trainer

Line 77's run_dir.mkdir(parents=True, exist_ok=True) runs BEFORE the trainer is constructed (line 80). The Checkpointer fails fast if the directory does not exist; doing the mkdir inside the trainer would couple two responsibilities.

Pitfall 5: skipping the params union for non-GABA methods

The line params = list(model.parameters()) + list(mtl_loss.parameters()) is harmless for GABA (the second list is empty). Removing it for ‘cleanliness’ silently breaks Uncertainty, GradNorm, and AMNL-v7 — their learnable parameters never get optimised.

Takeaway

  • Production training is run_single_experiment at grace/experiments/phase1_cmapss.py:37 — ~150 lines that compose every component from chapters 5–21 into one reproducible call.
  • Three phases: setup (60 lines, ~10 s), main loop (130 lines, ~30 min), persist (15 lines, ~1 s). Every published 140-row experiment was produced by calling this function.
  • The fit() body has a strict cadence: warmup → adaptive WD → train_epoch → EMA-eval → scheduler → best track → early stop. Reordering breaks reproducibility.
  • The published GABA-FD002-seed-42 run lands its best at epoch 109 (RMSE = 7.42), and early-stops at epoch 189. The trajectory is visible in the timeline viz from real gradient_stats.csv data.
  • The 150-line skeleton is domain-agnostic. Replace dataset, backbone, MTL loss; everything else is unchanged. New MTL papers add 1 of those 3 slots and reuse the rest.
Loading comments...