Chapter 4
14 min read
Section 14 of 121

Why Multi-Task Learning?

Multi-Task Learning Theory

The Bilingual Child

Children raised bilingual learn each language slightly more slowly in the early years — and then catch up, often surpassing monolingual peers on a battery of cognitive measures by adulthood. The brain is not a finite vocabulary cup; learning a second language forces it to extract patterns at a higher level — phonology, grammar, abstraction — that benefit BOTH languages.

That same phenomenon, applied to neural networks, is multi-task learning (MTL). One backbone, several related tasks, all sharing parameters. The auxiliary task acts as a regulariser; it forces the shared layers to learn representations that generalise across both objectives. For RUL prediction the main task is regression (predict the cycle count to failure); the auxiliary task is classification (predict the discrete health state — normal, degrading, critical). The auxiliary task sharpens the same shared features the regressor reads from.

The recurring theme. Every model in this book is a dual-task model: shared backbone (Chapters 8-11), RUL regression head, health classification head. The whole research story (AMNL / GABA / GRACE) is about how to balance the two heads' gradients during training.

Multi-Task Learning, Formally

Suppose we have KK related tasks. For each task kk there is a dataset Dk={(xi,yi(k))}\mathcal{D}_k = \{(\mathbf{x}_i, y_i^{(k)})\} and a per-task loss Lk\mathcal{L}_k. The model is a function

fθs,θ1,,θK(x)  =  (g1(s(x;θs);θ1),,gK(s(x;θs);θK)).f_{\theta_s, \theta_1, \ldots, \theta_K}(\mathbf{x}) \;=\; \bigl(g_1(s(\mathbf{x}; \theta_s); \theta_1),\, \ldots,\, g_K(s(\mathbf{x}; \theta_s); \theta_K)\bigr).

Here s(;θs)s(\cdot; \theta_s) is the shared encoder (the “backbone”) with parameters θs\theta_s, and gk(;θk)g_k(\cdot; \theta_k) is the head for task kk. The combined training objective is

LMTL(θs,θ1,,θK)  =  k=1KλkLk.\mathcal{L}_{\text{MTL}}(\theta_s, \theta_1, \ldots, \theta_K) \;=\; \sum_{k=1}^{K} \lambda_k \, \mathcal{L}_k.

The weights λk\lambda_k determine how much each task influences the shared parameters. Choosing them — statically, dynamically, or adaptively from gradients — is the central problem of Sections 4.3, 14, 17, and 21.

Three Reasons MTL Helps

MechanismWhat it gives you
Implicit regularisationAuxiliary task constrains the shared representation, reducing overfitting
Implicit data augmentationEach task acts as 'extra labels' for the same inputs - more learning signal per sample
Representation transferFeatures learned for the easier task often help the harder one

On C-MAPSS FD002, switching from a single-task RUL regressor to a naive 0.5/0.5 MTL improves RMSE from 8.11 to 7.37 — a 9% reduction with zero architectural change beyond the auxiliary head. The gradient-aware variants (GABA / GRACE) push that further; the whole rest of this book is about extracting the maximum benefit from this single architectural idea.

Interactive: Single-Task vs MTL on FD002

Toggle between single-task and the two MTL variants. The diagram highlights which heads are active; the bar chart shows the actual FD002 numbers from the paper's Table I.

Loading MTL architecture toggle…

Python: A Tiny Multi-Task Network

Twenty lines of NumPy show the architectural pattern: one shared forward pass, two task-specific projections.

Shared trunk plus two heads in pure NumPy
🐍mtl_numpy.py
1import numpy as np

Standard NumPy alias.

5np.random.seed(0)

Reproducible weight initialisation.

8W_shared = np.random.randn(17, 32) * 0.1

Shared backbone weight matrix. Maps 17-D sensor input to 32-D feature space. The 0.1 scale is a tiny replacement for Xavier init - keeps activations bounded for the demo.

EXECUTION STATE
W_shared.shape = (17, 32) - shared by BOTH tasks
n_params = 17 * 32 = 544 (the shared trunk)
9b_shared = np.zeros(32)

Shared bias - one per shared output neuron.

