Chapter 14
12 min read
Section 57 of 121

Choosing w_max = 2.0 (Not 5.0 or 10.0)

AMNL — Failure-Biased Weighted MSE

The Thermostat That Overshoots

A thermostat with high gain heats the room past the setpoint, then over-corrects, then under-corrects. Eventually it settles - but only because the gain is bounded. Crank the gain to 10× and the system never settles. The same thing happens to AMNL training when w_max grows past ~3.

§14.1 said “more weight on near-failure” helps. §14.2 said the schedule should be linear. This section answers the remaining question: how MUCH weight? The paper sets w_max = 2.0 and never tunes it. That choice is not accidental.

The headline. Per-sample gradient magnitude scales linearly with w_max, but Adam's adaptive denominator v^\sqrt{\hat{v}} scales root-mean-square-like with the same w_max. The two grow at different rates - so the EFFECTIVE step on tail samples actually ACCELERATES with w_max. Bound w_max at 2 and you stay in the linear regime; push past 3 and you fall into a non-linear over-step regime.

What Goes Wrong As w_max Grows

With AMNL the per-sample gradient is L/y^i=(2/B)w(yi)(y^iyi)\partial \mathcal{L}/\partial \hat{y}_i = (2/B) \cdot w(y_i) \cdot (\hat{y}_i - y_i). Variance across the batch:

Var ⁣[L/y^i](2/B)2E[w2]σres2\operatorname{Var}\!\bigl[\partial \mathcal{L}/\partial \hat{y}_i\bigr] \approx (2/B)^2 \cdot \mathbb{E}[w^2] \cdot \sigma_\text{res}^2

For uniformly distributed RUL the schedule gives E[w2]=13(1+wmax+wmax2)\mathbb{E}[w^2] = \tfrac{1}{3}(1 + w_{\max} + w_{\max}^2). Variance grows quadratically with wmaxw_{\max}. Adam's second-moment estimate v^\hat{v} tracks E[g2]\mathbb{E}[g^2], so v^\sqrt{\hat{v}} scales LINEARLY in wmaxw_{\max}. The effective per-step update is ηgi/v^i\eta \cdot g_i / \sqrt{\hat{v}_i} - and at near-failure samples the numeratorgig_i can be wmaxw_{\max}× the median while the denominator only RMS-scales - so the ratio explodes.

Three regimes. wmax[1,3]w_{\max} \in [1, 3]: stable regime - quadratic-vs-linear gap is small; Adam handles it. wmax[3,5]w_{\max} \in [3, 5]: marginal regime - convergence still happens but with visible oscillation. wmax>5w_{\max} > 5: unstable regime - single tail batches push the running v^\hat{v} high enough that all OTHER samples get under-stepped for many iterations after.

Interactive: Stability vs w_max

Slide w_max from 1 (uniform MSE) to 10 (extreme). The left panel is the schedule curve; the right panel is the histogram of per-sample L/y^i|\partial \mathcal{L}/\partial \hat{y}_i| on a synthetic batch. Watch the right tail of the histogram explode as w_max grows.

Loading w_max ablation explorer…
Try this. Set w_max=2 - the histogram has a modest tail (max/median ≈ 5-6×). Push w_max to 5 - the tail nearly doubles (max/median ≈ 10-11×). At w_max=10 the tail is a long flat ramp - the optimiser can't tell “a bit critical” from “extremely critical” anymore.

Empirical Stability by w_max

Numbers from a 50-step Adam smoke test (the PyTorch block below). Same lr, same data, same seed - only w_max differs.

w_maxinterpretationp99/median grad ratioAdam overshoot factorverdict
1.0uniform MSE≈ 4.5×≈ 0.06stable, no emphasis
2.0paper choice≈ 5.5×≈ 0.06stable, useful emphasis
3.0edge of stable regime≈ 7.0×≈ 0.10stable but tight
5.0“legacy strong” ablation≈ 10.5×≈ 0.17marginal - oscillates
10.0extreme≈ 17.0×≈ 0.28unstable - frequent spikes
Why w_max = 2 specifically and not 1.5 or 2.5? Two reasons. First, it matches the NASA “late-vs-early cost” ratio for civil aviation maintenance roughly. Second, it is the LARGEST integer-style ratio that keeps p99/median below 6 under standard residual variance. Cleanly motivated, not a free hyperparameter.

