Chapter 18
22 min read
Section 76 of 121

Full PyTorch Implementation

The GABA Algorithm

From Four Pieces To One Module

Sections 18.1 through 18.4 dissected the four mechanisms of GABA in isolation: gradient-norm measurement, EMA smoothing, the minimum floor, and the warmup gate. Each was a clean idea with a clean implementation. This section assembles them into the single class the paper actually uses in production: grace/core/gaba.py:GABALoss.

The integration is non-trivial in three ways. First, all four mechanisms share state (the EMA buffer, the step counter, the logging snapshots) that has to live in registered nn.Module buffers so it survives device moves and checkpoint round-trips. Second, the autograd flow has to thread through both the per-task gradient measurements (which use torch.autograd.grad) and the eventual weighted-combination backward (which uses loss.backward()) without conflicts. Third, the public API needs an inspection surface so monitoring code can pull live gi,λ^i,λig_i, \hat{\lambda}_i, \lambda^*_i values without breaking encapsulation.

The full class is 80 lines. 80 lines of Python implements an algorithm that drops paper FD002 NASA from 498 (DKAMFormer) to 224 — a 55% operational-safety improvement. Two ideas (closed form + EMA) plus two stabilisers (floor + warmup) plus careful autograd hygiene plus four logging helpers. That is the complete recipe.

The Class Anatomy

Eight methods, two registered buffers, two logging slots, four hyperparameters:

ComponentTypePurposeDefined in §
betaPython floatEMA smoothing coefficient. Paper: 0.99.§18.2
warmup_stepsPython intSteps with uniform 1/K weighting before adaptive logic kicks in. Paper: 100.§18.4
min_weightPython floatFloor on per-task weight. Paper: 0.05.§18.3
n_tasksPython intK. RUL + health for this paper.§17 (model anatomy)
ema_weightsRegistered buffer (K,)EMA-smoothed task weights. Persistent across steps and checkpoints.§18.2
step_countRegistered buffer (long)Step counter. Used by warmup gate. Persistent.§18.4
_last_grad_normsOptional TensorLogging snapshot of per-task gradient norms.§18.5 (this section)
_last_raw_weightsOptional TensorLogging snapshot of pre-EMA raw weights.§18.5
forward(rul_loss, health_loss, shared_params)MethodK=2 convenience wrapper.
forward_k(losses, shared_params)Method (workhorse)Full pipeline: gate → grad norms → closed form → EMA → floor → renorm → combine.this section
get_weights()MethodInspection: return current EMA weights as dict.this section
get_gradient_stats()MethodInspection: return last gradient norms and raw weights as dict.this section

__init__: Hyperparameters and Buffers

Construction stores four hyperparameters as plain Python attributes (they are not learnable) and registers two buffers via nn.Module.register_buffer:

  • ema_weights initialised to 1K/K\mathbf{1}_K / K — uniform 1/K. Paper Algorithm 1 line 2.
  • step_count initialised to 00 as a 0-dim long tensor.

The two logging slots _last_grad_norms and _last_raw_weights are NOT registered as buffers. They are debugging snapshots, not state we want to checkpoint. Saving them in state_dict would inflate checkpoints unnecessarily and re-introduce a fictitious dependency between training resumption and the most recent gradient measurement.

forward(): K=2 Convenience Wrapper

The K=2 convenience method is three lines. It packages Lrul\mathcal{L}_{\text{rul}} and Lhealth\mathcal{L}_{\text{health}} into a list and delegates to the K-task workhorse:

return self.forward_k([rul_loss, health_loss], shared_params)

The wrapper exists for backward compatibility with the paper's baseline trainer signature, which expects a (rul_loss, health_loss, shared_params=None, **kwargs) contract. The **kwargs catches stray trainer-provided kwargs (e.g. model for GradNorm baselines) so all loss classes can share the same call site.

forward_k(): The Workhorse

The general K-task method runs the full pipeline. Pseudocode:

  • Increment step_count.
  • Warmup gate: if tWt \leq W OR shared_params is None, return 1/K1/K uniform weights and skip the rest.
  • Per-task gradient norms: loop K times, each call to compute_task_grad_norm uses torch.autograd.grad with retain_graph=True.
  • Closed form (§17.3 + §18.1): vector form λi=(Sgi)/((K1)S)\lambda_i = (S - g_i) / ((K-1)\,S).
  • EMA update (§18.2): λ^βλ^+(1β)λ\hat{\lambda} \leftarrow \beta \hat{\lambda} + (1-\beta) \lambda — saved back to the buffer with .detach().
  • Floor + renorm (§18.3): clamp(min=...) / .sum().
  • Combine: weighted sum of the K losses → scalar.
  • Return the scalar; caller calls .backward() and optimiser.step().
The detach is the trickiest line. self.ema_weights[:K] = ema_w.detach() on line 53 of the paper code. Without .detach(), the autograd graph through the EMA buffer accumulates ACROSS STEPS. After 1,000 steps, backward() walks all the way back to step 0 and exhausts GPU memory. This is the most-cited GABA implementation bug.

Logging Buffers and Inspection Helpers

Two methods expose internal state for monitoring without breaking encapsulation:

  • get_weights() returns the current EMA weights as a Dict[str, float]. For K=2 the keys are rul_weight and health_weight; for general K they are task_0_weight, etc.
  • get_gradient_stats() returns the last gradient norms, the gradient ratio (most useful single number), and the un-smoothed raw weights. Empty dict during warmup; populated keys during active steps.

Both helpers are read-only and call .detach().cpu() on every tensor before exposing it, so the trainer can log to W&B / TensorBoard without risk of breaking the autograd graph.

Interactive: Live GABA Dashboard