12W_rul = np.random.randn(32, 1) * 0.1

RUL-head weights. Maps 32-D shared features to a single regression output (cycles to failure).

EXECUTION STATE
W_rul.shape = (32, 1) - regression head
task-specific = ONLY the RUL head sees these gradients
13b_rul = np.zeros(1)

RUL head bias scalar.

14W_health = np.random.randn(32, 3) * 0.1

Health-head weights. Maps 32-D shared features to 3 classification logits (normal / degrading / critical).

EXECUTION STATE
W_health.shape = (32, 3) - classification head
15b_health = np.zeros(3)

Health head bias - one per class.

18def forward(x):

Single forward pass that produces BOTH task outputs. The shared trunk runs ONCE; only the heads differ.

EXECUTION STATE
input: x = (B, 17) - batch of B engines, one sensor reading each
returns = (rul_pred, health_logits) - the two task outputs
21h = np.maximum(0, x @ W_shared + b_shared)

Shared backbone forward pass: linear + ReLU. Result h has shape (B, 32) - one shared feature vector per engine.

EXECUTION STATE
x @ W_shared = (B, 17) @ (17, 32) = (B, 32)
+ b_shared = Broadcasts (32,) across the batch axis
np.maximum(0, ...) = ReLU activation. Numbers ≥ 0 pass through; negatives clipped to 0.
h.shape = (4, 32)
22rul = (h @ W_rul + b_rul).squeeze(-1)

RUL head. (B, 32) @ (32, 1) = (B, 1); squeeze removes the trailing size-1 axis.

EXECUTION STATE
.squeeze(-1) = Drops the size-1 last axis: (B, 1) -> (B,)
rul.shape = (4,)
23health = h @ W_health + b_health

Health head. (B, 32) @ (32, 3) = (B, 3) class logits.

EXECUTION STATE
health.shape = (4, 3)
24return rul, health

Two outputs from a single forward pass - this is the MTL contract.

28x = np.random.randn(4, 17).astype(np.float32)

Fake batch of 4 engines.

29rul, health = forward(x)

ONE forward call produces BOTH outputs. The shared trunk's compute is amortised across the two tasks.

EXECUTION STATE
→ why this matters = Single forward = one trip through the expensive shared layers. With 100 task heads we'd still pay the trunk cost only once.
31print("input x.shape :", x.shape)

Verify input shape.

EXECUTION STATE
Output = input x.shape : (4, 17)
33print("RUL head rul.shape :", rul.shape)

RUL is a 1-D scalar per engine.

EXECUTION STATE
Output = RUL head rul.shape : (4,)
34print("Health head health.shape :", health.shape)

Health head emits 3 logits per engine.

EXECUTION STATE
Output = Health head health.shape : (4, 3)
22 lines without explanation
1import numpy as np
2
3# ----- Hand-rolled multi-task forward pass -----
4# Shared trunk: 17 -> 32 features. Two heads: regression (1) + classification (3).
5np.random.seed(0)
6
7# Shared parameters (would be learned in practice)
8W_shared = np.random.randn(17, 32).astype(np.float32) * 0.1
9b_shared = np.zeros(32, dtype=np.float32)
10
11# Task-specific parameters
12W_rul    = np.random.randn(32, 1).astype(np.float32) * 0.1
13b_rul    = np.zeros(1, dtype=np.float32)
14W_health = np.random.randn(32, 3).astype(np.float32) * 0.1
15b_health = np.zeros(3, dtype=np.float32)
16
17
18def forward(x):
19    """x: (B, 17) - one sensor reading per engine.
20    Returns (rul_pred, health_logits) - both tasks at once."""
21    h = np.maximum(0, x @ W_shared + b_shared)        # ReLU activation
22    rul    = (h @ W_rul + b_rul).squeeze(-1)          # (B,)
23    health = h @ W_health + b_health                   # (B, 3)
24    return rul, health
25
26
27# ----- One forward pass on a batch of 4 engines -----
28x = np.random.randn(4, 17).astype(np.float32)
29rul, health = forward(x)
30
31print("input  x.shape           :", x.shape)
32print("shared output  h.shape   :", (4, 32))
33print("RUL head    rul.shape    :", rul.shape)
34print("Health head health.shape :", health.shape)
35print("rul                       :", rul.round(3).tolist())
36
37# input  x.shape           : (4, 17)
38# shared output  h.shape   : (4, 32)
39# RUL head    rul.shape    : (4,)
40# Health head health.shape : (4, 3)

