Chapter 4
15 min read
Section 17 of 121

A Gradient-Level View of MTL

Multi-Task Learning Theory

Tug-of-War on the Shared Parameters

Two children pull a rope from opposite ends. One is twice the size of the other; their pulls have the same direction (away from centre) but very different magnitudes. The rope's actual motion is the SUM of the two forces — and it moves toward the bigger child. Worse, if the smaller child suddenly redirects 90 degrees, the rope does not respond to the change at all — her pull is too small to matter.

That is what happens to the shared parameters of an MTL model during training. Every backward pass produces one gradient vector per task. The optimiser follows their (lambda-weighted) sum. If one task's gradient is much larger than the other's — or if their directions disagree — the shared parameters are being yanked one way while the other task's preferences are ignored.

The whole research thesis in one sentence. Loss weighting (Section 4.3) only matters insofar as it produces sensible gradients. The honest diagnostic is the gradient norm, not the loss value.

The Gradient Decomposition

With combined loss L=kλkLk\mathcal{L} = \sum_k \lambda_k \mathcal{L}_k, the gradient with respect to a shared parameter θs\theta_s decomposes linearly:

θsL  =  k=1KλkθsLk.\nabla_{\theta_s} \mathcal{L} \;=\; \sum_{k=1}^{K} \lambda_k \, \nabla_{\theta_s} \mathcal{L}_k.

This is the definition of derivative of a sum. The combined update is a convex combination (when kλk=1\sum_k \lambda_k = 1) of per-task gradients. Two properties of this linear combination determine whether learning succeeds:

PropertySymbolFailure mode
Per-task magnitudeθsLk\|\nabla_{\theta_s} \mathcal{L}_k\|If they differ by orders of magnitude, the larger task dominates
Per-task direction^θsLk\hat{\nabla}_{\theta_s} \mathcal{L}_kIf two tasks pull opposite ways, gradients partially cancel

Two Failure Modes: Magnitude and Direction

Magnitude imbalance. Suppose L1=500L2\|\nabla \mathcal{L}_1\| = 500 \cdot \|\nabla \mathcal{L}_2\| as we will measure on real C-MAPSS in Section 12. With λ1=λ2=0.5\lambda_1 = \lambda_2 = 0.5, the combined gradient is approximately equal to 12L1\frac{1}{2}\nabla \mathcal{L}_1 — task 2's influence is essentially noise.

Direction conflict. Suppose two tasks have similar magnitudes but opposite directions: L1=L2\nabla \mathcal{L}_1 = -\nabla \mathcal{L}_2. Their sum is zero. The optimiser does not move. With static lambda, the only escape is to abandon one task.

Magnitude vs direction. The paper's key empirical finding (Section 12, Zhou et al. 2025) is that on prognostic data, magnitude imbalance dominates direction conflict. That is why GABA (Section 17) addresses magnitude only, and why gradient-surgery methods like PCGrad (Section 25) underperform gradient-balancing methods on this domain.

Interactive: Combine Two Gradient Vectors

Below: two task gradients drawn as 2-D arrows from the same origin. Slide lambda; slide the angle of ghealthg_{\text{health}}; slide the magnitude ratio. The green arrow is what the optimiser actually follows.

Loading gradient combination viz…

Set magnitude ratio to 20x and the green arrow snaps to the red regardless of lambda — magnitude wins. Now drop ratio to 1x and rotate the health gradient to 180 degrees: green nearly disappears — direction conflict cancels both. These are the two failure modes the rest of the book engineers around.

Python: Per-Task Gradient Norms by Hand

A toy in NumPy that quantifies what the diagram shows. With a 100x magnitude imbalance and equal lambda, RUL contributes 99% of the total gradient. The corresponding measurement on real C-MAPSS produces a 500x ratio — meaning health contributes ~0.2% of the total update step.

A 100x imbalance + lambda=0.5 = 99% RUL contribution
🐍gradient_contributions.py
1import numpy as np