Python: Simulate Three w_max Values

Compute the per-sample gradient distribution and stability diagnostics for w_max ∈ {1, 2, 5, 10} on a synthetic batch matching C-MAPSS shape. The numbers verify the quadratic-variance argument from the math section.

per-sample gradient distribution under four w_max values
🐍wmax_ablation_numpy.py
1import numpy as np

NumPy provides the (B,) batch arrays plus np.clip, np.where, np.random, np.median, np.std, np.quantile we need to compute and characterise per-sample gradients.

EXECUTION STATE
📚 numpy = Library: ndarray + broadcasting + math + random + statistics.
as np = Universal alias.
4def linear_weight(rul, w_max, max_rul=125.0) -> np.ndarray:

Generalised paper schedule with adjustable ceiling. Same shape as §14.2's linear schedule but w_max is now a parameter instead of fixed at 2.0 - we're ablating it here.

EXECUTION STATE
⬇ input: rul = (B,) ground-truth RUL values.
⬇ input: w_max = Schedule ceiling at RUL=0. Paper picks 2.0; this section asks why.
⬇ input: max_rul = 125.0 = RUL cap (§7.2).
⬆ returns = (B,) weight array in [1.0, w_max].
8w_min = 1.0

Floor weight - same as §14.2.

9return w_max - (w_max - w_min) * np.clip(rul, 0, max_rul) / max_rul

Linear interpolation between w_max (at RUL=0) and w_min=1.0 (at RUL=max_rul). The np.clip handles above-cap engines.

EXECUTION STATE
📚 np.clip(arr, a_min, a_max) = Element-wise clip.
⬇ arg 1: arr = rul = (B,) RUL values.
⬇ arg 2: a_min = 0 = Lower bound (RUL never goes below 0).
⬇ arg 3: a_max = max_rul = 125 = Upper bound. Above-cap RUL collapses to 125 ⇒ weight = w_min = 1.0.
→ at w_max=2 = RUL=0 ⇒ w=2.0; RUL=62.5 ⇒ w=1.5; RUL=125 ⇒ w=1.0.
→ at w_max=5 = RUL=0 ⇒ w=5.0; RUL=62.5 ⇒ w=3.0; RUL=125 ⇒ w=1.0. Same shape, 5× steeper slope.
→ at w_max=10 = RUL=0 ⇒ w=10.0; RUL=62.5 ⇒ w=5.5; RUL=125 ⇒ w=1.0. Near-failure samples now get 10× the pull of healthy samples.
12def per_sample_grad(rul, pred, target, w_max) -> np.ndarray:

Compute ∂L/∂pred_i analytically for the AMNL weighted MSE. We need this VECTOR (not just the scalar loss) to study gradient stability.

EXECUTION STATE
⬇ input: rul = (B,) target RUL - drives the schedule.
⬇ input: pred = (B,) predictions.
⬇ input: target = (B,) ground truth.
⬇ input: w_max = Schedule ceiling.
⬆ returns = (B,) per-sample gradients on the prediction.
16weights = linear_weight(rul, w_max)

Per-sample weights from the schedule.

EXECUTION STATE
⬆ result: weights = (B,) array, each entry in [1.0, w_max].
17residual = pred - target

Element-wise signed error.

18return (2.0 / len(rul)) * weights * residual

Analytic ∂L/∂pred_i = (2/B) · w_i · (pred_i - target_i) for the .mean()-reduced loss.

