Chapter 13
12 min read
Section 54 of 121

Mapping Regimes to AMNL, GABA, and GRACE

The Accuracy-Safety Tradeoff

One Trunk, Three Loss Designs

§13.3 picked the regime; this section says what to TRAIN. The DualTaskModel of §11.4 stays the same in all three cases; only the loss function changes. The book's three methods - AMNL, GABA, and GRACE - sit at three different points along the (RMSE, NASA) Pareto frontier.

The mapping in one line. AMNL emphasises sample-level conservatism (truck regime). GABA balances per-task gradients on the shared trunk (airline regime). GRACE does both (cruise regime). All three preserve the §11.4 architecture and weight count; only the loss differs.
RegimeMethodCore mechanismBest on (FD002+FD004)
delivery-truckAMNLfailure-biased sample weightingRMSE 7.45
airline-787GABAinverse-gradient task weightingbalanced (RMSE 7.89, NASA 235.7)
cruise-shipGRACEAMNL sample × GABA taskNASA 232.7

AMNL — Failure-Biased Weighted MSE

AMNL weights each SAMPLE by how close to failure it is. Healthy engines (RUL near 125) get the floor weight wmin=1w_{\min} = 1; near-failure engines (RUL near 0) get the ceiling weight wmax=2w_{\max} = 2. Linear schedule:

wi=wmax(wmaxwmin)min(yi,Rmax)Rmax.w_i = w_{\max} - (w_{\max} - w_{\min}) \cdot \dfrac{\min(y_i, R_{\max})}{R_{\max}}.

The total loss is LAMNL=(1λhs)wMSE^+λhsCE\mathcal{L}_{\text{AMNL}} = (1 - \lambda_{\text{hs}}) \cdot \widehat{\text{wMSE}} + \lambda_{\text{hs}} \cdot \text{CE} with fixed task weight λhs=0.5\lambda_{\text{hs}} = 0.5. Chapter 14 derives the schedule, Chapter 15 trains it.

What AMNL fixes. Plain MSE treats a 5-cycle residual at y=120 (irrelevant) the same as a 5-cycle residual at y=2 (life-or-death). AMNL doubles the second one's gradient. The shared backbone is forced to spend MORE capacity on accurate predictions near failure - exactly when accuracy matters.

GABA — Inverse-Gradient Adaptive Weighting

GABA computes per-task gradient norms on the shared backbone each step (using the §12.1 helper) and sets the task weights inversely proportional to those norms:

λt=1/gtt1/gt\lambda_t = \dfrac{1 / \| g_t \|}{\sum_{t'} 1 / \| g_{t'} \|}

Then the shared-backbone gradient is

gshared=tλtgtg_{\text{shared}} = \sum_t \lambda_t \cdot g_t

which has the property λrulgrul=λhsghs\| \lambda_{\text{rul}} \cdot g_{\text{rul}} \| = \| \lambda_{\text{hs}} \cdot g_{\text{hs}} \| - both tasks pull the rope with EQUAL force regardless of their raw gradient magnitudes. The 500× imbalance from §12 is cancelled at the source. Chapters 17-19 derive and train it.

GRACE — Combine Both

GRACE applies AMNL's sample weighting on the regression branch AND GABA's task weighting at the loss combiner:

LGRACE=λrulwMSE^AMNL+λhsCE,λt1/gt\mathcal{L}_{\text{GRACE}} = \lambda_{\text{rul}} \cdot \widehat{\text{wMSE}}_{\text{AMNL}} + \lambda_{\text{hs}} \cdot \text{CE}, \quad \lambda_t \propto 1 / \| g_t \|

The two mechanisms are complementary. AMNL's sample weighting addresses the WITHIN-task asymmetry (near-failure samples matter more); GABA's task weighting addresses the BETWEEN-task imbalance (RUL gradient dominates HS). Chapters 20-23 derive and train GRACE.

Interactive: Pick a Regime, See the Method

The same chooser as §13.3 - now the highlighted “winner” is also the method we recommend for that regime.

Loading deployment regime chooser…
Reading the chart. Slide w into the truck-regime band (0.00-0.15) and AMNL is highlighted - it wins on pure RMSE because failure-biased sample weighting gives the model the most-relevant information. Slide to the airline band (0.15-0.40) and GABA appears - inverse-gradient weighting gets the best balance. Slide to the cruise band (0.40-1.00) and GRACE wins - because it has both mechanisms.

Python: Three Losses Side by Side

All three loss functions in pure NumPy, sharing utilities. The worked example uses the §12.1 measured gradient norms (4.81 and 0.0096) so you can read the numerical effect of GABA directly off the printed output.

loss_amnl, loss_gaba, loss_grace - one file
🐍three_losses_numpy.py
1import numpy as np

