Chapter 15
20 min read
Section 63 of 121

Full Training Script Walkthrough

AMNL Training Pipeline

Putting Eight Sections in One Script

Sections §14.1-§15.4 each isolated one piece of the AMNL pipeline. This section stitches them all together into a single end-to-end training script - the same code shape the paper ships in paper_ieee_tii/experiments/train_amnl_v7.py. Twelve stages, <200 lines.

One driver function. train_amnl(dataset_name, epochs, lr) takes three knobs; everything else is paper-canonical default. Run it once per C-MAPSS subset to reproduce the paper's Table II AMNL row.

Interactive: 12-Stage Pipeline

Click each stage to see its paper-file location and the book section that derives it. Setup runs once; the epoch + eval block runs 200 times.

Loading pipeline flow…
Read the colour groups. Blue = data / model setup (one-time). Green = loss stack (one-time). Amber = optimiser stack (one-time). Pink = per-epoch train. Purple = per-epoch eval + scheduling. The whole training run is just these 12 stages cycling for 200 iterations.

Paper Files at a Glance

StagePaper fileBook section
1-2. Data + DataLoadergrace/data/cmapss_dataset.py§7
3. DualTaskModelgrace/models/dual_task_model.py§11.4
4. AMNL RUL lossgrace/core/weighted_mse.py§14.1
5. FixedWeightLoss combinergrace/core/baselines.py:34-49§15.1
6. AdamW + ReduceLROnPlateauexperiments/train_amnl_v7.py:480-496§15.2
7. EMA + clip_gradgrace/training/callbacks.py:54-87 + utils§15.3
8-12. Training loop drivergrace/training/trainer.py:126-217§14.4 + §15.5
To reproduce the paper. cd paper_ieee_tii && python experiments/train_amnl_v7.py --dataset FD002 --seed 0 invokes exactly the pipeline below. The paper's Table II numbers come from running this for 5 seeds (0-4) per dataset and averaging.

Python: Pseudo-Pipeline

Conceptual NumPy walkthrough. Each call wraps a stub or a paper-imported component. The pseudo-output at the bottom shows what the real run would print on FD002.

train_one_dataset() — 12-stage walkthrough
🐍train_one_dataset_pseudo.py
5import numpy as np

We only use NumPy for np.random.seed in the smoke test; the rest is conceptual Python that doesn&apos;t actually run end-to-end (the real code is the PyTorch block below).

EXECUTION STATE
📚 numpy = Library: ndarray + math + random + statistics. Used here only for the seed.
as np = Universal alias.
8def train_one_dataset(dataset_name='FD002', epochs=200, lr=1e-3, weight_decay=1e-4, grad_clip=1.0, ema_decay=0.999) -> dict:

End-to-end driver. Six hyperparameters - all paper-canonical defaults. The function body wires together every piece from §14-§15.4 in the order the paper trainer does.

EXECUTION STATE
⬇ input: dataset_name = 'FD002' = Which C-MAPSS subset to train on. Multi-condition by default.
⬇ input: epochs = 200 = Paper default.
⬇ input: lr = 1e-3 = AdamW initial lr (§15.2).
⬇ input: weight_decay = 1e-4 = AdamW decoupled L2 (§15.2).
⬇ input: grad_clip = 1.0 = Joint-norm clip (§15.3).
⬇ input: ema_decay = 0.999 = Weight EMA β (§15.3).
⬆ returns = {history: per-epoch dict, best_rmse: scalar} - the whole training trace.
17train_ds = build_dataset(dataset_name, split="train")

Step 1 - build the training dataset. CMAPSSFullDataset (paper file grace/data/cmapss_dataset.py) returns (seq, rul, health, uid) tuples with a 30-cycle sliding window.