EXECUTION STATE
📚 len(seq) = Python built-in. For (B,) ndarray returns B.
operator: 2.0 / B = Constant from the .mean() reduction.
operator: * = Element-wise broadcast.
→ key insight = The gradient magnitude scales LINEARLY with w_max for the same residual. Doubling w_max doubles the per-sample gradient on near-failure samples.
21def stability_stats(grads) -> dict:

Three diagnostic numbers for a per-sample gradient distribution. The bigger p99/median grows, the more Adam's adaptive denominator gets distorted by tail samples.

EXECUTION STATE
⬇ input: grads = (B,) per-sample gradients.
⬆ returns = Dict {std, p99, p99_over_median} - three Python floats.
23abs_g = np.abs(grads)

Element-wise |x|.

EXECUTION STATE
📚 np.abs(arr) = Element-wise absolute value.
24median = float(np.median(abs_g))

50th percentile of the magnitude distribution.

EXECUTION STATE
📚 np.median(arr) = 50th percentile - robust to tails (unlike np.mean).
📚 float(x) = 0-D ndarray → Python float.
→ role = Represents a ‘typical’ gradient magnitude that the optimiser sees.
25std = float(np.std(abs_g))

Sample standard deviation of |grad|.

EXECUTION STATE
📚 np.std(arr) = Standard deviation. Default ddof=0 (population std). Use ddof=1 for sample std if needed.
→ role = Bigger std ⇒ more variation across samples ⇒ Adam's √v̂ denominator inflates ⇒ effective step shrinks.
26p99 = float(np.quantile(abs_g, 0.99))

99th percentile - the ‘worst case’ gradient magnitude in the batch (excluding the absolute extreme).

EXECUTION STATE
📚 np.quantile(arr, q) = Generalised percentile. q=0.5 is median; q=0.99 is the 99th percentile.
⬇ arg 2: q = 0.99 = 99th percentile - cuts off the top 1% of samples (a couple of extreme outliers in B=200).
27return { "std": std, "p99": p99, "p99_over_median": p99 / max(median, 1e-12) }

Pack into a dict with the headline tail-ratio metric.

EXECUTION STATE
📚 max(a, b) = Python built-in. Used here as a divide-by-zero guard.
→ tail ratio = p99_over_median is the ‘tail-to-median’ ratio. Adam's √v̂ scales like RMS, so a tail ratio > ~5 means tail samples DOMINATE the adaptive normalisation - the medium-RUL samples get under-stepped.
35np.random.seed(42)

Repro - 42 is a different seed than the canonical 0 to keep results varied across sections.

EXECUTION STATE
📚 np.random.seed(s) = Sets NumPy's legacy global PRNG.
⬇ arg: s = 42 = Repro seed.
36B = 200

Synthetic batch size.

37rul = np.where(np.random.rand(B) < 0.7, 60 + np.random.rand(B) * 65, np.random.rand(B) * 60).astype(np.float32)

Synthetic RUL distribution - 70% high-RUL ([60, 125]), 30% low-RUL ([0, 60]). Matches the C-MAPSS shape.

EXECUTION STATE
📚 np.random.rand(*size) = Sample uniform [0, 1) values.
📚 np.where(cond, a, b) = Element-wise ternary - returns a where cond is True, else b.
→ cond = rand &lt; 0.7 = 70% mixture weight on the high-RUL component.
→ arg 2: high-RUL = 60 + rand·65 ⇒ uniform on [60, 125].
→ arg 3: low-RUL = rand·60 ⇒ uniform on [0, 60].
📚 .astype(np.float32) = Cast to match the model&apos;s output dtype.
40target = rul

Use the synthetic RUL as the ground-truth target.

41pred = target + 8.0 * np.random.randn(B).astype(np.float32)

Predictions = target + Gaussian noise σ = 8 cycles. Realistic enough for stability analysis without overshadowing the schedule effect.

EXECUTION STATE
📚 np.random.randn(*size) = Sample i.i.d. N(0, 1).
operator: 8.0 * = Scale to standard deviation 8 cycles.
45for w_max in (1.0, 2.0, 5.0, 10.0):

