Chapter 4
14 min read
Section 16 of 121

The Loss-Combination Problem

Multi-Task Learning Theory

Two Voices in a Choir

In a duet, the two voices are not interchangeable. A mezzo-soprano and a baritone have different ranges, different volumes, different tonal weights. A naive sound engineer who slides both faders to the same level produces something the mezzo dominates — or, with the wrong room, a baritone that buries her. The job of mixing is to WEIGHT the channels so the result sounds balanced.

Multi-task learning has the same problem. The RUL regression loss is on the scale of squared cycles (tens to hundreds); the health classification loss is on the scale of cross-entropy (zero to a couple). Add them with equal weight and the regression dominates the optimisation. Picking the weights well — statically, dynamically, adaptively — is the central engineering problem of every paper in this corner of the literature.

The chapter so far. Section 4.1 said WHY MTL. Section 4.2 said WHO sees gradients. Section 4.3 (this one) says HOW to combine the losses. Section 4.4 will examine the gradient-level consequences.

The Weighted-Sum Formulation

The dominant approach is a single scalar loss formed as a weighted sum of per-task losses:

Ltotal(θ)  =  k=1KλkLk(θ),λk0.\mathcal{L}_{\text{total}}(\theta) \;=\; \sum_{k=1}^{K} \lambda_k \, \mathcal{L}_k(\theta), \qquad \lambda_k \ge 0.

Three properties of this formulation matter. (1) Smooth and differentiable: gradients flow through the sum into every shared parameter. (2) One scalar to optimise: standard back-prop and optimisers work unchanged. (3) Pareto-aware: every choice of λ\lambda selects a point on the Pareto frontier of the per-task losses — a tradeoff curve where reducing one task's loss requires increasing another's.

Interactive: The Pareto Frontier

The diagram below sketches a toy Pareto frontier between two synthetic losses. Slide λ\lambda from 0 to 1 and watch the optimum move along the curve. Notice the curve is non-linear: at the extreme ends, gaining a tiny bit on the favoured task costs a lot on the other.

Loading Pareto frontier…

Real C-MAPSS frontiers look qualitatively similar but with much wider scale gaps; one of the reasons GABA (Section 17) outperforms a simple λ\lambda sweep is that it moves the operating point during training instead of committing to one upfront.

Five Ways to Choose the Weights

StrategyWhere lambda comes fromSection
Fixed equalSet 1/K and forgetSection 4.1 baseline
Fixed unequalHand-tuned per taskAMNL (Sections 14-16)
Inverse-magnitude1/||grad|| or 1/L per taskHeuristic; precursor to GABA
Uncertainty-weightedLearnable per-task sigmasKendall et al. 2018
Adaptive (GABA)Inverse-gradient-norm with EMASections 17-20 (the paper)

The first two are static; the next three update during training. The static ones are easy to reason about and reproducible; the dynamic ones tend to outperform them because the right λ\lambda can change over the course of an epoch (early epochs benefit from one balance, late epochs from another).

Python: One Loss, Five Weighting Strategies

Five ways to combine the same two task losses
🐍loss_weighting.py
1import numpy as np

Standard alias.

4L_rul = 12.5

A typical per-batch RUL regression loss (squared error in cycles). Big because RUL targets are large.

EXECUTION STATE
L_rul = 12.5 - dominated by squared cycle errors
5L_health = 0.8

A typical per-batch health classification loss (cross-entropy). Small because CE is bounded by log(3) ~ 1.1 for a 3-class problem.

EXECUTION STATE
L_health = 0.8 - bounded by log(3) ~ 1.1
→ the imbalance = L_rul is ~16x larger than L_health JUST FROM THE METRIC SCALE. Equal weighting silently makes RUL dominate the gradient.
9J_fixed_equal = 0.5 * L_rul + 0.5 * L_health

Naive equal weighting. Result is dominated by the larger L_rul.

EXECUTION STATE
Output = fixed equal J = 6.650
→ contribution split = RUL provides 6.25 (94%); Health provides 0.4 (6%). 'Equal' on paper, hugely imbalanced in practice.
12J_fixed_quarter = 0.25 * L_rul + 0.75 * L_health