EXECUTION STATE
→ build_dataset = Stub for paper&apos;s CMAPSSFullDataset constructor. Real call is more verbose with normalisation params.
⬇ arg 1: dataset_name = &apos;FD001&apos; / &apos;FD002&apos; / &apos;FD003&apos; / &apos;FD004&apos;.
⬇ arg 2: split="train" = Loads the training partition. The test partition has different RUL truncation (only the LAST window per engine is reported).
⬆ result: train_ds = Dataset of ~17,000 (seq, rul, health, uid) tuples for FD002 train split.
18test_ds = build_dataset(dataset_name, split="test")

Same for the held-out test partition.

EXECUTION STATE
→ split test = C-MAPSS provides explicit train/test splits per subset. Test windows are LAST-cycle-only per engine.
⬆ result: test_ds = Test dataset (~260 last-cycle windows for FD002).
19train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

Wraps the dataset in a torch.utils.data.DataLoader. Yields (seq, rul, hs, uid) batches.

EXECUTION STATE
📚 DataLoader = PyTorch&apos;s standard batching/loading utility.
⬇ arg: batch_size = 32 = Paper default for C-MAPSS.
⬇ arg: shuffle = True = Reshuffle every epoch.
20test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

Same but no shuffling - test order is deterministic.

EXECUTION STATE
⬇ arg: shuffle = False = Test order matches the engine ID order so per-engine reporting works.
23model = DualTaskModel(c_in=14, lstm_hidden=256, num_heads=8, shared_dim=32, num_classes=3)

Step 3 - build the §11.4 model with paper defaults.

EXECUTION STATE
⬇ arg: c_in = 14 = Number of informative C-MAPSS sensors.
⬇ arg: lstm_hidden = 256 = Per-direction BiLSTM hidden size (§9).
⬇ arg: num_heads = 8 = Multi-head attention heads (§10).
⬇ arg: shared_dim = 32 = FC funnel output (§11.1).
⬇ arg: num_classes = 3 = Health classes (§11.3).
27rul_criterion = moderate_weighted_mse_loss

First-class function reference - no call. Paper&apos;s AMNL RUL loss (§14.1).

28hs_criterion = cross_entropy

Stable log_softmax + nll_loss for the health branch.

29mtl_loss = FixedWeightLoss(rul_weight=0.5, health_weight=0.5)

Paper&apos;s 0.5/0.5 combiner (§15.1).

EXECUTION STATE
⬇ args = Both 0.5 - the AMNL choice. Replace this line with GABALoss to switch to Part VI.
32optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

AdamW base (§15.2).

EXECUTION STATE
📚 AdamW = Adam with decoupled weight decay. Paper default.
⬇ arg: lr = 1e-3 = Initial step size before warmup multiplier kicks in.
⬇ arg: weight_decay = 1e-4 = Decoupled L2 strength.
33scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=30, min_lr=5e-6)

Reactive lr scheduler (§15.2).

EXECUTION STATE
⬇ arg: factor = 0.5 = 50% reduction per trigger.
⬇ arg: patience = 30 = Epochs without improvement before firing.
⬇ arg: min_lr = 5e-6 = Floor for the lr.
36ema = ExponentialMovingAverage(model, decay=ema_decay)

Weight-EMA tracker (§15.3). Snapshots all parameters now; updates per step.

EXECUTION STATE
⬇ arg: decay = 0.999 = Half-life ≈ 693 steps.
39history = {"rmse": [], "nasa": [], "rul_loss": [], "hs_loss": [], "lr": []}

Five lists for per-epoch logging.

40best_rmse = float("inf")

Initialise running best to +∞ so the first epoch always counts as an improvement.

EXECUTION STATE
📚 float("inf") = IEEE-754 +∞.
41for epoch in range(epochs):

Per-epoch loop. 200 iterations.

LOOP TRACE · 6 iterations
epoch 0
warmup factor = 0.10 → lr = 1e-4
rmse = ≈ 25-40 (random init)
epoch 9
warmup factor = 0.91 → lr = 9.1e-4 (last warmup epoch)
epoch 10
warmup factor = 1.00 → lr = 1e-3 (post-warmup)
scheduler = starts watching from here
epoch 60
rmse = ≈ 14-15 (descent done)
epoch 90
scheduler = first plateau cut → lr = 5e-4
epoch 199
rmse = ≈ 13.36 (paper FD002+FD004 average)
lr = ≈ 6e-5 (after several plateau cuts)
43for pg in optimizer.param_groups:

Iterate AdamW&apos;s parameter groups (typically just one).

44pg["lr"] = lr * (0.1 + 0.9 * min(epoch, 10) / 10)

Apply linear warmup (§15.2). After epoch 9 the formula returns lr · 1.0 = lr.

EXECUTION STATE
📚 min(a, b) = Built-in. Caps the warmup at epoch 10.
→ at epoch 0 = lr · (0.1 + 0) = 0.1 · lr.
→ at epoch 9 = lr · (0.1 + 0.81) = 0.91 · lr.
→ at epoch 10+ = lr · 1.0 (saturates).
47train_stats = train_epoch(model, train_loader, rul_criterion, hs_criterion, mtl_loss, optimizer, ema, grad_clip=grad_clip)

Run one full epoch over the training loader. Inside: forward → losses → mtl_loss → backward → clip_grad → optim.step → ema.update (§14.4 + §15.3). Returns the avg per-task losses.

EXECUTION STATE
→ train_epoch wrapper = Encapsulates the §14.4 training-step body. Real paper code is paper_ieee_tii/grace/training/trainer.py:243-284.
⬇ args = All seven training-time pieces: model, loader, two criteria, combiner, optimiser, ema. grad_clip is the only kwarg.
⬆ result: train_stats = {loss, rul_loss, hs_loss, grad_norm} averaged over the epoch.
52ema.apply_shadow(model)

Swap live weights with EMA shadow values BEFORE evaluation (§15.3). The model now uses the smoothed weights for the metric computation.

53eval_stats = evaluate(model, test_loader)

Compute test metrics with the smoothed weights. Returns RMSE (last-cycle, all-cycle), NASA score, R², health accuracy, F1.

EXECUTION STATE
→ evaluate = Wrapper around paper&apos;s grace/training/evaluator.py.
⬆ result: eval_stats = Dict with rmse_last, rmse_all, nasa_score, r2_last, health_accuracy, health_f1.
54ema.restore(model)

Put live weights back so training resumes correctly. Forgetting this line silently breaks the next training step.

57if epoch >= 10:

Only call scheduler.step() AFTER warmup ends (§15.2).

58scheduler.step(eval_stats["rmse"])

ReduceLROnPlateau records the val metric. Cuts lr if no improvement for `patience` epochs.

EXECUTION STATE
⬇ arg: metrics = rmse = Validation RMSE for this epoch.
61if eval_stats["rmse"] < best_rmse:

Track best epoch. The paper trainer also deep-copies the model state and EMA shadow here for checkpointing.

62best_rmse = eval_stats["rmse"]

Update running best.

64history["rmse"].append(eval_stats["rmse"])

Append per-epoch metrics for plotting / logging.

65history["nasa"].append(eval_stats["nasa_score"])

Same for NASA.

66history["rul_loss"].append(train_stats["rul_loss"])

Per-task loss tracking.

67history["hs_loss"].append(train_stats["hs_loss"])

Same.

68history["lr"].append(optimizer.param_groups[0]["lr"])

Track the lr trajectory.

70return {"history": history, "best_rmse": best_rmse}

Final return.

EXECUTION STATE
⬆ return = Dict with full per-epoch history + the best RMSE seen. Caller plots / saves.
74np.random.seed(0)

Repro.

EXECUTION STATE
📚 np.random.seed(s) = Set NumPy&apos;s global PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
75print("training pipeline (pseudo)")

Header for the pseudo-output.