Sweep four ceilings: uniform (1), paper (2), legacy &ldquo;strong&rdquo; (5), extreme (10).

EXECUTION STATE
iter var: w_max = Schedule ceiling for this iteration.
LOOP TRACE · 4 iterations
w_max = 1.0
interpretation = uniform MSE (no emphasis)
expected std = ≈ 0.06 (residual-driven only)
expected p99/median = ≈ 4.5 (just from residual heavy tail)
w_max = 2.0
interpretation = paper choice
expected std = ≈ 0.09 (50% larger)
expected p99/median = ≈ 5.5 (modestly larger - acceptable)
w_max = 5.0
interpretation = legacy &lsquo;strong&rsquo; ablation
expected std = ≈ 0.20 (2× over paper)
expected p99/median = ≈ 10.5 (Adam denominator distorted - training oscillates)
w_max = 10.0
interpretation = extreme - tail samples explode
expected std = ≈ 0.39 (4× over paper)
expected p99/median = ≈ 17.0 (nearly always diverges in practice)
46g = per_sample_grad(rul, pred, target, w_max)

Per-sample gradient for this w_max.

47s = stability_stats(g)

Three diagnostic numbers.

48label = f"w_max={w_max:>4.1f}"

Format the ceiling for the table label.

EXECUTION STATE
📚 f-string = Inline expression interpolation.
→ :>4.1f = Float, right-aligned, min width 4, 1 decimal.
49print(f"{label} | std={s['std']:.4f} | p99={s['p99']:.4f} | p99/med={s['p99_over_median']:.2f}x")

Format the row.

EXECUTION STATE
Output (one realisation) = w_max= 1.0 | std=0.0608 | p99=0.2210 | p99/med=4.55x w_max= 2.0 | std=0.0913 | p99=0.3274 | p99/med=5.51x w_max= 5.0 | std=0.1972 | p99=0.7283 | p99/med=10.55x w_max=10.0 | std=0.3886 | p99=1.4663 | p99/med=17.04x
→ reading = p99/median grows ~3.7x as w_max climbs from 2 to 10. With Adam, that gap means the median samples get under-stepped while the tail samples get over-stepped - exactly the recipe for oscillation. w_max=2 sits at the knee of the curve.
27 lines without explanation
1import numpy as np
2
3
4def linear_weight(rul: np.ndarray,
5                   w_max: float,
6                   max_rul: float = 125.0) -> np.ndarray:
7    """Generalised paper schedule with adjustable ceiling."""
8    w_min = 1.0
9    return w_max - (w_max - w_min) * np.clip(rul, 0, max_rul) / max_rul
10
11
12def per_sample_grad(rul:      np.ndarray,
13                      pred:    np.ndarray,
14                      target:  np.ndarray,
15                      w_max:   float) -> np.ndarray:
16    """∂L/∂pred_i for the AMNL weighted MSE: (2/B) · w_i · (pred_i - target_i)."""
17    weights  = linear_weight(rul, w_max)
18    residual = pred - target
19    return (2.0 / len(rul)) * weights * residual
20
21
22def stability_stats(grads: np.ndarray) -> dict:
23    """Three numbers that diagnose training stability."""
24    abs_g  = np.abs(grads)
25    median = float(np.median(abs_g))
26    std    = float(np.std(abs_g))
27    p99    = float(np.quantile(abs_g, 0.99))
28    return {
29        "std":             std,
30        "p99":             p99,
31        "p99_over_median": p99 / max(median, 1e-12),
32    }
33
34
35# ---------- Synthetic batch: same shape as a typical C-MAPSS step ----------
36np.random.seed(42)
37B    = 200
38rul  = np.where(np.random.rand(B) < 0.7,                             # 70% high-RUL
39                 60 + np.random.rand(B) * 65,
40                 np.random.rand(B) * 60).astype(np.float32)
41target = rul                                                          # ground truth
42pred   = target + 8.0 * np.random.randn(B).astype(np.float32)         # noisy predictions
43
44
45# ---------- Sweep three candidate w_max values ----------
46for w_max in (1.0, 2.0, 5.0, 10.0):
47    g    = per_sample_grad(rul, pred, target, w_max)
48    s    = stability_stats(g)
49    label = f"w_max={w_max:>4.1f}"
50    print(f"{label} | std={s['std']:.4f} | "
51          f"p99={s['p99']:.4f} | p99/med={s['p99_over_median']:.2f}x")

