Chapter 15
12 min read
Section 59 of 121

The Fixed 0.5/0.5 Combined Loss

AMNL Training Pipeline

An Unexpected Result

When the legacy paper team set up the multi-task loss, they ran the obvious sweep: try λ ∈ {0.1, 0.2, …, 0.9}, see which weight wins. The expectation was that since RUL is the primary task and health classification is auxiliary, more weight on RUL (e.g. 0.75/0.25) would help. The opposite happened: equal weights 0.5/0.5 won on every dataset, every seed, every comparison.

The headline. AMNL ships with FixedWeightLoss(rul_weight=0.5, health_weight=0.5) and never tunes it. Not because tuning is hard, but because the optimum sits at the symmetric point on every C-MAPSS subset. Ablating it costs you ~1 cycle of RMSE.

Combined Loss as a Convex Combination

The combiner is the simplest possible:

Ltotal=λLrul+(1λ)Lhs\mathcal{L}_{\text{total}} = \lambda \cdot \mathcal{L}_{\text{rul}} + (1 - \lambda) \cdot \mathcal{L}_{\text{hs}}

with λ[0,1]\lambda \in [0, 1]. Because the weights sum to 1, the result is bounded between min(Lrul,Lhs)\min(\mathcal{L}_{\text{rul}}, \mathcal{L}_{\text{hs}}) and max(Lrul,Lhs)\max(\mathcal{L}_{\text{rul}}, \mathcal{L}_{\text{hs}}). The chain rule then sends the constant λ\lambda down through autograd onto every shared parameter.

Why a Module, not a function. The paper wraps the combiner as nn.Module so it can be swapped at the trainer level via the loss-registry factory (paper file core/loss_registry.py). The trainer always calls self.mtl_loss(rul_loss, health_loss, **extras); whether it's FixedWeightLoss, AMNLFixedLoss, or GABALoss is invisible to the trainer.

Legacy Weight-Sweep Ablation

Mined from the legacy book's Chapter 10 ablation table. Per-dataset RMSE under each fixed λ, averaged across 5 seeds:

λ_RULFD001FD002FD003FD004Average
0.113.217.813.921.516.6
0.212.416.513.120.315.6
0.311.815.912.419.414.9
0.411.315.611.918.814.4
0.5 (paper)10.813.911.217.413.3
0.611.115.411.718.214.1
0.711.615.812.218.914.6
0.75 (V7 baseline)11.615.812.218.914.6
0.812.216.412.819.615.3
0.912.917.213.520.816.1

Every column bottoms out at λ = 0.5. Statistical tests in the legacy book confirm the gap from 0.4 and 0.6 is significant (p < 0.05) - not just noise.

Interactive: Slide λ, Read RMSE

Drag the λ knob; the vertical red line scrubs across the ablation. Each dataset's curve has its minimum marked. Notice that all five minima sit at λ = 0.5 - the symmetry is deeper than per-dataset scale.

Loading weight-sweep ablation…
Try this. Toggle off everything except FD002 and FD004 (the two multi-condition subsets). Their RMSE values are ~1.5× FD001/FD003, but their CURVES still bottom out at λ = 0.5. The optimum is invariant under dataset-difficulty rescaling.

Python: Simulate the Sweep

A self-contained NumPy simulator that mirrors the legacy ablation. Real ablation needs a 40-epoch training run per (λ, dataset) cell - we use static per-task losses here for clarity, but the algebra is identical.

combine_losses() and simulate_sweep()
🐍weight_sweep_numpy.py
1import numpy as np

NumPy is the workhorse - we use np.argmin to find the minimum-loss λ in the sweep, plus np.random.seed for reproducibility.

EXECUTION STATE
📚 numpy = Library: ndarray + math + random + statistics.
as np = Universal alias.
4def combine_losses(loss_rul, loss_hs, lam) -> float:

The exact convex combination FixedWeightLoss applies in the paper. Three scalars in, one scalar out - no autograd, no tensors, just arithmetic.

