Chapter 17
14 min read
Section 68 of 121

Equalizing Task Contributions

Inverse-Gradient Balancing: The Idea

Two Voices, One Microphone

Imagine two presenters sharing one microphone. The first presenter has a booming voice; the second is soft-spoken. Whoever speaks louder dominates the room, regardless of what they say. The fair fix is not to gag the loud one - it is to TURN UP the soft one. Multi-task learning has the same problem at the gradient level: whichever task pulls hardest on the shared backbone wins, regardless of which task is more important.

§12 measured the asymmetry: on FD002 at initialisation, RUL regression has ≈ 5.05.0 gradient norm on the shared backbone, while health classification has only ≈ 0.010.01. That is a 500× imbalance. Health is ‘turned off’ not because the task is unimportant but because its gradient cannot compete with the loudness of the regression head.

The GABA idea in one line. Each step, compute per-task gradient norms; give MORE weight to the task with the SMALLER gradient. Smooth with EMA. Done.

Inverse-Gradient Intuition

For two tasks RUL and health, we want a weight pair(λrul,λhealth)(\lambda_{\text{rul}}, \lambda_{\text{health}}) with λrul+λhealth=1\lambda_{\text{rul}} + \lambda_{\text{health}} = 1 such that the EFFECTIVE backbone update from each task is balanced. The effective contribution of task i to the backbone is λigi\lambda_i \cdot g_i. Setting the two contributions equal:

λrulgrul=λhealthghealth\lambda_{\text{rul}} \cdot g_{\text{rul}} = \lambda_{\text{health}} \cdot g_{\text{health}} with λrul+λhealth=1\lambda_{\text{rul}} + \lambda_{\text{health}} = 1. Solving: λrul=ghealthgrul+ghealth\lambda_{\text{rul}} = \frac{g_{\text{health}}}{g_{\text{rul}} + g_{\text{health}}}, λhealth=grulgrul+ghealth\lambda_{\text{health}} = \frac{g_{\text{rul}}}{g_{\text{rul}} + g_{\text{health}}}.

For our 500× ratio: λrul=0.01/5.010.002\lambda_{\text{rul}} = 0.01/5.01 \approx 0.002 and λhealth0.998\lambda_{\text{health}} \approx 0.998. Health gets nearly all the weight - because RUL was already winning the gradient race.

The K-Task Formula

For K tasks, the same idea generalises (paper eq. 4):

λi=jgjgi(K1)jgj\lambda_i = \frac{\sum_j g_j - g_i}{(K - 1) \sum_j g_j}

The numerator jgjgi\sum_j g_j - g_i is the sum of OTHER tasks' gradient norms. Tasks with small gradients have large ‘sum-of-others’ and therefore get large weights. Setting K=2K=2 recovers the closed form above.

PropertyWhat it gives youWhy it matters
Sums to 1.0Σ λ_i = (K · S − S) / ((K-1) · S) = 1No need to manually re-normalise
Non-negativeAll numerators ≥ 0 since g_i ≤ Σ g_jNo negative loss weights
Equal at balanceIf all g_i equal, all λ_i = 1/KDegenerates to vanilla equal-weight when nothing needs fixing
Inverse-monotonicIf g_i ↑ then λ_i ↓Larger-gradient task gets less weight - the whole point
Zero hyperparameters in raw formNo tuning constants - β, warmup, min_weight are stabilisersMakes GABA hard to misconfigure

Interactive: Drag Gradients, Watch Weights

Drag the two log-scale sliders. The top bar plot shows the gradient norms on a log axis (real units span 5 decades); the right plot shows the resulting raw GABA weights on a linear axis. The bottom line plot traces 200 simulated training steps with the paper's EMA(β = 0.99) smoothing and 100-step warmup.

Loading task-balancing visualizer…
Try this. Set both gradients equal (say both at 1.0). The bars become identical and λrul=λhealth=0.5\lambda_{\text{rul}} = \lambda_{\text{health}} = 0.5 - GABA degenerates to vanilla equal weighting when there is nothing to balance. Now move g_rul to 100 and g_health to 0.001. Watch λ_health snap to ≈ 1.0 and λ_rul to ≈ 0.0. That extreme ratio motivates themin_weight = 0.05 floor we add in code.

Python: GABA Weights From Gradient Norms

Implement the K-task formula in pure NumPy. Three lines of meaningful code, plus a smoke test on the realistic 500× ratio. The closed-form K=2 specialisation is verified to match the K-task formula exactly.

gaba_weights() — pure NumPy K-task formula
🐍gaba_weights_numpy.py
1import numpy as np

NumPy is the numerical computing library used here for vectorised K-task arithmetic. The K-task GABA formula is naturally a single line of vector ops (subtract, scale, divide) - exactly NumPy's sweet spot. All numerics run as compiled C, so even a K=1000 task vector evaluates in microseconds.

EXECUTION STATE
📚 numpy = Library for numerical computing - provides ndarray (N-dim arrays), broadcasting, element-wise math, linear algebra, and reductions. The ndarray dtype is fixed (float64 here), enabling vectorised C loops under the hood.
as np = Universal alias. We will write np.array(), grad_norms.sum(), etc. throughout this file. Without this alias we'd have to type numpy.array(...) every time.
→ why NumPy here = GABA needs Σ g_j (a reduction) and (Σ g_j − g) / ((K−1) Σ g_j) (broadcasting). Both are one-liners in NumPy, vs. an explicit Python loop with manual scalar accumulation.
4def gaba_weights(grad_norms: np.ndarray) -> np.ndarray:

The core GABA weight function. Takes a vector of K per-task gradient norms; returns a vector of K weights that sum to 1.0. The whole GABA idea fits in three lines of body code.

EXECUTION STATE
⬇ input: grad_norms (K,) = 1-D ndarray of L2 gradient norms, one per task. Each entry g_i = ||∂L_i / ∂θ_shared||_2 measured on the shared backbone parameters.
→ grad_norms purpose = Tells GABA how hard each task is pulling on the shared backbone. Large norm = this task is dominating the joint gradient and should get LESS weight; small norm = this task is being drowned out and needs MORE weight.
→ realistic value = [5.0, 0.01] - measured in §12.3 on FD002 at init. RUL regression norm is ~500× the health classification norm.
→ : np.ndarray annotation = Type hint declaring the expected input is a NumPy array. Python doesn't enforce this at runtime, but tools like mypy/pylint use it for static checks.
→ -> np.ndarray annotation = Type hint declaring the function returns a NumPy array. Documents the contract.
⬆ returns = 1-D ndarray of K weights, non-negative, summing to 1.0. weights[i] is the loss multiplier for task i.
5docstring: """Compute GABA inverse-gradient task weights for K tasks."""

Documents the formula (paper eq. 4), the central insight (smaller gradient → larger weight), the input shape (K,), and the output shape (K,). Stored on the function as __doc__ and shown by help(gaba_weights).

19K = grad_norms.shape[0]

Read the number of tasks from the input array shape. .shape returns a tuple of dimension sizes; for a 1-D array of length K, .shape == (K,) and .shape[0] == K.