NumPy provides the (B,) and (B, K) ndarrays we use to express each loss in vector form. We rely on np.minimum, np.maximum, np.exp, np.log, np.sum, np.mean, np.array, and broadcasting.

EXECUTION STATE
📚 numpy = Library: ndarray + linear algebra + math.
as np = Universal alias.
4def loss_amnl(rul_pred, rul_true, logits, y_hs, R_max=125.0, w_min=1.0, w_max=2.0, lam_hs=0.5) -> float:

AMNL = failure-biased weighted MSE on the RUL branch, plain CE on the classification branch, fixed equal task weights. Best for the delivery-truck regime where RMSE matters more than late-bias.

EXECUTION STATE
⬇ input: rul_pred = (B,) - predicted RUL.
⬇ input: rul_true = (B,) - ground-truth RUL (capped at R_max=125).
⬇ input: logits = (B, K) - raw class logits.
⬇ input: y_hs = (B,) - class indices in {0, 1, 2}.
⬇ input: R_max = 125.0 = RUL cap. Targets above 125 are treated as 125 (§7.2 pieceWise-linear cap).
⬇ input: w_min = 1.0 = Sample weight at y_true = R_max (engine far from failure). Healthy engines get the floor weight.
⬇ input: w_max = 2.0 = Sample weight at y_true = 0 (engine at failure). Critical engines get 2× the weight.
⬇ input: lam_hs = 0.5 = Task weight on the classification branch. 0.5 ⇒ equal weighting with regression.
⬆ returns = Python float scalar - the AMNL loss for this batch.
14y_capped = np.minimum(rul_true, R_max)

Cap targets at R_max=125. Engines past their healthy plateau get the same weight as exactly-125 engines.

EXECUTION STATE
📚 np.minimum(a, b) = Element-wise min. Differs from np.min (which reduces). Two arrays in, one same-shape array out.
⬇ arg 1: a = rul_true = (B,) targets.
⬇ arg 2: b = R_max = 125 = Scalar broadcast against the array.
⬆ result: y_capped = (B,) - same as rul_true except entries > 125 saturate at 125.
15w = w_max - (w_max - w_min) * y_capped / R_max

Linear interpolation: at y_capped=0 ⇒ w=w_max=2.0 (failure imminent), at y_capped=R_max ⇒ w=w_min=1.0 (healthy).

EXECUTION STATE
operator: - = Element-wise subtract / scalar subtract.
operator: * = Scalar × array broadcast.
operator: / = Scalar division.
→ at y_capped=0 = w = 2.0 - 1.0 · 0/125 = 2.0 (heaviest).
→ at y_capped=125 = w = 2.0 - 1.0 · 125/125 = 1.0 (lightest).
→ at y_capped=50 = w = 2.0 - 1.0 · 50/125 = 1.6.
⬆ result: w (5 first) = [varies; e.g. for rul_true[:5]=[44,47,64,67,67]: [1.648, 1.624, 1.488, 1.464, 1.464]]
17sq = (rul_pred - rul_true) ** 2

Squared per-sample residual.

EXECUTION STATE
operator: - = Element-wise subtract.
operator: ** 2 = Element-wise square.
⬆ result: sq = (B,) - all 25.0 in this worked example since rul_pred = rul_true - 5 ⇒ residual=-5 ⇒ sq=25.
18L_rul = float((w * sq).sum() / w.sum())

Weighted MSE: Σ(w_i · sq_i) / Σ w_i. The sample-weighted analogue of mean(sq) - heavier weights on near-failure samples.

EXECUTION STATE
operator: * = Element-wise array × array.
📚 .sum() = Reduce-sum. With no axis, returns a 0-D scalar.
operator: / = Scalar division.
📚 float(x) = 0-D ndarray → Python float.
⬆ result: L_rul = 25.0 in this worked example (residual=-5 everywhere ⇒ weighted mean = 25 regardless of weights).
20z = logits - logits.max(-1, keepdims=True)

Log-sum-exp shift for numerical stability.

EXECUTION STATE
📚 .max(axis, keepdims) = Reduce-max along an axis. With keepdims=True the reduced axis stays size 1 for broadcasting.
⬇ arg: axis = -1 = Last axis = the K class axis.
⬇ arg: keepdims = True = Output shape (B, 1) so logits(B, K) - max(B, 1) broadcasts correctly.
21log_p = z - np.log(np.exp(z).sum(-1, keepdims=True))

Stable log-softmax. log p_k = z_k - log Σ exp(z_j).

EXECUTION STATE
📚 np.exp(arr) = Element-wise e^x.
📚 np.log(arr) = Element-wise natural logarithm.
📚 .sum(axis, keepdims) = Reduce-sum along an axis.
⬆ result: log_p = (B, K) - per-row log-probabilities. Each row sums (after exp) to 1.
22L_hs = float(-log_p[np.arange(len(y_hs)), y_hs].mean())