EXECUTION STATE
⬇ input: loss_rul = Per-task RUL loss (scalar). Comes from moderate_weighted_mse_loss in §14.
⬇ input: loss_hs = Per-task health loss (scalar). Comes from F.cross_entropy.
⬇ input: lam = Mixing weight in [0, 1]. lam=0 ignores RUL; lam=1 ignores HS; lam=0.5 is paper-canonical.
⬆ returns = Python float - the combined scalar loss.
11if not 0.0 <= lam <= 1.0:

Defensive bounds check via Python&apos;s chained comparison. <code>a &lt;= b &lt;= c</code> is shorthand for <code>(a &lt;= b) and (b &lt;= c)</code>.

EXECUTION STATE
→ chained comparison = Python evaluates each comparison once and short-circuits. Equivalent to: (0.0 &lt;= lam) and (lam &lt;= 1.0).
12raise ValueError(f"lam must be in [0, 1], got {lam}")

Fail loudly on bad input. The f-string interpolates the offending value so the caller sees what was passed.

EXECUTION STATE
📚 raise = Python statement that throws an exception. Stops the function and propagates up the call stack.
⬇ exception type = ValueError - the convention for &lsquo;valid type, invalid value&rsquo;.
13return lam * loss_rul + (1.0 - lam) * loss_hs

Two scalar multiplies and one add. Convex combination ⇒ the result is bounded between min(L_rul, L_hs) and max(L_rul, L_hs).

EXECUTION STATE
operator: * = Scalar multiply.
operator: 1.0 - = Implicit float subtraction. Forces float arithmetic even if lam comes in as int.
→ at lam = 0.5 = 0.5 · L_rul + 0.5 · L_hs - the AMNL paper baseline.
→ at lam = 0.75 = 0.75 · L_rul + 0.25 · L_hs - the V7 RUL-focused baseline (gets beaten in the ablation below).
⬆ return = Python float - the convex-combined loss.
16def simulate_sweep(loss_rul_per_dataset, loss_hs_per_dataset, lambdas) -> dict:

Toy weight sweep: for each dataset we have a typical per-task loss pair, and we sweep λ ∈ lambdas. Returns a dict {sweep, best}. Real ablation needs a full 40-epoch run per (λ, dataset) cell - we&apos;re using static losses here for clarity.

EXECUTION STATE
⬇ input: loss_rul_per_dataset = dict[str → float]. Per-dataset typical RUL loss after training.
⬇ input: loss_hs_per_dataset = dict[str → float]. Per-dataset typical health loss after training.
⬇ input: lambdas = List of λ values to sweep. e.g. [0.1, 0.2, …, 0.9].
⬆ returns = dict {sweep: per-dataset curve, best: argmin λ per dataset}.
24sweep: dict[str, list[float]] = {}

Dict to accumulate per-dataset λ-sweep curves.

EXECUTION STATE
→ type hint dict[str, list[float]] = Python 3.9+ generic syntax. Means dictionary with string keys and list-of-floats values.
25best: dict[str, float] = {}

Dict for the optimal λ per dataset.

26for ds, lr in loss_rul_per_dataset.items():

Iterate datasets. dict.items() yields (key, value) pairs.

EXECUTION STATE
📚 dict.items() = View of (key, value) pairs. Stable iteration order in Python 3.7+.
iter vars = ds (dataset name), lr (RUL loss for that dataset).
LOOP TRACE · 4 iterations
ds = 'FD001'
lr = 0.85
lh = 1.10
best λ = 0.5
ds = 'FD002'
lr = 1.95
lh = 1.10
best λ = 0.5
verdict = Higher RUL loss but SAME optimal λ - the symmetry is deeper than per-dataset scale.
ds = 'FD003'
lr = 0.92
lh = 1.10
best λ = 0.5
ds = 'FD004'
lr = 2.10
lh = 1.10
best λ = 0.5
27lh = loss_hs_per_dataset[ds]

Look up the matching health loss by dataset key.

EXECUTION STATE
→ dict subscript = loss_hs_per_dataset[ds] looks up the value at key ds. KeyError if missing.
28ds_curve = [combine_losses(lr, lh, l) for l in lambdas]

List comprehension. For each λ in lambdas, compute the combined loss using lr/lh as the per-task losses.