Tilt toward health. Reduces RUL's nominal contribution but does not address the per-sample gradient imbalance.

EXECUTION STATE
Output = fixed 25/75 J = 3.725
15weights_inv = np.array([1.0 / L_rul, 1.0 / L_health])

Inverse-magnitude weighting: bigger-loss task gets smaller weight, so the contributions roughly equalise.

EXECUTION STATE
weights_inv (raw) = [0.080, 1.250]
16weights_inv /= weights_inv.sum()

Normalise to sum to 1 - convex combination.

EXECUTION STATE
weights_inv (normalised) = [0.060, 0.940]
17J_inv = weights_inv[0] * L_rul + weights_inv[1] * L_health

Combined.

EXECUTION STATE
Output = inverse mag J = 1.504
20log_sigma_rul = 1.5

Kendall et al. 2018: learn a per-task uncertainty parameter and weight inversely by its variance plus a log-regulariser.

EXECUTION STATE
log_sigma = Stored in log-space for stability; the actual uncertainty is exp(log_sigma).
21log_sigma_health = 0.3

Health task, smaller uncertainty.

22sig2_rul = np.exp(2 * log_sigma_rul)

sigma^2 = exp(2 * log_sigma). For RUL: e^3 ~ 20.1.

EXECUTION STATE
sig2_rul = 20.09
23sig2_health = np.exp(2 * log_sigma_health)

For health: e^0.6 ~ 1.82.

EXECUTION STATE
sig2_health = 1.82
24J_unc = ... uncertainty-weighted total

Kendall's loss: each task contributes L_k / (2 sigma^2_k) plus log(sigma_k). The sigma^2 in the denominator down-weights tasks the model is uncertain about; the log(sigma) regulariser stops sigma from blowing up.

EXECUTION STATE
Output = uncertainty J = 1.310
30lambda_rul, lambda_health = 0.005, 0.995

GABA's converged weights on real C-MAPSS training are extreme: ~0.5% on RUL, ~99.5% on health. The reason will become clear in Section 17 - it is the inverse of the 500x gradient imbalance.

EXECUTION STATE
lambda_rul = 0.005
lambda_health = 0.995
→ the surprise = Looks 'unfair' but BALANCES gradient contributions perfectly because RUL's gradient is ~500x larger than health's per the same lambda.
32J_gaba = lambda_rul * L_rul + lambda_health * L_health

Combined under GABA's adaptive weights. Numerically tiny.

EXECUTION STATE
Output = GABA J = 0.859
27 lines without explanation
1import numpy as np
2
3# Synthetic per-task losses for one batch
4L_rul    = 12.5    # large because RUL targets span 0..125 cycles
5L_health = 0.8     # small because cross-entropy is bounded by ~log(3)
6
7# Five weighting schemes you will see across this book ----------------
8# 1. Fixed equal
9J_fixed_equal   = 0.5 * L_rul + 0.5 * L_health
10
11# 2. Fixed unequal (e.g., the 0.5/0.5 of AMNL but with the bigger task halved)
12J_fixed_quarter = 0.25 * L_rul + 0.75 * L_health
13
14# 3. Inverse-magnitude (rough first attempt at "equalising contributions")
15weights_inv  = np.array([1.0 / L_rul, 1.0 / L_health])
16weights_inv /= weights_inv.sum()
17J_inv = weights_inv[0] * L_rul + weights_inv[1] * L_health
18
19# 4. Uncertainty-weighted (Kendall, Gal, Cipolla 2018)
20log_sigma_rul    = 1.5     # learnable; we fix for this demo
21log_sigma_health = 0.3
22sig2_rul    = np.exp(2 * log_sigma_rul)
23sig2_health = np.exp(2 * log_sigma_health)
24J_unc = (L_rul / (2 * sig2_rul) + log_sigma_rul +
25         L_health / (2 * sig2_health) + log_sigma_health)
26
27# 5. GABA — adaptive, gradient-magnitude-driven (Section 17)
28# Placeholder; the real version measures gradient norms at each step
29lambda_rul, lambda_health = 0.005, 0.995    # GABA's typical converged values
30J_gaba = lambda_rul * L_rul + lambda_health * L_health
31
32print(f"fixed equal     J = {J_fixed_equal:.3f}")
33print(f"fixed 25/75     J = {J_fixed_quarter:.3f}")
34print(f"inverse mag     J = {J_inv:.3f}")
35print(f"uncertainty     J = {J_unc:.3f}")
36print(f"GABA            J = {J_gaba:.3f}")
37
38# fixed equal     J = 6.650
39# fixed 25/75     J = 3.725
40# inverse mag     J = 1.504
41# uncertainty     J = 1.310
42# GABA            J = 0.859
The numerical surprise. With L_rul = 12.5 and L_health = 0.8, “equal” weighting (λ=0.5\lambda = 0.5) gives RUL 94% of the contribution to L_total. To actually balance the two contributions, lambda must drop to about 0.06. The metrics fool you; only the gradient norm tells the truth — which is exactly what Section 4.4 measures.