Four panels share the same x-axis (training step). The amber band is warmup. Drag the red ‘current step’ slider to see all four pipeline state variables at the same moment in training. The third panel (smoothed task weights) is the headline output; the first two show the inputs the closed form consumes; the fourth shows the resulting combined loss.

Loading GABA dashboard…
Try this. Set W = 0 and watch panel 3 immediately drop from 0.5 to a noisy trajectory. The first 100 steps now run on cold-start gradients — observe how panel 2 (gradient ratio) is wildly oscillating during this period, justifying §18.4's warmup gate. Restore W = 100 and watch panel 3 stay flat through warmup, then begin its smooth descent toward the steady-state weight after step 100.

Python: Full Class From Scratch (NumPy)

A pedagogical NumPy mirror of the paper class. Same hyperparameters, same buffers, same per-step pipeline, same bounded-weight guarantee — without autograd, so the structure is fully visible. A 5-step smoke test with warmup_steps=2 exercises both branches of the gate.

Full GABA pipeline in pure NumPy
🐍gaba_loss_numpy.py
1docstring

Module docstring. The class below mirrors the paper's grace/core/gaba.py:GABALoss line for line, but in pure NumPy so the structure is visible without PyTorch.

3import numpy as np

NumPy supplies the ndarray and all array math used here.

EXECUTION STATE
📚 numpy = Numerical computing library. Used for ndarray, np.full, np.maximum, np.dot, np.asarray.
6class GABALossNumPy:

Plain Python class (no nn.Module subclass needed since we don't have autograd). Same four hyperparameters, same two state variables, same two logging variables as paper.

7docstring

Records the relationship to the paper class. Drop-in NumPy version for understanding.

13def __init__(self, beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2):

Constructor. All four defaults are paper-canonical from main.tex:347-362.

EXECUTION STATE
⬇ beta = 0.99 = EMA coefficient. Time constant τ = 1/(1−β) = 100 steps. Paper eq. 5.
⬇ warmup_steps = 100 = Steps with uniform weighting before adaptive logic kicks in. Paper Algorithm 1 line 1.
⬇ min_weight = 0.05 = Floor for the post-EMA clamp. Paper main.tex:354.
⬇ n_tasks = 2 = Number of tasks (RUL + health for K=2).
14self.beta = beta

Store as instance attribute. Used inside step() for the EMA update.

15self.warmup_steps = warmup_steps

Store as instance attribute. Used by the warmup gate.

16self.min_weight = min_weight

Store as instance attribute. Used by np.maximum() in the floor step.

17self.n_tasks = n_tasks

Cached K. Used to size the initial ema_weights buffer.

18# State (mirrors paper register_buffer)

Comment marking the persistent state — these are the variables that would be torch buffers in the PyTorch version.

19self.ema_weights = np.full(n_tasks, 1.0 / n_tasks)

EMA-smoothed task weights, initialised at uniform 1/K. For K=2: [0.5, 0.5]. Paper Algorithm 1 line 2.

EXECUTION STATE
📚 np.full(shape, fill) = Build an ndarray with the given shape filled with the given scalar.
ema_weights (init) = ndarray (2,) = [0.5, 0.5]. Will be updated inside step() during the active branch.
20self.step_count = 0

Plain Python int counter. Incremented at the top of every step() call. Used by the warmup gate.

EXECUTION STATE
step_count = Integer ≥ 0. 0 at construction; first step() call increments it to 1.
21# Logging buffers (mirrors paper _last_grad_norms, _last_raw_weights)

Comment marking the introspection slots — these are the values exposed to monitoring code in the paper class.

22self.last_grad_norms = None

Snapshot of the per-task gradient norms from the last active step. None during warmup.

EXECUTION STATE
last_grad_norms = ndarray (K,) or None. Records the most recent g_i values for logging.
23self.last_raw_weights = None

Snapshot of the un-smoothed closed-form weights (before EMA + floor). None during warmup.

EXECUTION STATE
last_raw_weights = ndarray (K,) or None. Records the most recent un-smoothed lambda for logging.
25def step(self, losses, grad_norms):

One full GABA per-step update. Takes K task losses and K gradient norms; returns the combined scalar loss and the weight vector used.

EXECUTION STATE
⬇ input: losses = List of K floats. Per-task scalar losses for THIS batch.
⬇ input: grad_norms = List of K floats OR None. Per-task L2 gradient norms on shared backbone. None during warmup (we don't bother computing them).
⬆ returns = Tuple (total_loss: float, weights: ndarray). total_loss is the scalar combined loss; weights is the lambda* used.
26docstring

Records the contract: K losses + K gradient norms → scalar loss + K weights.

27K = len(losses)

Cache the number of tasks. Equal to self.n_tasks for normal use; making it dynamic supports K-changing variants.

EXECUTION STATE
K = Integer. 2 for our example.
28self.step_count += 1

Increment FIRST so the gate compares the post-increment value. step 0 → 1 on the first call.

EXECUTION STATE
step_count (after) = 1 on the first call, 2 on the second, etc.
29if self.step_count <= self.warmup_steps:

Warmup gate. Inclusive ≤ so step 100 is still warmup, step 101 is first active. Same semantics as paper grace/core/gaba.py:107.

30# Warmup: uniform weights, no gradient measurement

Inline comment marking the warmup branch.

31weights = np.full(K, 1.0 / K)

Uniform 1/K. EMA buffer is NOT updated during warmup (we skip past the entire active branch).

EXECUTION STATE
weights (warmup) = [0.5, 0.5] for K=2. Same value the EMA buffer was initialised at.
33g = np.asarray(grad_norms, dtype=np.float64)

Convert the list/tuple to an ndarray. np.asarray avoids a copy if grad_norms is already an ndarray.

EXECUTION STATE
📚 np.asarray(a, dtype) = Build an ndarray from a (no-copy if a is already ndarray with matching dtype). Compare to np.array which always copies.
⬇ dtype = np.float64 = 64-bit float. Matches PyTorch float32 for headline accuracy with extra headroom on large products like total_norm * (K - 1).
g = ndarray (K,). Per-task gradient norms for this step.
34self.last_grad_norms = g.copy()

Snapshot for logging. .copy() so later mutations of g don't bleed into the snapshot.

EXECUTION STATE
📚 .copy() = ndarray method: return a deep copy with its own data buffer.
35total_norm = g.sum() + 1e-12

Sum of gradient norms plus an epsilon for safety against the all-zero case.

EXECUTION STATE
📚 .sum() = ndarray reduction → scalar sum.
1e-12 = Numerical guard. Prevents divide-by-zero when every gradient is exactly zero (rare but possible on stationary points).
36# Closed form (paper eq. 4)

Inline comment marking the closed-form computation.

37raw = (total_norm - g) / ((K - 1) * total_norm)

K-task GABA closed form. (total_norm - g) is element-wise scalar minus vector. The result is a (K,) vector that sums to 1.

EXECUTION STATE
(total_norm - g) = Element-wise. For K=2 with g=[250, 0.2]: total=250.2, total-g=[0.2, 250]. Notice the SWAP: smaller g gets larger numerator.
(K - 1) * total_norm = Normaliser. For K=2: 1 * 250.2 = 250.2. Picked so the resulting weights sum to 1.
raw = ndarray (K,). Un-smoothed closed-form weights. For K=2 with the above g: [0.0008, 0.9992].
38self.last_raw_weights = raw.copy()

Snapshot for logging.

39# EMA (paper eq. 5)

Inline comment marking the EMA update.

40self.ema_weights = self.beta * self.ema_weights + (1 - self.beta) * raw

Convex combination: 99% history + 1% new measurement. Mutates the persistent buffer in place via assignment.

EXECUTION STATE
self.beta * self.ema_weights = 0.99 · [0.5, 0.5] = [0.495, 0.495] (first active step starting from uniform init).
(1 - self.beta) * raw = 0.01 · [0.0008, 0.9992] = [8e-6, 0.0099924].
ema_weights (after) = [0.49500, 0.50492] — barely budged from uniform (slow EMA inertia).
41# Floor + renorm (paper eq. 6)

Inline comment marking the floor + renormalisation step.

42clamped = np.maximum(self.ema_weights, self.min_weight)

Element-wise floor at min_weight. For freshly-out-of-warmup ema_weights ~ [0.495, 0.505], the floor is a no-op (both above 0.05).

EXECUTION STATE
📚 np.maximum(a, b) = Element-wise max — floors values at b. Different from np.max which reduces.
clamped = ndarray (K,). Each element ≥ min_weight = 0.05.
43weights = clamped / clamped.sum()

Renormalise to the simplex. Sum = 1 by construction.

EXECUTION STATE
weights (active) = ndarray (K,). On the simplex; bounded in approximately [min_weight, 1−min_weight].
44return float(np.dot(weights, losses)), weights

Combine: weighted sum of losses → scalar. np.dot computes the weighted combination.

EXECUTION STATE
📚 np.dot(a, b) = 1-D · 1-D = inner product (scalar). Equivalent to (weights * losses).sum() but more idiomatic.
float(...) = Convert numpy scalar to Python float for clean printing.
⬆ return = Tuple (total_loss: float, weights: ndarray K). Caller multiplies by lr and steps the optimiser elsewhere.
48gaba = GABALossNumPy(beta=0.99, warmup_steps=2, min_weight=0.05, n_tasks=2)

Instantiate with warmup_steps=2 (small) so the smoke test exits warmup quickly. Paper warmup is 100; we use 2 here to demo the active branch within 5 steps.

EXECUTION STATE
warmup_steps=2 = Test-only override. Real training uses 100.
50trace = [...]

Hand-built test sequence: 5 (step, L_rul, L_hp, g_rul, g_health) tuples. Steps 1-2 are warmup (g=None); steps 3-5 are active with realistic gradient magnitudes.

EXECUTION STATE
step 1-2 = Warmup. grad_norms = None, weights = [0.5, 0.5].
step 3-5 = Active. 500x-imbalance gradients (250 vs 0.2 → 1250x) progressing slightly.
58for step, L_rul, L_hp, g_r, g_h in trace:

Iterate the trace. Tuple-unpack assigns five names per iteration.

LOOP TRACE · 5 iterations
step 1 (warmup)
weights = [0.500000, 0.500000]
total = 0.5*5000 + 0.5*1.10 = 2500.55
step 2 (warmup)
weights = [0.500000, 0.500000]
total = 0.5*4900 + 0.5*1.08 = 2450.54
step 3 (FIRST active)
g (raw) = [250.0, 0.20]
raw closed form = [0.000800, 0.999200]
ema (1st update) = 0.99·[0.5, 0.5] + 0.01·[0.000800, 0.999200] = [0.495008, 0.504992]
weights = [0.495008, 0.504992] (no clamping needed; ratio ~ 1.02x)
total = 0.495008·4800 + 0.504992·1.06 = 2376.57
step 4 (active)
g = [300.0, 0.22]
weights = [0.490065, 0.509935]
total = 2303.84
step 5 (active)
g = [350.0, 0.24]
weights = [0.485171, 0.514829]
total = 2232.31
59losses = [L_rul, L_hp]

Pack the per-task losses into a list for the step() call.

60grads = [g_r, g_h] if g_r is not None else None

Pack gradient norms or pass None. The latter signals warmup-only behaviour even if step_count > warmup_steps.

61total, w = gaba.step(losses, grads)

Apply one full GABA step. Mutates gaba.step_count, gaba.ema_weights, and gaba.last_grad_norms / last_raw_weights when active.

62print formatted row

f-string with width specs to align the table.

65print final ema_weights

After 5 steps, the EMA has barely budged from uniform (3 EMA updates × 1% absorption ≈ 3% drift toward the new target).

EXECUTION STATE
Output = (blank) final ema_weights = [0.48517144 0.51482856]
66print final step_count

Sanity check on the counter.

EXECUTION STATE
Final output =
step 1 | L_rul=5000.0 L_hp=1.10 | w=(0.500000, 0.500000) | total=2500.55
step 2 | L_rul=4900.0 L_hp=1.08 | w=(0.500000, 0.500000) | total=2450.54
step 3 | L_rul=4800.0 L_hp=1.06 | w=(0.495008, 0.504992) | total=2376.57
step 4 | L_rul=4700.0 L_hp=1.04 | w=(0.490065, 0.509935) | total=2303.84
step 5 | L_rul=4600.0 L_hp=1.02 | w=(0.485171, 0.514829) | total=2232.31

final ema_weights = [0.48517144 0.51482856]
final step_count  = 5
24 lines without explanation
1"""GABA from scratch in pure NumPy - pedagogical mirror of paper class."""
2
3import numpy as np
4
5
6class GABALossNumPy:
7    """NumPy version of grace/core/gaba.py:GABALoss for clarity (no autograd).
8
9    Same hyperparameters, same buffers, same per-step pipeline, same
10    bounded-weight guarantee. Drop-in replacement for the paper class on
11    any platform that doesn't have PyTorch.
12    """
13
14    def __init__(self, beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2):
15        self.beta = beta
16        self.warmup_steps = warmup_steps
17        self.min_weight = min_weight
18        self.n_tasks = n_tasks
19        # State (mirrors paper register_buffer)
20        self.ema_weights = np.full(n_tasks, 1.0 / n_tasks)
21        self.step_count = 0
22        # Logging buffers (mirrors paper _last_grad_norms, _last_raw_weights)
23        self.last_grad_norms = None
24        self.last_raw_weights = None
25
26    def step(self, losses, grad_norms):
27        """One full GABA step. losses and grad_norms are length-K lists."""
28        K = len(losses)
29        self.step_count += 1
30        if self.step_count <= self.warmup_steps:
31            # Warmup: uniform weights, no gradient measurement
32            weights = np.full(K, 1.0 / K)
33        else:
34            g = np.asarray(grad_norms, dtype=np.float64)
35            self.last_grad_norms = g.copy()
36            total_norm = g.sum() + 1e-12
37            # Closed form (paper eq. 4)
38            raw = (total_norm - g) / ((K - 1) * total_norm)
39            self.last_raw_weights = raw.copy()
40            # EMA (paper eq. 5)
41            self.ema_weights = self.beta * self.ema_weights + (1 - self.beta) * raw
42            # Floor + renorm (paper eq. 6)
43            clamped = np.maximum(self.ema_weights, self.min_weight)
44            weights = clamped / clamped.sum()
45        return float(np.dot(weights, losses)), weights
46
47
48# ---------- 5-step smoke test (warmup=2 so we exit early) ----------
49gaba = GABALossNumPy(beta=0.99, warmup_steps=2, min_weight=0.05, n_tasks=2)
50
51trace = [
52    (1, 5000.0, 1.10, None, None),
53    (2, 4900.0, 1.08, None, None),
54    (3, 4800.0, 1.06, 250.0, 0.20),
55    (4, 4700.0, 1.04, 300.0, 0.22),
56    (5, 4600.0, 1.02, 350.0, 0.24),
57]
58
59for step, L_rul, L_hp, g_r, g_h in trace:
60    losses = [L_rul, L_hp]
61    grads  = [g_r, g_h] if g_r is not None else None
62    total, w = gaba.step(losses, grads)
63    print(f"step {step} | L_rul={L_rul:.1f} L_hp={L_hp:.2f} | "
64          f"w=({w[0]:.6f}, {w[1]:.6f}) | total={total:.2f}")
65
66print(f"\nfinal ema_weights = {gaba.ema_weights}")
67print(f"final step_count  = {gaba.step_count}")

PyTorch: The Paper's GABALoss Verbatim

The actual paper code from paper_ieee_tii/grace/core/gaba.py. Every line annotated. The compute_task_grad_norm helper at the top is also paper code (from grace/core/gradient_utils.py); the remaining methods are the full GABALoss class as it appears in the public release.

grace/core/gaba.py — verbatim, every line
🐍gaba_loss_paper.py
1docstring

Module docstring. The class below is the actual paper code from grace/core/gaba.py — line-for-line copy. The compute_task_grad_norm helper is also paper code from grace/core/gradient_utils.py.

3from __future__ import annotations

Python forward-references for type hints. Lets us write Optional[torch.Tensor] without importing the whole module if we only need the type at type-check time.

EXECUTION STATE
📚 from __future__ import annotations = Python feature flag: defers type-hint evaluation. Enables string-style type hints without quoting.
4import torch

Core PyTorch.

EXECUTION STATE
📚 torch = Tensor library with autograd. Used for tensors, register_buffer, autograd.grad, .clamp, .detach.
5import torch.nn as nn

Module primitives.

EXECUTION STATE
📚 torch.nn = PyTorch nn package. Provides nn.Module base class.
6from typing import Dict, List, Optional

Type-hint aliases. Dict / List / Optional are standard library generics from PEP 484.

EXECUTION STATE
📚 typing = Python standard library for type hints. Dict[str, float], List[Tensor], Optional[X] = X | None.
9def compute_task_grad_norm(loss, shared_params, retain_graph=True):

Helper that returns ||grad(loss) on shared_params||_2. Same function used in §18.1; create_graph=False keeps it cheap.

EXECUTION STATE
⬇ input: loss = 0-dim tensor. Scalar to differentiate.
⬇ input: shared_params = List of nn.Parameter. Backbone parameters only (per get_shared_params from §18.1).
⬇ input: retain_graph = Bool. True so the graph survives for the next per-task gradient call AND the eventual combined-loss backward.
⬆ returns = 0-dim tensor. ||g||_2 on shared params.
10docstring

Records the function's purpose and references the paper file.

11grads = torch.autograd.grad(loss, shared_params, retain_graph=retain_graph, create_graph=False, allow_unused=True)

Functional autograd. Returns gradient tensors WITHOUT writing to .grad. The create_graph=False flag is the key cost-saving relative to GradNorm.

EXECUTION STATE
📚 torch.autograd.grad = Functional differentiation: ∂outputs/∂inputs as a tuple, no .grad mutation.
create_graph=False = Don't track gradient-of-gradient. ~1x memory.
allow_unused=True = Tolerate parameters disconnected from this loss.
13total = torch.tensor(0.0, device=loss.device)

Accumulator for the squared-norm sum. Built on the same device as the loss.

14for g in grads:

Iterate per-parameter gradient tensors and accumulate squared L2 norms. Skip None entries from allow_unused.

15if g is not None:

Filter out the None entries.

16total = total + g.pow(2).sum()

Out-of-place add to keep autograd happy. Sum of squared elements per parameter.

17return total.sqrt()

Final square root. ||g||_2 = sqrt(sum_p ||g_p||_2^2).

20class GABALoss(nn.Module):

The full paper class. nn.Module subclass for buffer persistence and checkpoint compatibility. ALL state (ema_weights, step_count) lives in registered buffers.

EXECUTION STATE
📚 nn.Module = Base class for stateful PyTorch components. Tracks parameters, buffers, submodules.
21docstring

Class docstring from the paper file.

23def __init__(self, beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2):

Constructor with all paper-canonical defaults. Match the NumPy class signature exactly.

EXECUTION STATE
⬇ beta = 0.99 = EMA coefficient (§18.2).
⬇ warmup_steps = 100 = Warmup duration (§18.4).
⬇ min_weight = 0.05 = Floor (§18.3).
⬇ n_tasks = 2 = K. RUL + health.
24super().__init__()

Initialise nn.Module base class. Required first line of every Module __init__.

25self.beta = beta

Plain Python attribute. NOT registered (it's a hyperparameter, not learnable state).

26self.warmup_steps = warmup_steps

Plain attribute.

27self.min_weight = min_weight

Plain attribute.

28self.n_tasks = n_tasks

Plain attribute.

29self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)

Persistent state: EMA-smoothed weights. Buffer means (a) tracked in state_dict, (b) moves with .to(device), (c) NOT updated by optimisers.

EXECUTION STATE
📚 register_buffer(name, tensor) = nn.Module method. Register a non-learnable persistent state tensor.
→ init = torch.ones(2) / 2 = [0.5, 0.5]. Uniform K-task weighting.
30self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))