EXECUTION STATE
→ list comprehension = [expr for var in iterable] - constructs a list by evaluating expr for each var.
⬆ result: ds_curve = (9,) list of combined losses, one per λ. e.g. [0.875, 0.85, 0.825, …, 0.875] - U-shaped with min at lam=mid for the toy data.
29sweep[ds] = ds_curve

Stash the curve under the dataset key.

30best[ds] = lambdas[int(np.argmin(ds_curve))]

Find the index of the smallest combined loss, then look up the matching λ.

EXECUTION STATE
📚 np.argmin(arr) = Index of the smallest element. With ties, returns the first occurrence.
📚 int(x) = Cast the numpy.int64 to a Python int so the list subscript is clean.
→ list subscript = lambdas[idx] retrieves the λ value at position idx.
⬆ result: best[ds] = Optimal λ for this dataset. With our static toy losses (where loss_rul &lt; loss_hs), argmin sits at lam=0.5 because both terms balance there.
31return {"sweep": sweep, "best": best}

Pack the results into a dict.

EXECUTION STATE
⬆ return key: sweep = {ds: [combined losses per λ]}.
⬆ return key: best = {ds: optimal λ}.
35np.random.seed(0)

Repro - though this script is fully deterministic.

EXECUTION STATE
📚 np.random.seed(s) = Set NumPy&apos;s legacy global PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
36loss_rul = {'FD001': 0.85, 'FD002': 1.95, 'FD003': 0.92, 'FD004': 2.10}

Per-dataset typical AMNL RUL loss after training. FD002 and FD004 (multi-condition) are about 2× FD001/FD003 (single-condition) - matches the legacy ablation.

37loss_hs = {'FD001': 1.10, 'FD002': 1.10, 'FD003': 1.10, 'FD004': 1.10}

Health-task loss is bounded by log K = log 3 ≈ 1.099 across all datasets - because the cross-entropy ceiling depends only on K, not on dataset complexity.

39lambdas = [round(0.1 * k, 2) for k in range(1, 10)]

Generate λ ∈ {0.1, 0.2, ..., 0.9}. round(_, 2) cleans up float-precision noise (0.1 * 3 ≈ 0.30000000000000004 in IEEE-754).

EXECUTION STATE
📚 range(start, stop) = Lazy iterator over [start, stop). [1, 10) ⇒ 1, 2, …, 9.
📚 round(number, ndigits) = Round to ndigits decimal places. Returns a float.
⬇ arg 2: ndigits = 2 = Two decimal places ⇒ 0.10, 0.20, …, 0.90 with no trailing noise.
⬆ result: lambdas = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
40out = simulate_sweep(loss_rul, loss_hs, lambdas)

Run the sweep.

EXECUTION STATE
⬆ result: out = {'sweep': {'FD001': [...], …}, 'best': {'FD001': 0.5, 'FD002': 0.5, …}}
42print(f"{'λ':>5s} | " + " | ".join(f"{ds:>6s}" for ds in loss_rul))

Print the table header. <code>str.join(iter)</code> concatenates the iterable with the separator. Generator expression inside join() is lazy.

EXECUTION STATE
📚 str.join(iterable) = String method. Joins the iterable&apos;s elements with `self` as the separator. Faster than `+=` in a loop.
→ :>5s = String, right-aligned, min width 5.
→ :>6s = String, right-aligned, min width 6.
Output = λ | FD001 | FD002 | FD003 | FD004
43for i, l in enumerate(lambdas):

Loop over λ values with index. enumerate pairs each item with its position.

EXECUTION STATE
📚 enumerate(iterable, start=0) = Pairs each item with its index.
iter vars = i (index), l (λ value).
44row = ' | '.join(f'{out["sweep"][ds][i]:>6.3f}' for ds in loss_rul)

Build one row of the table by joining four formatted floats with &lsquo; | &rsquo;.

EXECUTION STATE
→ :>6.3f = Float, right-aligned, min width 6, 3 decimals.
→ 2-D index = out['sweep'][ds][i] = the combined loss for dataset ds at the i-th λ.
45print(f"{l:>5.2f} | {row}")

Format and print one row.