Standard alias.

4np.random.seed(0)

Reproducibility.

7g_rul = np.array([5.0])

Pretend 'gradient norm' for the RUL regression task on the shared parameter. Real C-MAPSS measures this for the actual shared layers in Section 12; 5.0 is a round number for the demo.

EXECUTION STATE
g_rul = [5.0]
8g_health = np.array([0.05])

Health classification has a 100x smaller gradient norm. On real C-MAPSS the ratio is ~500x.

EXECUTION STATE
g_health = [0.05]
11lambda_ = 0.5

Naive equal weighting.

12g_total = lambda_ * g_rul + (1 - lambda_) * g_health

The actual gradient seen by the optimiser. Even with lambda = 0.5, the result is dominated by the larger gradient.

EXECUTION STATE
g_total = [2.525] (= 0.5 * 5 + 0.5 * 0.05)
14print(f"||g_rul|| = {np.linalg.norm(g_rul):.3f}")

L2 norm of the per-task gradient.

EXECUTION STATE
Output = ||g_rul|| = 5.000
15print(f"||g_health|| = {np.linalg.norm(g_health):.3f}")

Health gradient norm.

EXECUTION STATE
Output = ||g_health|| = 0.050
16print(f"ratio = {...:.1f}x")

100x in this toy. Real C-MAPSS is 500x.

EXECUTION STATE
Output = ratio = 100.0x
17print(f"||g_total|| = {np.linalg.norm(g_total):.3f}")

Norm of the combined gradient. Closer to ||g_rul|| than to ||g_health||.

EXECUTION STATE
Output = ||g_total|| = 2.525
20contrib_rul = lambda_ * np.linalg.norm(g_rul) / np.linalg.norm(g_total)

What fraction of the total gradient magnitude came from RUL? lambda * ||g_rul|| divided by ||g_total||. Approximate; only exact when the two task gradients are aligned.

21contrib_health = (1 - lambda_) * np.linalg.norm(g_health) / np.linalg.norm(g_total)

Same for health.

22print(f"contribution from RUL : {100 * contrib_rul:.1f}%")

RUL accounts for 99% of the optimiser's update step.

EXECUTION STATE
Output = contribution from RUL : 99.0%
→ the punchline = Even though lambda = 0.5 ('equal'), 99% of the actual gradient comes from RUL. The shared parameters move in directions RUL prefers; health barely steers them.
23print(f"contribution from health : {100 * contrib_health:.1f}%")

Just 1%.

EXECUTION STATE
Output = contribution from health : 1.0%
16 lines without explanation
1import numpy as np
2
3# Toy: 1-D shared parameter, two task losses with known gradient profile
4np.random.seed(0)
5
6# Pretend the shared parameter w receives gradients from two tasks
7# Task 1 (RUL regression) - large gradient because squared error scales like RUL^2
8g_rul    = np.array([5.0])      # arbitrary "norm" 5.0 in this toy
9# Task 2 (health classification) - small because cross-entropy is bounded
10g_health = np.array([0.05])     # 100x smaller!
11
12# Combined gradient under static lambda = 0.5
13lambda_ = 0.5
14g_total = lambda_ * g_rul + (1 - lambda_) * g_health
15
16print(f"||g_rul||    = {np.linalg.norm(g_rul):.3f}")
17print(f"||g_health|| = {np.linalg.norm(g_health):.3f}")
18print(f"ratio        = {np.linalg.norm(g_rul) / np.linalg.norm(g_health):.1f}x")
19print(f"||g_total||  = {np.linalg.norm(g_total):.3f}")
20
21# What fraction of the total gradient comes from each task?
22contrib_rul    = lambda_ * np.linalg.norm(g_rul)    / np.linalg.norm(g_total)
23contrib_health = (1 - lambda_) * np.linalg.norm(g_health) / np.linalg.norm(g_total)
24print(f"contribution from RUL    : {100 * contrib_rul:.1f}%")
25print(f"contribution from health : {100 * contrib_health:.1f}%")
26
27# ratio        = 100.0x
28# ||g_total||  = 2.525
29# contribution from RUL    : 99.0%
30# contribution from health : 1.0%