Step counter buffer. dtype=torch.long because step counts are integers.

EXECUTION STATE
→ why a buffer? = Survives checkpoint save/load. Resuming training preserves the warmup-progress state.
31self._last_grad_norms: Optional[torch.Tensor] = None

Logging slot for the per-task gradient norms from the last active step. NOT registered as a buffer — it's a debugging aid, not state we want to checkpoint.

EXECUTION STATE
Optional[torch.Tensor] = Type hint: tensor or None. None during warmup; tensor (K,) during active steps.
32self._last_raw_weights: Optional[torch.Tensor] = None

Logging slot for the un-smoothed closed-form weights. Same idea.

34def forward(self, rul_loss, health_loss, shared_params=None, **kwargs):

Two-task convenience wrapper. Just packages the two losses into a list and calls forward_k. Backward-compatible with baselines that have a fixed (rul, health) signature.

EXECUTION STATE
⬇ input: rul_loss = 0-dim tensor. Scalar RUL regression loss.
⬇ input: health_loss = 0-dim tensor. Scalar health classification loss.
⬇ input: shared_params = Optional list of nn.Parameter. None ⇒ warmup-only behaviour.
⬇ **kwargs = Catch-all for extra trainer-provided kwargs (e.g. model). Ignored here but allowed for compatibility.
⬆ returns = 0-dim tensor. The combined weighted loss.
35return self.forward_k([rul_loss, health_loss], shared_params)