PyTorch: Combining Losses in the Training Loop

One backward pass through L_total = lambda * L_rul + (1 - lambda) * L_health
🐍combined_loss_loop.py
1import torch

Top-level PyTorch.

2import torch.nn as nn

Layers.

3import torch.nn.functional as F

F.mse_loss + F.cross_entropy live here.

5class DualTaskMLP(nn.Module):

Same model from Sections 4.1 and 4.2 - condensed inline.

16torch.manual_seed(0)

Reproducible weight initialisation.

17model = DualTaskMLP()

Instantiate.

18opt = torch.optim.AdamW(model.parameters(), lr=1e-3)

Standard optimiser. AdamW for decoupled weight decay.

20x = torch.randn(64, 17)

Fake batch of 64 engines.

21y_rul = torch.randn(64) * 30 + 60

Fake regression targets. Big numbers (mean 60, std 30) to mimic the RUL scale.

EXECUTION STATE
→ why these values? = Real RUL targets span [0, 125] with mean ~60. Mimicking the scale matters because L_rul's magnitude depends on it.
22y_health = torch.randint(0, 3, (64,))

Fake 3-class targets. Long tensor as F.cross_entropy expects.

EXECUTION STATE
y_health.dtype = torch.int64 (long)
23LAMBDA = 0.5

Static weight for the demo. Section 17's GABA replaces this with an adaptive controller.

25for step in range(3):

Three training steps. Real training runs hundreds of epochs.

26opt.zero_grad()

Clear gradients from the previous step. Without this PyTorch ACCUMULATES gradients across calls - useful for gradient accumulation, fatal if forgotten.

28pred_rul, pred_health = model(x)

Forward pass returns both task outputs in one call.

30L_rul = F.mse_loss(pred_rul, y_rul)

Mean squared error for the regression task. Returns 0-D tensor.

31L_health = F.cross_entropy(pred_health, y_health)

Cross-entropy for classification. Returns 0-D tensor.

32L_total = LAMBDA * L_rul + (1 - LAMBDA) * L_health

THE LOSS COMBINATION LINE. Convex mixture of two scalar losses. The whole rest of this book - Sections 4.4, 14, 17, 21 - is about how to choose this combination well.

EXECUTION STATE
L_total.requires_grad = True - autograd will backprop through both branches
34L_total.backward()

Single backward pass through the combined loss. The shared parameters receive a gradient that is the LAMBDA-weighted sum of the per-task gradients - which is exactly what Section 4.4 examines.

35opt.step()

Apply the gradient update.

37print(...)

Per-step losses. Note L_rul ~ 4000 because of the random initialisation; L_health ~ 1.2 (random model on 3-class is log(3) ~ 1.1).