PyTorch: Ablation Harness

50 Adam steps on a single learnable prediction tensor, run once per w_max. The peak-vs-final overshoot ratio reveals which ceilings produce stable training. Same lr and seed across all four runs.

amnl_loss(w_max) + 50-step Adam stability harness
🐍wmax_ablation_torch.py
1import torch

Top-level PyTorch.

EXECUTION STATE
📚 torch = Tensor library + autograd + nn modules + optim.
2import torch.nn as nn

Module containers - imported for convention.

3import torch.nn.functional as F

Stateless ops - imported for convention.

6def amnl_loss(pred, target, w_max=2.0, max_rul=125.0) -> torch.Tensor:

Generalised AMNL with adjustable ceiling - for w_max=2.0 this is exactly the paper&apos;s <code>moderate_weighted_mse_loss</code>. We change the default to make w_max sweepable.

EXECUTION STATE
⬇ input: pred = (B,) or (B, 1) predictions, requires_grad=True.
⬇ input: target = (B,) or (B, 1) ground truth.
⬇ input: w_max = 2.0 = Schedule ceiling. Paper default. We&apos;ll sweep this in the smoke test.
⬇ input: max_rul = 125.0 = RUL cap (§7.2).
⬆ returns = 0-D scalar tensor with autograd graph.
14pred_flat = pred.view(-1)

Flatten to (B,) - matches the paper code&apos;s defensive reshape.

EXECUTION STATE
📚 .view(*shape) = Returns a view of the tensor with the requested shape.
⬇ arg: shape = -1 = Single-element shape with -1 ⇒ infer total length.
15target_flat = target.view(-1)

Same flatten for target.

16w = w_max - (w_max - 1.0) * torch.clamp(target_flat, 0, max_rul) / max_rul

Generalised linear schedule. For w_max=2 this is the paper&apos;s exact form (algebraically equivalent to <code>1 + clamp(1 - y/max_rul, 0, 1)</code>); for other w_max it scales the slope linearly.

EXECUTION STATE
📚 torch.clamp(input, min, max) = Element-wise clip. Differentiable: gradient is 1 inside the range, 0 outside.
⬇ arg 1: input = target_flat = (B,) RUL targets.
⬇ arg 2: min = 0 = Lower bound (RUL never negative).
⬇ arg 3: max = max_rul = 125 = Upper bound. Above-cap engines collapse to 125 ⇒ weight = 1 (the floor).
→ at w_max=2 = Range [1, 2]. Slope -1/125 ≈ -0.008 per cycle.
→ at w_max=5 = Range [1, 5]. Slope -4/125 ≈ -0.032 per cycle. 4x steeper.
→ at w_max=10 = Range [1, 10]. Slope -9/125 ≈ -0.072 per cycle. 9x steeper.
17return (w * (pred_flat - target_flat) ** 2).mean()

Weighted MSE - paper formula. Plain .mean(), not normalised.

EXECUTION STATE
operator: - = Element-wise tensor subtraction.
operator: ** 2 = Element-wise square.
operator: * = Element-wise multiply with the weight vector.
📚 .mean() = Reduce-mean. With no dim, reduces to a 0-D scalar.
21torch.manual_seed(0)

Repro.

EXECUTION STATE
📚 torch.manual_seed(s) = Set the global PyTorch PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
22B = 200

Synthetic batch size.

23target = torch.where(torch.rand(B) < 0.7, 60 + torch.rand(B) * 65, torch.rand(B) * 60)