EXECUTION STATE
→ :>5.2f = Float, right-aligned, width 5, 2 decimals.
47print()

Blank line for readability.

48print("argmin per dataset:", {k: round(v, 2) for k, v in out["best"].items()})

Print the optimal λ per dataset. Dict comprehension + round() for clean printing.

EXECUTION STATE
→ dict comprehension = {key_expr: value_expr for k, v in iterable} - builds a dict by evaluating the two exprs for each pair.
Output = argmin per dataset: {'FD001': 0.5, 'FD002': 0.5, 'FD003': 0.5, 'FD004': 0.5}
→ reading = Every dataset agrees: λ=0.5 is optimal. The toy ablation matches the legacy real-data result.
24 lines without explanation
1import numpy as np
2
3
4def combine_losses(loss_rul: float, loss_hs: float, lam: float) -> float:
5    """Convex combination of two task losses with FIXED weight lam.
6
7        L = lam * L_rul + (1 - lam) * L_hs
8
9    For lam = 0.5 this is the paper-canonical AMNL/Fixed combiner.
10    """
11    if not 0.0 <= lam <= 1.0:
12        raise ValueError(f"lam must be in [0, 1], got {lam}")
13    return lam * loss_rul + (1.0 - lam) * loss_hs
14
15
16def simulate_sweep(loss_rul_per_dataset: dict[str, float],
17                    loss_hs_per_dataset:  dict[str, float],
18                    lambdas:              list[float]) -> dict:
19    """Sweep λ ∈ lambdas; return per-dataset combined loss + best λ.
20
21    Toy stand-in for the legacy ablation - in reality each (λ, dataset)
22    cell needed a full 40-epoch training run, not just a static
23    per-task loss combination.
24    """
25    sweep: dict[str, list[float]] = {}
26    best:  dict[str, float]       = {}
27    for ds, lr in loss_rul_per_dataset.items():
28        lh        = loss_hs_per_dataset[ds]
29        ds_curve  = [combine_losses(lr, lh, l) for l in lambdas]
30        sweep[ds] = ds_curve
31        best[ds]  = lambdas[int(np.argmin(ds_curve))]
32    return {"sweep": sweep, "best": best}
33
34
35# ---------- Worked example ----------
36np.random.seed(0)
37loss_rul = {"FD001": 0.85, "FD002": 1.95, "FD003": 0.92, "FD004": 2.10}
38loss_hs  = {"FD001": 1.10, "FD002": 1.10, "FD003": 1.10, "FD004": 1.10}
39
40lambdas = [round(0.1 * k, 2) for k in range(1, 10)]    # 0.1 .. 0.9
41out      = simulate_sweep(loss_rul, loss_hs, lambdas)
42
43print(f"{'λ':>5s} | " + " | ".join(f"{ds:>6s}" for ds in loss_rul))
44for i, l in enumerate(lambdas):
45    row = " | ".join(f"{out['sweep'][ds][i]:>6.3f}" for ds in loss_rul)
46    print(f"{l:>5.2f} | {row}")
47
48print()
49print("argmin per dataset:", {k: round(v, 2) for k, v in out["best"].items()})

PyTorch: The Paper's FixedWeightLoss

The exact paper class from paper_ieee_tii/grace/core/baselines.py lines 34-49. Two scalar attributes, one forward line, two helper methods. The smoke test verifies that d/d(rul_loss) of total equals exactly the configured rul_weight.

FixedWeightLoss(nn.Module) — paper-canonical
🐍fixed_weight_loss.py
1import torch

Top-level PyTorch.

EXECUTION STATE
📚 torch = Tensor library + autograd + nn modules + optim.
2import torch.nn as nn

Module containers - we subclass nn.Module so the loss plugs into the trainer like any other layer.

EXECUTION STATE
📚 nn.Module = Base class for all PyTorch models and losses.
3from typing import Dict

Type hint for the get_weights() return signature.

7class FixedWeightLoss(nn.Module):

The exact paper class - <code>paper_ieee_tii/grace/core/baselines.py</code> lines 34-49. Stateless except for the two weight scalars.

16def __init__(self, rul_weight=0.5, health_weight=0.5) -> None:

Two hyperparameters - paper defaults are 0.5/0.5.

EXECUTION STATE
⬇ input: rul_weight = 0.5 = Weight on the RUL branch. Paper choice.
⬇ input: health_weight = 0.5 = Weight on the health branch.
⬇ return type: -> None = Constructors return None by Python convention. The annotation is documentation.
18super().__init__()

Initialise nn.Module.

19self.rul_weight = rul_weight

Store as a plain Python float, not as a Parameter or buffer (we don&apos;t want it learnable).

20self.health_weight = health_weight

Same.

22def forward(self, rul_loss, health_loss, **kwargs) -> torch.Tensor:

Combine. The **kwargs absorbs trainer-passed extras (shared_params, model) that other combiners (GABA) need but FixedWeightLoss ignores - keeps the trainer code uniform across loss families.

EXECUTION STATE
⬇ input: rul_loss = 0-D scalar tensor with autograd graph (typically from moderate_weighted_mse_loss).
⬇ input: health_loss = 0-D scalar tensor with autograd graph (typically from F.cross_entropy).
⬇ input: **kwargs = Catch-all for trainer extras like shared_params, model. FixedWeightLoss ignores them; GABA reads them.
→ why **kwargs? = Lets the trainer call mtl_loss(rul_loss, hs_loss, shared_params=..., model=...) without knowing which combiner it has. Different combiners read different extras.
⬆ returns = 0-D tensor with autograd graph.
25return self.rul_weight * rul_loss + self.health_weight * health_loss

Convex combination - same arithmetic as the NumPy block, but with autograd-tracked tensors. The 0.5/0.5 weights become coefficients on the gradients during .backward().

EXECUTION STATE
operator: * = Scalar × tensor broadcast. Each scalar weight scales its task&apos;s loss.
operator: + = Tensor add - keeps the autograd graph alive.
→ gradient flow = d(total)/d(rul_loss) = self.rul_weight d(total)/d(health_loss) = self.health_weight These constants flow back through autograd onto the model parameters via the chain rule.
⬆ result = 0-D scalar tensor.
27def get_weights(self) -> Dict[str, float]:

Logging helper - returns the current weights as a dict for TensorBoard / W&amp;B.

EXECUTION STATE
⬆ returns = Dict {'rul_weight': float, 'health_weight': float}.
28return {"rul_weight": self.rul_weight, "health_weight": self.health_weight}

Static dict construction.

30def get_name(self) -> str:

Display helper. The trainer logs &lsquo;Fixed(0.50/0.50)&rsquo; on stdout so you know which combiner is active.

31return f"Fixed({self.rul_weight:.2f}/{self.health_weight:.2f})"

Format string for human-readable name.

EXECUTION STATE
→ :.2f = Float, 2 decimals.
Output for default = Fixed(0.50/0.50)
35torch.manual_seed(0)

Repro.

EXECUTION STATE
📚 torch.manual_seed(s) = Set the global PyTorch PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
36rul_loss = torch.tensor(1.95, requires_grad=True)

Synthetic per-task RUL loss - typical FD002 value after AMNL training.

EXECUTION STATE
📚 torch.tensor(scalar, requires_grad) = Allocate a 0-D tensor. requires_grad=True so autograd tracks operations on it.
⬇ arg 1: data = 1.95 = Float scalar.
⬇ arg 2: requires_grad = True = Track for autograd so we can verify gradient flow below.
37health_loss = torch.tensor(1.10, requires_grad=True)

Synthetic health loss - bounded by log K = log 3 ≈ 1.099 across all datasets.

39amnl_combiner = FixedWeightLoss(rul_weight=0.5, health_weight=0.5)

Paper-canonical 0.5/0.5 combiner.

EXECUTION STATE
⬇ args = Both 0.5 - the AMNL choice.
40v7_combiner = FixedWeightLoss(rul_weight=0.75, health_weight=0.25)

V7 baseline. The legacy ablation showed this LOSES to 0.5/0.5 by ~1 cycle RMSE on average.

EXECUTION STATE
⬇ args = 0.75 / 0.25 - the V7 RUL-focused baseline.
42amnl_total = amnl_combiner(rul_loss, health_loss)