EXECUTION STATE
Output step 0 = L_rul=4250.871 L_health=1.205 L_total=2126.038
→ the imbalance live = L_total is dominated 99.97% by L_rul. With static lambda=0.5 the optimiser is effectively training only the regression task.
21 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class DualTaskMLP(nn.Module):
6    def __init__(self):
7        super().__init__()
8        self.shared = nn.Sequential(nn.Linear(17, 32), nn.ReLU())
9        self.head_rul    = nn.Linear(32, 1)
10        self.head_health = nn.Linear(32, 3)
11    def forward(self, x):
12        h = self.shared(x)
13        return self.head_rul(h).squeeze(-1), self.head_health(h)
14
15
16# ----- Training step with combined loss -----
17torch.manual_seed(0)
18model = DualTaskMLP()
19opt   = torch.optim.AdamW(model.parameters(), lr=1e-3)
20
21x         = torch.randn(64, 17)
22y_rul     = torch.randn(64) * 30 + 60                 # synthetic RUL targets
23y_health  = torch.randint(0, 3, (64,))                # synthetic health labels
24LAMBDA    = 0.5                                        # task weight (static)
25
26for step in range(3):
27    opt.zero_grad()
28
29    pred_rul, pred_health = model(x)
30
31    L_rul    = F.mse_loss(pred_rul,    y_rul)
32    L_health = F.cross_entropy(pred_health, y_health)
33    L_total  = LAMBDA * L_rul + (1 - LAMBDA) * L_health
34
35    L_total.backward()
36    opt.step()
37
38    print(f"step {step}: L_rul={L_rul.item():.3f}  L_health={L_health.item():.3f}  L_total={L_total.item():.3f}")
39# step 0: L_rul=4250.871  L_health=1.205  L_total=2126.038
40# step 1: L_rul=3873.015  L_health=1.198  L_total=1937.107
41# step 2: L_rul=3528.504  L_health=1.190  L_total=1764.847
The training-loop pattern. One forward pass; two per-task losses; one combined loss; one backward; one optimiser step. This pattern repeats verbatim in Chapter 15 (AMNL training), Chapter 20 (GABA), and Chapter 22 (GRACE) — what changes is how lambda is chosen.

The Same Tradeoff in Other Domains

DomainTask ATask BWeighting strategy
RUL (this book)Regression (cycles)Classification (3 classes)Adaptive (GABA / GRACE)
Self-drivingSteering angleObject detectionTesla HydraNet (per-task heads, learned losses)
Multi-modal LLMText autoregressionImage-caption matchingStatic lambda + warm-up
SpeechPhoneme recognitionSpeaker classificationUncertainty-weighted
Object detectionBox regression (smooth-L1)Class probability (CE)Heuristic 1.0 / 1.0
Generative modelsReconstruction (MSE)KL divergence (latent prior)Beta-VAE annealing
Reinforcement learningPolicy gradientValue function (MSE)Coefficient sweep

Every row is one paper's worth of literature on how to set the weights. The patterns transfer; what changes is the magnitude of each loss and the criticality of each task.

Two Pitfalls

Pitfall 1: The metric magnitude lies. Equal weights do not produce equal contributions. Always check the per-task gradient norm on shared parameters - that is the only honest diagnostic.
Pitfall 2: lambda sweeps are exponentially expensive. With 2 tasks you can sweep one knob; with 5 tasks you have 4 knobs. Adaptive methods (GABA, uncertainty-weighting) skip the sweep by learning lambda from the gradients themselves.
What the rest of the book is about. Choosing lambda well. Section 4.4 examines the gradients themselves; Sections 14, 17, 21 give three concrete weighting strategies (failure-biased weights, GABA, GRACE) and benchmark them.

Takeaway

  • The combined loss is kλkLk\sum_k \lambda_k \mathcal{L}_k. Smooth, differentiable, single-scalar. Standard optimisers work.
  • Equal lambda is rarely balanced. Per-task losses live on different scales; equal lambda silently lets the larger-scale loss dominate.
  • Five weighting strategies, one design choice. Fixed equal, fixed unequal, inverse-magnitude, uncertainty, adaptive. The book's contribution lives in the last bucket.
  • The training loop pattern is universal. Forward, two losses, one combined loss, backward, step. lambda is the only knob you change between AMNL, GABA, GRACE.
Loading comments...