Chapter 14
12 min read
Section 58 of 121

PyTorch Implementation

AMNL — Failure-Biased Weighted MSE

How the Pieces Wire Together

§14.1-§14.3 designed AMNL one piece at a time: the sample-weight schedule, the linear shape, the w_max ceiling. This section plugs them into the §11.4 DualTaskModel and runs one full training step end to end.

Five components, each in its own paper file:

ComponentPaper filePurpose
DualTaskModelpaper_ieee_tii/grace/models/dual_task_model.pyshared backbone + RUL & health heads
moderate_weighted_mse_losspaper_ieee_tii/grace/core/weighted_mse.pyAMNL RUL loss with sample weights
F.cross_entropytorch.nn.functionalhealth-branch loss
FixedWeightLoss(0.5, 0.5)paper_ieee_tii/grace/core/baselines.py0.5/0.5 task combiner
torch.optim.Adam(lr=1e-3)torch.optimoptimiser - paper default
Plug-and-play. Swapping AMNL → GABA in Part VI changes ONE line: the mtl_loss assignment. Everything else - DualTaskModel, weighted_mse, Adam - stays unchanged. That is the point of the paper's factory pattern.

The Files in paper_ieee_tii

For reference, the paper code is organised as follows. Every section in this chapter corresponds to a specific file below.

PathSectionLines (approx)
grace/core/weighted_mse.py§14.1, §14.239
grace/core/baselines.py (FixedWeightLoss)§14.4~17 of 437
grace/core/baselines.py (BaseMTLLoss)§14.4~12 of 437
grace/models/dual_task_model.py§11.454
grace/models/task_heads.py§11.2, §11.353
grace/models/backbone.py§8-§11.1152
grace/training/trainer.py§14.4290
Reproduce the paper. cd paper_ieee_tii && python experiments/train_amnl_v7.py --dataset FD002 --seed 0 runs the full AMNL pipeline. The training step inside that script is exactly the function shown in the PyTorch block below.

Interactive: One Step, Eight Stages

Click through the eight stages. Each stage corresponds to one or two lines of the PyTorch block - see the colour match.

Loading training step tracer…
Watch the gradient stage. At stage 6 (backward) the autograd engine flows the loss gradient back through the heads, into the FC funnel, then into the attention block, BiLSTM, and CNN. The §12 imbalance shows up here - on the shared backbone the RUL gradient is ~500× the HS gradient, even with FixedWeightLoss(0.5, 0.5).

Python: Manual Step from Scratch

Pure NumPy reference - same algorithm as the paper, but with all gradients computed analytically so the chain rule is visible. The toy backbone is a single Linear layer (real backbone is CNN+BiLSTM+Attention) but the loss machinery is identical.

amnl_training_step() — analytic forward + backward
🐍amnl_step_numpy.py
1import numpy as np

NumPy is the only dependency of this from-scratch reference implementation. We use ndarray, broadcasting, np.maximum, np.clip, np.exp, np.log, np.arange, plus matmul (@) for the analytic gradients.

EXECUTION STATE
📚 numpy = Library: ndarray + broadcasting + linear algebra + math.
as np = Universal alias.
4def amnl_training_step(seq, rul_tgt, health_tgt, params, lr=1e-3, max_rul=125.0) -> dict:

ONE gradient-descent step. Mirrors the paper's _train_epoch inner loop (paper_ieee_tii/grace/training/trainer.py:243-284) but with a single-Linear backbone instead of CNN+BiLSTM+Attention so the algebra is visible.

EXECUTION STATE
⬇ input: seq = (B, T, F_in) - batch of windows. Shape (4, 30, 14) in the smoke test.
⬇ input: rul_tgt = (B, 1) - capped RUL targets. The trainer reshapes via .view(-1, 1).
⬇ input: health_tgt = (B,) - integer class labels.
⬇ input: params = Dict of weight matrices: W_back, W_rul, W_hs.
⬇ input: lr = 1e-3 = Learning rate. Paper uses Adam at this lr; we use plain SGD here for clarity.
⬇ input: max_rul = 125.0 = Paper RUL cap.
⬆ returns = dict {rul_loss, health_loss, total_loss, weights_min, weights_max} for logging.
14z = seq.mean(axis=1) @ params['W_back']

Toy backbone: average over time then project. Real backbone is CNN→BiLSTM→Attention; the result is the same shape (B, 256) and the same role (shared 256-D feature vector).