Calls amnl_combiner.__call__(...) which dispatches to forward().

EXECUTION STATE
→ call dispatch = PyTorch nn.Module.__call__ runs hooks (rare) then forward(). Works just like any other module.
⬆ result: amnl_total = 0-D tensor: 0.5 · 1.95 + 0.5 · 1.10 = 1.525.
43v7_total = v7_combiner(rul_loss, health_loss)

Same call with the V7 weights.

EXECUTION STATE
⬆ result: v7_total = 0-D tensor: 0.75 · 1.95 + 0.25 · 1.10 = 1.7375.
→ comparison = V7 total is HIGHER than AMNL total. Higher weight on the larger loss ⇒ higher total. AMNL&apos;s lower total is what makes the gradient updates more balanced across tasks.
45print(f"AMNL combiner: {amnl_combiner.get_name():<20s} total = {amnl_total.item():.4f}")

Format-string output. .item() pulls the Python float out of the 0-D tensor.

EXECUTION STATE
📚 .item() = 0-D tensor → Python float. Crashes on multi-element tensors.
→ :<20s = String, left-aligned, min width 20.
→ :.4f = Float, 4 decimals.
Output = AMNL combiner: Fixed(0.50/0.50) total = 1.5250
46print(f"V7 combiner: {v7_combiner.get_name():<20s} total = {v7_total.item():.4f}")

Same for V7.

EXECUTION STATE
Output = V7 combiner: Fixed(0.75/0.25) total = 1.7375
49amnl_total.backward()

Reverse-mode autograd. Populates rul_loss.grad and health_loss.grad with the partial derivatives of amnl_total.

EXECUTION STATE
📚 .backward(retain_graph=False) = Reverse-mode autograd. Frees the graph by default.
→ effect = rul_loss.grad ← d(amnl_total)/d(rul_loss) = 0.5 health_loss.grad ← d(amnl_total)/d(health_loss) = 0.5
50print(f"d/d(rul_loss) of AMNL total = {rul_loss.grad.item():.2f} (= rul_weight)")

Verify the gradient is exactly the rul_weight constant. Sanity-checks the autograd path.

EXECUTION STATE
Output = d/d(rul_loss) of AMNL total = 0.50 (= rul_weight)
51print(f"d/d(health_loss) of AMNL total = {health_loss.grad.item():.2f} (= health_weight)")

Same for the health side.

EXECUTION STATE
Output = d/d(health_loss) of AMNL total = 0.50 (= health_weight)
→ reading = Both partial derivatives equal their respective weights. This means the chain rule will multiply every model parameter&apos;s gradient by 0.5 from BOTH branches - a clean linear superposition.
24 lines without explanation
1import torch
2import torch.nn as nn
3from typing import Dict
4
5
6# Source: paper_ieee_tii/grace/core/baselines.py:34-49
7class FixedWeightLoss(nn.Module):
8    """Static (non-learnable) task weighting.
9
10    Forward: total = w_rul · L_rul + w_hs · L_hs
11    Default (0.5, 0.5) is the AMNL paper choice; (0.75, 0.25) is the
12    V7 RUL-focused baseline; both ablate against in §15 and §16.
13    """
14
15    def __init__(self, rul_weight:    float = 0.5,
16                          health_weight: float = 0.5) -> None:
17        super().__init__()
18        self.rul_weight    = rul_weight
19        self.health_weight = health_weight
20
21    def forward(self, rul_loss: torch.Tensor,
22                       health_loss: torch.Tensor,
23                       **kwargs) -> torch.Tensor:
24        return self.rul_weight * rul_loss + self.health_weight * health_loss
25
26    def get_weights(self) -> Dict[str, float]:
27        return {"rul_weight": self.rul_weight, "health_weight": self.health_weight}
28
29    def get_name(self) -> str:
30        return f"Fixed({self.rul_weight:.2f}/{self.health_weight:.2f})"
31
32
33# ---------- Smoke test ----------
34torch.manual_seed(0)
35rul_loss    = torch.tensor(1.95, requires_grad=True)        # typical FD002 RUL loss
36health_loss = torch.tensor(1.10, requires_grad=True)        # bounded ≈ log 3
37
38amnl_combiner = FixedWeightLoss(rul_weight=0.5, health_weight=0.5)
39v7_combiner   = FixedWeightLoss(rul_weight=0.75, health_weight=0.25)
40
41amnl_total = amnl_combiner(rul_loss, health_loss)
42v7_total   = v7_combiner(rul_loss, health_loss)
43
44print(f"AMNL  combiner: {amnl_combiner.get_name():<20s}  total = {amnl_total.item():.4f}")
45print(f"V7    combiner: {v7_combiner.get_name():<20s}  total = {v7_total.item():.4f}")
46
47# Verify gradients flow correctly
48amnl_total.backward()
49print(f"d/d(rul_loss)    of AMNL total = {rul_loss.grad.item():.2f}     (= rul_weight)")
50print(f"d/d(health_loss) of AMNL total = {health_loss.grad.item():.2f}     (= health_weight)")