EXECUTION STATE
📚 .shape = ndarray attribute: tuple of dimension sizes. For 2D matrix (3, 4) shape = (3, 4); for 1D vector (K,) shape = (K,). Reading .shape[0] picks out the first dimension.
→ example = np.array([5.0, 0.01]).shape → (2,); .shape[0] → 2
K = Integer. Number of tasks. K=2 for RUL+health in this book. The K-task formula generalises to any K ≥ 2.
→ numerical value here = K = 2 (RUL regression + health classification)
20total = grad_norms.sum() + 1e-12

Sum all K per-task gradient norms into a scalar Σ g_j, then add a tiny epsilon so we can never divide by zero in the next line.

EXECUTION STATE
📚 .sum() = ndarray method: sums all elements (no axis argument → reduce-all). Equivalent to np.sum(grad_norms). For a 1-D array, returns a 0-d scalar.
→ example = np.array([5.0, 0.01]).sum() → 5.01
+ 1e-12 = Numerical safeguard (10⁻¹²). Without it, if every task had exactly zero gradient (perfect convergence) we'd divide by zero on the next line. The value is tiny enough not to perturb realistic numbers.
→ why this small = 1e-12 ≪ any realistic g_j (which are ≥ 1e-6 in practice), so it's invisible to real arithmetic but saves us from a NaN.
total = Scalar Σ g_j over all tasks (plus the eps).
→ numerical value = 5.0 + 0.01 + 1e-12 ≈ 5.010
21weights = (total - grad_norms) / ((K - 1) * total)

The K-task GABA formula in one vectorised line. The numerator is the sum-of-OTHER-tasks gradient norm; the denominator normalises so weights sum to 1.0.

EXECUTION STATE
📚 broadcasting (scalar - vector) = NumPy auto-expands `total` (scalar) to match grad_norms (vector). The subtraction runs element-wise: result[i] = total - grad_norms[i].
total - grad_norms = = Σ_j g_j − g_i for each i = sum of gradient norms from OTHER tasks. Tasks with small g_i get a LARGE numerator; tasks with large g_i get a SMALL numerator. THIS is the inverse-gradient mechanic.
→ numerical value = [5.010 − 5.0, 5.010 − 0.01] = [0.010, 5.000]
(K - 1) * total = Normalisation constant chosen so the K weights sum to exactly (K-1)/(K-1) = 1. For K=2 this is just `total`; for K=3 it's 2·total; etc.
→ numerical value = (2 − 1) × 5.010 = 5.010
/ (vector / scalar) = Broadcasting again - each numerator element is divided by the same scalar denominator. Result is shape (K,).
weights = λ_i for each task. Verify: Σ_i λ_i = (K·S − S) / ((K−1)·S) = (K−1)·S / ((K−1)·S) = 1. ✓
→ numerical value = [0.010 / 5.010, 5.000 / 5.010] = [0.001996, 0.998004]
→ reading the result = λ_rul = 0.002, λ_health = 0.998. Health (small gradient) gets ~99.8% of the loss weight; RUL is suppressed because it was already winning the gradient race.
22return weights

Return the K-vector of GABA weights. Caller multiplies element-wise by per-task losses and sums to get the combined loss.

EXECUTION STATE
⬆ return: weights (K,) = [0.001996, 0.998004]. Sums to 1.0000. Use as: combined_loss = (weights * task_losses).sum().
25# ---------- Smoke test on real predictive-maintenance numbers ----------

Section divider comment. Below this point we exercise the function with measured (not synthetic) gradient norms from the C-MAPSS dataset.

26# From section 12.3: g_rul is ~500x g_health on shared backbone params at init.

Cross-reference. §12.3 measured per-task gradient norms across all four C-MAPSS subsets and found a median ratio of ~500× - which is what motivates GABA in the first place.

27g_rul = 5.000 # paper-realistic RUL gradient norm

Realistic RUL gradient norm at training initialisation. Not a synthetic value: this is the median measured across FD001-004 in §12.3.

EXECUTION STATE
g_rul = Scalar = 5.0. L2 norm of ∂(RUL MSE) / ∂θ_shared at init.
→ why so big = MSE on RUL ∈ [0, 125] cycles produces per-sample residuals of order 10-100. Squared, then differentiated, that gives ∂loss/∂params of order 5-10 in L2 norm.
28g_health = 0.010 # paper-realistic health gradient norm

Realistic health classification gradient norm at training initialisation. About 500× smaller than the RUL norm - the asymmetry GABA exists to fix.

EXECUTION STATE
g_health = Scalar = 0.01. L2 norm of ∂(health CE) / ∂θ_shared at init.
→ why so small = Cross-entropy gradient per logit is bounded by 1.0 (= |softmax_p − one_hot|). With only 3 classes and a backbone of millions of parameters, the contribution to each shared parameter is tiny.
→ 500× ratio = g_rul / g_health = 5.0 / 0.01 = 500. This is empirical, measured on real data - not a contrived example.
29grad_norms = np.array([g_rul, g_health])

Pack the two scalars into a 1-D ndarray of length 2 so the K-task function can consume them.

EXECUTION STATE
📚 np.array(list) = Builds an ndarray from a Python list/tuple. Infers dtype from the contents (float64 here, since both inputs are Python floats).
→ arg: [g_rul, g_health] = Python list of two floats: [5.0, 0.01]. Order is meaningful - index 0 will be RUL, index 1 will be health, throughout the rest of the program.
grad_norms (2,) = [5.0, 0.01]
31weights = gaba_weights(grad_norms)

Apply the GABA weight function we defined above. The interesting computation happens inside; the result is a 2-vector.

EXECUTION STATE
⬇ arg: grad_norms = [5.0, 0.01] - the realistic 500× imbalance.
weights (2,) = [0.001996, 0.998004]
→ reading = RUL gets ≈ 0.2% of the loss combination. Health gets ≈ 99.8%. The dominating-gradient task (RUL) is heavily suppressed; the under-trained task (health) is amplified.
33print(f"g_rul = {g_rul:.3f}")

Pretty-print g_rul. The :.3f format spec means 'float, 3 digits after the decimal'.

EXECUTION STATE
📚 f-string :.3f = Format specifier inside an f-string. .3 = 3 decimal places, f = fixed-point. Example: f'{3.14159:.3f}' → '3.142'
Output = g_rul = 5.000
34print(f"g_health = {g_health:.3f}")

Same pattern as line 33 - 3-decimal float for g_health.

EXECUTION STATE
Output = g_health = 0.010
35print(f"ratio = {g_rul / g_health:.1f}x")

Compute and print the gradient ratio. :.1f = 1 decimal place. The trailing 'x' is just literal text (the multiplier suffix).

EXECUTION STATE
g_rul / g_health = 5.0 / 0.01 = 500.0
Output = ratio = 500.0x
36print(f"lambda_rul = {weights[0]:.4f}")

Print λ_rul = weights[0]. :.4f = 4-decimal float, needed because the value is small (0.002).

EXECUTION STATE
weights[0] = 0.001996 (numpy indexing: index 0 = RUL)
Output = lambda_rul = 0.0020
37print(f"lambda_health = {weights[1]:.4f}")

Print λ_health = weights[1].

EXECUTION STATE
weights[1] = 0.998004
Output = lambda_health = 0.9980
38print(f"sum = {weights.sum():.4f}")

Sanity check: Σ_i λ_i must equal 1.0. If not, the formula is broken.

EXECUTION STATE
weights.sum() = 0.001996 + 0.998004 = 1.000000 (modulo float-rounding)
Output = sum = 1.0000
40# ---------- Verify the closed-form K=2 simplification ----------