EXECUTION STATE
Output (paper-realistic) = training pipeline (pseudo) - dataset : FD002 - epochs : 200 - paper RMSE : 13.36 (FD002 + FD004 average) - paper NASA : 1302 (FD002 + FD004 average)
49 lines without explanation
1# Pseudo-pipeline. Real Python uses PyTorch (next block); this version
2# wires the abstract pieces together so the SHAPES and CONTROL FLOW are
3# visible without DL framework specifics.
4
5import numpy as np
6
7
8def train_one_dataset(dataset_name: str = "FD002",
9                       epochs:        int = 200,
10                       lr:            float = 1e-3,
11                       weight_decay:  float = 1e-4,
12                       grad_clip:     float = 1.0,
13                       ema_decay:     float = 0.999) -> dict:
14    """End-to-end AMNL training for one C-MAPSS subset.
15
16    Returns history dict with per-epoch loss / RMSE / NASA score.
17    """
18    # 1-2. Data: build dataset and loader (paper §7)
19    train_ds = build_dataset(dataset_name, split="train")
20    test_ds  = build_dataset(dataset_name, split="test")
21    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
22    test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False)
23
24    # 3. Model: §11.4 DualTaskModel
25    model = DualTaskModel(c_in=14, lstm_hidden=256, num_heads=8,
26                           shared_dim=32, num_classes=3)
27
28    # 4-5. Loss stack: §14.1 + §15.1
29    rul_criterion = moderate_weighted_mse_loss
30    hs_criterion  = cross_entropy
31    mtl_loss      = FixedWeightLoss(rul_weight=0.5, health_weight=0.5)
32
33    # 6. Optimiser + scheduler: §15.2
34    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
35    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=30, min_lr=5e-6)
36
37    # 7. EMA + grad clip: §15.3
38    ema = ExponentialMovingAverage(model, decay=ema_decay)
39
40    # ---------- Per-epoch loop ----------
41    history = {"rmse": [], "nasa": [], "rul_loss": [], "hs_loss": [], "lr": []}
42    best_rmse = float("inf")
43    for epoch in range(epochs):
44        # 8. Warmup (§15.2)
45        for pg in optimizer.param_groups:
46            pg["lr"] = lr * (0.1 + 0.9 * min(epoch, 10) / 10)
47
48        # 9. Training step (§14.4)
49        train_stats = train_epoch(model, train_loader,
50                                    rul_criterion, hs_criterion, mtl_loss,
51                                    optimizer, ema,
52                                    grad_clip=grad_clip)
53
54        # 10. Eval with EMA shadow (§15.3 + §13)
55        ema.apply_shadow(model)
56        eval_stats = evaluate(model, test_loader)
57        ema.restore(model)
58
59        # 11. Scheduler step (§15.2)
60        if epoch >= 10:
61            scheduler.step(eval_stats["rmse"])
62
63        # 12. Track best
64        if eval_stats["rmse"] < best_rmse:
65            best_rmse = eval_stats["rmse"]
66
67        history["rmse"]      .append(eval_stats["rmse"])
68        history["nasa"]      .append(eval_stats["nasa_score"])
69        history["rul_loss"]  .append(train_stats["rul_loss"])
70        history["hs_loss"]   .append(train_stats["hs_loss"])
71        history["lr"]        .append(optimizer.param_groups[0]["lr"])
72
73    return {"history": history, "best_rmse": best_rmse}
74
75
76# ---------- Smoke test ----------
77np.random.seed(0)
78print("training pipeline (pseudo)")
79print("  - dataset       : FD002")
80print("  - epochs        : 200")
81print("  - paper RMSE    : 13.36 (FD002 + FD004 average)")
82print("  - paper NASA    : 1302  (FD002 + FD004 average)")
83# In real code: out = train_one_dataset()

PyTorch: Paper Trainer Driver

Production version. Imports every paper-canonical piece by name and runs them in the trainer's exact order. The per-epoch print emits the same shape the paper trainer logs.

train_amnl() — paper-canonical end-to-end
🐍train_amnl.py
1import torch

Top-level PyTorch.

EXECUTION STATE
📚 torch = Tensor library + autograd + nn modules + optim.
2import torch.nn.functional as F

F.cross_entropy for the health branch.

3import torch.optim as optim