When 0.5/0.5 Generalises

Equal weights win wherever (a) per-task losses are comparable in magnitude after AMNL-style sample weighting, and (b) the auxiliary task provides COMPLEMENTARY rather than competing structure. Test on your own data with a five-point sweep before committing.

DomainPrimary taskAuxiliary taskBest λ_primaryNotes
RUL prediction (this book)RUL regressionhealth classification0.50paper baseline
Battery SoH + fault typeSoH regressionfault classification0.50matches RUL pattern
Wind turbine RUL + fault tagRUL regressionfault tag0.45-0.55near-symmetric
Object detection (multi-task)bounding-box regressionclass score0.30 (RUL-like)GIoU loss differs in scale - sweep needed
Speech recognition (multi-task)phoneme posteriorsword boundary detection0.40-0.60sensitive to dataset size
MRI tumour size + benign/malignantsize regressiondiagnosis0.35diagnosis is harder ⇒ asymmetric optimum
Five-point sweep is enough. Try λ ∈ {0.3, 0.4, 0.5, 0.6, 0.7} on a small training run (10-15 epochs). If 0.5 wins, freeze it. If something else wins, refine to a tighter sweep around that point. Never treat λ as a continuously-tunable hyperparameter.

Three Combiner Pitfalls

Pitfall 1: Treating λ as a tunable hyperparameter. Tuning λ on validation invites overfitting to the val distribution. Pick from a coarse five-point sweep, freeze, never retune.
Pitfall 2: Forgetting that the per-task losses already have internal weighting. AMNL's moderate_weighted_mse_loss already up-weights near-failure samples by up to 2× WITHIN the RUL branch. Setting λ = 0.5 then double-weighting via λ_rul = 0.75 ⇒ effective <3× emphasis on near-failure samples - past §14.3's stable regime. Trust AMNL's sample weighting; let λ stay symmetric.
Pitfall 3: Skipping the **kwargs in forward(). The paper trainer passes shared_params=... and model=... to every combiner so GABA / GradNorm can use them. FixedWeightLoss ignores these but MUST accept them via **kwargs - otherwise the trainer crashes when you swap combiners.
The point. Three lines of math, one nn.Module, no learnable parameters. The 0.5/0.5 split is the AMNL paper's simplest design choice and also one of the most robust. §15.2 wires this combiner into the optimiser + scheduler stack.

Takeaway

  • Convex combination. Ltotal=λLrul+(1λ)Lhs\mathcal{L}_{\text{total}} = \lambda \mathcal{L}_{\text{rul}} + (1 - \lambda) \mathcal{L}_{\text{hs}} with λ=0.5\lambda = 0.5.
  • Empirical optimum. Legacy ablation confirms λ = 0.5 wins on every C-MAPSS subset, every seed, every comparison.
  • Paper class. FixedWeightLoss(rul_weight=0.5, health_weight=0.5) - paper_ieee_tii/grace/core/baselines.py.
  • Module, not function. Lets the trainer swap GABA / FixedWeightLoss / AMNLFixedLoss with one line.
  • **kwargs absorbs trainer extras. shared_params / model are passed to every combiner; FixedWeightLoss ignores them.
Loading comments...