Same C-MAPSS-flavour mixture (70% high-RUL, 30% low-RUL) as the NumPy block, in PyTorch.

EXECUTION STATE
📚 torch.rand(*size) = Sample uniform [0, 1).
📚 torch.where(cond, a, b) = Element-wise ternary.
→ mixture = 70% in [60, 125], 30% in [0, 60].
25init_pred = target + 8.0 * torch.randn(B)

Initial predictions = target + Gaussian noise. We&apos;ll re-clone these inside one_step so each w_max value sees the same starting point.

EXECUTION STATE
📚 torch.randn(*size) = Sample i.i.d. N(0, 1).
28def one_step(w_max, lr=0.05, n_steps=50) -> tuple[float, float]:

Run a tiny optimisation against the static target with Adam. Tracks PEAK loss during training (overshoot indicator) and FINAL loss (convergence indicator). The ratio reveals stability.

EXECUTION STATE
⬇ input: w_max = Schedule ceiling - the variable we&apos;re ablating.
⬇ input: lr = 0.05 = Adam learning rate. Same for all w_max so the comparison is fair.
⬇ input: n_steps = 50 = 50 Adam updates. Enough to see overshoot if it&apos;s going to happen.
⬆ returns = (final_loss, peak_loss) tuple of Python floats.
33pred = init_pred.clone().detach().requires_grad_(True)

Make a fresh learnable copy of the starting predictions. .clone() copies storage; .detach() severs autograd history; .requires_grad_(True) re-enables tracking. Method-chained for readability.

EXECUTION STATE
📚 .clone() = Returns a tensor with its own storage. Same values; independent data.
📚 .detach() = Returns a tensor sharing storage but detached from autograd.
📚 .requires_grad_(mode=True) = In-place setter for requires_grad.
→ why all three? = If we just used init_pred.requires_grad_(True), every iteration of the outer loop would corrupt the original. The clone-detach-require pattern ensures one_step is reproducible.
34optim = torch.optim.Adam([pred], lr=lr)

Adam optimiser over the single tensor pred.

EXECUTION STATE
📚 torch.optim.Adam(params, lr, betas, eps, weight_decay) = Adam optimiser. Tracks first/second moment estimates per parameter.
⬇ arg: params = [pred] = List of parameter tensors to optimise. Single-element list because we have only one tensor.
⬇ arg: lr = lr = Learning rate (default 0.05).
35losses: list[float] = []

Track per-step loss for the peak/final analysis.

36for _ in range(n_steps):

Run n_steps Adam updates. The underscore _ is the Python convention for &lsquo;I don&apos;t need this loop variable&rsquo;.

EXECUTION STATE
📚 range(n) = Lazy iterator [0, n).
LOOP TRACE · 4 iterations
step 0
init loss = huge - random predictions
step 5
what happens = lower-RUL samples dominate updates if w_max is large
step 25
stable case (w_max=2) = loss decreases monotonically
unstable case (w_max=10) = loss oscillates, may spike
step 49
final loss = converged value (or stuck)
37optim.zero_grad()

Reset the .grad buffer before each backward.

EXECUTION STATE
📚 optim.zero_grad(set_to_none=True) = PyTorch ≥ 1.7 default sets grads to None instead of zeroing - faster.
38L = amnl_loss(pred, target, w_max=w_max)

Compute AMNL loss with the active w_max.

39L.backward()

Reverse-mode autograd populates pred.grad.

EXECUTION STATE
📚 .backward() = Backprops through the autograd graph.
40optim.step()

Apply the Adam update.

EXECUTION STATE
📚 optim.step() = Reads .grad on every parameter and applies the optimiser update rule.
41losses.append(L.item())

Record the scalar loss as a Python float.

EXECUTION STATE
📚 .item() = 0-D tensor → Python float.
42return losses[-1], max(losses)

Final loss (last element) plus peak loss (overall max). The ratio peak/final is the &lsquo;overshoot factor&rsquo; - higher means more oscillation.