Optimisers and lr schedulers.

4from torch.utils.data import DataLoader

Standard DataLoader.

7from grace.data.cmapss_dataset import CMAPSSFullDataset

Paper&apos;s sliding-window dataset (§7).

8from grace.models.dual_task_model import DualTaskModel

§11.4 architecture.

9from grace.core.weighted_mse import moderate_weighted_mse_loss

§14.1 AMNL RUL loss.

10from grace.core.baselines import FixedWeightLoss

§15.1 0.5/0.5 combiner.

11from grace.training.callbacks import ExponentialMovingAverage

§15.3 EMA helper.

12from grace.training.evaluator import evaluate_model

§13 RMSE + NASA computation.

15def train_amnl(dataset_name="FD002", epochs=200, lr=1e-3) -> dict:

Top-level driver. Three knobs - everything else is paper default.

EXECUTION STATE
⬇ input: dataset_name = Which C-MAPSS subset.
⬇ input: epochs = 200 = Paper default.
⬇ input: lr = 1e-3 = Paper default.
⬆ returns = Dict {history, best_rmse}.
21device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Pick GPU when available, fall back to CPU. Paper code does this on every entrypoint.

EXECUTION STATE
📚 torch.device(string) = Construct a device handle from a string.
📚 torch.cuda.is_available() = Returns True if a CUDA-enabled GPU is detected and usable.
24train_ds = CMAPSSFullDataset(dataset_name, split="train")

Build the training partition.

25test_ds = CMAPSSFullDataset(dataset_name, split="test")

Test partition.

26train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

Training DataLoader with parallel workers and pinned memory for fast GPU transfers.

EXECUTION STATE
⬇ arg: batch_size = 32 = Paper default.
⬇ arg: shuffle = True = Reshuffle every epoch.
⬇ arg: num_workers = 4 = 4 background processes prefetch batches concurrently with training. Speedup is dataset-dependent.
⬇ arg: pin_memory = True = Allocate batches in pinned (page-locked) host memory for faster .to(cuda) transfers. Essentially free if you have a GPU.
27test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

Test loader. shuffle=False because we report per-engine results in input order.

EXECUTION STATE
⬇ arg: shuffle = False = Determinism for per-engine reporting.
⬇ arg: num_workers = 2 = Less parallelism for test - smaller dataset, fewer batches.
30model = DualTaskModel(c_in=14, lstm_hidden=256, num_heads=8, shared_dim=32, num_classes=3).to(device)

Build §11.4 model and move to device. Method-chained .to(device) returns the same module so the assignment captures the moved version.

EXECUTION STATE
📚 .to(device) = Move every parameter and buffer to the device. In-place if device is the same as current; copies otherwise.
33rul_criterion = moderate_weighted_mse_loss

First-class function reference. No call.

34hs_criterion = F.cross_entropy

First-class function reference.

35mtl_loss = FixedWeightLoss(rul_weight=0.5, health_weight=0.5)

Paper combiner. Replace this line for ablations - everything else stays.

38optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4, betas=(0.9, 0.999), eps=1e-8)

AdamW with paper-canonical hyperparameters.

EXECUTION STATE
⬇ arg: weight_decay = 1e-4 = Decoupled L2 (§15.2).
⬇ arg: betas = (0.9, 0.999) = Canonical Adam moment decays.
⬇ arg: eps = 1e-8 = Numerical stabiliser in the denominator.
40scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=30, min_lr=5e-6)

Reactive scheduler.

EXECUTION STATE
⬇ arg: mode = "min" = Watch a metric expected to DECREASE (RMSE).
⬇ arg: factor = 0.5 = 50% reduction per trigger.
⬇ arg: patience = 30 = Epochs without improvement.
⬇ arg: min_lr = 5e-6 = Floor.
44ema = ExponentialMovingAverage(model, decay=0.999)

EMA tracker (§15.3). Snapshots all model parameters now.

47history = {"rmse": [], "nasa": [], "lr": []}