EXECUTION STATE
📚 .mean(axis) = Reduce-mean. axis=1 averages over time, collapsing (B, T, F_in) → (B, F_in).
⬇ arg: axis=1 = Time axis. Real model uses CNN+BiLSTM+Attention here; we use mean for clarity.
@ = NumPy matmul. (B, F_in) @ (F_in, 256) = (B, 256).
⬆ result: z = (4, 256) shared feature matrix.
15rul_pred = np.maximum(0.0, z @ params['W_rul'])

Linear projection to a scalar, then ReLU/clamp ≥ 0 - matches the paper&apos;s <code>torch.clamp(rul_head(features), min=0.0)</code> in DualTaskModel.forward.

EXECUTION STATE
📚 np.maximum(a, b) = Element-wise max of two arrays/scalars.
⬇ arg 1: a = 0.0 = Lower bound. RUL must be non-negative.
⬇ arg 2: b = z @ W_rul = (B, 256) @ (256, 1) = (B, 1) raw RUL prediction.
@ = NumPy matmul.
⬆ result: rul_pred = (4, 1) - non-negative RUL predictions.
16health_logits = z @ params['W_hs']

Raw class logits. NO softmax - F.cross_entropy will apply log_softmax internally.

EXECUTION STATE
@ = NumPy matmul. (B, 256) @ (256, 3) = (B, 3).
⬆ result: health_logits = (4, 3) raw logits.
19weights = 1.0 + np.clip(1.0 - rul_tgt[:, 0] / max_rul, 0, 1)

AMNL paper formula. rul_tgt[:, 0] selects the first column of the (B, 1) target into a (B,) vector. The clip floors above-cap engines at w=1.

EXECUTION STATE
📚 np.clip(arr, a_min, a_max) = Element-wise clip.
→ indexing [:, 0] = Drop the trailing dim. (B, 1) → (B,) for cleaner broadcasting.
→ at rul_tgt = [5, 40, 90, 125] = weights = [1 + clip(0.96, 0, 1), 1 + clip(0.68, 0, 1), 1 + clip(0.28, 0, 1), 1 + clip(0, 0, 1)] = [1.96, 1.68, 1.28, 1.00]
⬆ result: weights = [1.96, 1.68, 1.28, 1.00] ← linear in target
22residual = (rul_pred - rul_tgt)[:, 0]

Element-wise residual, then drop the trailing dim.

EXECUTION STATE
operator: - = Element-wise subtract. (B, 1) - (B, 1) = (B, 1).
→ indexing [:, 0] = (B, 1) → (B,).
⬆ result: residual = (4,) signed errors.
23rul_loss = float((weights * residual ** 2).mean())

Paper&apos;s moderate_weighted_mse_loss. Element-wise weight × squared residual, plain .mean() reduction (NOT normalised by weight sum), cast to Python float for logging.

EXECUTION STATE
operator: * = Element-wise.
operator: ** 2 = Element-wise square.
📚 .mean() = Reduce-mean to a 0-D scalar.
📚 float(x) = 0-D ndarray → Python float.
⬆ result: rul_loss = Scalar - the weighted MSE.
26z_max = health_logits.max(-1, keepdims=True)

Log-sum-exp shift for numerical stability.

EXECUTION STATE
📚 .max(axis, keepdims) = Reduce-max along an axis.
⬇ arg: axis = -1 = Last axis = 3 classes.
⬇ arg: keepdims = True = Output shape (B, 1) for broadcasting.
27log_p = (health_logits - z_max) - np.log(np.exp(health_logits - z_max).sum(-1, keepdims=True))

Stable log-softmax: log p_k = z_k - log Σ exp(z_j) with the max-subtraction trick.

EXECUTION STATE
📚 np.exp(arr) = Element-wise e^x.
📚 np.log(arr) = Element-wise log.
📚 .sum(axis, keepdims) = Reduce-sum.
⬆ result: log_p = (B, 3) log-probabilities.
30health_loss = float(-log_p[np.arange(len(health_tgt)), health_tgt].mean())

Negative log-likelihood of the true class - the mean cross-entropy.

EXECUTION STATE
📚 np.arange(stop) = Integer range [0, stop).
📚 len(seq) = Python built-in. For (B,) returns B.
→ fancy indexing = log_p[np.arange(B), health_tgt] picks log_p[i, health_tgt[i]] per i. Shape (B,).
⬆ result: health_loss = Scalar - the mean CE.
33total_loss = 0.5 * rul_loss + 0.5 * health_loss

Equal-weighted combine - paper&apos;s FixedWeightLoss(0.5, 0.5). Part VI replaces this with GABA&apos;s adaptive weights.