Below this divider we verify that the K-task formula reduces to the simple K=2 closed form derived in the 'Inverse-Gradient Intuition' section above.

41expected_rul = g_health / (g_rul + g_health)

Closed-form K=2 specialisation. Setting K=2 in the K-task formula: λ_rul = (Σ g_j − g_rul) / ((2−1) · Σ g_j) = g_health / (g_rul + g_health).

EXECUTION STATE
g_rul + g_health = 5.0 + 0.01 = 5.01
g_health / (...) = 0.01 / 5.01 = 0.001996
expected_rul = 0.001996 - matches weights[0] exactly. Confirms the K-task formula collapses to the K=2 closed form.
42expected_health = g_rul / (g_rul + g_health)

K=2 specialisation for health. By symmetry: λ_health = g_rul / (g_rul + g_health).

EXECUTION STATE
g_rul / (g_rul + g_health) = 5.0 / 5.01 = 0.998004
expected_health = 0.998004 - matches weights[1] exactly.
43print(f"\nclosed-form lambda_rul = {expected_rul:.4f}")

Print the closed-form RUL weight. The leading \n inserts a blank line, visually separating the closed-form block from the K-task block above.

EXECUTION STATE
\n inside f-string = Newline character. Printed first, before the rest of the string. Acts as a vertical separator.
Output = (blank line) closed-form lambda_rul = 0.0020
44print(f"closed-form lambda_health = {expected_health:.4f}")

Print the closed-form health weight. Equals weights[1] from the K-task formula - QED.

EXECUTION STATE
Output = closed-form lambda_health = 0.9980
Final output =
g_rul    = 5.000
g_health = 0.010
ratio    = 500.0x
lambda_rul     = 0.0020
lambda_health  = 0.9980
sum            = 1.0000

closed-form lambda_rul    = 0.0020
closed-form lambda_health = 0.9980
20 lines without explanation
1import numpy as np
2
3
4def gaba_weights(grad_norms: np.ndarray) -> np.ndarray:
5    """Compute GABA inverse-gradient task weights for K tasks.
6
7    Formula (paper eq. 4):
8        lambda_i = (sum_j g_j - g_i) / ((K - 1) * sum_j g_j)
9
10    Tasks with SMALLER gradients get LARGER weights, because the
11    backbone is already being pulled hard by the larger-gradient task.
12
13    Args:
14        grad_norms: shape (K,) - per-task L2 gradient norms on shared params.
15
16    Returns:
17        weights: shape (K,) - non-negative, sum to 1.0.
18    """
19    K = grad_norms.shape[0]
20    total = grad_norms.sum() + 1e-12
21    weights = (total - grad_norms) / ((K - 1) * total)
22    return weights
23
24
25# ---------- Smoke test on real predictive-maintenance numbers ----------
26# From section 12.3: g_rul is ~500x g_health on shared backbone params at init.
27g_rul    = 5.000     # paper-realistic RUL gradient norm
28g_health = 0.010     # paper-realistic health gradient norm
29grad_norms = np.array([g_rul, g_health])
30
31weights = gaba_weights(grad_norms)
32
33print(f"g_rul    = {g_rul:.3f}")
34print(f"g_health = {g_health:.3f}")
35print(f"ratio    = {g_rul / g_health:.1f}x")
36print(f"lambda_rul     = {weights[0]:.4f}")
37print(f"lambda_health  = {weights[1]:.4f}")
38print(f"sum            = {weights.sum():.4f}")
39
40# ---------- Verify the closed-form K=2 simplification ----------
41expected_rul    = g_health / (g_rul + g_health)
42expected_health = g_rul    / (g_rul + g_health)
43print(f"\nclosed-form lambda_rul    = {expected_rul:.4f}")
44print(f"closed-form lambda_health = {expected_health:.4f}")

PyTorch: GABALoss From The Paper

Now the production version - the real GABALoss from paper_ieee_tii/grace/core/gaba.py. EMA buffers, warmup window, gradient computation via torch.autograd.grad, min_weight floor, and re-normalisation. Smoke test runs five training-style iterations on a tiny 14-dim toy backbone so the full machinery executes.

GABALoss — paper code from grace/core/gaba.py
🐍gaba_loss_torch.py
1import torch

Core PyTorch package. Provides the Tensor class, autograd engine, and device abstractions. Everything in this file depends on torch.

EXECUTION STATE
📚 torch = Tensor library with built-in automatic differentiation. We use: torch.Tensor, torch.zeros, torch.ones, torch.tensor, torch.randn, torch.rand, torch.randint, torch.manual_seed, torch.autograd.grad.
2import torch.nn as nn

PyTorch's neural-network module, aliased to nn for brevity. Provides nn.Module (the base class for all stateful components), layer primitives like nn.Linear, and the nn.functional submodule for stateless ops.

EXECUTION STATE
📚 torch.nn = The nn submodule. Contains nn.Module, nn.Linear, nn.functional.cross_entropy, etc.
as nn = Conventional alias. Lets us write nn.Linear(...) instead of torch.nn.Linear(...).
3from typing import List, Optional

Type-hint aliases from the standard library. List[X] denotes a list of X; Optional[X] means 'X or None'. Used in the forward() signature to document that shared_params can be a list of Parameters or None.

EXECUTION STATE
📚 typing.List = Generic list alias. List[nn.Parameter] = list whose items are nn.Parameter.
📚 typing.Optional = Optional[X] is shorthand for Union[X, None]. Documents that the argument is allowed to be None.
6class GABALoss(nn.Module):

Define GABALoss as a subclass of nn.Module. This makes it a stateful PyTorch component that can register buffers (EMA state, step count) which auto-move to the right device and serialize into state_dict. This matches the canonical paper code in grace/core/gaba.py.

EXECUTION STATE
📚 nn.Module = Base class for all PyTorch modules. Provides hooks for parameters (learnable), buffers (non-learnable state), submodule registration, .to(device), .state_dict(), .train()/.eval() modes.
→ why a Module not a function = GABA needs persistent state (ema_weights, step_count) that survives across batches and is checkpoint-compatible. A plain function would lose state every call; a Module gives us that for free via register_buffer.
7docstring: """Gradient-Aware Balanced Adaptation - paper code (grace/core/gaba.py)."""

Identifies this as the canonical paper implementation. Should match the production file line-for-line so anything debugged here applies to production.

12def __init__(self, beta: float = 0.99, warmup_steps: int = 100, min_weight: float = 0.05, n_tasks: int = 2) -> None:

Constructor. All four hyperparameters are paper defaults; only n_tasks needs changing for K ≠ 2.

EXECUTION STATE
⬇ input: beta = 0.99 = EMA smoothing coefficient. Update rule: ema_w = β·ema_w + (1−β)·raw_w. Higher β = slower response, lower variance. Paper default: 0.99 (effective window ≈ 100 steps).
→ why 0.99 = Half-life ≈ ln(2)/(1−β) = 69 steps. With training batches of 64 samples, that's ~4400 samples — a sensible smoothing window for noisy per-batch gradient norms.
⬇ input: warmup_steps = 100 = Steps with equal weighting before adaptive logic kicks in. Paper default: 100. Lets the model warm up before we trust per-task gradient norms.
⬇ input: min_weight = 0.05 = Floor on per-task weight after EMA. Prevents any task from being driven to zero (which would happen at the K=2 limit when one gradient → ∞). Paper default: 0.05.
⬇ input: n_tasks = 2 = K — number of tasks. RUL + health for this book. Override for multi-objective extensions (e.g. RUL + health + KL + monotonicity).
→ : float / int annotations = Per-arg type hints. Python doesn't enforce them at runtime, but mypy/IDEs use them for static checks.
→ -> None annotation = Constructor returns None (Python convention - __init__ never returns a value).
14super().__init__()