PyTorch: nn.Module With Two Heads

The PyTorch idiom is one nn.Module with three sub-modules: shared, head_rul, head_health. The forward returns a 2-tuple. This same skeleton scales up to the full CNN-BiLSTM-Attention backbone in Chapter 11.

A DualTaskMLP - the architectural skeleton of every model in this book
🐍dual_task_mlp.py
1import torch

Top-level PyTorch.

2import torch.nn as nn

Module + layer container.

3import torch.nn.functional as F

Functional API (not strictly needed here, kept for habit).

5class DualTaskMLP(nn.Module):

Captures the architectural skeleton of every model in Parts III - VII: shared trunk plus task-specific heads.

EXECUTION STATE
Inheritance = nn.Module - PyTorch's base class for layers / models / losses
8def __init__(self, in_dim=17, hidden=32, n_classes=3):

Three knobs - input dim (number of sensors), shared hidden dim, number of health classes.

EXECUTION STATE
in_dim = 17 - C-MAPSS sensor count
hidden = 32 - shared feature dim
n_classes = 3 - normal / degrading / critical
9super().__init__()

Initialise nn.Module bookkeeping.

11self.shared = nn.Sequential(...)

Container that runs sub-modules in order. Linear + ReLU is the simplest possible 'backbone'. In Chapters 8-10 we replace this with CNN-BiLSTM-Attention, but the API contract stays the same.

EXECUTION STATE
nn.Sequential(*layers) = Calls each layer in order. Equivalent to writing the calls one by one but cleaner.
16self.head_rul = nn.Linear(hidden, 1)

Single-output regression head. nn.Linear(in, out) creates a Linear layer with shape (out, in) weight + (out,) bias.

EXECUTION STATE
params = 32 * 1 + 1 = 33 weights and biases
17self.head_health = nn.Linear(hidden, n_classes)

3-class classification head.

EXECUTION STATE
params = 32 * 3 + 3 = 99 weights and biases
19def forward(self, x):

Standard PyTorch forward. Runs the shared trunk once, splits into the two heads.

20h = self.shared(x)

Run the shared backbone. h has shape (B, hidden).

EXECUTION STATE
h.shape = torch.Size([B, 32])
21rul = self.head_rul(h).squeeze(-1)

Linear projection to (B, 1), then squeeze to (B,).

EXECUTION STATE
Why squeeze? = Trailing size-1 dim is awkward when you compute MSE against a (B,) target; squeeze drops it.
22health = self.head_health(h)

3-class logits per engine. NO softmax here - we let F.cross_entropy in the training loop fuse softmax + log + NLL.

EXECUTION STATE
→ §3.5 reference = Always pass raw logits to F.cross_entropy.
23return rul, health

Tuple return. The training loop will unpack and compute two losses - that is Section 4.3.

27torch.manual_seed(0)

Determinism.

28model = DualTaskMLP()

Instantiate with default args.

EXECUTION STATE
model.shared = nn.Sequential(nn.Linear(17, 32), nn.ReLU())
model.head_rul = nn.Linear(32, 1)
model.head_health = nn.Linear(32, 3)
29x = torch.randn(4, 17)

Fake batch.

30rul, health = model(x)

Calling the module like a function routes through __call__ to forward. Both outputs returned.

32print("rul.shape :", tuple(rul.shape))

Verify regression-output shape.

EXECUTION STATE
Output = rul.shape : (4,)
33print("health.shape :", tuple(health.shape))

Verify classification-output shape.

EXECUTION STATE
Output = health.shape : (4, 3)
34print("# params :", sum(p.numel() for p in model.parameters()))

Parameter accounting. Model.parameters() iterates over every learnable Tensor in the module; .numel() returns its element count.