EXECUTION STATE
→ why 0.5/0.5? = Paper&apos;s baseline. AMNL handles WITHIN-task weighting (samples); the task weights stay equal.
⬆ result: total_loss = Scalar - the AMNL training objective.
36d_rul = (2.0 / len(rul_tgt)) * weights * residual

Analytic ∂L_rul/∂rul_pred = (2/B) · w · residual (per sample). Same formula as §14.1 line 11.

EXECUTION STATE
→ derivation = L_rul = (1/B) Σ w_i · (pred - tgt)². Differentiate wrt pred_i: (2/B) · w_i · (pred - tgt).
⬆ result: d_rul = (B,) per-sample gradient on the prediction.
37grad_rul = z.T @ (d_rul * (rul_pred[:, 0] > 0))[:, None]

Backprop through W_rul. The (rul_pred &gt; 0) mask handles the ReLU/clamp on the output - if the pre-clamp output was negative, the gradient through that sample is 0.

EXECUTION STATE
operator: > 0 = Boolean mask. True where rul_pred is strictly positive (non-clipped).
→ [:, None] = Add a trailing axis: (B,) → (B, 1) so the matmul shapes match.
→ .T = NumPy transpose. z.T has shape (256, B).
@ = Matmul: (256, B) @ (B, 1) = (256, 1).
⬆ result: grad_rul = (256, 1) - matches W_rul shape.
39p = np.exp(log_p)

Softmax probabilities, recovered for the gradient computation.

EXECUTION STATE
⬆ result: p = (B, 3) probabilities.
40onehot = np.zeros_like(p); onehot[np.arange(len(health_tgt)), health_tgt] = 1.0

One-hot encoding via fancy indexing.

EXECUTION STATE
📚 np.zeros_like(arr) = Allocate zeros with same shape and dtype.
→ fancy assign = Sets onehot[i, health_tgt[i]] = 1 for each i.
41d_logits = (p - onehot) / len(health_tgt)

Analytic ∂L_hs/∂logits = (p - onehot) / B.

EXECUTION STATE
→ bound = Each element of (p - onehot) is in [-1, 1]. Per-element gradient bounded by 1/B.
⬆ result: d_logits = (B, 3) per-sample logit gradient.
42grad_hs = z.T @ d_logits

Backprop through W_hs.

EXECUTION STATE
@ = Matmul: (256, B) @ (B, 3) = (256, 3).
⬆ result: grad_hs = (256, 3) - matches W_hs shape.
43grad_back = (seq.mean(axis=1)).T @ ( 0.5 * d_rul[:, None] * params['W_rul'].T + 0.5 * d_logits @ params['W_hs'].T )

Backprop into the SHARED W_back. Both heads contribute via the chain rule, weighted by the 0.5 task weights from FixedWeightLoss. THIS is the line where the §12 gradient imbalance shows up - the RUL term dominates the sum even with the 0.5 weight.

EXECUTION STATE
→ 0.5 * d_rul[:, None] * W_rul.T = RUL contribution to ∂L/∂z, shape (B, 256).
→ 0.5 * d_logits @ W_hs.T = Health contribution to ∂L/∂z, shape (B, 256).
→ @ = (F_in, B) @ (B, 256) = (F_in, 256).
⬆ result: grad_back = (F_in, 256) - matches W_back shape.
→ THIS is the imbalance = RUL contribution scale ≈ residual size (~10-100); HS contribution scale ≈ (p - onehot) / B (~0.01-0.05). Even with 0.5/0.5 task weights, RUL dominates the shared-backbone gradient by ~500×.
47params['W_back'] -= lr * grad_back

Plain SGD update on the shared backbone. Real paper uses Adam.

48params['W_rul'] -= lr * 0.5 * grad_rul

Update RUL head with the 0.5 task weight pre-applied.

49params['W_hs'] -= lr * 0.5 * grad_hs

Update health head.

51return { ... }

Logging dict for the trainer to consume.

EXECUTION STATE
⬆ keys = rul_loss, health_loss, total_loss, weights_min, weights_max - matches what paper trainer logs.
60np.random.seed(0)

Repro.

EXECUTION STATE
📚 np.random.seed(s) = Sets NumPy&apos;s legacy global PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
61B, T, F_in = 4, 30, 14

Smoke-test sizes - 4 engines, 30-cycle window, 14 sensors.

62params = { ... }

Three weight matrices initialised at small Gaussian scale.

EXECUTION STATE
📚 np.random.randn(*size) = Sample i.i.d. N(0, 1).
📚 .astype(np.float32) = Cast to float32.
→ * 0.05 = Small init scale to keep activations near zero.
67seq = np.random.randn(B, T, F_in).astype(np.float32)

