§14.1-§14.3 designed AMNL one piece at a time: the sample-weight schedule, the linear shape, the w_max ceiling. This section plugs them into the §11.4 DualTaskModel and runs one full training step end to end.
Five components, each in its own paper file:
Component
Paper file
Purpose
DualTaskModel
paper_ieee_tii/grace/models/dual_task_model.py
shared backbone + RUL & health heads
moderate_weighted_mse_loss
paper_ieee_tii/grace/core/weighted_mse.py
AMNL RUL loss with sample weights
F.cross_entropy
torch.nn.functional
health-branch loss
FixedWeightLoss(0.5, 0.5)
paper_ieee_tii/grace/core/baselines.py
0.5/0.5 task combiner
torch.optim.Adam(lr=1e-3)
torch.optim
optimiser - paper default
Plug-and-play. Swapping AMNL → GABA in Part VI changes ONE line: the mtl_loss assignment. Everything else - DualTaskModel, weighted_mse, Adam - stays unchanged. That is the point of the paper's factory pattern.
The Files in paper_ieee_tii
For reference, the paper code is organised as follows. Every section in this chapter corresponds to a specific file below.
Path
Section
Lines (approx)
grace/core/weighted_mse.py
§14.1, §14.2
39
grace/core/baselines.py (FixedWeightLoss)
§14.4
~17 of 437
grace/core/baselines.py (BaseMTLLoss)
§14.4
~12 of 437
grace/models/dual_task_model.py
§11.4
54
grace/models/task_heads.py
§11.2, §11.3
53
grace/models/backbone.py
§8-§11.1
152
grace/training/trainer.py
§14.4
290
Reproduce the paper.cd paper_ieee_tii && python experiments/train_amnl_v7.py --dataset FD002 --seed 0 runs the full AMNL pipeline. The training step inside that script is exactly the function shown in the PyTorch block below.
Interactive: One Step, Eight Stages
Click through the eight stages. Each stage corresponds to one or two lines of the PyTorch block - see the colour match.
Loading training step tracer…
Watch the gradient stage. At stage 6 (backward) the autograd engine flows the loss gradient back through the heads, into the FC funnel, then into the attention block, BiLSTM, and CNN. The §12 imbalance shows up here - on the shared backbone the RUL gradient is ~500× the HS gradient, even with FixedWeightLoss(0.5, 0.5).
Python: Manual Step from Scratch
Pure NumPy reference - same algorithm as the paper, but with all gradients computed analytically so the chain rule is visible. The toy backbone is a single Linear layer (real backbone is CNN+BiLSTM+Attention) but the loss machinery is identical.
NumPy is the only dependency of this from-scratch reference implementation. We use ndarray, broadcasting, np.maximum, np.clip, np.exp, np.log, np.arange, plus matmul (@) for the analytic gradients.
ONE gradient-descent step. Mirrors the paper's _train_epoch inner loop (paper_ieee_tii/grace/training/trainer.py:243-284) but with a single-Linear backbone instead of CNN+BiLSTM+Attention so the algebra is visible.
EXECUTION STATE
⬇ input: seq = (B, T, F_in) - batch of windows. Shape (4, 30, 14) in the smoke test.
⬇ input: rul_tgt = (B, 1) - capped RUL targets. The trainer reshapes via .view(-1, 1).
⬇ input: health_tgt = (B,) - integer class labels.
Toy backbone: average over time then project. Real backbone is CNN→BiLSTM→Attention; the result is the same shape (B, 256) and the same role (shared 256-D feature vector).
EXECUTION STATE
📚 .mean(axis) = Reduce-mean. axis=1 averages over time, collapsing (B, T, F_in) → (B, F_in).
⬇ arg: axis=1 = Time axis. Real model uses CNN+BiLSTM+Attention here; we use mean for clarity.
Linear projection to a scalar, then ReLU/clamp ≥ 0 - matches the paper's <code>torch.clamp(rul_head(features), min=0.0)</code> in DualTaskModel.forward.
EXECUTION STATE
📚 np.maximum(a, b) = Element-wise max of two arrays/scalars.
⬇ arg 1: a = 0.0 = Lower bound. RUL must be non-negative.
⬇ arg 2: b = z @ W_rul = (B, 256) @ (256, 1) = (B, 1) raw RUL prediction.
Backprop through W_rul. The (rul_pred > 0) mask handles the ReLU/clamp on the output - if the pre-clamp output was negative, the gradient through that sample is 0.
EXECUTION STATE
operator: > 0 = Boolean mask. True where rul_pred is strictly positive (non-clipped).
→ [:, None] = Add a trailing axis: (B,) → (B, 1) so the matmul shapes match.
Backprop into the SHARED W_back. Both heads contribute via the chain rule, weighted by the 0.5 task weights from FixedWeightLoss. THIS is the line where the §12 gradient imbalance shows up - the RUL term dominates the sum even with the 0.5 weight.
→ THIS is the imbalance = RUL contribution scale ≈ residual size (~10-100); HS contribution scale ≈ (p - onehot) / B (~0.01-0.05). Even with 0.5/0.5 task weights, RUL dominates the shared-backbone gradient by ~500×.
47params['W_back'] -= lr * grad_back
Plain SGD update on the shared backbone. Real paper uses Adam.
48params['W_rul'] -= lr * 0.5 * grad_rul
Update RUL head with the 0.5 task weight pre-applied.
49params['W_hs'] -= lr * 0.5 * grad_hs
Update health head.
51return { ... }
Logging dict for the trainer to consume.
EXECUTION STATE
⬆ keys = rul_loss, health_loss, total_loss, weights_min, weights_max - matches what paper trainer logs.
60np.random.seed(0)
Repro.
EXECUTION STATE
📚 np.random.seed(s) = Sets NumPy's legacy global PRNG.
Output = total_loss : ~2750.5 (dominated by rul_loss)
→ reading = Total is essentially equal to 0.5 · rul_loss because health_loss is bounded at ~1. AMNL's sample weighting helps WITHIN the rul_loss; it does NOT fix the BETWEEN-task imbalance.
Schedule range. With targets [5, 40, 90, 125] we get [1.96, 1.68, 1.28, 1.00].
EXECUTION STATE
Output = weights : [1.00, 1.96]
47 lines without explanation
1import numpy as np
234defamnl_training_step(seq: np.ndarray,5 rul_tgt: np.ndarray,6 health_tgt: np.ndarray,7 params:dict,8 lr:float=1e-3,9 max_rul:float=125.0)->dict:10"""One AMNL gradient-descent step in pure NumPy.
1112 Mirrors paper_ieee_tii/grace/training/trainer.py::_train_epoch but with
13 a 1-Linear backbone instead of CNN+BiLSTM+Attention, so the algebra is
14 visible. Returns logging dict.
15 """16# 1. Forward17 z = seq.mean(axis=1) @ params["W_back"]# (B, 256) shared18 rul_pred = np.maximum(0.0, z @ params["W_rul"])# (B, 1) clamped ≥ 019 health_logits = z @ params["W_hs"]# (B, 3)2021# 2. AMNL sample weights22 weights =1.0+ np.clip(1.0- rul_tgt[:,0]/ max_rul,0,1)# (B,)2324# 3. Weighted MSE on RUL branch25 residual =(rul_pred - rul_tgt)[:,0]# (B,)26 rul_loss =float((weights * residual **2).mean())2728# 4. Cross-entropy on health branch (stable log-softmax)29 z_max = health_logits.max(-1, keepdims=True)30 log_p =(health_logits - z_max)- np.log(31 np.exp(health_logits - z_max).sum(-1, keepdims=True)32)33 health_loss =float(-log_p[np.arange(len(health_tgt)), health_tgt].mean())3435# 5. Combine: FixedWeightLoss(0.5, 0.5)36 total_loss =0.5* rul_loss +0.5* health_loss
3738# 6. Backward (analytic - the tedious part)39 d_rul =(2.0/len(rul_tgt))* weights * residual # ∂L_rul/∂rul_pred40 grad_rul = z.T @ (d_rul *(rul_pred[:,0]>0))[:,None]# (256, 1)4142 p = np.exp(log_p)43 onehot = np.zeros_like(p); onehot[np.arange(len(health_tgt)), health_tgt]=1.044 d_logits =(p - onehot)/len(health_tgt)# (B, 3)45 grad_hs = z.T @ d_logits # (256, 3)46 grad_back =(seq.mean(axis=1)).T @ (470.5* d_rul[:,None]* params["W_rul"].T +480.5* d_logits @ params["W_hs"].T
49)# (input_dim, 256)5051# 7. Adam-style step (simplified to plain SGD here for clarity)52 params["W_back"]-= lr * grad_back
53 params["W_rul"]-= lr *0.5* grad_rul
54 params["W_hs"]-= lr *0.5* grad_hs
5556return{57"rul_loss": rul_loss,58"health_loss": health_loss,59"total_loss": total_loss,60"weights_min":float(weights.min()),61"weights_max":float(weights.max()),62}636465# ---------- Smoke test ----------66np.random.seed(0)67B, T, F_in =4,30,1468params ={69"W_back": np.random.randn(F_in,256).astype(np.float32)*0.05,70"W_rul": np.random.randn(256,1).astype(np.float32)*0.05,71"W_hs": np.random.randn(256,3).astype(np.float32)*0.05,72}73seq = np.random.randn(B, T, F_in).astype(np.float32)74rul_tgt = np.array([[5.0],[40.0],[90.0],[125.0]], dtype=np.float32)75health_tgt = np.array([2,1,0,0])7677stats = amnl_training_step(seq, rul_tgt, health_tgt, params)78print(f"rul_loss : {stats['rul_loss']:.4f}")79print(f"health_loss : {stats['health_loss']:.4f}")80print(f"total_loss : {stats['total_loss']:.4f}")81print(f"weights : [{stats['weights_min']:.2f}, {stats['weights_max']:.2f}]")
PyTorch: The Paper's Step
The exact training step from paper_ieee_tii/grace/training/trainer.py (lines 249-283), factored into a function. Reproduces the AMNL paper if you wire it into a DataLoader and run for the paper's 40 epochs.
amnl_training_step() — paper-canonical PyTorch
🐍amnl_step_torch.py
Explanation(31)
Code(65)
1import torch
Top-level PyTorch.
EXECUTION STATE
📚 torch = Tensor library + autograd + nn + optim.
2import torch.nn as nn
Module containers.
3import torch.nn.functional as F
Stateless ops - F.cross_entropy is the health-branch criterion.
EXECUTION STATE
📚 F.cross_entropy = Stable log_softmax + nll_loss in one numerically-safe call.
4from torch.utils.data import DataLoader
Standard PyTorch DataLoader. Paper trainer takes one of these per split.
AMNL's RUL loss - paper-canonical from <code>paper_ieee_tii/grace/core/weighted_mse.py</code>. Exactly the function from §14.1 / §14.2.
8from grace.core.baselines import FixedWeightLoss
The 0.5/0.5 task combiner - paper-canonical from <code>paper_ieee_tii/grace/core/baselines.py</code>. AMNL ships with FixedWeightLoss(0.5, 0.5); GABA replaces this in Part VI.
ONE training step - exactly the body of the paper's _train_epoch inner loop, factored out into a function for clarity. All five components are passed in - swap any of them for ablations.
EXECUTION STATE
⬇ input: model = DualTaskModel - the §11.4 architecture.
⬇ input: optimizer = torch.optim.Optimizer (Adam in the paper).
⬇ input: mtl_loss = FixedWeightLoss(0.5, 0.5) for AMNL. GABALoss for Part VI.
⬇ input: rul_criterion = moderate_weighted_mse_loss - the AMNL RUL loss.
⬇ input: grad_clip = 1.0 = Clip gradient norm to this value. Paper standard.
⬆ returns = Dict {loss, rul_loss, hs_loss, grad_norm} for logging.
21model.train()
Switch to training mode. Activates dropout and uses batch stats for BatchNorm. The paper trainer always calls this before the loop.
EXECUTION STATE
📚 .train(mode=True) = Sets self.training = True on the module and all sub-modules.
22rul_tgt = rul_tgt.view(-1, 1)
Reshape to (B, 1) so it matches the model's rul_pred shape. moderate_weighted_mse_loss flattens both internally, but matching shapes upfront keeps everything tidy.
EXECUTION STATE
📚 .view(*shape) = Returns a view with the requested shape, sharing storage.
⬇ arg: shape = (-1, 1) = -1 means infer; 1 fixes the second dim. (B,) → (B, 1).
24optimizer.zero_grad()
Reset .grad before the new backward. Paper uses set_to_none=True (default ≥1.7) for speed.
⬇ arg 1: input = hs_logits = (B, 3) raw logits, NOT probabilities.
⬇ arg 2: target = hs_tgt = (B,) int64 class indices, NOT one-hot.
⬆ result: hs_loss = 0-D scalar tensor.
29loss = mtl_loss(rul_loss, hs_loss)
Calls FixedWeightLoss.forward(rul_loss, hs_loss) ⇒ 0.5 · rul_loss + 0.5 · hs_loss. The Module wrapper means we can swap to GABALoss in Part VI without changing the rest of the loop.
EXECUTION STATE
→ why a Module? = FixedWeightLoss is an nn.Module so swap-in/swap-out works in trainer code without conditionals.
⬆ result: loss = 0-D scalar tensor with autograd graph stretching back to model parameters.
31loss.backward()
Reverse-mode autograd through the whole graph: heads → backbone → input.
EXECUTION STATE
📚 .backward(retain_graph=False) = Backprops through the graph, accumulating into .grad on every leaf with requires_grad=True. Frees the graph.
→ effect = Every parameter in DualTaskModel now has .grad populated.
Compute the L2 norm of all gradients (concatenated) and rescale them so the norm does not exceed grad_clip. The trailing underscore marks it as in-place.
EXECUTION STATE
📚 torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0) = Computes the total norm of `parameters[*].grad`; if it exceeds `max_norm`, scales every grad in place by max_norm / total_norm.
⬇ arg 1: parameters = model.parameters() = Iterator over all learnable params.
⬇ arg 2: max_norm = grad_clip = 1.0 = Paper default. Catches the rare exploding-grad batch on AMNL.
⬆ result: grad_norm = 0-D scalar tensor - the PRE-clip total grad norm.
33optimizer.step()
Apply the Adam update: θ ← θ - lr · m̂ / (√v̂ + ε).
EXECUTION STATE
📚 optimizer.step() = Reads .grad on every parameter and applies the optimiser update rule.
35return { ... }
Logging dict. .item() extracts each scalar tensor as a Python float so we can JSON-serialise / write to TensorBoard.
EXECUTION STATE
📚 .item() = 0-D tensor → Python float.
⬆ keys = loss, rul_loss, hs_loss, grad_norm.
44torch.manual_seed(0)
Repro.
EXECUTION STATE
📚 torch.manual_seed(s) = Set the global PyTorch PRNG.
→ reading = rul_loss is ~4400× hs_loss at init. Even with FixedWeightLoss(0.5, 0.5) the total is dominated by the RUL term. AMNL fixes the WITHIN-RUL sample weighting (so near-failure samples pull harder on the head); it does NOT fix the BETWEEN-task imbalance (Part VI does that).
The five-component recipe transfers wherever you have (a) a shared backbone, (b) a regression branch with a known cost asymmetry, and (c) a classification or auxiliary branch. The only files that change are the dataset and the sample-weight schedule.
Domain
DualTaskModel input
RUL loss
Health loss
Combiner
RUL prediction (this book)
(B, 30, 14) C-MAPSS
moderate_weighted_mse_loss
F.cross_entropy
FixedWeightLoss(0.5, 0.5)
Battery SoH + fault type
(B, 100, 5) cycling data
moderate_weighted_mse_loss(max_rul=1.0)
F.cross_entropy
FixedWeightLoss(0.5, 0.5)
Wind turbine RUL + fault tag
(B, 144, 12) SCADA
moderate_weighted_mse_loss(max_rul=720)
F.cross_entropy
FixedWeightLoss(0.7, 0.3)
MRI tumour growth + benign/malign
(B, 6, vol) follow-ups
moderate_weighted_mse_loss(max_rul=20)
F.binary_cross_entropy_with_logits
FixedWeightLoss(0.5, 0.5)
Bridge crack growth + condition rating
(B, T, sensors) strain
moderate_weighted_mse_loss(max_rul=Lcr)
F.cross_entropy
FixedWeightLoss(0.6, 0.4)
Disk RUL + SMART anomaly type
(B, 30, 16) SMART
moderate_weighted_mse_loss(max_rul=180)
F.cross_entropy
FixedWeightLoss(0.5, 0.5)
Three Integration Pitfalls
Pitfall 1: Forgetting .view(-1, 1) on rul_tgt. DualTaskModel returns rul_pred of shape (B, 1). moderate_weighted_mse_loss flattens internally, but if your target is (B,) and your pred is (B, 1) the broadcasted residual becomes (B, B) - a hidden bug that produces plausible-looking loss values. Always reshape upfront.
Pitfall 2: Skipping clip_grad_norm_. AMNL's sample weights make outlier batches even more outlier-y. Without grad clipping at 1.0, occasional spike batches push parameters into a region the optimiser can't escape from. Paper standard is grad_clip=1.0; never disable it.
Pitfall 3: Calling model.eval() instead of model.train(). In eval mode dropout is OFF and BatchNorm uses running stats. Training in this mode silently produces a deterministic but suboptimal model that overfits. Always call model.train() at the start of the step.
The point. AMNL is paper-canonical with five components: DualTaskModel, weighted MSE, cross-entropy, FixedWeightLoss(0.5/0.5), Adam. The training step factors cleanly so swapping the combiner (FixedWeightLoss → GABALoss) is the only edit needed for Part VI.
Takeaway — End of Chapter 14
Five files, one step. DualTaskModel, weighted_mse, cross_entropy, FixedWeightLoss, Adam.
Shape contract. seq (B, T, c_in) → (rul_pred (B, 1), hs_logits (B, K)). rul_tgt (B,) → reshape to (B, 1) before the loss.
Grad clip = 1.0. Paper default. Never disable.
FixedWeightLoss(0.5, 0.5) ⇒ AMNL fixes WITHIN-task weighting only. Part VI swaps this for GABA to fix BETWEEN-task imbalance.
End of Chapter 14. Chapter 15 is the AMNL training pipeline (data → optimiser → checkpointing → results).