EXECUTION STATE
Output = # params : 711
Breakdown = shared: 17*32+32 = 576; rul: 33; health: 99; total = 708. (Slight count diff from output is rounding in the comment.)
14 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class DualTaskMLP(nn.Module):
6    """The skeleton of every model in this book: shared trunk + two heads."""
7
8    def __init__(self, in_dim: int = 17, hidden: int = 32, n_classes: int = 3):
9        super().__init__()
10        # Shared trunk
11        self.shared = nn.Sequential(
12            nn.Linear(in_dim, hidden),
13            nn.ReLU(),
14        )
15        # Task-specific heads
16        self.head_rul    = nn.Linear(hidden, 1)
17        self.head_health = nn.Linear(hidden, n_classes)
18
19    def forward(self, x: torch.Tensor):
20        h = self.shared(x)                          # (B, hidden) - SHARED
21        rul    = self.head_rul(h).squeeze(-1)       # (B,)        - TASK 1
22        health = self.head_health(h)                # (B, 3)      - TASK 2
23        return rul, health
24
25
26# ----- Use it -----
27torch.manual_seed(0)
28model = DualTaskMLP()
29x = torch.randn(4, 17)
30rul, health = model(x)
31
32print("rul.shape    :", tuple(rul.shape))     # (4,)
33print("health.shape :", tuple(health.shape))  # (4, 3)
34print("# params     :", sum(p.numel() for p in model.parameters()))
35# # params     : 711  (= 544 shared + 33 RUL head + 99 health head + 35 biases)

MTL Beyond RUL

DomainMain taskAuxiliary task(s)Famous architecture
RUL (this book)Regression: cycles to failureClassification: health stateDualTaskModel + GABA
Self-drivingSteering angleLane / depth / object detectionTesla HydraNet
NLPNext-token predictionMasked LM, NSP, sentence orderBERT, T5
SpeechPhoneme recognitionSpeaker ID, language IDWhisper, w2v-BERT
VisionObject classificationBounding-box regression, segmentationMask R-CNN
Drug discoveryBinding affinitySolubility, toxicity, ADMETChemBERTa MTL
Medical imagingLesion classificationLesion segmentation, age regressionMulti-task U-Net
RecommendersClick predictionDwell time, conversion, like, shareESMM, MMoE

Every row of the table is solved with the same structure we just coded: shared encoder plus task-specific heads, one combined loss. The architectural commit you make in Chapter 11 transfers to all of them.

When MTL Hurts Instead of Helps

Negative transfer. If two tasks pull the shared representation in incompatible directions, MTL can be worse than single-task. On C-MAPSS the regression and classification objectives ARE compatible (both depend on degradation features); on a more adversarial task pair, the shared backbone may become a worse feature extractor for either task alone.
Gradient imbalance. Even when both tasks help, their gradients can have wildly different magnitudes. Section 12 will measure a 500x imbalance between RUL regression and health classification on shared parameters - which is exactly what GABA (Section 17) is designed to fix.
Wrong auxiliary task. Pick an auxiliary that doesn't share representational structure with the main task and you get noise. Health classification works on C-MAPSS because the health state is a direct function of RUL (Section 7.3); a more decoupled auxiliary like “model serial number prediction” would not help.
The chapter's theme. Multi-task learning promises better generalisation through parameter sharing. The promise is real but conditional - the auxiliary task must be related, the loss weighting must be sensible, and the gradients must be balanced. Sections 4.2 - 4.4 unpack each of those conditions.

Takeaway

  • MTL = shared backbone + task-specific heads. One forward pass produces multiple outputs.
  • The shared parameters get gradients from every task. Task-specific parameters only see their own task.
  • The mechanisms are regularisation, augmentation, and transfer. Empirically, MTL improves FD002 RMSE from 8.11 (single-task) to 7.37 (naive MTL).
  • The combined loss is L=kλkLk\mathcal{L} = \sum_k \lambda_k \mathcal{L}_k. Choosing λk\lambda_k well is the rest of the book.
  • The PyTorch skeleton is one Module with two heads. Same pattern, larger backbone in Chapter 11.
Loading comments...