Random window data.

68rul_tgt = np.array([[5.0], [40.0], [90.0], [125.0]], dtype=np.float32)

Hand-picked targets spanning the RUL range - 5 (critical), 40 (degrading), 90 (healthy), 125 (cap). Lets us see weights[i] differ by sample.

69health_tgt = np.array([2, 1, 0, 0])

Class indices: critical, degrading, normal, normal. Aligned with the RUL ranges.

71stats = amnl_training_step(seq, rul_tgt, health_tgt, params)

Run the step.

72print(f"rul_loss : {stats['rul_loss']:.4f}")

Show the RUL loss.

EXECUTION STATE
→ :.4f = Float, 4 decimals.
Output = rul_loss : ~5500.0 (huge - residuals near init are just the target magnitudes themselves)
73print(f"health_loss : {stats['health_loss']:.4f}")

CE bound near log K = log 3 ≈ 1.099.

EXECUTION STATE
Output = health_loss : ~1.10 (uniform softmax at init)
74print(f"total_loss : {stats['total_loss']:.4f}")

0.5 · rul_loss + 0.5 · health_loss.

EXECUTION STATE
Output = total_loss : ~2750.5 (dominated by rul_loss)
→ reading = Total is essentially equal to 0.5 · rul_loss because health_loss is bounded at ~1. AMNL&apos;s sample weighting helps WITHIN the rul_loss; it does NOT fix the BETWEEN-task imbalance.
75print(f"weights : [{stats['weights_min']:.2f}, {stats['weights_max']:.2f}]")

Schedule range. With targets [5, 40, 90, 125] we get [1.96, 1.68, 1.28, 1.00].

EXECUTION STATE
Output = weights : [1.00, 1.96]
47 lines without explanation
1import numpy as np
2
3
4def amnl_training_step(seq:        np.ndarray,
5                         rul_tgt:    np.ndarray,
6                         health_tgt: np.ndarray,
7                         params:     dict,
8                         lr:         float = 1e-3,
9                         max_rul:    float = 125.0) -> dict:
10    """One AMNL gradient-descent step in pure NumPy.
11
12    Mirrors paper_ieee_tii/grace/training/trainer.py::_train_epoch but with
13    a 1-Linear backbone instead of CNN+BiLSTM+Attention, so the algebra is
14    visible. Returns logging dict.
15    """
16    # 1. Forward
17    z            = seq.mean(axis=1) @ params["W_back"]            # (B, 256) shared
18    rul_pred     = np.maximum(0.0, z @ params["W_rul"])            # (B, 1) clamped ≥ 0
19    health_logits = z @ params["W_hs"]                              # (B, 3)
20
21    # 2. AMNL sample weights
22    weights = 1.0 + np.clip(1.0 - rul_tgt[:, 0] / max_rul, 0, 1)   # (B,)
23
24    # 3. Weighted MSE on RUL branch
25    residual = (rul_pred - rul_tgt)[:, 0]                           # (B,)
26    rul_loss = float((weights * residual ** 2).mean())
27
28    # 4. Cross-entropy on health branch (stable log-softmax)
29    z_max  = health_logits.max(-1, keepdims=True)
30    log_p  = (health_logits - z_max) - np.log(
31        np.exp(health_logits - z_max).sum(-1, keepdims=True)
32    )
33    health_loss = float(-log_p[np.arange(len(health_tgt)), health_tgt].mean())
34
35    # 5. Combine: FixedWeightLoss(0.5, 0.5)
36    total_loss = 0.5 * rul_loss + 0.5 * health_loss
37
38    # 6. Backward (analytic - the tedious part)
39    d_rul     = (2.0 / len(rul_tgt)) * weights * residual           # ∂L_rul/∂rul_pred
40    grad_rul  = z.T @ (d_rul * (rul_pred[:, 0] > 0))[:, None]        # (256, 1)
41
42    p          = np.exp(log_p)
43    onehot     = np.zeros_like(p); onehot[np.arange(len(health_tgt)), health_tgt] = 1.0
44    d_logits   = (p - onehot) / len(health_tgt)                     # (B, 3)
45    grad_hs    = z.T @ d_logits                                     # (256, 3)
46    grad_back  = (seq.mean(axis=1)).T @ (
47        0.5 * d_rul[:, None] * params["W_rul"].T +
48        0.5 * d_logits @ params["W_hs"].T
49    )                                                                # (input_dim, 256)
50
51    # 7. Adam-style step (simplified to plain SGD here for clarity)
52    params["W_back"] -= lr * grad_back
53    params["W_rul"]  -= lr * 0.5 * grad_rul
54    params["W_hs"]   -= lr * 0.5 * grad_hs
55
56    return {
57        "rul_loss":   rul_loss,
58        "health_loss": health_loss,
59        "total_loss":  total_loss,
60        "weights_min":  float(weights.min()),
61        "weights_max":  float(weights.max()),
62    }
63
64
65# ---------- Smoke test ----------
66np.random.seed(0)
67B, T, F_in = 4, 30, 14
68params = {
69    "W_back": np.random.randn(F_in, 256).astype(np.float32) * 0.05,
70    "W_rul":  np.random.randn(256, 1).astype(np.float32)    * 0.05,
71    "W_hs":   np.random.randn(256, 3).astype(np.float32)    * 0.05,
72}
73seq        = np.random.randn(B, T, F_in).astype(np.float32)
74rul_tgt    = np.array([[5.0], [40.0], [90.0], [125.0]], dtype=np.float32)
75health_tgt = np.array([2, 1, 0, 0])
76
77stats = amnl_training_step(seq, rul_tgt, health_tgt, params)
78print(f"rul_loss    : {stats['rul_loss']:.4f}")
79print(f"health_loss : {stats['health_loss']:.4f}")
80print(f"total_loss  : {stats['total_loss']:.4f}")
81print(f"weights     : [{stats['weights_min']:.2f}, {stats['weights_max']:.2f}]")