Call nn.Module's constructor. Required: it sets up the internal dicts that register_buffer() and register_parameter() write to. Forgetting this raises AttributeError on the next register_buffer call.

EXECUTION STATE
📚 super() = Returns a proxy object for the parent class (nn.Module here). super().__init__() runs nn.Module.__init__() with no args.
15self.beta = beta

Store the EMA coefficient as an instance attribute so forward() can read it.

EXECUTION STATE
self.beta = 0.99 (or whatever the caller passed)
16self.warmup_steps = warmup_steps

Store the warmup step count as an instance attribute.

EXECUTION STATE
self.warmup_steps = 100 (paper default) or 2 (smoke-test override)
17self.min_weight = min_weight

Store the per-task weight floor as an instance attribute.

EXECUTION STATE
self.min_weight = 0.05
18self.n_tasks = n_tasks

Store K as an instance attribute. Used to size the EMA weights buffer below.

EXECUTION STATE
self.n_tasks = 2 for RUL + health
20# EMA-smoothed weights start at uniform 1/K each.

Comment marking the buffer-registration block. Buffers are persistent tensor state that move with .to(device) and save/load via state_dict.

21self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)

Register a non-learnable tensor as a buffer. EMA weights are persistent state — they don't receive gradients (so they're not Parameters) but they DO need to move to GPU and serialize with the model.

EXECUTION STATE
📚 register_buffer(name, tensor) = nn.Module method. Registers a tensor under the given name. After this call, self.ema_weights returns the tensor; the tensor moves with .to(device); state_dict()['ema_weights'] saves and loads it.
⬇ arg 1: 'ema_weights' = Buffer name. Becomes self.ema_weights and the key in state_dict.
⬇ arg 2: torch.ones(n_tasks) / n_tasks = Initial value tensor.
📚 torch.ones(*size) = Tensor of all ones with the given shape. torch.ones(2) → [1.0, 1.0].
→ / n_tasks = Element-wise scalar divide. [1.0, 1.0] / 2 = [0.5, 0.5]. For K=3: torch.ones(3)/3 = [0.333, 0.333, 0.333].
→ initial value here = [0.5, 0.5] - uniform K-task weighting at construction time.
→ why a buffer not a Parameter = Parameters get gradients (and would be silently updated by optimizer.step()). Buffers don't — they're pure state.
22self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))

Register the warmup step counter as a buffer. dtype=torch.long is 64-bit integer — never overflows for any realistic training horizon.

EXECUTION STATE
📚 torch.tensor(value, dtype) = Build a tensor from a Python value with explicit dtype. Scalar value → 0-dim (scalar) tensor.
⬇ arg 1: 0 = Initial value of the counter. We'll increment in forward().
⬇ arg 2: dtype=torch.long = 64-bit signed integer. Step counts must be exact integers (no float drift) and large enough that 2^63 will never overflow.
→ why a buffer not a Python int = A Python int wouldn't survive checkpoint save/load and can't live on GPU. As a buffer, step_count is part of state_dict and moves with the model.
24def forward(self, rul_loss: torch.Tensor, health_loss: torch.Tensor, shared_params: Optional[List[nn.Parameter]] = None) -> torch.Tensor:

Called once per training step. Takes the two per-task scalar losses plus the list of shared backbone parameters; returns the GABA-weighted combined loss tensor that the caller can call .backward() on.

EXECUTION STATE
⬇ input: rul_loss = 0-dim tensor. The MSE-on-RUL loss for the current batch.
→ realistic value = ≈ 4500 (huge because RUL ∈ [0, 125] and untrained predictions are random Linear outputs)
⬇ input: health_loss = 0-dim tensor. The cross-entropy-on-health loss for the current batch.
→ realistic value = ≈ 1.10 (close to ln(3) = 1.0986 — random predictions over 3 classes)
⬇ input: shared_params = List of nn.Parameter objects (or None). The backbone params we measure per-task gradient norms on. Pass None to force equal weighting (warmup behaviour).
→ Optional[List[nn.Parameter]] hint = Documents that shared_params is either a list of nn.Parameters OR None.
⬆ returns = 0-dim tensor = λ_rul · rul_loss + λ_health · health_loss after EMA + clamp + renorm. Caller does total.backward() to update everything.
26K = 2

Hard-coded K=2 for this RUL+health convenience overload. The general K-task version lives in forward_k() in the paper file.

EXECUTION STATE
K = 2 - matches len(losses) on line 29.
27device = rul_loss.device

Read the device (cpu/cuda/mps) from the first loss so we can build new tensors on the same device. Mixing devices raises a runtime error.

EXECUTION STATE
📚 .device = Tensor attribute: the device the tensor lives on (torch.device('cpu'), torch.device('cuda:0'), torch.device('mps'), etc.).
device = Whatever device rul_loss was computed on. Used in torch.ones(K, device=device) and torch.zeros((), device=device) below.
28self.step_count += 1

Increment the warmup counter in place. PyTorch supports in-place += on tensors; it modifies the buffer storage directly.

EXECUTION STATE
self.step_count = Becomes 1 on the first call, 2 on the second, etc. Used by the warmup check below.
29losses = [rul_loss, health_loss]

Wrap the two losses into a list so the per-task loop can iterate uniformly. Order matters: index 0 = RUL throughout.

EXECUTION STATE
losses = [rul_loss_tensor, health_loss_tensor]. len(losses) == K == 2.
31# During warmup -> equal weighting.

Comment introducing the warmup branch. During warmup, gradient norms are too noisy to trust, so we fall back to uniform weighting.

32if shared_params is None or self.step_count.item() <= self.warmup_steps:

Branch into equal-weighting if either: (a) the caller didn&apos;t pass shared_params (forcing skip), OR (b) we&apos;re still inside the warmup window.

EXECUTION STATE
📚 .item() = Tensor method: extract the Python scalar from a 0-dim tensor. Required because self.warmup_steps is a Python int, and you can&apos;t directly compare a 0-dim tensor with <= to a Python int (well, you can in modern PyTorch, but .item() makes the intent explicit).
shared_params is None = Truthy when caller explicitly passes None. Use case: an evaluation pass that doesn&apos;t need adaptive weighting.
self.step_count.item() <= self.warmup_steps = Truthy for steps 1..warmup_steps. With warmup=2 (smoke test): True for steps 1, 2; False from step 3 onward.
→ why warmup = At early training the model has no useful features yet, so per-task gradient norms are noise. Equal weighting prevents pathological initial GABA weights from poisoning the EMA buffer.
33weights = torch.ones(K, device=device) / K

Build the equal-weights vector [1/K, 1/K, ..., 1/K] on the right device. For K=2 that&apos;s [0.5, 0.5].