Delegate to the K-task workhorse. Wraps the two losses in a list and forwards.

EXECUTION STATE
[rul_loss, health_loss] = Python list of length 2. Order: rul (index 0), health (index 1).
37def forward_k(self, losses, shared_params=None):

The actual workhorse. Implements the full per-step pipeline: gate → closed form → EMA → floor → renorm → combine.

EXECUTION STATE
⬇ input: losses = List of K 0-dim tensors.
⬇ input: shared_params = Optional list of nn.Parameter. None ⇒ warmup.
⬆ returns = 0-dim tensor. The combined loss for backward().
38K = len(losses)

Number of tasks. For K=2 entry via forward(): K=2.

39device = losses[0].device

Read the device from the first loss so we build new tensors on the same device. Critical for multi-GPU.

EXECUTION STATE
📚 .device = Tensor attribute. The device the tensor lives on (cpu / cuda:0 / mps).
40self.step_count += 1

Increment the warmup counter. PyTorch supports in-place add on 0-dim long tensors.

41if shared_params is None or self.step_count.item() <= self.warmup_steps:

The warmup gate. Two triggers: (a) caller didn't pass shared_params, OR (b) we're inside the W=100 window.

EXECUTION STATE
📚 .item() = Tensor method. 0-dim tensor → Python scalar.
42weights = torch.ones(K, device=device) / K