PyTorch: The Paper's Step

The exact training step from paper_ieee_tii/grace/training/trainer.py (lines 249-283), factored into a function. Reproduces the AMNL paper if you wire it into a DataLoader and run for the paper's 40 epochs.

amnl_training_step() — paper-canonical PyTorch
🐍amnl_step_torch.py
1import torch

Top-level PyTorch.

EXECUTION STATE
📚 torch = Tensor library + autograd + nn + optim.
2import torch.nn as nn

Module containers.

3import torch.nn.functional as F

Stateless ops - F.cross_entropy is the health-branch criterion.

EXECUTION STATE
📚 F.cross_entropy = Stable log_softmax + nll_loss in one numerically-safe call.
4from torch.utils.data import DataLoader

Standard PyTorch DataLoader. Paper trainer takes one of these per split.

7from grace.core.weighted_mse import moderate_weighted_mse_loss

AMNL&apos;s RUL loss - paper-canonical from <code>paper_ieee_tii/grace/core/weighted_mse.py</code>. Exactly the function from §14.1 / §14.2.

8from grace.core.baselines import FixedWeightLoss

The 0.5/0.5 task combiner - paper-canonical from <code>paper_ieee_tii/grace/core/baselines.py</code>. AMNL ships with FixedWeightLoss(0.5, 0.5); GABA replaces this in Part VI.

9from grace.models.dual_task_model import DualTaskModel

The §11.4 architecture - paper-canonical from <code>paper_ieee_tii/grace/models/dual_task_model.py</code>.

12def amnl_training_step(model, optimizer, mtl_loss, rul_criterion, hs_criterion, seq, rul_tgt, hs_tgt, grad_clip=1.0) -> dict:

ONE training step - exactly the body of the paper&apos;s _train_epoch inner loop, factored out into a function for clarity. All five components are passed in - swap any of them for ablations.

EXECUTION STATE
⬇ input: model = DualTaskModel - the §11.4 architecture.
⬇ input: optimizer = torch.optim.Optimizer (Adam in the paper).
⬇ input: mtl_loss = FixedWeightLoss(0.5, 0.5) for AMNL. GABALoss for Part VI.
⬇ input: rul_criterion = moderate_weighted_mse_loss - the AMNL RUL loss.
⬇ input: hs_criterion = F.cross_entropy.
⬇ input: seq = (B, T, c_in) - batch of windows.
⬇ input: rul_tgt = (B,) or (B, 1) - capped RUL targets.
⬇ input: hs_tgt = (B,) - class indices.
⬇ input: grad_clip = 1.0 = Clip gradient norm to this value. Paper standard.
⬆ returns = Dict {loss, rul_loss, hs_loss, grad_norm} for logging.
21model.train()

Switch to training mode. Activates dropout and uses batch stats for BatchNorm. The paper trainer always calls this before the loop.

EXECUTION STATE
📚 .train(mode=True) = Sets self.training = True on the module and all sub-modules.
22rul_tgt = rul_tgt.view(-1, 1)

Reshape to (B, 1) so it matches the model&apos;s rul_pred shape. moderate_weighted_mse_loss flattens both internally, but matching shapes upfront keeps everything tidy.