Per-epoch logging dict.

48best_rmse = float("inf")

Initial running best.

49for epoch in range(epochs):

Per-epoch loop. 200 iterations.

51if epoch < 10:

Warmup window.

52for pg in optimizer.param_groups:

Iterate AdamW&apos;s parameter groups.

53pg["lr"] = lr * (0.1 + 0.9 * epoch / 10)

Linear warmup formula. Same as §15.2.

56model.train()

Switch to training mode. Activates dropout, makes BatchNorm use batch stats. CRITICAL - eval mode silently disables dropout and corrupts AMNL&apos;s sample-weighted regularisation.

57for seq, rul_tgt, hs_tgt, _ in train_loader:

Batch loop. Underscore _ discards the uid (used in evaluation but not training).

EXECUTION STATE
iter vars = seq (B, T, F), rul_tgt (B,), hs_tgt (B,), _ uid (B,).
58seq, rul_tgt, hs_tgt = seq.to(device), rul_tgt.to(device).view(-1, 1), hs_tgt.to(device)

Move tensors to device and reshape rul_tgt to (B, 1) to match the model output. Multi-assignment unpacks the right-hand-side tuple.

EXECUTION STATE
📚 .view(-1, 1) = Reshape - infer first dim, fix second to 1. (B,) → (B, 1).
59optimizer.zero_grad()

Reset .grad before backward.

60rul_pred, hs_logits = model(seq)

DualTaskModel forward returns the (rul, logits) tuple.

61rul_loss = rul_criterion(rul_pred, rul_tgt)

AMNL weighted MSE.

62hs_loss = hs_criterion(hs_logits, hs_tgt)

Cross-entropy.

63loss = mtl_loss(rul_loss, hs_loss)

FixedWeightLoss combine: 0.5 · rul_loss + 0.5 · hs_loss.

64loss.backward()

Reverse-mode autograd. Populates every model parameter&apos;s .grad.

65torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Joint-norm clip (§15.3). Trailing underscore = in-place. Runs AFTER backward, BEFORE step.

EXECUTION STATE
⬇ arg: max_norm = 1.0 = Paper default.
66optimizer.step()

Apply AdamW update with the (clipped) gradients.

67ema.update(model)

Pull EMA shadow toward freshly-updated live weights. ORDER MATTERS - must come AFTER optimizer.step().

70ema.apply_shadow(model)

Swap shadow IN for evaluation.

71eval_stats = evaluate_model(model, test_loader, device)

Compute RMSE / NASA / accuracy / F1. Returns dict.

72ema.restore(model)

Put live weights back. Forgetting this silently breaks training.

75if epoch >= 10:

Skip scheduler.step during warmup.

76scheduler.step(eval_stats["rmse"])

ReduceLROnPlateau records the metric and possibly cuts lr.

79if eval_stats["rmse"] < best_rmse:

Track best.

80best_rmse = eval_stats["rmse"]

Update.

82history["rmse"].append(eval_stats["rmse"])

Log per-epoch.

83history["nasa"].append(eval_stats["nasa_score"])

Same for NASA.

84history["lr"].append(optimizer.param_groups[0]["lr"])

Track lr trajectory.

86if epoch % 20 == 0:

Print every 20 epochs to keep stdout manageable.

87print(f"epoch={epoch:>3d} rmse={eval_stats['rmse']:.3f} nasa={eval_stats['nasa_score']:.1f} lr={optimizer.param_groups[0]['lr']:.2e}")

Per-epoch log line.

EXECUTION STATE
→ :>3d = Integer, right-aligned, width 3.
→ :.3f = Float, 3 decimals.
→ :.1f = Float, 1 decimal.
→ :.2e = Float in scientific, 2 decimals.
Output (one realisation) = epoch= 0 rmse=32.412 nasa=4520.3 lr=1.00e-04 epoch= 20 rmse=18.234 nasa=2103.5 lr=1.00e-03 epoch= 60 rmse=14.812 nasa=1456.2 lr=1.00e-03 epoch=120 rmse=13.612 nasa=1325.4 lr=2.50e-04 epoch=199 rmse=13.367 nasa=1302.1 lr=6.25e-05
→ reading = Loss drops fastest during warmup + early steady-state. Plateau cuts visible at epoch ~90 and ~120 from the lr column. Final RMSE ≈ 13.36 - matches the paper Table II for FD002 + FD004 average.
89return {"history": history, "best_rmse": best_rmse}