Warmup branch: uniform 1/K weights on the right device.

EXECUTION STATE
📚 torch.ones(*size, device) = Build a tensor of all-ones with given shape and device.
44grad_norms = torch.zeros(K, device=device)

Pre-allocate the gradient-norm vector. We'll fill in entries one-by-one inside the for-loop.

45for i, loss_i in enumerate(losses):

Loop over the K tasks. enumerate gives (index, loss) pairs.

LOOP TRACE · 2 iterations
i = 0 (rul)
loss_i = rul_loss tensor.
grad_norms[0] = ≈ 76.5 (paper-realistic; depends on backbone state)
i = 1 (health)
loss_i = health_loss tensor.
grad_norms[1] = ≈ 0.37 (paper-realistic)
46grad_norms[i] = compute_task_grad_norm(loss_i, shared_params, retain_graph=True)

Measure the L2 norm of this task's gradient on the shared backbone. retain_graph=True so the autograd graph survives subsequent calls.

EXECUTION STATE
retain_graph=True = Required so the SAME forward pass can be differentiated K times (once per task) plus once more for the combined loss.
47self._last_grad_norms = grad_norms.detach().clone()

Snapshot for logging. .detach() strips autograd history; .clone() decouples from the working buffer.

EXECUTION STATE
📚 .detach() = Strip autograd history. Required so backward() can't accidentally walk back through this snapshot.
📚 .clone() = Copy the underlying data. Without it, future mutations of grad_norms would also change the snapshot.
48total_norm = grad_norms.sum() + 1e-12

Sum of gradient norms + numerical guard.

49raw_weights = (total_norm - grad_norms) / ((K - 1) * total_norm)

K-task GABA closed form (paper eq. 4). Element-wise.

EXECUTION STATE
raw_weights = Tensor (K,). Un-smoothed inverse-proportional weights.
50self._last_raw_weights = raw_weights.detach().clone()

Snapshot for logging.

51ema_w = self.ema_weights[:K].to(device)

Slice the EMA buffer to length K and move to the loss device. The slice supports K-flexible deployments where n_tasks could be larger than the K used in this call.

EXECUTION STATE
[:K] = Slice. For K=2 with n_tasks=2 this is a no-op; included for K-flexibility.
📚 .to(device) = Tensor method. Move to the given device. No-op if already there.
52ema_w = self.beta * ema_w + (1.0 - self.beta) * raw_weights

Paper eq. 5. Convex combination of history and the new measurement.

53self.ema_weights[:K] = ema_w.detach()

Save back to the buffer with .detach() to break the autograd graph. CRITICAL — without detach, memory grows linearly with step count.

EXECUTION STATE
→ most common GABA bug = Forgetting .detach() here. Memory leak ⇒ OOM after ~1,000 steps.
54weights = ema_w.clamp(min=self.min_weight)

Floor at min_weight (paper eq. 6 part 1). Element-wise.

EXECUTION STATE
📚 .clamp(min=v) = Tensor method. Element-wise floor: max(x, v) per element.
55weights = weights / weights.sum()

Renormalise to the simplex (paper eq. 6 part 2). Sum=1 by construction.

56total_loss = torch.tensor(0.0, device=device)

Accumulator for the weighted combined loss. 0-dim tensor on the loss device.

57for w, l in zip(weights, losses):

Iterate (weight, loss) pairs and accumulate. zip yields tuples.

LOOP TRACE · 2 iterations
iter 0: w=lambda_rul, l=rul_loss
total_loss after = lambda_rul * rul_loss
iter 1: w=lambda_health, l=health_loss
total_loss after = lambda_rul * rul_loss + lambda_health * health_loss
58total_loss = total_loss + w * l

Out-of-place add. Builds the weighted combined loss term-by-term.

59return total_loss

The final scalar loss. Caller calls .backward() on this.

EXECUTION STATE
⬆ return = 0-dim tensor. The combined loss with autograd history routed through both tasks AND the (now detached) EMA weights.
61def get_weights(self):

Inspection helper. Returns the current EMA weights as a Python dict for monitoring code (TensorBoard, W&B, etc.).

EXECUTION STATE
⬆ returns = Dict[str, float]. Keys depend on n_tasks (named for K=2, indexed for K&gt;2).
62w = self.ema_weights.detach().cpu()

Pull the buffer to CPU and detach for safe Python access.

EXECUTION STATE
📚 .cpu() = Tensor method. Move to CPU. No-op if already on CPU.
63if self.n_tasks == 2:

Special case for the K=2 RUL+health setup so the keys are human-readable.

64return {"rul_weight": w[0].item(), "health_weight": w[1].item()}

Dict with named keys for the two-task case.

EXECUTION STATE
📚 .item() = Tensor → Python scalar.
65return {f"task_{i}_weight": w[i].item() for i in range(self.n_tasks)}

Generic K-task dict comprehension. Keys: task_0_weight, task_1_weight, etc.

EXECUTION STATE
📚 dict comprehension = Python: {key_expr: value_expr for var in iterable}. Builds a dict in one expression.
67def get_gradient_stats(self):

Logging helper. Exposes the last_grad_norms and last_raw_weights snapshots as a flat dict for monitoring tools.

EXECUTION STATE
⬆ returns = Dict[str, float]. Empty during warmup, fully populated during active steps.
68stats = {}

Empty dict to accumulate keys.

69if self._last_grad_norms is not None:

Skip the gradient-norm keys if we haven't run an active step yet (i.e. still in warmup).

70n = self._last_grad_norms.cpu()

Pull to CPU for Python access.

71stats["grad_norm_rul"] = n[0].item()

Logged g_rul from the last active step.

72stats["grad_norm_health"] = n[1].item()