EXECUTION STATE
📚 torch.ones(*size, device) = All-ones tensor with the given shape on the given device. torch.ones(2, device='cpu') → tensor([1.0, 1.0]).
⬇ arg: K = 2 = Shape of the output tensor. K=2 ⇒ shape (2,).
⬇ kwarg: device = device = Same device as the input losses. Avoids cross-device errors in the weighted-sum step.
/ K = Scalar divide. [1, 1] / 2 = [0.5, 0.5].
weights = tensor([0.5, 0.5])
34else:

Adaptive branch: warmup is done AND shared_params was passed, so we measure per-task gradients and compute GABA weights.

35# Per-task gradient norms on shared params.

Comment introducing the gradient-measurement block.

36grad_norms = torch.zeros(K, device=device)

Pre-allocate the gradient-norm vector. We&apos;ll fill in entries one by one inside the loop.

EXECUTION STATE
📚 torch.zeros(*size, device) = All-zeros tensor of the given shape on the given device.
⬇ arg: K = 2 = Output shape (2,).
⬇ kwarg: device = device = Match the loss device.
grad_norms (initial) = tensor([0.0, 0.0])
37for i, loss_i in enumerate(losses):

Loop over the K=2 tasks. enumerate yields (index, value) pairs so we can write back into grad_norms[i].

LOOP TRACE · 3 iterations
i=0, loss_i = rul_loss
grads = Tuple of gradient tensors, one per shared_param. For nn.Linear(14,32) backbone: grads = (∂rul_loss/∂W (32,14), ∂rul_loss/∂b (32,))
grad_norms[0] = ≈ 76.53 (toy example with random untrained backbone — the realistic FD002 value would be ~5.0 from §12.3)
i=1, loss_i = health_loss
grads = Tuple of gradient tensors of CE loss w.r.t. the same backbone params.
grad_norms[1] = ≈ 0.367 (toy example — the realistic FD002 value would be ~0.01)
after loop
grad_norms = tensor([76.53, 0.37]) — ratio ≈ 209× in this toy run
38grads = torch.autograd.grad(loss_i, shared_params, retain_graph=True, create_graph=False, allow_unused=True)

Compute ∂loss_i / ∂shared_params functionally — WITHOUT writing into the .grad attributes. Critical for &apos;peeking&apos; at per-task gradients without contaminating the main backward pass that the optimizer relies on.

EXECUTION STATE
📚 torch.autograd.grad(outputs, inputs, ...) = Functional autograd. Computes ∂outputs/∂inputs and returns a tuple of gradient tensors. Unlike loss.backward(), it does NOT accumulate into param.grad — it returns the values directly. Essential when you need to inspect gradients without affecting subsequent optimizer steps.
⬇ arg: outputs = loss_i = The scalar loss to differentiate. Must be a 0-dim tensor (or you must supply grad_outputs for non-scalar).
⬇ arg: inputs = shared_params = List of tensors to compute gradients w.r.t. Returns one gradient per input.
⬇ kwarg: retain_graph = True = DON&apos;T free the autograd graph after this call. Required because (a) we&apos;re going to call autograd.grad() AGAIN for the second task, and (b) the caller will eventually call total.backward() through the weighted combination.
→ without retain_graph = Default is False. The first autograd.grad() call would free the graph; the second call (i=1) would crash with &apos;Trying to backward through the graph a second time&apos;.
⬇ kwarg: create_graph = False = Don&apos;t track the gradients-of-gradients. We don&apos;t need second-order; this saves memory and compute.
⬇ kwarg: allow_unused = True = Tolerate parameters that don&apos;t affect loss_i (for example, a head-specific param accidentally passed in). Returns None for those entries instead of raising RuntimeError.
→ without allow_unused = Default is False. If shared_params contains anything not on loss_i&apos;s graph, RuntimeError: &apos;One of the differentiated Tensors appears to not have been used in the graph&apos;.
⬆ result: grads = Tuple of tensors (or None entries). One gradient tensor per shared_param, with shape matching the param.
40grad_norms[i] = sum((g.detach().norm() ** 2 for g in grads if g is not None)).sqrt()

Compute the global L2 norm of the gradient by summing squared per-param norms then taking sqrt. Equivalent to flattening all grads into one giant vector and calling .norm() — but cheaper because we skip the concatenation.

EXECUTION STATE
📚 generator expression (g ... for g in grads ...) = Lazy iterator. Yields one value at a time without materialising a list. Used inside sum() it streams the squared norms one by one.
📚 .detach() = Tensor method: returns a view with no autograd history. Used here so the norm computation isn&apos;t accidentally tracked into the eventual backward() through total.
📚 .norm() = Tensor method: L2 norm by default (ord=2). For tensor t, t.norm() = sqrt(Σ t_ij²).
** 2 = Element-wise square of the scalar norm. Sets up the sum-of-squared-norms pattern.
if g is not None = Filter out None entries that allow_unused=True returned for unused parameters.
📚 sum(iterable) = Python builtin. Sums the iterable starting from 0 by default. Returns a 0-d tensor here because each yielded value is a 0-d tensor.
.sqrt() = Final square root: ||g||_global = sqrt(Σ_p ||g_p||²). Recovers the L2 norm of the concatenated gradient.
→ why sum-of-squared-norms = ||(a, b)||² = ||a||² + ||b||² (Pythagoras in product space). So sqrt(Σ_p ||g_p||²) = ||(g_1, g_2, ..., g_P)||₂. We avoid an expensive cat() over all parameters.
→ numerical value (toy run) = i=0 (RUL): grad_norms[0] ≈ 76.53. i=1 (health): grad_norms[1] ≈ 0.37. Ratio ≈ 209×.
42total_norm = grad_norms.sum() + 1e-12

Sum across the K=2 task gradient norms; add epsilon for divide-by-zero safety (mirrors the NumPy version).

EXECUTION STATE
📚 .sum() = Tensor method: sum all elements. For a (K,) vector returns a 0-d scalar.
1e-12 = Numerical safeguard. Same role as in the NumPy version.
total_norm = 0-d tensor ≈ 76.90 (76.53 + 0.37) for the toy run.
43raw_weights = (total_norm - grad_norms) / ((K - 1) * total_norm)

The K-task GABA formula in vectorised PyTorch. Identical math to the NumPy version; broadcasting handles scalar−vector subtraction and vector/scalar division.

EXECUTION STATE
total_norm - grad_norms = Element-wise. For grad_norms=[76.53, 0.37], total=76.90: [76.90 − 76.53, 76.90 − 0.37] = [0.37, 76.53].
(K - 1) * total_norm = Normaliser. For K=2: 1 × 76.90 = 76.90.
raw_weights = [0.37/76.90, 76.53/76.90] ≈ [0.0048, 0.9952]
→ reading = Raw GABA puts ~99.5% of the weight on health (the small-gradient task) and ~0.5% on RUL.
45# EMA smoothing.

Comment introducing the EMA-smoothing block. Without smoothing, weights would lurch around with each batch&apos;s gradient noise.

46ema_w = self.beta * self.ema_weights + (1.0 - self.beta) * raw_weights

Exponential moving average update. New value is 99% old + 1% current — slow, stable response to batch-to-batch gradient noise.