PyTorch: torch.autograd.grad for Per-Task Inspection

Real measurements use torch.autograd.grad — it computes per-task gradients without mutating the model's .grad attributes, so you can inspect each task's signal in isolation. This is the same primitive Section 18 uses inside the GABA controller.

Measure ||grad_RUL|| and ||grad_health|| separately
🐍per_task_gradient_norms.py
1import torch

Top-level PyTorch.

2import torch.nn as nn

Layers.

3import torch.nn.functional as F

Loss functions.

6class DualTaskMLP(nn.Module):

Same model from §4.1 / §4.2.

16torch.manual_seed(0)

Determinism.

17model = DualTaskMLP()

Instantiate.

18x = torch.randn(64, 17)

Fake batch.

19y_rul = torch.randn(64) * 30 + 60

Fake regression targets at C-MAPSS scale.

20y_health = torch.randint(0, 3, (64,))

Fake 3-class targets.

22pred_rul, pred_health = model(x)

Single forward pass producing both task outputs.

23L_rul = F.mse_loss(pred_rul, y_rul)

Regression loss.

24L_health = F.cross_entropy(pred_health, y_health)

Classification loss.

27shared_params = list(model.shared.parameters())

List the shared trunk's parameters - what we want to inspect gradients for.

30g_rul_per_param = torch.autograd.grad(L_rul, shared_params, retain_graph=True)

Compute the gradient of L_rul with respect to the shared parameters - WITHOUT touching their .grad attributes. retain_graph=True keeps the computation graph alive so we can backprop again for L_health.

EXECUTION STATE
torch.autograd.grad(outputs, inputs, retain_graph=...) = Returns a tuple of gradients - one tensor per input. Does NOT modify .grad. The principled way to inspect per-task gradients in MTL.
retain_graph=True = Critical here - we will run backprop again on the SAME forward graph for L_health.
31g_health_per_param = torch.autograd.grad(L_health, shared_params)

Same for the classification loss. No retain_graph needed - this is the last call.

33def total_norm(grads):

Helper that flattens all per-parameter gradients and computes a single L2 norm.

34return torch.cat([g.flatten() for g in grads]).norm().item()

Flatten each gradient tensor, concatenate, take L2 norm, unwrap to a Python float.

36n_rul = total_norm(g_rul_per_param)

Norm of the RUL gradient over the shared parameters.

EXECUTION STATE
Example = n_rul ~ 8.6 (depends on init)
37n_health = total_norm(g_health_per_param)

Norm of the health gradient over the shared parameters.

EXECUTION STATE
Example = n_health ~ 0.045
39print(f"||grad_RUL|| ...")

RUL norm.

EXECUTION STATE
Output = ||grad_RUL|| over shared params: 8.6453
40print(f"||grad_health|| ...")

Health norm.

EXECUTION STATE
Output = ||grad_health|| over shared params: 0.0449
41print(f"ratio : {n_rul / n_health:.1f}x")

The headline number. ~192x in this toy setup. Real C-MAPSS measurements (Section 12) push this to 500x. This is the core empirical observation that motivates the entire research thesis.