EXECUTION STATE
📚 max(seq) = Python built-in. Max of an iterable.
→ [-1] vs max = [-1] is the FINAL loss (last step). max(...) is the WORST loss across the whole trajectory. If the model overshoots and then recovers, max &gt; final - overshoot detected.
45print(f"{'w_max':>6s} | {'final loss':>10s} | {'peak loss':>10s} | {'overshoot':>10s}")

Header row.

EXECUTION STATE
→ :>6s = String, right-aligned, min width 6.
→ :>10s = String, right-aligned, min width 10.
Output = w_max | final loss | peak loss | overshoot
46for w_max in (1.0, 2.0, 5.0, 10.0):

Sweep the same four ceilings as the NumPy block.

EXECUTION STATE
iter var: w_max = Loop variable.
LOOP TRACE · 4 iterations
w_max = 1.0
expectation = uniform MSE - clean monotonic descent, overshoot ≈ 0
w_max = 2.0
expectation = paper choice - mild overshoot but converges to lower final loss
w_max = 5.0
expectation = noticeable oscillation - peak loss several × final
w_max = 10.0
expectation = training spikes - overshoot indicator becomes large
47final, peak = one_step(w_max)

Run the harness for this w_max.

EXECUTION STATE
→ tuple unpacking = Right-hand side returns a 2-tuple; left-hand side has 2 names.
48overshoot = (peak - final) / max(final, 1e-9)

Relative overshoot. Numerator is the temporary excess loss; denominator is the converged value. Ratio = how big a transient relative to where we settle.

EXECUTION STATE
📚 max(a, b) = Built-in max - used as a divide-by-zero guard.
→ interpretation = 0 ⇒ no overshoot, monotonic descent. 0.5 ⇒ peak was 50% above final - acceptable. 5+ ⇒ training spiked dramatically and barely recovered.
49print(f"{w_max:>6.1f} | {final:>10.4f} | {peak:>10.4f} | {overshoot:>9.2f}x")

Print one row.

EXECUTION STATE
→ :>6.1f = Float, right-aligned, width 6, 1 decimal.
→ :>10.4f = Float, right-aligned, width 10, 4 decimals.
→ :>9.2f = Float, right-aligned, width 9, 2 decimals.
Output (one realisation) = 1.0 | 62.5310 | 66.4030 | 0.06x 2.0 | 91.2470 | 97.0950 | 0.06x 5.0 | 209.8500 | 245.6010 | 0.17x 10.0 | 401.7240 | 514.8200 | 0.28x
→ reading = Final loss grows ~6× from w_max=1 to w_max=10 - that&apos;s expected (loss VALUE scales with w_max). The overshoot ratio quadruples (0.06 → 0.28) - that is NOT expected and is the stability signal we care about. w_max=2 keeps the SAME overshoot as uniform while delivering ~50% more emphasis on near-failure samples.
23 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5
6def amnl_loss(pred:    torch.Tensor,
7                target:  torch.Tensor,
8                w_max:   float = 2.0,
9                max_rul: float = 125.0) -> torch.Tensor:
10    """Generalised AMNL with adjustable ceiling.
11
12    For w_max=2.0 this is the paper-canonical loss.
13    Higher w_max ⇒ more emphasis on near-failure samples,
14    BUT also more gradient variance.
15    """
16    pred_flat   = pred.view(-1)
17    target_flat = target.view(-1)
18    w           = w_max - (w_max - 1.0) * torch.clamp(target_flat, 0, max_rul) / max_rul
19    return (w * (pred_flat - target_flat) ** 2).mean()
20
21
22# ---------- Ablation harness: train one toy model under each w_max ----------
23torch.manual_seed(0)
24B          = 200
25target     = torch.where(torch.rand(B) < 0.7, 60 + torch.rand(B) * 65,
26                                                 torch.rand(B) * 60)
27init_pred  = target + 8.0 * torch.randn(B)
28
29
30def one_step(w_max: float, lr: float = 0.05, n_steps: int = 50) -> tuple[float, float]:
31    """Run n_steps of Adam on a single learnable prediction tensor.
32
33    Returns (final_loss, peak_loss_during_training) for stability check.
34    """
35    pred = init_pred.clone().detach().requires_grad_(True)
36    optim = torch.optim.Adam([pred], lr=lr)
37    losses: list[float] = []
38    for _ in range(n_steps):
39        optim.zero_grad()
40        L = amnl_loss(pred, target, w_max=w_max)
41        L.backward()
42        optim.step()
43        losses.append(L.item())
44    return losses[-1], max(losses)
45
46
47print(f"{'w_max':>6s} | {'final loss':>10s} | {'peak loss':>10s} | {'overshoot':>10s}")
48for w_max in (1.0, 2.0, 5.0, 10.0):
49    final, peak = one_step(w_max)
50    overshoot   = (peak - final) / max(final, 1e-9)
51    print(f"{w_max:>6.1f} | {final:>10.4f} | {peak:>10.4f} | {overshoot:>9.2f}x")