EXECUTION STATE
self.beta = 0.99
self.ema_weights = Previous EMA-smoothed weights. After warmup this starts at [0.5, 0.5] (initial buffer value).
(1.0 - self.beta) = 0.01 — the fraction of the new raw_weights that bleeds into the EMA each step.
raw_weights = [0.0048, 0.9952] - this step&apos;s unsmoothed weights.
ema_w = 0.99 · [0.5, 0.5] + 0.01 · [0.0048, 0.9952] ≈ [0.4950, 0.5050]
→ reading = After ONE adaptive step (step 3 in the smoke test), the weights barely budge — that&apos;s β=0.99 doing its job. With constant raw_weights, full convergence to ~[0.005, 0.995] would take ~500 steps.
47self.ema_weights = ema_w.detach()

Save the new EMA weights into the buffer. The .detach() is CRITICAL: without it, the EMA buffer would carry an autograd graph that grows by one step every iteration, eventually exhausting GPU memory.

EXECUTION STATE
📚 .detach() = Tensor method: returns a new tensor with the same data but no autograd history. After detach, any backward() through this tensor stops here — there&apos;s nothing upstream to propagate to.
→ memory bug if omitted = Forgetting .detach() here is the most common GABA bug. self.ema_weights would retain references to the entire computation graph from the moment it was first computed, and every call to backward() through `total` would re-traverse all those steps. Memory grows linearly until OOM.
→ why we still detach even though raw_weights is used inline = ema_w was constructed from self.beta * self.ema_weights (graph-free buffer reads still create a new graph node) + ... * raw_weights (which has a full graph back to loss). Without detach, the graph leaks across steps via self.ema_weights.
49# Floor + re-normalise.

Comment introducing the safety net: clamp small weights up to min_weight, then renormalise so they sum to 1 again.

50weights = ema_w.clamp(min=self.min_weight)

Floor each weight at min_weight=0.05. Prevents any task from being driven to zero (which could happen at the K=2 limit when one task&apos;s gradient norm is huge).

EXECUTION STATE
📚 .clamp(min, max) = Tensor method: element-wise clip. clamp(min=0.05) sets any element below 0.05 to 0.05; elements above are unchanged. Equivalent to torch.maximum(t, 0.05).
⬇ kwarg: min = self.min_weight = Lower bound = 0.05 (paper default). No max kwarg, so no upper bound.
→ effect on toy run = ema_w ≈ [0.4950, 0.5050]. Both values are already > 0.05, so clamp is a no-op for this step. It would matter much later in training when ema_w[0] drifts below 0.05.
→ effect at convergence = If raw GABA gives [0.002, 0.998], after EMA convergence ema_w ≈ [0.002, 0.998]. clamp(min=0.05) → [0.05, 0.998]. THEN we renormalise on the next line.
51weights = weights / weights.sum()

Renormalise so the floored weights sum back to 1.0. Required because clamp() can push the sum above 1.

EXECUTION STATE
weights.sum() = Scalar tensor. After a no-op clamp, equals (close to) 1.0 already.
weights / weights.sum() = Project back onto the simplex (Σ λ_i = 1). For the toy step 3: ≈ [0.4950, 0.5050] / 1.000 ≈ [0.4950, 0.5050] (no-op).
→ at convergence = [0.05, 0.998] / 1.048 ≈ [0.0477, 0.9523] — the floor steals 4.5% of health&apos;s weight to keep RUL alive.
→ why renormalise after clamp = Without it, weights might sum to 1.05 or 1.10 after clamping, scaling the combined loss artificially. Renormalising keeps the loss magnitude comparable to the equal-weighted baseline.
53# Combine.

Comment marking the loss-combination block. Below: weighted sum λ_i · L_i, returned as the final scalar.

54total = torch.zeros((), device=device)

Initialise the combined-loss accumulator as a 0-dim zero tensor on the right device. Empty tuple () specifies a scalar shape.

EXECUTION STATE
📚 torch.zeros(shape, device) = Tensor of zeros with given shape and device. shape=() ⇒ 0-dim scalar tensor.
⬇ arg: () = Empty tuple = 0-dim shape. Different from (1,) which would be a 1-element vector.
⬇ kwarg: device = device = Match the loss device. Adding a CPU tensor to a CUDA loss would crash.
total (initial) = tensor(0.) — 0-dim scalar.
55for w, l in zip(weights, losses):

Iterate task-weight and task-loss in lockstep. zip yields tuples (w_i, loss_i) for i = 0..K-1.

LOOP TRACE · 2 iterations
step 3 of smoke test, w=0.4950, l=rul_loss=4521.15
w * l = 0.4950 × 4521.15 ≈ 2238.0
total = 0.0 + 2238.0 = 2238.0
step 3 of smoke test, w=0.5050, l=health_loss=1.1936
w * l = 0.5050 × 1.1936 ≈ 0.6028
total = 2238.0 + 0.6028 ≈ 2238.6
56total = total + w * l

Accumulate one weighted task-loss term. Out-of-place addition — PyTorch prefers this over += for autograd safety on accumulator tensors.

EXECUTION STATE
w * l = Scalar × scalar tensor product. Both are 0-d so the product is 0-d.
+ (out-of-place) = Creates a new tensor. In-place += on `total` could break the autograd graph if `total` is being referenced elsewhere.
57return total

Return the GABA-weighted combined loss. The caller calls .backward() on this to update both heads and the shared backbone in one go.

EXECUTION STATE
⬆ return: total = 0-dim tensor = λ_rul · rul_loss + λ_health · health_loss after EMA, clamp, renorm. Has full autograd graph back to the model parameters.
60# ---------- Smoke test ----------

Section divider. Below this comment we drive the GABALoss with a tiny toy backbone for 5 steps to verify the full machinery executes (warmup, autograd.grad, EMA, clamp, renorm).

61torch.manual_seed(0)

Set the global PyTorch PRNG seed for reproducibility. Affects every subsequent torch.randn / torch.rand / torch.randint and the random init of nn.Linear weights.

EXECUTION STATE
📚 torch.manual_seed(s) = Sets the seed of the default CPU PRNG (and the CUDA PRNG via a wrapper). All sampling that follows is deterministic given this seed.
⬇ arg: 0 = Seed value. Any non-negative int works; 0 is conventional.
62backbone = nn.Linear(14, 32)

Tiny shared backbone for the smoke test: maps 14 sensor inputs → 32 hidden features. Mimics §11&apos;s feature dimension after a CNN encoder, but linear so the test runs in milliseconds.

EXECUTION STATE
📚 nn.Linear(in_features, out_features, bias=True) = Fully-connected layer. Stores W of shape (out_features, in_features) and b of shape (out_features,). Forward: y = x @ W.T + b.
⬇ arg 1: in_features = 14 = Input dimension — 14 sensor channels in C-MAPSS.
⬇ arg 2: out_features = 32 = Output dimension — 32 latent features. Matches the §11 CNN feature size.
→ param count = W (32, 14) = 448 weights + b (32,) = 32 biases = 480 params. These are the SHARED parameters GABA measures gradients on.
63rul_head = nn.Linear(32, 1)

RUL regression head: 32 features → 1 scalar prediction.

EXECUTION STATE
⬇ args (32, 1) = 32 inputs, 1 output. Standard regression head.
→ param count = 32 weights + 1 bias = 33 params. NOT shared — head-specific.
64hp_head = nn.Linear(32, 3)

Health classification head: 32 features → 3 logits (3 health states).