Final return.

92if __name__ == "__main__":

Standard Python entrypoint guard. Only runs when the script is executed directly (not when imported). Lets the same file be both a library and a CLI.

EXECUTION STATE
→ __name__ = Special variable. Equals &apos;__main__&apos; when running as a script; module name when imported.
93out = train_amnl("FD002", epochs=200)

Run the full pipeline.

94print(f"FD002 best RMSE : {out['best_rmse']:.3f}")

Final summary.

EXECUTION STATE
Output (paper-realistic) = FD002 best RMSE : 13.367
40 lines without explanation
1import torch
2import torch.nn.functional as F
3import torch.optim as optim
4from torch.utils.data import DataLoader
5
6# Paper-canonical pieces (paper_ieee_tii/grace/...)
7from grace.data.cmapss_dataset      import CMAPSSFullDataset       # §7
8from grace.models.dual_task_model    import DualTaskModel            # §11.4
9from grace.core.weighted_mse         import moderate_weighted_mse_loss # §14
10from grace.core.baselines            import FixedWeightLoss           # §15.1
11from grace.training.callbacks        import ExponentialMovingAverage   # §15.3
12from grace.training.evaluator        import evaluate_model             # §13
13
14
15def train_amnl(dataset_name: str = "FD002",
16                epochs:        int = 200,
17                lr:            float = 1e-3) -> dict:
18    """Reproduce the paper&apos;s AMNL training run on one C-MAPSS subset.
19
20    Source: paper_ieee_tii/experiments/train_amnl_v7.py (top-level driver).
21    """
22    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
24    # 1-2. Data
25    train_ds = CMAPSSFullDataset(dataset_name, split="train")
26    test_ds  = CMAPSSFullDataset(dataset_name, split="test")
27    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=4, pin_memory=True)
28    test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
29
30    # 3. Model
31    model = DualTaskModel(c_in=14, lstm_hidden=256, num_heads=8,
32                           shared_dim=32, num_classes=3).to(device)
33
34    # 4-5. Loss stack
35    rul_criterion = moderate_weighted_mse_loss
36    hs_criterion  = F.cross_entropy
37    mtl_loss      = FixedWeightLoss(rul_weight=0.5, health_weight=0.5)
38
39    # 6. Optim + sched
40    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4,
41                              betas=(0.9, 0.999), eps=1e-8)
42    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min",
43                                                        factor=0.5, patience=30,
44                                                        min_lr=5e-6)
45
46    # 7. EMA
47    ema = ExponentialMovingAverage(model, decay=0.999)
48
49    # ---------- Training loop ----------
50    history = {"rmse": [], "nasa": [], "lr": []}
51    best_rmse = float("inf")
52    for epoch in range(epochs):
53        # 8. Warmup
54        if epoch < 10:
55            for pg in optimizer.param_groups:
56                pg["lr"] = lr * (0.1 + 0.9 * epoch / 10)
57
58        # 9. Train
59        model.train()
60        for seq, rul_tgt, hs_tgt, _ in train_loader:
61            seq, rul_tgt, hs_tgt = seq.to(device), rul_tgt.to(device).view(-1, 1), hs_tgt.to(device)
62            optimizer.zero_grad()
63            rul_pred, hs_logits = model(seq)
64            rul_loss = rul_criterion(rul_pred, rul_tgt)
65            hs_loss  = hs_criterion(hs_logits, hs_tgt)
66            loss     = mtl_loss(rul_loss, hs_loss)
67            loss.backward()
68            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
69            optimizer.step()
70            ema.update(model)
71
72        # 10. Evaluate with EMA
73        ema.apply_shadow(model)
74        eval_stats = evaluate_model(model, test_loader, device)
75        ema.restore(model)
76
77        # 11. Scheduler step (after warmup)
78        if epoch >= 10:
79            scheduler.step(eval_stats["rmse"])
80
81        # 12. Track best
82        if eval_stats["rmse"] < best_rmse:
83            best_rmse = eval_stats["rmse"]
84
85        history["rmse"].append(eval_stats["rmse"])
86        history["nasa"].append(eval_stats["nasa_score"])
87        history["lr"]  .append(optimizer.param_groups[0]["lr"])
88
89        if epoch % 20 == 0:
90            print(f"epoch={epoch:>3d}  rmse={eval_stats['rmse']:.3f}  nasa={eval_stats['nasa_score']:.1f}  lr={optimizer.param_groups[0]['lr']:.2e}")
91
92    return {"history": history, "best_rmse": best_rmse}
93
94
95if __name__ == "__main__":
96    out = train_amnl("FD002", epochs=200)
97    print(f"FD002 best RMSE : {out['best_rmse']:.3f}")