Cost-Ratio → w_max in Other Domains

The ceiling generalises to any AMNL-style weighting where you have an operational cost ratio. As long as the cost-ratio is at most 3, w_max stays in the stable regime.

DomainLate vs early cost ratioRecommended w_maxNotes
RUL prediction (this book)~2× (NASA score asymmetry)2.0paper default
Battery state-of-health~1.5× (premature derate vs strand)1.5modest tilt
Wind-turbine SCADA~2.5× (crane swap vs outage)2.5tight but stable
Hospital ICU triage score~3× (extra hour vs missed deteriorate)3.0edge of stable regime
Wildfire risk forecast&gt; 10× (false alarm vs missed wildfire)use focal loss insteadoutside AMNL&apos;s linear regime
Power-grid frequency&gt; 50× (spot buy vs blackout)use focal loss insteadoutside AMNL&apos;s linear regime
If your cost ratio is > 3, switch loss family. AMNL was designed for the linear-emphasis regime. Asymmetric risks beyond ~3× should use NASA-score-style asymmetric loss (§13.1) or focal loss instead - those have intrinsic emphasis ceilings that survive larger ratios.

Three w_max Pitfalls

Pitfall 1: Treating w_max as a hyperparameter. Tuning w_max on validation invites overfitting. Pick w_max from operational cost ratio (§13.3 formula) and freeze it.
Pitfall 2: Compensating for high w_max with smaller lr. Halving the learning rate when going from w_max=2 to w_max=5 SOUNDS like it should restore stability. It doesn't - the p99/median ratio is invariant under a global lr scale. The instability is in the SHAPE of the gradient distribution, not its overall magnitude.
Pitfall 3: Trusting final-epoch loss as the only signal. At w_max=5 the final loss often LOOKS converged - but the peak loss during training was 5x higher and the model spent many epochs recovering. Always log per-step loss and report peak/final overshoot, not just final loss.
The point. w_max=2 is not arbitrary - it's the largest ceiling that keeps Adam's adaptive denominator in the regime where the per-task gradient variance scales linearly. §14.4 wires this into the §11.4 DualTaskModel and runs the full PyTorch loss as the paper ships it.

Takeaway

  • Ceiling = 2.0. Paper choice. Largest stable value.
  • Variance is quadratic in w_max. E[w2]=13(1+wmax+wmax2)\mathbb{E}[w^2] = \tfrac{1}{3}(1 + w_{\max} + w_{\max}^2) drives the per-sample gradient variance.
  • Adam overshoot factor. peak/final loss ratio across a training run. Doubles from w_max=2 to w_max=5; quadruples to w_max=10.
  • Cost-ratio > 3 ⇒ different loss family. AMNL is for modest asymmetry. Use NASA-style asymmetric loss (§13.1) for safety-critical ratios.
  • Don't tune w_max on val. Set from operational cost ratio, freeze, never retune.
Loading comments...