EXECUTION STATE
📚 .view(*shape) = Returns a view with the requested shape, sharing storage.
⬇ arg: shape = (-1, 1) = -1 means infer; 1 fixes the second dim. (B,) → (B, 1).
24optimizer.zero_grad()

Reset .grad before the new backward. Paper uses set_to_none=True (default ≥1.7) for speed.

EXECUTION STATE
📚 optimizer.zero_grad(set_to_none=True) = PyTorch ≥ 1.7 default - faster, less memory.
25rul_pred, hs_logits = model(seq)

DualTaskModel.forward returns the (rul_pred, health_logits) tuple. rul_pred is already clamped ≥ 0 by the model.

EXECUTION STATE
⬆ result: rul_pred = (B, 1) non-negative scalar per engine.
⬆ result: hs_logits = (B, 3) raw logits.
27rul_loss = rul_criterion(rul_pred, rul_tgt)

Calls moderate_weighted_mse_loss(rul_pred, rul_tgt, max_rul=125.0). Returns a 0-D scalar tensor with autograd graph.

EXECUTION STATE
→ reduction = Plain .mean() (paper formula). NOT normalised by Σw.
⬆ result: rul_loss = 0-D scalar tensor.
28hs_loss = hs_criterion(hs_logits, hs_tgt)

Calls F.cross_entropy(hs_logits, hs_tgt). Standard mean cross-entropy.

EXECUTION STATE
📚 F.cross_entropy(input, target, reduction='mean') = Stable log_softmax + nll_loss.
⬇ arg 1: input = hs_logits = (B, 3) raw logits, NOT probabilities.
⬇ arg 2: target = hs_tgt = (B,) int64 class indices, NOT one-hot.
⬆ result: hs_loss = 0-D scalar tensor.
29loss = mtl_loss(rul_loss, hs_loss)

Calls FixedWeightLoss.forward(rul_loss, hs_loss) ⇒ 0.5 · rul_loss + 0.5 · hs_loss. The Module wrapper means we can swap to GABALoss in Part VI without changing the rest of the loop.

EXECUTION STATE
→ why a Module? = FixedWeightLoss is an nn.Module so swap-in/swap-out works in trainer code without conditionals.
⬆ result: loss = 0-D scalar tensor with autograd graph stretching back to model parameters.
31loss.backward()

Reverse-mode autograd through the whole graph: heads → backbone → input.

EXECUTION STATE
📚 .backward(retain_graph=False) = Backprops through the graph, accumulating into .grad on every leaf with requires_grad=True. Frees the graph.
→ effect = Every parameter in DualTaskModel now has .grad populated.
32grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

Compute the L2 norm of all gradients (concatenated) and rescale them so the norm does not exceed grad_clip. The trailing underscore marks it as in-place.

EXECUTION STATE
📚 torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0) = Computes the total norm of `parameters[*].grad`; if it exceeds `max_norm`, scales every grad in place by max_norm / total_norm.
⬇ arg 1: parameters = model.parameters() = Iterator over all learnable params.
⬇ arg 2: max_norm = grad_clip = 1.0 = Paper default. Catches the rare exploding-grad batch on AMNL.
⬆ result: grad_norm = 0-D scalar tensor - the PRE-clip total grad norm.
33optimizer.step()

Apply the Adam update: θ ← θ - lr · m̂ / (√v̂ + ε).

EXECUTION STATE
📚 optimizer.step() = Reads .grad on every parameter and applies the optimiser update rule.
35return { ... }

Logging dict. .item() extracts each scalar tensor as a Python float so we can JSON-serialise / write to TensorBoard.

EXECUTION STATE
📚 .item() = 0-D tensor → Python float.
⬆ keys = loss, rul_loss, hs_loss, grad_norm.
44torch.manual_seed(0)

Repro.

EXECUTION STATE
📚 torch.manual_seed(s) = Set the global PyTorch PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
45model = DualTaskModel(c_in=14, lstm_hidden=256, num_heads=8, shared_dim=32, num_classes=3)

Instantiate the §11.4 architecture with paper defaults.

EXECUTION STATE
⬇ arg: c_in = 14 = 14 informative C-MAPSS sensors.
⬇ arg: lstm_hidden = 256 = Per-direction BiLSTM hidden size.
⬇ arg: num_heads = 8 = Multi-head attention heads.
⬇ arg: shared_dim = 32 = Output of the FC funnel.
⬇ arg: num_classes = 3 = Health classes.
47optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Paper-default Adam at lr=1e-3.

