Two Voices, One Microphone
Imagine two presenters sharing one microphone. The first presenter has a booming voice; the second is soft-spoken. Whoever speaks louder dominates the room, regardless of what they say. The fair fix is not to gag the loud one - it is to TURN UP the soft one. Multi-task learning has the same problem at the gradient level: whichever task pulls hardest on the shared backbone wins, regardless of which task is more important.
§12 measured the asymmetry: on FD002 at initialisation, RUL regression has ≈ 5.0 gradient norm on the shared backbone, while health classification has only ≈ 0.01. That is a 500× imbalance. Health is ‘turned off’ not because the task is unimportant but because its gradient cannot compete with the loudness of the regression head.
Inverse-Gradient Intuition
For two tasks RUL and health, we want a weight pair(λrul,λhealth) with λrul+λhealth=1 such that the EFFECTIVE backbone update from each task is balanced. The effective contribution of task i to the backbone is λi⋅gi. Setting the two contributions equal:
λrul⋅grul=λhealth⋅ghealth with λrul+λhealth=1. Solving: λrul=grul+ghealthghealth, λhealth=grul+ghealthgrul.
For our 500× ratio: λrul=0.01/5.01≈0.002 and λhealth≈0.998. Health gets nearly all the weight - because RUL was already winning the gradient race.
The K-Task Formula
For K tasks, the same idea generalises (paper eq. 4):
λi=(K−1)∑jgj∑jgj−gi
The numerator ∑jgj−gi is the sum of OTHER tasks' gradient norms. Tasks with small gradients have large ‘sum-of-others’ and therefore get large weights. Setting K=2 recovers the closed form above.
| Property | What it gives you | Why it matters |
|---|---|---|
| Sums to 1.0 | Σ λ_i = (K · S − S) / ((K-1) · S) = 1 | No need to manually re-normalise |
| Non-negative | All numerators ≥ 0 since g_i ≤ Σ g_j | No negative loss weights |
| Equal at balance | If all g_i equal, all λ_i = 1/K | Degenerates to vanilla equal-weight when nothing needs fixing |
| Inverse-monotonic | If g_i ↑ then λ_i ↓ | Larger-gradient task gets less weight - the whole point |
| Zero hyperparameters in raw form | No tuning constants - β, warmup, min_weight are stabilisers | Makes GABA hard to misconfigure |
Interactive: Drag Gradients, Watch Weights
Drag the two log-scale sliders. The top bar plot shows the gradient norms on a log axis (real units span 5 decades); the right plot shows the resulting raw GABA weights on a linear axis. The bottom line plot traces 200 simulated training steps with the paper's EMA(β = 0.99) smoothing and 100-step warmup.
Try this. Set both gradients equal (say both at 1.0). The bars become identical and λrul=λhealth=0.5 - GABA degenerates to vanilla equal weighting when there is nothing to balance. Now move g_rul to 100 and g_health to 0.001. Watch λ_health snap to ≈ 1.0 and λ_rul to ≈ 0.0. That extreme ratio motivates themin_weight = 0.05floor we add in code.
Python: GABA Weights From Gradient Norms
Implement the K-task formula in pure NumPy. Three lines of meaningful code, plus a smoke test on the realistic 500× ratio. The closed-form K=2 specialisation is verified to match the K-task formula exactly.
NumPy is the numerical computing library used here for vectorised K-task arithmetic. The K-task GABA formula is naturally a single line of vector ops (subtract, scale, divide) - exactly NumPy's sweet spot. All numerics run as compiled C, so even a K=1000 task vector evaluates in microseconds.
The core GABA weight function. Takes a vector of K per-task gradient norms; returns a vector of K weights that sum to 1.0. The whole GABA idea fits in three lines of body code.
Documents the formula (paper eq. 4), the central insight (smaller gradient → larger weight), the input shape (K,), and the output shape (K,). Stored on the function as __doc__ and shown by help(gaba_weights).
Read the number of tasks from the input array shape. .shape returns a tuple of dimension sizes; for a 1-D array of length K, .shape == (K,) and .shape[0] == K.
Sum all K per-task gradient norms into a scalar Σ g_j, then add a tiny epsilon so we can never divide by zero in the next line.
The K-task GABA formula in one vectorised line. The numerator is the sum-of-OTHER-tasks gradient norm; the denominator normalises so weights sum to 1.0.
Return the K-vector of GABA weights. Caller multiplies element-wise by per-task losses and sums to get the combined loss.
Section divider comment. Below this point we exercise the function with measured (not synthetic) gradient norms from the C-MAPSS dataset.
Cross-reference. §12.3 measured per-task gradient norms across all four C-MAPSS subsets and found a median ratio of ~500× - which is what motivates GABA in the first place.
Realistic RUL gradient norm at training initialisation. Not a synthetic value: this is the median measured across FD001-004 in §12.3.
Realistic health classification gradient norm at training initialisation. About 500× smaller than the RUL norm - the asymmetry GABA exists to fix.
Pack the two scalars into a 1-D ndarray of length 2 so the K-task function can consume them.
Apply the GABA weight function we defined above. The interesting computation happens inside; the result is a 2-vector.
Pretty-print g_rul. The :.3f format spec means 'float, 3 digits after the decimal'.
Same pattern as line 33 - 3-decimal float for g_health.
Compute and print the gradient ratio. :.1f = 1 decimal place. The trailing 'x' is just literal text (the multiplier suffix).
Print λ_rul = weights[0]. :.4f = 4-decimal float, needed because the value is small (0.002).
Print λ_health = weights[1].
Sanity check: Σ_i λ_i must equal 1.0. If not, the formula is broken.
Below this divider we verify that the K-task formula reduces to the simple K=2 closed form derived in the 'Inverse-Gradient Intuition' section above.
Closed-form K=2 specialisation. Setting K=2 in the K-task formula: λ_rul = (Σ g_j − g_rul) / ((2−1) · Σ g_j) = g_health / (g_rul + g_health).
K=2 specialisation for health. By symmetry: λ_health = g_rul / (g_rul + g_health).
Print the closed-form RUL weight. The leading \n inserts a blank line, visually separating the closed-form block from the K-task block above.
Print the closed-form health weight. Equals weights[1] from the K-task formula - QED.
g_rul = 5.000 g_health = 0.010 ratio = 500.0x lambda_rul = 0.0020 lambda_health = 0.9980 sum = 1.0000 closed-form lambda_rul = 0.0020 closed-form lambda_health = 0.9980
1import numpy as np
2
3
4def gaba_weights(grad_norms: np.ndarray) -> np.ndarray:
5 """Compute GABA inverse-gradient task weights for K tasks.
6
7 Formula (paper eq. 4):
8 lambda_i = (sum_j g_j - g_i) / ((K - 1) * sum_j g_j)
9
10 Tasks with SMALLER gradients get LARGER weights, because the
11 backbone is already being pulled hard by the larger-gradient task.
12
13 Args:
14 grad_norms: shape (K,) - per-task L2 gradient norms on shared params.
15
16 Returns:
17 weights: shape (K,) - non-negative, sum to 1.0.
18 """
19 K = grad_norms.shape[0]
20 total = grad_norms.sum() + 1e-12
21 weights = (total - grad_norms) / ((K - 1) * total)
22 return weights
23
24
25# ---------- Smoke test on real predictive-maintenance numbers ----------
26# From section 12.3: g_rul is ~500x g_health on shared backbone params at init.
27g_rul = 5.000 # paper-realistic RUL gradient norm
28g_health = 0.010 # paper-realistic health gradient norm
29grad_norms = np.array([g_rul, g_health])
30
31weights = gaba_weights(grad_norms)
32
33print(f"g_rul = {g_rul:.3f}")
34print(f"g_health = {g_health:.3f}")
35print(f"ratio = {g_rul / g_health:.1f}x")
36print(f"lambda_rul = {weights[0]:.4f}")
37print(f"lambda_health = {weights[1]:.4f}")
38print(f"sum = {weights.sum():.4f}")
39
40# ---------- Verify the closed-form K=2 simplification ----------
41expected_rul = g_health / (g_rul + g_health)
42expected_health = g_rul / (g_rul + g_health)
43print(f"\nclosed-form lambda_rul = {expected_rul:.4f}")
44print(f"closed-form lambda_health = {expected_health:.4f}")PyTorch: GABALoss From The Paper
Now the production version - the real GABALoss from paper_ieee_tii/grace/core/gaba.py. EMA buffers, warmup window, gradient computation via torch.autograd.grad, min_weight floor, and re-normalisation. Smoke test runs five training-style iterations on a tiny 14-dim toy backbone so the full machinery executes.
Core PyTorch package. Provides the Tensor class, autograd engine, and device abstractions. Everything in this file depends on torch.
PyTorch's neural-network module, aliased to nn for brevity. Provides nn.Module (the base class for all stateful components), layer primitives like nn.Linear, and the nn.functional submodule for stateless ops.
Type-hint aliases from the standard library. List[X] denotes a list of X; Optional[X] means 'X or None'. Used in the forward() signature to document that shared_params can be a list of Parameters or None.
Define GABALoss as a subclass of nn.Module. This makes it a stateful PyTorch component that can register buffers (EMA state, step count) which auto-move to the right device and serialize into state_dict. This matches the canonical paper code in grace/core/gaba.py.
Identifies this as the canonical paper implementation. Should match the production file line-for-line so anything debugged here applies to production.
Constructor. All four hyperparameters are paper defaults; only n_tasks needs changing for K ≠ 2.
Call nn.Module's constructor. Required: it sets up the internal dicts that register_buffer() and register_parameter() write to. Forgetting this raises AttributeError on the next register_buffer call.
Store the EMA coefficient as an instance attribute so forward() can read it.
Store the warmup step count as an instance attribute.
Store the per-task weight floor as an instance attribute.
Store K as an instance attribute. Used to size the EMA weights buffer below.
Comment marking the buffer-registration block. Buffers are persistent tensor state that move with .to(device) and save/load via state_dict.
Register a non-learnable tensor as a buffer. EMA weights are persistent state — they don't receive gradients (so they're not Parameters) but they DO need to move to GPU and serialize with the model.
Register the warmup step counter as a buffer. dtype=torch.long is 64-bit integer — never overflows for any realistic training horizon.
Called once per training step. Takes the two per-task scalar losses plus the list of shared backbone parameters; returns the GABA-weighted combined loss tensor that the caller can call .backward() on.
Hard-coded K=2 for this RUL+health convenience overload. The general K-task version lives in forward_k() in the paper file.
Read the device (cpu/cuda/mps) from the first loss so we can build new tensors on the same device. Mixing devices raises a runtime error.
Increment the warmup counter in place. PyTorch supports in-place += on tensors; it modifies the buffer storage directly.
Wrap the two losses into a list so the per-task loop can iterate uniformly. Order matters: index 0 = RUL throughout.
Comment introducing the warmup branch. During warmup, gradient norms are too noisy to trust, so we fall back to uniform weighting.
Branch into equal-weighting if either: (a) the caller didn't pass shared_params (forcing skip), OR (b) we're still inside the warmup window.
Build the equal-weights vector [1/K, 1/K, ..., 1/K] on the right device. For K=2 that's [0.5, 0.5].
Adaptive branch: warmup is done AND shared_params was passed, so we measure per-task gradients and compute GABA weights.
Comment introducing the gradient-measurement block.
Pre-allocate the gradient-norm vector. We'll fill in entries one by one inside the loop.
Loop over the K=2 tasks. enumerate yields (index, value) pairs so we can write back into grad_norms[i].
Compute ∂loss_i / ∂shared_params functionally — WITHOUT writing into the .grad attributes. Critical for 'peeking' at per-task gradients without contaminating the main backward pass that the optimizer relies on.
Compute the global L2 norm of the gradient by summing squared per-param norms then taking sqrt. Equivalent to flattening all grads into one giant vector and calling .norm() — but cheaper because we skip the concatenation.
Sum across the K=2 task gradient norms; add epsilon for divide-by-zero safety (mirrors the NumPy version).
The K-task GABA formula in vectorised PyTorch. Identical math to the NumPy version; broadcasting handles scalar−vector subtraction and vector/scalar division.
Comment introducing the EMA-smoothing block. Without smoothing, weights would lurch around with each batch's gradient noise.
Exponential moving average update. New value is 99% old + 1% current — slow, stable response to batch-to-batch gradient noise.
Save the new EMA weights into the buffer. The .detach() is CRITICAL: without it, the EMA buffer would carry an autograd graph that grows by one step every iteration, eventually exhausting GPU memory.
Comment introducing the safety net: clamp small weights up to min_weight, then renormalise so they sum to 1 again.
Floor each weight at min_weight=0.05. Prevents any task from being driven to zero (which could happen at the K=2 limit when one task's gradient norm is huge).
Renormalise so the floored weights sum back to 1.0. Required because clamp() can push the sum above 1.
Comment marking the loss-combination block. Below: weighted sum λ_i · L_i, returned as the final scalar.
Initialise the combined-loss accumulator as a 0-dim zero tensor on the right device. Empty tuple () specifies a scalar shape.
Iterate task-weight and task-loss in lockstep. zip yields tuples (w_i, loss_i) for i = 0..K-1.
Accumulate one weighted task-loss term. Out-of-place addition — PyTorch prefers this over += for autograd safety on accumulator tensors.
Return the GABA-weighted combined loss. The caller calls .backward() on this to update both heads and the shared backbone in one go.
Section divider. Below this comment we drive the GABALoss with a tiny toy backbone for 5 steps to verify the full machinery executes (warmup, autograd.grad, EMA, clamp, renorm).
Set the global PyTorch PRNG seed for reproducibility. Affects every subsequent torch.randn / torch.rand / torch.randint and the random init of nn.Linear weights.
Tiny shared backbone for the smoke test: maps 14 sensor inputs → 32 hidden features. Mimics §11's feature dimension after a CNN encoder, but linear so the test runs in milliseconds.
RUL regression head: 32 features → 1 scalar prediction.
Health classification head: 32 features → 3 logits (3 health states).
Instantiate GABA. We override warmup_steps from 100 (paper default) to 2 so the adaptive code path fires within the 5-step smoke test.
Synthetic input: 64-sample batch of 14-dim sensor readings, drawn from the standard normal.
Synthetic RUL targets in [0, 125] cycles. torch.rand is uniform on [0, 1), then we scale by the C-MAPSS clip value.
Synthetic 3-class health labels for the batch.
Five training-style iterations. Steps 1-2 are warmup (equal weights); steps 3-5 use adaptive GABA. Note: this loop has NO optimizer.step(), so model params and losses don't change between iterations — the only thing drifting is the EMA buffer.
Forward through the shared backbone: (64, 14) → (64, 32). nn.Linear's __call__ runs the forward pass and registers the autograd graph.
RUL head: (64, 32) → (64, 1). Raw scalar predictions — no activation.
Health head: (64, 32) → (64, 3). Raw 3-class logits — softmax happens inside cross_entropy.
Plain MSE on RUL. Element-wise residual, square, mean over the 64-sample batch.
Standard 3-class cross-entropy. Combines log_softmax + nll_loss in one numerically-stable call.
Apply the GABA-weighted combination. shared_params is the explicit list of nn.Parameters in the BACKBONE only (not the heads) — these are the parameters whose gradient norms drive the per-task weighting.
Pretty-print the per-step metrics. .item() extracts Python scalars; :7.2f means width-7, 2 decimals.
Continuation of the print statement (Python implicit string concatenation). Reads the EMA buffer directly to expose the current λ values.
step 0 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.5000, 0.5000) step 1 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.5000, 0.5000) step 2 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.4950, 0.5050) step 3 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.4901, 0.5099) step 4 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.4853, 0.5147)
1import torch
2import torch.nn as nn
3from typing import List, Optional
4
5
6class GABALoss(nn.Module):
7 """Gradient-Aware Balanced Adaptation - paper code (grace/core/gaba.py).
8
9 EMA-smoothed inverse-gradient weights with warmup + min_weight floor.
10 """
11
12 def __init__(self, beta: float = 0.99, warmup_steps: int = 100,
13 min_weight: float = 0.05, n_tasks: int = 2) -> None:
14 super().__init__()
15 self.beta = beta
16 self.warmup_steps = warmup_steps
17 self.min_weight = min_weight
18 self.n_tasks = n_tasks
19
20 # EMA-smoothed weights start at uniform 1/K each.
21 self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)
22 self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))
23
24 def forward(self, rul_loss: torch.Tensor, health_loss: torch.Tensor,
25 shared_params: Optional[List[nn.Parameter]] = None) -> torch.Tensor:
26 K = 2
27 device = rul_loss.device
28 self.step_count += 1
29 losses = [rul_loss, health_loss]
30
31 # During warmup -> equal weighting.
32 if shared_params is None or self.step_count.item() <= self.warmup_steps:
33 weights = torch.ones(K, device=device) / K
34 else:
35 # Per-task gradient norms on shared params.
36 grad_norms = torch.zeros(K, device=device)
37 for i, loss_i in enumerate(losses):
38 grads = torch.autograd.grad(loss_i, shared_params, retain_graph=True,
39 create_graph=False, allow_unused=True)
40 grad_norms[i] = sum((g.detach().norm() ** 2 for g in grads if g is not None)).sqrt()
41
42 total_norm = grad_norms.sum() + 1e-12
43 raw_weights = (total_norm - grad_norms) / ((K - 1) * total_norm)
44
45 # EMA smoothing.
46 ema_w = self.beta * self.ema_weights + (1.0 - self.beta) * raw_weights
47 self.ema_weights = ema_w.detach()
48
49 # Floor + re-normalise.
50 weights = ema_w.clamp(min=self.min_weight)
51 weights = weights / weights.sum()
52
53 # Combine.
54 total = torch.zeros((), device=device)
55 for w, l in zip(weights, losses):
56 total = total + w * l
57 return total
58
59
60# ---------- Smoke test ----------
61torch.manual_seed(0)
62backbone = nn.Linear(14, 32) # tiny shared backbone
63rul_head = nn.Linear(32, 1) # task 1
64hp_head = nn.Linear(32, 3) # task 2
65gaba = GABALoss(beta=0.99, warmup_steps=2, min_weight=0.05, n_tasks=2)
66
67x = torch.randn(64, 14)
68rul_target = torch.rand(64, 1) * 125.0
69hp_target = torch.randint(0, 3, (64,))
70
71for step in range(5):
72 feat = backbone(x)
73 rul_pred = rul_head(feat)
74 hp_logits = hp_head(feat)
75 rul_loss = ((rul_pred - rul_target) ** 2).mean()
76 health_loss = nn.functional.cross_entropy(hp_logits, hp_target)
77
78 total = gaba(rul_loss, health_loss, shared_params=list(backbone.parameters()))
79 print(f"step {step} | rul_loss = {rul_loss.item():7.2f} | health_loss = {health_loss.item():.4f}"
80 f" | lambda = ({gaba.ema_weights[0]:.4f}, {gaba.ema_weights[1]:.4f})")Inverse-Gradient In Other Domains
The inverse-gradient idea generalises. Anywhere two or more objectives share parameters and have unequal gradient magnitudes, GABA-style weighting applies:
| Domain | Tasks sharing parameters | Why GABA helps |
|---|---|---|
| RUL prediction (this book) | RUL regression + health classification | 500× gradient imbalance at init |
| Multi-modal speech (Whisper-style) | ASR transcription + language detection + speaker ID | Cross-entropy heads have unequal label vocab |
| Object detection (Detectron2) | Bounding-box regression + class CE + objectness logit | Bbox L1 loss dwarfs CE when boxes are big |
| Recommender systems (multi-head DLRM) | Click prediction + dwell time + share/save | Click logit gradient dwarfs dwell-time MSE |
| Robotics (multi-task imitation) | End-effector position + gripper open/close + force control | Force MSE gradient is tiny vs position MSE |
| Drug discovery (molecular property) | logP regression + toxicity classification + binding score | Each property has different label scale |
In every row, the larger-gradient task otherwise dominates the shared trunk; inverse-gradient weighting equalises the backbone update so all tasks contribute proportionally.
Three Pitfalls Of Naive Implementations
retain_graph=True on torch.autograd.grad. We compute K per-task gradients on the same forward graph, then backward through the weighted combination. Without retain_graph=True, the second per-task gradient call destroys the graph and we crash.self.ema_weights retains autograd history across steps, every backward() walks all the way back to epoch 0. Memory grows unboundedly until OOM. Always self.ema_weights = ema_w.detach().Takeaway
- Two voices, one microphone. Multi-task learning is dominated by the loudest gradient unless you deliberately equalise.
- K-task GABA formula: λi=(∑jgj−gi)/((K−1)∑jgj). Larger gradient ⇒ smaller weight.
- Closed-form K=2: λrul=ghealth/(grul+ghealth). Drops out of the K-task formula by setting K=2.
- Implementation: 3 lines of NumPy or 50 lines of PyTorch. NumPy is for understanding; PyTorch adds EMA + warmup + min_weight + buffer state.
- Zero hyperparameters in the raw formula. Stabilisers (β, warmup, min_weight) have safe defaults; you almost never need to tune them per dataset.
- Generalises beyond RUL. Anywhere shared parameters serve unequal-gradient objectives, GABA-style inverse weighting applies.