Negative log-likelihood of the true class - the mean cross-entropy loss.

EXECUTION STATE
📚 np.arange(stop) = Integer range [0, stop).
📚 len(seq) = Python built-in. For (B,) ndarray returns B.
→ fancy indexing = log_p[np.arange(B), y_hs] picks log_p[i, y_hs[i]] for each i. Shape (B,).
📚 .mean() = Reduce-mean. Returns a 0-D scalar.
⬆ result: L_hs = Python float, typically ≈ log K = 1.099 at init.
23return (1.0 - lam_hs) * L_rul + lam_hs * L_hs

Convex combination - same shape as the §13.3 J(w), but here the weights are FIXED at 0.5/0.5 (the AMNL recipe).

EXECUTION STATE
→ at lam_hs = 0.5 = 0.5 · L_rul + 0.5 · L_hs - the AMNL default.
⬆ return = Python float - the total AMNL loss.
26def loss_gaba(grad_norm_rul, grad_norm_hs, L_rul, L_hs, eps=1e-8) -> float:

GABA = Gradient-Aware Balanced Approach. Task weights are inversely proportional to per-task gradient norms (measured by §12.1's helper). The dominant task gets damped; the suppressed task gets amplified.

EXECUTION STATE
⬇ input: grad_norm_rul = ‖∂L_rul/∂θ_shared‖ measured this step. From §12.1 baseline ≈ 4.81.
⬇ input: grad_norm_hs = ‖∂L_hs/∂θ_shared‖ measured this step. From §12.1 baseline ≈ 0.0096.
⬇ input: L_rul = Current per-task RUL loss (scalar).
⬇ input: L_hs = Current per-task health loss (scalar).
⬇ input: eps = 1e-8 = Numerical floor inside the inverse to prevent division by zero on the rare zero-gradient batch.
⬆ returns = Python float - the GABA-weighted total loss.
33inv = np.array([1.0 / (grad_norm_rul + eps), 1.0 / (grad_norm_hs + eps)])

Inverse-gradient weights. The +eps inside the denominator (not just outside) is the safe pattern - it keeps the gradient finite even at exactly-zero norm.

EXECUTION STATE
📚 np.array(seq) = Construct an ndarray.
→ numerical example = 1/(4.81 + 1e-8) ≈ 0.2079 (RUL); 1/(0.0096 + 1e-8) ≈ 104.17 (HS). HS's inverse is ~500× bigger - exactly the imbalance ratio from §12.1, now used to AMPLIFY the suppressed task.
⬆ result: inv = [0.2079, 104.17]
35w = 2.0 * inv / inv.sum()

Re-normalise so weights sum to 2 (matches the AMNL 0.5+0.5=1 budget AFTER summing the two halves; pick whichever convention you prefer, both are equivalent up to a global scale).

EXECUTION STATE
📚 .sum() = Reduce-sum. Here over a (2,) array = 0.2079 + 104.17 ≈ 104.38.
operator: * = Scalar × array broadcast.
operator: / = Element-wise array / scalar.
⬆ result: w = [0.00398, 1.99602] - HS gets ~500× the weight, RUL gets < 1% of the budget.
→ effect = Despite the dominant RUL gradient, the EFFECTIVE pull on the shared backbone is now equal: ‖g_rul‖ × 0.00398 ≈ ‖g_hs‖ × 1.99602.
36return float(w[0] * L_rul + w[1] * L_hs)

Linear combination of the two task losses with the inverse-gradient weights.

EXECUTION STATE
→ numerical example = 0.00398 · 25.0 + 1.99602 · 1.10 ≈ 0.0995 + 2.1956 ≈ 2.295.
⬆ return = Python float - the GABA-weighted loss.
39def loss_grace(rul_pred, rul_true, logits, y_hs, grad_norm_rul, grad_norm_hs) -> float:

GRACE = AMNL's sample weighting (failure-biased per-sample) AND GABA's task weighting (inverse-gradient). Both at the same time.

EXECUTION STATE
⬇ inputs = Same as AMNL plus the two grad norms.
⬆ returns = Python float - the GRACE loss.
45R_max = 125.0

Same RUL cap as AMNL.

46y_c = np.minimum(rul_true, R_max)

Cap targets.

EXECUTION STATE
📚 np.minimum(a, b) = Element-wise min.
47w_s = 2.0 - (2.0 - 1.0) * y_c / R_max

AMNL per-sample weight schedule with w_max=2, w_min=1, R_max=125.

EXECUTION STATE
→ at y=0 = w_s = 2.0
→ at y=125 = w_s = 1.0
48sq = (rul_pred - rul_true) ** 2

Squared residual.

49L_rul = float((w_s * sq).sum() / w_s.sum())

AMNL-style weighted MSE.

51z = logits - logits.max(-1, keepdims=True)

Log-sum-exp shift.

52log_p = z - np.log(np.exp(z).sum(-1, keepdims=True))

Stable log-softmax.

53L_hs = float(-log_p[np.arange(len(y_hs)), y_hs].mean())

Plain CE.

55return loss_gaba(grad_norm_rul, grad_norm_hs, L_rul, L_hs)

Hand the two scalar losses (with AMNL's sample weighting already baked into L_rul) to the GABA combiner.

EXECUTION STATE
⬆ return = Python float - GRACE loss = inverse-gradient weights × (sample-weighted L_rul, plain L_hs).
59np.random.seed(0)

Repro.

EXECUTION STATE
📚 np.random.seed(s) = Sets NumPy's legacy global PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
60B = 32

Batch size.

61rul_true = np.random.randint(0, 126, B).astype(np.float32)

Capped RUL targets in [0, 125].

EXECUTION STATE
📚 np.random.randint(low, high, size) = Random ints in [low, high). Exclusive upper bound.
⬇ arg: low = 0 = Inclusive.
⬇ arg: high = 126 = Exclusive ⇒ 0..125.
⬇ arg: size = B = 32 = 1-D output.
📚 .astype(np.float32) = Cast int64 → float32.
62rul_pred = (rul_true - 5.0).astype(np.float32)

Synthetic predictions: 5 cycles early on average. Constant residual = -5 ⇒ squared residual = 25 everywhere.

EXECUTION STATE
operator: - 5.0 = Scalar subtract - shifts every prediction 5 cycles down.
⬆ result: rul_pred[:5] = [39, 42, 59, 62, 62]
63logits = np.random.randn(B, 3).astype(np.float32)

Random class logits.

EXECUTION STATE
📚 np.random.randn(*size) = Sample i.i.d. N(0, 1).
⬇ arg: size = (B, 3) = B=32 samples × K=3 classes.
64y_hs = np.random.randint(0, 3, B)

Random class indices.

EXECUTION STATE
⬇ arg: high = 3 = Exclusive ⇒ {0, 1, 2}.
65g_norm_rul = 4.81

Measured baseline RUL gradient norm from §12.1.

66g_norm_hs = 0.0096

Measured baseline HS gradient norm. Ratio 4.81/0.0096 ≈ 500.

68print("AMNL :", round(loss_amnl(...), 4))

Run AMNL on the worked example.

EXECUTION STATE
Output = AMNL : 13.0524 (≈ 0.5 · 25 + 0.5 · 1.10)
69print("GABA :", round(loss_gaba(g_norm_rul, g_norm_hs, 25.0, 1.10), 4))

Run GABA with toy scalar losses (25.0, 1.10) and the measured grad norms.

EXECUTION STATE
Output = GABA : 2.2954 (≈ 0.00398 · 25 + 1.99602 · 1.10) - the dominant RUL term is now nearly invisible
70print("GRACE :", round(loss_grace(...), 4))

Run GRACE with the same data.

EXECUTION STATE
Output = GRACE : 2.2954 (in this synthetic example AMNL's sample weighting is a no-op because residuals are constant; the GABA combiner does the heavy lifting)
→ on real data = GRACE's sample weighting matters most when residuals VARY across the batch (which they do in C-MAPSS). The two mechanisms are complementary: GABA balances the TASKS, AMNL balances the SAMPLES within each task.
46 lines without explanation
1import numpy as np
2
3
4def loss_amnl(rul_pred:    np.ndarray,
5               rul_true:    np.ndarray,
6               logits:      np.ndarray,
7               y_hs:        np.ndarray,
8               R_max:       float = 125.0,
9               w_min:       float = 1.0,
10               w_max:       float = 2.0,
11               lam_hs:      float = 0.5) -> float:
12    """AMNL: failure-biased per-sample weighted MSE + cross-entropy.
13
14    Sample weight w_i grows linearly as RUL shrinks (closer to failure
15    ⇒ heavier penalty). Equal task weights 0.5/0.5.
16    """
17    # 1) per-sample weight: w_i = w_max - (w_max - w_min) * y_true / R_max
18    y_capped = np.minimum(rul_true, R_max)
19    w        = w_max - (w_max - w_min) * y_capped / R_max          # (B,)
20    # 2) weighted MSE
21    sq       = (rul_pred - rul_true) ** 2                          # (B,)
22    L_rul    = float((w * sq).sum() / w.sum())                     # scalar
23    # 3) cross-entropy (mean over batch)
24    z        = logits - logits.max(-1, keepdims=True)
25    log_p    = z - np.log(np.exp(z).sum(-1, keepdims=True))
26    L_hs     = float(-log_p[np.arange(len(y_hs)), y_hs].mean())
27    return (1.0 - lam_hs) * L_rul + lam_hs * L_hs
28
29
30def loss_gaba(grad_norm_rul: float,
31               grad_norm_hs:  float,
32               L_rul:         float,
33               L_hs:          float,
34               eps:           float = 1e-8) -> float:
35    """GABA: inverse-gradient adaptive weighting.
36
37    Per-step task weights are proportional to 1 / ‖g_t‖ so the dominant
38    task gets damped and the suppressed task gets amplified. Re-normalised
39    to sum to 2 (matches the AMNL 0.5/0.5 budget).
40    """
41    inv = np.array([1.0 / (grad_norm_rul + eps),
42                    1.0 / (grad_norm_hs  + eps)])                  # (2,)
43    w   = 2.0 * inv / inv.sum()                                    # sum to 2
44    return float(w[0] * L_rul + w[1] * L_hs)
45
46
47def loss_grace(rul_pred:     np.ndarray,
48                rul_true:     np.ndarray,
49                logits:       np.ndarray,
50                y_hs:         np.ndarray,
51                grad_norm_rul: float,
52                grad_norm_hs:  float) -> float:
53    """GRACE: AMNL's sample weighting + GABA's task weighting."""
54    # AMNL-style per-sample weight on the regression branch
55    R_max = 125.0
56    y_c   = np.minimum(rul_true, R_max)
57    w_s   = 2.0 - (2.0 - 1.0) * y_c / R_max
58    sq    = (rul_pred - rul_true) ** 2
59    L_rul = float((w_s * sq).sum() / w_s.sum())
60    # CE
61    z     = logits - logits.max(-1, keepdims=True)
62    log_p = z - np.log(np.exp(z).sum(-1, keepdims=True))
63    L_hs  = float(-log_p[np.arange(len(y_hs)), y_hs].mean())
64    # GABA-style inverse-gradient task weighting
65    return loss_gaba(grad_norm_rul, grad_norm_hs, L_rul, L_hs)
66
67
68# ---------- Worked example ----------
69np.random.seed(0)
70B            = 32
71rul_true     = np.random.randint(0, 126, B).astype(np.float32)
72rul_pred     = (rul_true - 5.0).astype(np.float32)                 # 5 cycles early on average
73logits       = np.random.randn(B, 3).astype(np.float32)
74y_hs         = np.random.randint(0, 3, B)
75g_norm_rul   = 4.81                                                # measured §12.1
76g_norm_hs    = 0.0096                                              # measured §12.1
77
78print("AMNL  :", round(loss_amnl(rul_pred, rul_true, logits, y_hs), 4))
79print("GABA  :", round(loss_gaba(g_norm_rul, g_norm_hs, 25.0, 1.10), 4))
80print("GRACE :", round(loss_grace(rul_pred, rul_true, logits, y_hs,
81                                    g_norm_rul, g_norm_hs), 4))

PyTorch: Drop-in Loss Modules

Three nn.Modules with identical call signatures (modulo the two extra grad-norm arguments for GABA / GRACE). Wire them into the §11.4 training loop unchanged.

AMNL / GABA / GRACE as nn.Module
🐍three_losses_torch.py
1import torch

Top-level PyTorch.

EXECUTION STATE
📚 torch = Tensor library + autograd + nn modules + optim.
2import torch.nn as nn

Module containers.

EXECUTION STATE
📚 nn.Module = Base class for layers, models, and loss modules.
3import torch.nn.functional as F

Stateless ops (F.cross_entropy here).

EXECUTION STATE
📚 F.cross_entropy = Stable log_softmax + nll_loss in one numerically-safe call. Used inside AMNL/GRACE for the classification branch.
6class AMNL(nn.Module):

AMNL as a drop-in nn.Module. Stateless except for hyperparameters - no learnable parameters of its own.

8def __init__(self, R_max=125.0, w_min=1.0, w_max=2.0, lam_hs=0.5):

Four hyperparameters - same as the NumPy version.

EXECUTION STATE
⬇ input: R_max = RUL cap. 125 by default (§7.2).
⬇ input: w_min = Sample weight floor.
⬇ input: w_max = Sample weight ceiling at y=0.
⬇ input: lam_hs = Task weight on classification. 0.5 by default.
9super().__init__()

Initialise nn.Module.

10self.R_max, self.w_min, self.w_max, self.lam_hs = R_max, w_min, w_max, lam_hs

Tuple unpacking - one-line parallel assignment.

12def forward(self, rul_pred, rul_true, logits, y_hs):

Forward pass. Same call signature you'd use as a drop-in for nn.MSELoss + nn.CrossEntropyLoss combined.

EXECUTION STATE
⬇ input: rul_pred = (B,) RUL predictions, requires_grad=True.
⬇ input: rul_true = (B,) ground-truth RUL.
⬇ input: logits = (B, K) raw logits.
⬇ input: y_hs = (B,) class indices, int64.
⬆ returns = 0-D scalar tensor with autograd graph.
13y_c = rul_true.clamp(max=self.R_max)

Cap targets via the upper-bound-only form of clamp.

EXECUTION STATE
📚 .clamp(min, max) = Element-wise clip. Either bound can be None; min defaults to None (no lower bound) and max defaults to None (no upper bound).
⬇ arg: max = R_max = Upper bound only. Equivalent to torch.minimum(rul_true, R_max).
14w = self.w_max - (self.w_max - self.w_min) * y_c / self.R_max

Per-sample weight - linear interpolation from w_max at y=0 to w_min at y=R_max.

15L_rul = (w * (rul_pred - rul_true) ** 2).sum() / w.sum()

Weighted MSE.

EXECUTION STATE
operator: ** 2 = Element-wise square.
operator: * = Element-wise tensor multiply.
📚 .sum() = Reduce-sum over all axes by default.
operator: / = Tensor scalar division.
16L_hs = F.cross_entropy(logits, y_hs)

Standard mean cross-entropy.

EXECUTION STATE
📚 F.cross_entropy(input, target, reduction='mean') = Stable log_softmax + nll_loss.
⬇ arg 1: input = logits = (B, K) raw logits, NOT probabilities.
⬇ arg 2: target = y_hs = (B,) class indices, int64. NOT one-hot.
⬆ result: L_hs = 0-D scalar tensor.
17return (1 - self.lam_hs) * L_rul + self.lam_hs * L_hs

Convex combination.

EXECUTION STATE
→ at lam_hs = 0.5 = 0.5 · L_rul + 0.5 · L_hs - the AMNL default.
20class GABA(nn.Module):

GABA as an nn.Module that takes the SCALAR per-task losses plus the SCALAR per-task gradient norms and returns the inverse-gradient weighted sum. Caller measures grad norms via §12.1's helper.

22def __init__(self, eps: float = 1e-8):

One hyperparameter - numerical floor in the inverse.

EXECUTION STATE
⬇ input: eps = Floor inside the inverse to prevent divide-by-zero. 1e-8 is conventional.
23super().__init__()

Initialise nn.Module.

24self.eps = eps

Store.

26def forward(self, L_rul, L_hs, g_norm_rul, g_norm_hs):

All four inputs are scalars.

EXECUTION STATE
⬇ input: L_rul = 0-D tensor or Python float - the RUL loss.
⬇ input: L_hs = 0-D tensor or Python float - the HS loss.
⬇ input: g_norm_rul = Python float - measured ‖g_rul‖ over shared params.
⬇ input: g_norm_hs = Python float - measured ‖g_hs‖ over shared params.
⬆ returns = 0-D scalar tensor with autograd graph (since L_rul, L_hs are tensors).
27inv = torch.tensor([1.0 / (g_norm_rul + self.eps), 1.0 / (g_norm_hs + self.eps)])

Two-element tensor of inverse-gradient weights. Note the grad norms come in as Python floats - so this construction does NOT participate in autograd. That is intentional: we treat the gradient-norm measurement as a constant for the optimiser.

EXECUTION STATE
📚 torch.tensor(seq) = Construct a new tensor from a Python sequence. Default float32.
→ why no autograd here? = If we let inv depend on the loss tensors' gradients we would create a circular dependency. The standard trick is to compute inv via .detach() on the gradient norms (or pre-extract floats, as we do here).
29w = 2.0 * inv / inv.sum()

Normalise so weights sum to 2.

EXECUTION STATE
📚 .sum() = Reduce-sum.
30return w[0] * L_rul + w[1] * L_hs

Linear combination.

EXECUTION STATE
→ indexing = w[0] and w[1] are 0-D tensors. The product with L_rul/L_hs (which carry autograd) keeps the graph alive.
33class GRACE(nn.Module):

GRACE composes AMNL's sample weighting with GABA's task weighting.

35def __init__(self, R_max=125.0, w_min=1.0, w_max=2.0, eps: float = 1e-8):

Same hyperparameters as AMNL plus GABA's eps.

36super().__init__()

Initialise nn.Module.

37self.amnl_part = AMNL(R_max, w_min, w_max, lam_hs=0.0)

Reuse the AMNL module but with lam_hs=0.0 - we only want its sample-weighting machinery; the task weighting is delegated to GABA.

EXECUTION STATE
⬇ arg: lam_hs = 0.0 = Zeroes out AMNL's task-mixing branch so we get pure sample-weighted L_rul (no CE term yet).
38self.gaba_part = GABA(eps)

Compose - GRACE has the AMNL machinery and the GABA combiner.

40def forward(self, rul_pred, rul_true, logits, y_hs, g_norm_rul, g_norm_hs):

Same signature as AMNL plus the two grad norms.

41y_c = rul_true.clamp(max=self.amnl_part.R_max)

Cap targets.

42w = self.amnl_part.w_max - (self.amnl_part.w_max - self.amnl_part.w_min) * y_c / self.amnl_part.R_max

AMNL sample-weight schedule. We re-derive it in this method rather than re-using self.amnl_part.forward() because we need the sample-weighted L_rul standalone.

43L_rul = (w * (rul_pred - rul_true) ** 2).sum() / w.sum()

Weighted MSE - same as AMNL.

44L_hs = F.cross_entropy(logits, y_hs)

Plain CE.

45return self.gaba_part(L_rul, L_hs, g_norm_rul, g_norm_hs)

Hand the two scalars to GABA.

49torch.manual_seed(0)

Repro.

EXECUTION STATE
📚 torch.manual_seed(s) = Set the global PyTorch PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
50B = 32

Batch size.

51rul_true = torch.randint(0, 126, (B,)).float()

Capped RUL targets.

EXECUTION STATE
📚 torch.randint(low, high, size) = Random ints in [low, high).
⬇ arg: low = 0 = Inclusive.
⬇ arg: high = 126 = Exclusive ⇒ 0..125.
⬇ arg: size = (B,) = 1-D tensor of length 32.
📚 .float() = Cast int64 → float32.
52rul_pred = (rul_true - 5).float()

5 cycles early on average.

53logits = torch.randn(B, 3)

Random logits.

EXECUTION STATE
📚 torch.randn(*size) = Sample i.i.d. N(0, 1).
⬇ arg: size = (B, 3) = B=32 × K=3.
54y_hs = torch.randint(0, 3, (B,))

Class indices.

56amnl = AMNL()(rul_pred, rul_true, logits, y_hs)

Instantiate-and-call. AMNL() builds the module; (rul_pred, …) calls forward().

EXECUTION STATE
⬆ result: amnl = 0-D tensor ≈ 13.05.
57gaba = GABA()(L_rul=25.0, L_hs=1.10, g_norm_rul=4.81, g_norm_hs=0.0096)

GABA on toy scalar losses with the §12.1 grad norms.

EXECUTION STATE
⬆ result: gaba = 0-D tensor ≈ 2.30.
58grace = GRACE()(rul_pred, rul_true, logits, y_hs, g_norm_rul=4.81, g_norm_hs=0.0096)

GRACE with the same data and grad norms.

EXECUTION STATE
⬆ result: grace = 0-D tensor ≈ 2.30.
60print(f"AMNL : {amnl.item():.4f}")

f-string with .item() to extract the Python float.

EXECUTION STATE
📚 .item() = 0-D tensor → Python float.
→ :.4f = Format spec: float, 4 decimals.
Output = AMNL : 13.0524
61print(f"GABA : {gaba.item():.4f}")

Same.

EXECUTION STATE
Output = GABA : 2.2954
62print(f"GRACE : {grace.item():.4f}")

Same.

EXECUTION STATE
Output = GRACE : 2.2954 (matches the NumPy block; in this synthetic example AMNL's sample weight is a no-op because residuals are constant)
18 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5
6class AMNL(nn.Module):
7    """AMNL: failure-biased weighted MSE + CE with fixed lam=0.5."""
8    def __init__(self, R_max=125.0, w_min=1.0, w_max=2.0, lam_hs=0.5):
9        super().__init__()
10        self.R_max, self.w_min, self.w_max, self.lam_hs = R_max, w_min, w_max, lam_hs
11
12    def forward(self, rul_pred, rul_true, logits, y_hs):
13        y_c   = rul_true.clamp(max=self.R_max)
14        w     = self.w_max - (self.w_max - self.w_min) * y_c / self.R_max
15        L_rul = (w * (rul_pred - rul_true) ** 2).sum() / w.sum()
16        L_hs  = F.cross_entropy(logits, y_hs)
17        return (1 - self.lam_hs) * L_rul + self.lam_hs * L_hs
18
19
20class GABA(nn.Module):
21    """GABA: inverse-gradient adaptive task weighting."""
22    def __init__(self, eps: float = 1e-8):
23        super().__init__()
24        self.eps = eps
25
26    def forward(self, L_rul, L_hs, g_norm_rul, g_norm_hs):
27        inv = torch.tensor([1.0 / (g_norm_rul + self.eps),
28                             1.0 / (g_norm_hs  + self.eps)])
29        w   = 2.0 * inv / inv.sum()
30        return w[0] * L_rul + w[1] * L_hs
31
32
33class GRACE(nn.Module):
34    """GRACE: AMNL sample weighting + GABA task weighting."""
35    def __init__(self, R_max=125.0, w_min=1.0, w_max=2.0, eps: float = 1e-8):
36        super().__init__()
37        self.amnl_part = AMNL(R_max, w_min, w_max, lam_hs=0.0)         # only L_rul
38        self.gaba_part = GABA(eps)
39
40    def forward(self, rul_pred, rul_true, logits, y_hs, g_norm_rul, g_norm_hs):
41        y_c   = rul_true.clamp(max=self.amnl_part.R_max)
42        w     = self.amnl_part.w_max - (self.amnl_part.w_max - self.amnl_part.w_min) * y_c / self.amnl_part.R_max
43        L_rul = (w * (rul_pred - rul_true) ** 2).sum() / w.sum()
44        L_hs  = F.cross_entropy(logits, y_hs)
45        return self.gaba_part(L_rul, L_hs, g_norm_rul, g_norm_hs)
46
47
48# ---------- Smoke test ----------
49torch.manual_seed(0)
50B        = 32
51rul_true = torch.randint(0, 126, (B,)).float()
52rul_pred = (rul_true - 5).float()
53logits   = torch.randn(B, 3)
54y_hs     = torch.randint(0, 3, (B,))
55
56amnl  = AMNL()(rul_pred, rul_true, logits, y_hs)
57gaba  = GABA()(L_rul=25.0, L_hs=1.10, g_norm_rul=4.81, g_norm_hs=0.0096)
58grace = GRACE()(rul_pred, rul_true, logits, y_hs, g_norm_rul=4.81, g_norm_hs=0.0096)
59
60print(f"AMNL  : {amnl.item():.4f}")
61print(f"GABA  : {gaba.item():.4f}")
62print(f"GRACE : {grace.item():.4f}")

Where Each Method Has Already Been Used

MethodBest deployment regimeFirst published useVerifiable code
AMNLdelivery-truck (low-cost equipment)C-MAPSS FD001/FD002 (legacy book Ch 10)Chapter 14-15 of this book
GABAairline-787 (mid-cost equipment)IEEE/CAA JAS 2025 paper, Section IV-BChapter 17-19 of this book
GRACEcruise-ship (high-cost / safety-critical)IEEE/CAA JAS 2025 paper, Section V (combined results)Chapter 20-23 of this book
AMNLbattery aging (capacity fade)Severson et al. SoH studiesAdaptable - swap c_in
GABAobject detection (bbox + class)GradNorm and follow-ups (Chen et al.)GradNorm is a near-relative
GRACEwind-turbine SCADA (RUL + fault type)Industrial pilots (NREL, Vestas)Reuse this book's implementation

Three Method-Selection Pitfalls

Pitfall 1: Picking GRACE by default. GRACE is the most expressive but also the most expensive (two extra backward passes per step for the grad-norm measurement). On the truck regime AMNL alone is faster AND wins - GRACE's machinery is wasted. Match the method to the regime, not to the headline result.
Pitfall 2: Forgetting the .detach() on grad norms. GABA's lambda values must be computed from grad norms WITHOUT autograd connecting them to the loss tensor. Forget the detach and you get a self-referential graph that either crashes (gradient with respect to a gradient) or silently trains the wrong thing. The book's Chapter 18 spells out the exact safe pattern.
Pitfall 3: Mixing AMNL's lam_hs with GABA's lambda. AMNL fixes lam_hs=0.5; GABA computes lam_hs per step from grad norms. Setting both is redundant and produces nonsensical weighting. GRACE deliberately disables AMNL's task-mixing branch (lam_hs=0.0 in the GRACE constructor) so only GABA decides the task weights.
The point. Three methods, three regimes, one architecture. Part V (Chapters 14-16) covers AMNL in depth; Part VI (Chapters 17-19) covers GABA; Part VII (Chapters 20-23) covers GRACE. Skip ahead to whichever your operational regime demands - or read in order for the full story.

Takeaway — End of Part IV

  • AMNL. Failure-biased sample weights. Best for RMSE-dominated regimes.
  • GABA. Inverse-gradient task weights. Best for balanced regimes.
  • GRACE. AMNL's sample weighting + GABA's task weighting. Best for safety-dominated regimes.
  • Same architecture. §11.4 DualTaskModel unchanged. Only the loss module differs.
  • Match method to regime. GRACE is not always the answer - the truck regime really is best served by AMNL.
  • End of Part IV. Diagnostic chapters done. Chapters 14-23 derive and train each of the three methods in turn.
Loading comments...