EXECUTION STATE
⬇ args (32, 3) = 32 inputs, 3 logits. Output goes to softmax via cross_entropy.
→ param count = 32×3 = 96 weights + 3 biases = 99 params. Head-specific, not shared.
65gaba = GABALoss(beta=0.99, warmup_steps=2, min_weight=0.05, n_tasks=2)

Instantiate GABA. We override warmup_steps from 100 (paper default) to 2 so the adaptive code path fires within the 5-step smoke test.

EXECUTION STATE
⬇ kwarg: beta = 0.99 = Paper default - keep.
⬇ kwarg: warmup_steps = 2 = Smoke-test override. Paper uses 100. With warmup=2, steps 1-2 use equal weighting and steps 3-5 use adaptive.
⬇ kwarg: min_weight = 0.05 = Paper default - keep.
⬇ kwarg: n_tasks = 2 = K=2 (RUL + health).
67x = torch.randn(64, 14)

Synthetic input: 64-sample batch of 14-dim sensor readings, drawn from the standard normal.

EXECUTION STATE
📚 torch.randn(*size) = Tensor of i.i.d. samples from N(0, 1). Shape arg is given as *args (positional ints).
⬇ args: (64, 14) = Output shape — 64 rows × 14 cols. Mimics 64 engine snapshots × 14 sensor channels.
x = shape (64, 14), float32, values ~ N(0, 1)
68rul_target = torch.rand(64, 1) * 125.0

Synthetic RUL targets in [0, 125] cycles. torch.rand is uniform on [0, 1), then we scale by the C-MAPSS clip value.

EXECUTION STATE
📚 torch.rand(*size) = Uniform [0, 1) tensor. (Different from torch.randn which is standard normal.)
⬇ args: (64, 1) = Output shape — column vector of 64 targets.
* 125.0 = Scale uniform [0, 1) → uniform [0, 125). Matches the C-MAPSS RUL clip.
rul_target = shape (64, 1), float32, values in [0, 125)
69hp_target = torch.randint(0, 3, (64,))

Synthetic 3-class health labels for the batch.

EXECUTION STATE
📚 torch.randint(low, high, size) = Uniform integer tensor on [low, high). Default dtype int64 — matches what cross_entropy expects.
⬇ arg: low = 0 = Lower bound (inclusive).
⬇ arg: high = 3 = Upper bound (exclusive). Combined with low=0 ⇒ values in {0, 1, 2}.
⬇ arg: size = (64,) = Output shape — 1-D tensor of 64 labels. Note the trailing comma: (64,) is a tuple, (64) is just an int.
hp_target = shape (64,), int64, values in {0, 1, 2}
71for step in range(5):

Five training-style iterations. Steps 1-2 are warmup (equal weights); steps 3-5 use adaptive GABA. Note: this loop has NO optimizer.step(), so model params and losses don&apos;t change between iterations — the only thing drifting is the EMA buffer.

LOOP TRACE · 5 iterations
step = 0 (warmup, step_count=1)
lambda after step = (0.5000, 0.5000) - equal weighting
rul_loss = ≈ 4521.15
health_loss = ≈ 1.1936
step = 1 (warmup, step_count=2)
lambda after step = (0.5000, 0.5000) - last warmup step
step = 2 (adaptive starts, step_count=3)
raw_weights = [0.0048, 0.9952] - from grad_norms ≈ [76.5, 0.37]
lambda after step = (0.4950, 0.5050) - first EMA update toward inverse-gradient
step = 3 (step_count=4)
lambda after step = (0.4901, 0.5099) - drift continues toward [~0.005, ~0.995]
step = 4 (step_count=5)
lambda after step = (0.4853, 0.5147) - slow convergence, β=0.99 needs ~500 steps to fully stabilise
72feat = backbone(x)

Forward through the shared backbone: (64, 14) → (64, 32). nn.Linear&apos;s __call__ runs the forward pass and registers the autograd graph.

EXECUTION STATE
feat = shape (64, 32), float32, fully autograd-tracked
73rul_pred = rul_head(feat)

RUL head: (64, 32) → (64, 1). Raw scalar predictions — no activation.

EXECUTION STATE
rul_pred = shape (64, 1), float32
74hp_logits = hp_head(feat)

Health head: (64, 32) → (64, 3). Raw 3-class logits — softmax happens inside cross_entropy.

EXECUTION STATE
hp_logits = shape (64, 3), float32
75rul_loss = ((rul_pred - rul_target) ** 2).mean()

Plain MSE on RUL. Element-wise residual, square, mean over the 64-sample batch.

EXECUTION STATE
rul_pred - rul_target = Element-wise residual, shape (64, 1).
** 2 = Element-wise square, shape (64, 1).
.mean() = Reduce-all to a scalar. Equivalent to .sum() / 64.
rul_loss = 0-dim tensor ≈ 4521.15 (huge because targets ∈ [0, 125] and untrained predictions are random)
76health_loss = nn.functional.cross_entropy(hp_logits, hp_target)

Standard 3-class cross-entropy. Combines log_softmax + nll_loss in one numerically-stable call.

EXECUTION STATE
📚 nn.functional.cross_entropy(input, target) = PyTorch&apos;s loss function. Internally: −Σ_i log(softmax(input)[target_i]) / N. Numerically stable via log-sum-exp.
⬇ arg: input = hp_logits (64, 3) = Raw logits — DON&apos;T apply softmax yourself; cross_entropy does it internally.
⬇ arg: target = hp_target (64,) = Integer class indices, dtype int64. NOT one-hot.
health_loss = 0-dim tensor ≈ 1.1936 (close to ln(3) = 1.0986, the maximum-entropy expectation for random 3-class predictions)
78total = gaba(rul_loss, health_loss, shared_params=list(backbone.parameters()))

Apply the GABA-weighted combination. shared_params is the explicit list of nn.Parameters in the BACKBONE only (not the heads) — these are the parameters whose gradient norms drive the per-task weighting.

EXECUTION STATE
📚 .parameters() = nn.Module method. Returns an iterator over all nn.Parameter objects under this module (including submodules). For nn.Linear: 2 params (weight, bias).
📚 list(...) = Materialise the iterator into a list. Necessary because GABA needs to consume shared_params TWICE (one autograd.grad() call per task) — an iterator would be exhausted after the first task.
⬇ kwarg: shared_params = [backbone.weight, backbone.bias] = List of the two nn.Parameter objects in the backbone. The heads&apos; parameters are intentionally NOT included — they don&apos;t need GABA balancing.
total = 0-dim tensor ≈ 0.5 · 4521.15 + 0.5 · 1.1936 ≈ 2261.2 (during warmup, equal weighting)
79print(f"step {step} | rul_loss = {rul_loss.item():7.2f} | health_loss = {health_loss.item():.4f}"

Pretty-print the per-step metrics. .item() extracts Python scalars; :7.2f means width-7, 2 decimals.

EXECUTION STATE
📚 .item() = Tensor method - extract a Python scalar from a 0-dim tensor. Required for use inside f-string format specs.
:7.2f format spec = Width 7, 2 decimal places, fixed-point float. Example: 4521.15 → ' 4521.15' (left-padded to width 7).
80f" | lambda = ({gaba.ema_weights[0]:.4f}, {gaba.ema_weights[1]:.4f})")

Continuation of the print statement (Python implicit string concatenation). Reads the EMA buffer directly to expose the current λ values.