EXECUTION STATE
Output = ratio : 192.4x
→ preview of Section 12 = On real C-MAPSS we measure 500x using n=4,120 samples across 20 training runs. The ratio is consistent across architectures and seeds.
23 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5# Same DualTaskMLP from Section 4.1
6class DualTaskMLP(nn.Module):
7    def __init__(self):
8        super().__init__()
9        self.shared = nn.Sequential(nn.Linear(17, 32), nn.ReLU())
10        self.head_rul    = nn.Linear(32, 1)
11        self.head_health = nn.Linear(32, 3)
12    def forward(self, x):
13        h = self.shared(x)
14        return self.head_rul(h).squeeze(-1), self.head_health(h)
15
16
17# ----- Compute per-task gradient norms separately -----
18torch.manual_seed(0)
19model = DualTaskMLP()
20x = torch.randn(64, 17)
21y_rul    = torch.randn(64) * 30 + 60
22y_health = torch.randint(0, 3, (64,))
23
24pred_rul, pred_health = model(x)
25L_rul    = F.mse_loss(pred_rul, y_rul)
26L_health = F.cross_entropy(pred_health, y_health)
27
28# Pull out just the SHARED parameters
29shared_params = list(model.shared.parameters())
30
31# autograd.grad(L, params) computes dL/dparams WITHOUT mutating .grad
32g_rul_per_param    = torch.autograd.grad(L_rul,    shared_params, retain_graph=True)
33g_health_per_param = torch.autograd.grad(L_health, shared_params)
34
35# Concatenate across all shared params and take the L2 norm
36def total_norm(grads):
37    return torch.cat([g.flatten() for g in grads]).norm().item()
38
39n_rul    = total_norm(g_rul_per_param)
40n_health = total_norm(g_health_per_param)
41
42print(f"||grad_RUL||    over shared params: {n_rul:.4f}")
43print(f"||grad_health|| over shared params: {n_health:.4f}")
44print(f"ratio                              : {n_rul / n_health:.1f}x")
45# ratio                              : 192.4x  (varies with seed)
The diagnostic. Run this once per epoch on a real training run; log the ratio. If it is huge (say, >10x) and your weighting is static, you are training a single-task model with the other head along for the ride.

Gradient Conflicts in Other ML Areas

DomainSource of imbalanceCommon fix
RUL (this book)Regression vs classification scaleGABA / inverse-gradient (this paper)
Reinforcement learningPolicy vs value lossCoefficient sweep, PPO clip
GANsGenerator vs discriminatorSpectral normalisation, learning-rate balancing
NLP fine-tuningPre-train objective vs downstream taskLearning-rate warm-up, layer freezing
Self-drivingSteering vs detection vs depthLoss weighting search, MTL routing
Federated learningPer-client gradient driftFedProx, FedAvg with momentum
Generative VAEReconstruction vs KL termBeta-annealing

Every row is gradient combat. The mathematical machinery is the same; what differs is which tool the field has settled on.

Preview: How Each Method Addresses This

MethodSectionWhat it does to gradients
Fixed equal weightingBaseline (Section 24)Ignores the imbalance - regression dominates
AMNLSection 14Reshapes the regression LOSS (failure-biased) - same gradient profile, different shape
GABASections 17-20ADAPTIVELY rescales lambda inversely to gradient norm - kills magnitude imbalance
GRACESections 21-23GABA + AMNL - balanced gradients with safety-tilted loss shape
GradNormSection 24Auxiliary loss that drives lambda toward equal-rate convergence
PCGradSection 25Projects out the direction conflict - addresses the WRONG failure mode here
The throughline. Every method in Parts V-VIII is a different answer to the same question: how do we get the shared parameters to receive sensible gradients from both tasks simultaneously? The next chapter (Part II) returns to data; Parts V-VII implement the three answers we recommend.

Takeaway

  • Gradient = sum of per-task contributions. L=kλkLk\nabla \mathcal{L} = \sum_k \lambda_k \nabla \mathcal{L}_k on shared parameters.
  • Two failure modes. Magnitude imbalance (one task dominates) and direction conflict (two tasks cancel).
  • On C-MAPSS, magnitude wins. The 500x ratio in Section 12 explains why magnitude-balancing methods (GABA) beat direction-surgery methods (PCGrad).
  • The diagnostic is autograd. torch.autograd.grad(L_k, shared_params) per task, then a single L2 norm. Run it once per epoch and watch the ratio.
  • The rest of the book is solutions to this problem. AMNL changes the loss shape; GABA balances magnitudes; GRACE does both.
Loading comments...