EXECUTION STATE
📚 torch.optim.Adam(params, lr, betas, eps, weight_decay) = Adam optimiser. Default betas=(0.9, 0.999), eps=1e-8.
⬇ arg: params = model.parameters() = All ~3.4M parameters.
⬇ arg: lr = 1e-3 = Paper default.
48mtl_loss = FixedWeightLoss(rul_weight=0.5, health_weight=0.5)

Equal-weighted task combiner - the AMNL baseline.

EXECUTION STATE
⬇ arg: rul_weight = 0.5 = Half the budget on RUL.
⬇ arg: health_weight = 0.5 = Half the budget on health.
49rul_criterion = moderate_weighted_mse_loss

First-class function reference - no call. Will be called inside the step.

50hs_criterion = F.cross_entropy

First-class function reference.

52seq = torch.randn(4, 30, 14, requires_grad=False)

Synthetic batch of 4 engines, 30-cycle window, 14 sensors.

EXECUTION STATE
📚 torch.randn(*size, requires_grad) = Sample i.i.d. N(0, 1).
⬇ arg: size = (4, 30, 14) = Smoke-test batch size.
⬇ arg: requires_grad = False = Inputs are constant - autograd does not need to track them.
53rul_tgt = torch.tensor([5.0, 40.0, 90.0, 125.0])

Hand-picked targets spanning the RUL range.

EXECUTION STATE
📚 torch.tensor(seq) = Allocate from a Python sequence; default float32.
54hs_tgt = torch.tensor([2, 1, 0, 0])

Class indices: critical, degrading, normal, normal.

56stats = amnl_training_step(model, optimizer, mtl_loss, rul_criterion, hs_criterion, seq, rul_tgt, hs_tgt)

Run one step.

60for k, v in stats.items():

Iterate the logging dict.

EXECUTION STATE
📚 dict.items() = View of (key, value) pairs.
iter vars = k (str), v (float).
LOOP TRACE · 4 iterations
k = 'loss'
v = ≈ 2400 (dominated by rul_loss; AMNL does not yet rebalance tasks)
k = 'rul_loss'
v = ≈ 4800 (huge - residuals near init are large)
k = 'hs_loss'
v = ≈ 1.10 (bounded by log K = 1.099)
k = 'grad_norm'
v = ≈ 0.95 (just under the clip threshold)
61print(f"{k:>12s} : {v:>10.4f}")

Format-string row. :>12s right-aligns the key to width 12; :>10.4f formats the value with 4 decimals at width 10.

EXECUTION STATE
→ :>12s = String, right-aligned, min width 12.
→ :>10.4f = Float, right-aligned, width 10, 4 decimals.
Output (one realisation) = loss : 2407.4231 rul_loss : 4813.7461 hs_loss : 1.1001 grad_norm : 0.9521
→ reading = rul_loss is ~4400× hs_loss at init. Even with FixedWeightLoss(0.5, 0.5) the total is dominated by the RUL term. AMNL fixes the WITHIN-RUL sample weighting (so near-failure samples pull harder on the head); it does NOT fix the BETWEEN-task imbalance (Part VI does that).
34 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from torch.utils.data import DataLoader
5
6# Paper-canonical pieces (paper_ieee_tii/grace/...)
7from grace.core.weighted_mse import moderate_weighted_mse_loss   # AMNL RUL loss
8from grace.core.baselines    import FixedWeightLoss              # 0.5/0.5 combiner
9from grace.models.dual_task_model import DualTaskModel            # §11.4 model
10
11
12def amnl_training_step(model:        DualTaskModel,
13                         optimizer:    torch.optim.Optimizer,
14                         mtl_loss:     FixedWeightLoss,
15                         rul_criterion: callable,
16                         hs_criterion:  callable,
17                         seq:          torch.Tensor,
18                         rul_tgt:      torch.Tensor,
19                         hs_tgt:       torch.Tensor,
20                         grad_clip:    float = 1.0) -> dict:
21    """One AMNL training step - paper-canonical recipe.
22
23    Mirrors paper_ieee_tii/grace/training/trainer.py::_train_epoch lines 249-283.
24    """
25    model.train()
26    rul_tgt = rul_tgt.view(-1, 1)                                  # match head output shape
27
28    optimizer.zero_grad()
29    rul_pred, hs_logits = model(seq)
30
31    rul_loss = rul_criterion(rul_pred, rul_tgt)
32    hs_loss  = hs_criterion(hs_logits, hs_tgt)
33    loss     = mtl_loss(rul_loss, hs_loss)                         # 0.5 · rul + 0.5 · hs
34
35    loss.backward()
36    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
37    optimizer.step()
38
39    return {
40        "loss":       loss.item(),
41        "rul_loss":   rul_loss.item(),
42        "hs_loss":    hs_loss.item(),
43        "grad_norm":  grad_norm.item(),
44    }
45
46
47# ---------- Smoke test ----------
48torch.manual_seed(0)
49model        = DualTaskModel(c_in=14, lstm_hidden=256, num_heads=8,
50                              shared_dim=32, num_classes=3)
51optimizer    = torch.optim.Adam(model.parameters(), lr=1e-3)
52mtl_loss     = FixedWeightLoss(rul_weight=0.5, health_weight=0.5)
53rul_criterion = moderate_weighted_mse_loss                          # paper-canonical
54hs_criterion  = F.cross_entropy
55
56seq     = torch.randn(4, 30, 14, requires_grad=False)
57rul_tgt = torch.tensor([5.0, 40.0, 90.0, 125.0])
58hs_tgt  = torch.tensor([2, 1, 0, 0])
59
60stats = amnl_training_step(model, optimizer, mtl_loss,
61                             rul_criterion, hs_criterion,
62                             seq, rul_tgt, hs_tgt)
63
64for k, v in stats.items():
65    print(f"{k:>12s} : {v:>10.4f}")