EXECUTION STATE
gaba.ema_weights[0] = Current EMA weight for RUL. After step 4: ≈ 0.4853.
gaba.ema_weights[1] = Current EMA weight for health. After step 4: ≈ 0.5147.
Final printed output =
step 0 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.5000, 0.5000)
step 1 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.5000, 0.5000)
step 2 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.4950, 0.5050)
step 3 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.4901, 0.5099)
step 4 | rul_loss = 4521.15 | health_loss = 1.1936 | lambda = (0.4853, 0.5147)
→ why losses don&apos;t change = We never call optimizer.step() in this smoke test - so model params (and therefore predictions and losses) are constant across all 5 iterations. The ONLY thing changing is the EMA buffer inside GABALoss.
21 lines without explanation
1import torch
2import torch.nn as nn
3from typing import List, Optional
4
5
6class GABALoss(nn.Module):
7    """Gradient-Aware Balanced Adaptation - paper code (grace/core/gaba.py).
8
9    EMA-smoothed inverse-gradient weights with warmup + min_weight floor.
10    """
11
12    def __init__(self, beta: float = 0.99, warmup_steps: int = 100,
13                 min_weight: float = 0.05, n_tasks: int = 2) -> None:
14        super().__init__()
15        self.beta         = beta
16        self.warmup_steps = warmup_steps
17        self.min_weight   = min_weight
18        self.n_tasks      = n_tasks
19
20        # EMA-smoothed weights start at uniform 1/K each.
21        self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)
22        self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))
23
24    def forward(self, rul_loss: torch.Tensor, health_loss: torch.Tensor,
25                shared_params: Optional[List[nn.Parameter]] = None) -> torch.Tensor:
26        K = 2
27        device = rul_loss.device
28        self.step_count += 1
29        losses = [rul_loss, health_loss]
30
31        # During warmup -> equal weighting.
32        if shared_params is None or self.step_count.item() <= self.warmup_steps:
33            weights = torch.ones(K, device=device) / K
34        else:
35            # Per-task gradient norms on shared params.
36            grad_norms = torch.zeros(K, device=device)
37            for i, loss_i in enumerate(losses):
38                grads = torch.autograd.grad(loss_i, shared_params, retain_graph=True,
39                                             create_graph=False, allow_unused=True)
40                grad_norms[i] = sum((g.detach().norm() ** 2 for g in grads if g is not None)).sqrt()
41
42            total_norm  = grad_norms.sum() + 1e-12
43            raw_weights = (total_norm - grad_norms) / ((K - 1) * total_norm)
44
45            # EMA smoothing.
46            ema_w = self.beta * self.ema_weights + (1.0 - self.beta) * raw_weights
47            self.ema_weights = ema_w.detach()
48
49            # Floor + re-normalise.
50            weights = ema_w.clamp(min=self.min_weight)
51            weights = weights / weights.sum()
52
53        # Combine.
54        total = torch.zeros((), device=device)
55        for w, l in zip(weights, losses):
56            total = total + w * l
57        return total
58
59
60# ---------- Smoke test ----------
61torch.manual_seed(0)
62backbone = nn.Linear(14, 32)              # tiny shared backbone
63rul_head = nn.Linear(32, 1)               # task 1
64hp_head  = nn.Linear(32, 3)               # task 2
65gaba     = GABALoss(beta=0.99, warmup_steps=2, min_weight=0.05, n_tasks=2)
66
67x          = torch.randn(64, 14)
68rul_target = torch.rand(64, 1) * 125.0
69hp_target  = torch.randint(0, 3, (64,))
70
71for step in range(5):
72    feat        = backbone(x)
73    rul_pred    = rul_head(feat)
74    hp_logits   = hp_head(feat)
75    rul_loss    = ((rul_pred - rul_target) ** 2).mean()
76    health_loss = nn.functional.cross_entropy(hp_logits, hp_target)
77
78    total = gaba(rul_loss, health_loss, shared_params=list(backbone.parameters()))
79    print(f"step {step} | rul_loss = {rul_loss.item():7.2f} | health_loss = {health_loss.item():.4f}"
80          f" | lambda = ({gaba.ema_weights[0]:.4f}, {gaba.ema_weights[1]:.4f})")

Inverse-Gradient In Other Domains

The inverse-gradient idea generalises. Anywhere two or more objectives share parameters and have unequal gradient magnitudes, GABA-style weighting applies:

DomainTasks sharing parametersWhy GABA helps
RUL prediction (this book)RUL regression + health classification500× gradient imbalance at init
Multi-modal speech (Whisper-style)ASR transcription + language detection + speaker IDCross-entropy heads have unequal label vocab
Object detection (Detectron2)Bounding-box regression + class CE + objectness logitBbox L1 loss dwarfs CE when boxes are big
Recommender systems (multi-head DLRM)Click prediction + dwell time + share/saveClick logit gradient dwarfs dwell-time MSE
Robotics (multi-task imitation)End-effector position + gripper open/close + force controlForce MSE gradient is tiny vs position MSE
Drug discovery (molecular property)logP regression + toxicity classification + binding scoreEach property has different label scale

In every row, the larger-gradient task otherwise dominates the shared trunk; inverse-gradient weighting equalises the backbone update so all tasks contribute proportionally.

Three Pitfalls Of Naive Implementations

Pitfall 1: Forgetting retain_graph=True on torch.autograd.grad. We compute K per-task gradients on the same forward graph, then backward through the weighted combination. Without retain_graph=True, the second per-task gradient call destroys the graph and we crash.
Pitfall 2: Forgetting to .detach() the EMA buffer. If self.ema_weights retains autograd history across steps, every backward() walks all the way back to epoch 0. Memory grows unboundedly until OOM. Always self.ema_weights = ema_w.detach().
Pitfall 3: Skipping the warmup window. At step 0-99 the model has no useful features yet. Gradient norms are noisy and unreliable; computing GABA weights from them produces garbage initial weights. The 100-step warmup + EMA(β=0.99) is the paper-canonical stabiliser.
The paper's zero-tuning claim. The raw inverse-gradient formula has zero hyperparameters; β, warmup, and min_weight only stabilise it. So you can deploy GABA with the four paper defaults (0.99 / 100 / 0.05 / 2) and almost never need to retune. That is what makes the section's title accurate: GABA EQUALISES task contributions automatically.

Takeaway

  • Two voices, one microphone. Multi-task learning is dominated by the loudest gradient unless you deliberately equalise.
  • K-task GABA formula: λi=(jgjgi)/((K1)jgj)\lambda_i = (\sum_j g_j - g_i) / ((K-1) \sum_j g_j). Larger gradient ⇒ smaller weight.
  • Closed-form K=2: λrul=ghealth/(grul+ghealth)\lambda_{\text{rul}} = g_{\text{health}} / (g_{\text{rul}} + g_{\text{health}}). Drops out of the K-task formula by setting K=2.
  • Implementation: 3 lines of NumPy or 50 lines of PyTorch. NumPy is for understanding; PyTorch adds EMA + warmup + min_weight + buffer state.
  • Zero hyperparameters in the raw formula. Stabilisers (β, warmup, min_weight) have safe defaults; you almost never need to tune them per dataset.
  • Generalises beyond RUL. Anywhere shared parameters serve unequal-gradient objectives, GABA-style inverse weighting applies.
Loading comments...