Logged g_health.

73stats["grad_ratio_rul_over_health"] = n[0].item() / (n[1].item() + 1e-12)

The 500x-imbalance figure from §12.3, computed live every step. The most useful single number for GABA monitoring.

EXECUTION STATE
→ why log this = On C-MAPSS this should hover near 500 once training stabilises. A drift to 5,000 or 50 indicates a data or backbone problem.
74if self._last_raw_weights is not None:

Same gate for the raw-weight keys.

75r = self._last_raw_weights.cpu()

Pull to CPU.

76stats["raw_weight_rul"] = r[0].item()

Un-smoothed lambda_rul (before EMA + floor). Useful to see how aggressive the closed form was BEFORE the stabilisers softened it.

77stats["raw_weight_health"] = r[1].item()

Un-smoothed lambda_health.

78return stats

Final dict. The trainer logs this every N steps to W&B / TensorBoard.

EXECUTION STATE
→ typical training-step output = {'grad_norm_rul': 76.5335, 'grad_norm_health': 0.3665, 'grad_ratio_rul_over_health': 208.8, 'raw_weight_rul': 0.004766, 'raw_weight_health': 0.995234}
12 lines without explanation
1"""Paper code: grace/core/gaba.py:GABALoss verbatim."""
2
3from __future__ import annotations
4import torch
5import torch.nn as nn
6from typing import Dict, List, Optional
7
8
9def compute_task_grad_norm(loss, shared_params, retain_graph=True):
10    """L2 norm of grad(loss) on shared_params (paper grace/core/gradient_utils.py)."""
11    grads = torch.autograd.grad(loss, shared_params, retain_graph=retain_graph,
12                                 create_graph=False, allow_unused=True)
13    total = torch.tensor(0.0, device=loss.device)
14    for g in grads:
15        if g is not None:
16            total = total + g.pow(2).sum()
17    return total.sqrt()
18
19
20class GABALoss(nn.Module):
21    """Gradient-Aware Balanced Adaptation loss for multi-task learning."""
22
23    def __init__(self, beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2):
24        super().__init__()
25        self.beta = beta
26        self.warmup_steps = warmup_steps
27        self.min_weight = min_weight
28        self.n_tasks = n_tasks
29        self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)
30        self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))
31        self._last_grad_norms: Optional[torch.Tensor] = None
32        self._last_raw_weights: Optional[torch.Tensor] = None
33
34    def forward(self, rul_loss, health_loss, shared_params=None, **kwargs):
35        return self.forward_k([rul_loss, health_loss], shared_params)
36
37    def forward_k(self, losses, shared_params=None):
38        K = len(losses)
39        device = losses[0].device
40        self.step_count += 1
41        if shared_params is None or self.step_count.item() <= self.warmup_steps:
42            weights = torch.ones(K, device=device) / K
43        else:
44            grad_norms = torch.zeros(K, device=device)
45            for i, loss_i in enumerate(losses):
46                grad_norms[i] = compute_task_grad_norm(loss_i, shared_params, retain_graph=True)
47            self._last_grad_norms = grad_norms.detach().clone()
48            total_norm = grad_norms.sum() + 1e-12
49            raw_weights = (total_norm - grad_norms) / ((K - 1) * total_norm)
50            self._last_raw_weights = raw_weights.detach().clone()
51            ema_w = self.ema_weights[:K].to(device)
52            ema_w = self.beta * ema_w + (1.0 - self.beta) * raw_weights
53            self.ema_weights[:K] = ema_w.detach()
54            weights = ema_w.clamp(min=self.min_weight)
55            weights = weights / weights.sum()
56        total_loss = torch.tensor(0.0, device=device)
57        for w, l in zip(weights, losses):
58            total_loss = total_loss + w * l
59        return total_loss
60
61    def get_weights(self):
62        w = self.ema_weights.detach().cpu()
63        if self.n_tasks == 2:
64            return {"rul_weight": w[0].item(), "health_weight": w[1].item()}
65        return {f"task_{i}_weight": w[i].item() for i in range(self.n_tasks)}
66
67    def get_gradient_stats(self):
68        stats = {}
69        if self._last_grad_norms is not None:
70            n = self._last_grad_norms.cpu()
71            stats["grad_norm_rul"]    = n[0].item()
72            stats["grad_norm_health"] = n[1].item()
73            stats["grad_ratio_rul_over_health"] = n[0].item() / (n[1].item() + 1e-12)
74        if self._last_raw_weights is not None:
75            r = self._last_raw_weights.cpu()
76            stats["raw_weight_rul"]    = r[0].item()
77            stats["raw_weight_health"] = r[1].item()
78        return stats

Wiring Into A Real Trainer

The paper's training loop calls GABALoss like any other multi-task loss. Sketch:

  • Setup (once). Build the model, build shared = get_shared_params(model), build gaba = GABALoss(beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2), build the optimiser over model.parameters() (NOT over gaba.parameters() — GABA has no learnable parameters). Move all to device with .to(device).
  • Per step. Forward through the model ONCE; compute per-task losses; call total = gaba(rul_loss, health_loss, shared_params=shared); opt.zero_grad(); total.backward(); opt.step().
  • Logging (every N steps). Read gaba.get_weights() and gaba.get_gradient_stats() and forward both dicts to your tracking system.
  • Checkpoint save / load. state_dict captures both the model and gaba.state_dict() (which contains ema_weights and step_count). On resume, load_state_dict restores both.
Empirical scale. On the paper's actual 3.5M-parameter CNN-BiLSTM-Attention backbone, GABALoss.forward_k adds about 5–10 ms per training step (two extra autograd.grad calls plus a few tensor ops). Total wall-clock training time on FD002 is unchanged within noise — the compute overhead is dwarfed by the per-batch forward + backward.