Drop-In for Other Domains

The five-component recipe transfers wherever you have (a) a shared backbone, (b) a regression branch with a known cost asymmetry, and (c) a classification or auxiliary branch. The only files that change are the dataset and the sample-weight schedule.

DomainDualTaskModel inputRUL lossHealth lossCombiner
RUL prediction (this book)(B, 30, 14) C-MAPSSmoderate_weighted_mse_lossF.cross_entropyFixedWeightLoss(0.5, 0.5)
Battery SoH + fault type(B, 100, 5) cycling datamoderate_weighted_mse_loss(max_rul=1.0)F.cross_entropyFixedWeightLoss(0.5, 0.5)
Wind turbine RUL + fault tag(B, 144, 12) SCADAmoderate_weighted_mse_loss(max_rul=720)F.cross_entropyFixedWeightLoss(0.7, 0.3)
MRI tumour growth + benign/malign(B, 6, vol) follow-upsmoderate_weighted_mse_loss(max_rul=20)F.binary_cross_entropy_with_logitsFixedWeightLoss(0.5, 0.5)
Bridge crack growth + condition rating(B, T, sensors) strainmoderate_weighted_mse_loss(max_rul=Lcr)F.cross_entropyFixedWeightLoss(0.6, 0.4)
Disk RUL + SMART anomaly type(B, 30, 16) SMARTmoderate_weighted_mse_loss(max_rul=180)F.cross_entropyFixedWeightLoss(0.5, 0.5)

Three Integration Pitfalls

Pitfall 1: Forgetting .view(-1, 1) on rul_tgt. DualTaskModel returns rul_pred of shape (B, 1). moderate_weighted_mse_loss flattens internally, but if your target is (B,) and your pred is (B, 1) the broadcasted residual becomes (B, B) - a hidden bug that produces plausible-looking loss values. Always reshape upfront.
Pitfall 2: Skipping clip_grad_norm_. AMNL's sample weights make outlier batches even more outlier-y. Without grad clipping at 1.0, occasional spike batches push parameters into a region the optimiser can't escape from. Paper standard is grad_clip=1.0; never disable it.
Pitfall 3: Calling model.eval() instead of model.train(). In eval mode dropout is OFF and BatchNorm uses running stats. Training in this mode silently produces a deterministic but suboptimal model that overfits. Always call model.train() at the start of the step.
The point. AMNL is paper-canonical with five components: DualTaskModel, weighted MSE, cross-entropy, FixedWeightLoss(0.5/0.5), Adam. The training step factors cleanly so swapping the combiner (FixedWeightLoss → GABALoss) is the only edit needed for Part VI.

Takeaway — End of Chapter 14

  • Five files, one step. DualTaskModel, weighted_mse, cross_entropy, FixedWeightLoss, Adam.
  • Shape contract. seq (B, T, c_in) → (rul_pred (B, 1), hs_logits (B, K)). rul_tgt (B,) → reshape to (B, 1) before the loss.
  • Grad clip = 1.0. Paper default. Never disable.
  • FixedWeightLoss(0.5, 0.5) ⇒ AMNL fixes WITHIN-task weighting only. Part VI swaps this for GABA to fix BETWEEN-task imbalance.
  • End of Chapter 14. Chapter 15 is the AMNL training pipeline (data → optimiser → checkpointing → results).
Loading comments...