Drop-In for Other PHM Domains

The 12-stage pipeline transfers wherever you have (a) a sliding-window time-series dataset, (b) a primary regression target, and (c) an auxiliary classification target. Swap the dataset and the c_in; everything else is unchanged.

DomainDataset classc_inmax_rulOther changes
RUL prediction (this book)CMAPSSFullDataset14125none
N-CMAPSS DS02NCMAPSSDataset20100model_configs.ncmapss_20feat
Battery SoH + fault typeBatteryDataset51.0max_rul=1.0 in moderate_weighted_mse_loss
Wind-turbine SCADASCADADataset12720longer windows (T=144)
MRI tumour growth + benign/malignantMRIFollowupDatasetvol20regression on volume, BCE for binary classification
Disk RUL + SMART anomaly typeBackblazeDataset16180daily windows instead of cycles
The recipe is the model. Most PHM applications can swap CMAPSSFullDataset for their own dataset class and reuse the rest of the script verbatim. The paper code is structured to make this easy - all the knobs live in three places: dataset class, c_in, and max_rul.

Three End-to-End Pitfalls

Pitfall 1: Calling scheduler.step() during warmup. ReduceLROnPlateau records the val metric on every call. If you call scheduler.step(val) during warmup (where lr is being externally set), the scheduler's internal ‘best’ gets corrupted and the first post-warmup cut fires too early. ALWAYS guard with if epoch >= warmup_epochs:.
Pitfall 2: Forgetting ema.restore() after eval. If you call apply_shadow for val and forgetrestore, training continues with the SHADOW weights. Updates accumulate in the wrong place; the next eval sees corrupted shadow values. Plausible-looking loss curves; irreproducible runs.
Pitfall 3: model.eval() mode during training. The evaluate_model helper internally callsmodel.eval(). If the trainer doesn't re-enter model.train() at the start of the next epoch, dropout stays off and AMNL's sample-weighted regularisation is silently broken. Paper trainer puts model.train() at the top of every_train_epoch call.
The point. Twelve stages, 100 lines, reproduces the paper's AMNL row of Table II. Every component lives in paper_ieee_tii/grace/; the driver lives in experiments/train_amnl_v7.py. Chapter 16 reports what those numbers actually look like at convergence.

Takeaway — End of Chapter 15

  • Twelve stages. 7 setup, 5 per-epoch. Wired together by paper_ieee_tii/grace/training/trainer.py.
  • Three knobs. train_amnl(dataset_name, epochs, lr); everything else is paper-canonical default.
  • Order matters. warmup → train → eval(EMA) → restore → scheduler → track best.
  • ~13.36 RMSE on FD002+FD004 avg. Paper Table II row. Reproducible with seed=0.
  • End of Chapter 15. Chapter 16 reports empirical results across all four C-MAPSS subsets.
Loading comments...