The Pattern In Other Multi-Task Pipelines

The same four-mechanism architecture appears in many adaptive controllers, just with different names and different formulas at each stage:

SystemMeasureComputeStabiliseOutput
Predictive maintenance (this paper)Per-task gradient norms ‖g_i‖Inverse-proportional weights λ_i = g_j / ΣgEMA β=0.99, floor 0.05, warmup 100 stepsCombined multi-task loss
Adam optimiser (Kingma & Ba 2015)Per-parameter gradient g_tFirst / second moment m_t, v_tEMA β₁=0.9, β₂=0.999, bias correctionPer-parameter step direction & scale
BatchNorm (Ioffe & Szegedy 2015)Per-channel batch mean / varianceNormalise activationsEMA running stats; warmup uses batch statsNormalised activations
Self-supervised target network (BYOL, He et al.)Online network parameters θTarget = EMA(online)EMA β=0.99-0.9999Target encoder for contrastive loss
RL Q-target (DQN, Mnih et al. 2015)Online Q-networkTarget = soft-copy of onlinePolyak averaging τ=0.995Bootstrapped target value
Federated learning (FedAvg + secure agg.)Client gradient updates Δ_iInverse-norm weights w_i ∝ 1/‖Δ_i‖Server-side outlier filteringAggregated global update
Model predictive control (industrial)Plant output y_tPID error → control u_tAnti-windup integrator clampActuator command

The recipe ‘measure → compute → stabilise → output’ is universal. GABA's contribution is not the recipe itself but the specific measurement (per-task gradient norm on shared parameters), the specific computation (inverse-proportional closed form), and the specific stabilisers (β=0.99 EMA, λ_min=0.05 floor, W=100 warmup) that empirically work for the 500×-imbalance regime characterised in §12.3.

Pitfalls In Wiring The Full Module

Pitfall 1: Adding gaba.parameters() to the optimiser. GABA has no nn.Parameter objects (only buffers). If you accidentally include them in the optimiser's parameter list, the optimiser works on an empty group — no harm but also no help. BUT if you write optimizer = AdamW(list(model.parameters()) + list(gaba.parameters())) and a future PyTorch version reclassifies a buffer as a parameter, the optimiser would silently start updating ema_weights as if it were a learnable weight. Always pass model.parameters() only.
Pitfall 2: Forgetting retain_graph=True in compute_task_grad_norm. First gradient norm computes; second one crashes with RuntimeError: Trying to backward through the graph a second time. The paper helper hard-codes retain_graph=True for exactly this reason.
Pitfall 3: Calling backward TWICE. Some users instinctively call rul_loss.backward(retain_graph=True) and health_loss.backward(retain_graph=True) themselves to populate .grad, then also call total.backward(). This DOUBLE-COUNTS the per-task gradients into p.grad. Use the functional autograd.grad (which the paper helper does) and then call backward ONLY on the combined total.
Pitfall 4: Wrong shared_params list. If you pass list(model.parameters()) instead of get_shared_params(model), the head parameters contaminate the gradient norms (§18.1 viz). The paper's 500× imbalance becomes a ~700× imbalance with different downstream λ\lambda^* and slightly different convergence behaviour. Always filter to backbone-only.
Pitfall 5: Skipping shared_params on validation. During validation we do with torch.no_grad():; gradients are unavailable. If you still call gaba(rul_loss, health_loss) without passing shared_params, the gate falls through to the warmup branch and returns 1/K1/K uniform weights — which is the right behaviour for validation logging. The paper code defaults shared_params=None precisely so this fallback is automatic.
The full integration test. If you instantiate GABALoss, run 250 training steps on a random 14→32→16 backbone with the paper's defaults, you should see (verified on the 0-seed run): steps 1–100 hold λrul=0.5\lambda^*_{\text{rul}} = 0.5; step 101 first reads gradient norms g(449,0.22)g \approx (449, 0.22); step 150 λrul0.30\lambda^*_{\text{rul}} \approx 0.30; step 250 λrul0.11\lambda^*_{\text{rul}} \approx 0.11 and still settling toward the floor-bound regime. If your numbers don't match, you've missed one of the four mechanisms.

Takeaway

  • GABALoss is 80 lines. Four hyperparameters, two buffers, two methods that do real work (forward, forward_k), two introspection helpers (get_weights, get_gradient_stats).
  • The forward_k method runs the full pipeline. Increment counter → gate → measure gradient norms → closed form → EMA → floor → renorm → combine. Mirrors paper Algorithm 1 line by line.
  • State lives in registered buffers. ema_weights and step_count survive .to(device) and checkpoint save/load. Logging snapshots live in unregistered slots so they don't inflate checkpoints.
  • Autograd hygiene is critical. retain_graph=True on every per-task grad call; create_graph=False to keep memory bounded; .detach() on the EMA write-back to prevent cross-step accumulation.
  • The K=2 wrapper exists for trainer compatibility. All baselines (Fixed, DWA, GradNorm, Uncertainty, PCGrad, CAGrad) share the (rul_loss, health_loss, shared_params=None, **kwargs) signature so the trainer can swap loss classes without changing the call site.
  • The pattern generalises. Adam moments, BatchNorm running stats, BYOL targets, Polyak Q-target averaging, Federated Averaging, PID anti-windup — all are instances of ‘measure → compute → stabilise → output’. GABA is just the gradient-balancing instance with the closed-form inverse rule.
  • Chapter 19 next. §19 reframes the entire pipeline as a closed-loop control system — GABA as a proportional controller, EMA as a first-order IIR low-pass filter, floor as anti-windup — and formally proves the bounded-weight guarantee that GradNorm cannot match (the property paper main.tex:387 calls ‘absent from loss-based approaches’).
Loading comments...