Sections 18.1 through 18.4 dissected the four mechanisms of GABA in isolation: gradient-norm measurement, EMA smoothing, the minimum floor, and the warmup gate. Each was a clean idea with a clean implementation. This section assembles them into the single class the paper actually uses in production: grace/core/gaba.py:GABALoss.
The integration is non-trivial in three ways. First, all four mechanisms share state (the EMA buffer, the step counter, the logging snapshots) that has to live in registered nn.Module buffers so it survives device moves and checkpoint round-trips. Second, the autograd flow has to thread through both the per-task gradient measurements (which use torch.autograd.grad) and the eventual weighted-combination backward (which uses loss.backward()) without conflicts. Third, the public API needs an inspection surface so monitoring code can pull live gi,λ^i,λi∗ values without breaking encapsulation.
The full class is 80 lines. 80 lines of Python implements an algorithm that drops paper FD002 NASA from 498 (DKAMFormer) to 224 — a 55% operational-safety improvement. Two ideas (closed form + EMA) plus two stabilisers (floor + warmup) plus careful autograd hygiene plus four logging helpers. That is the complete recipe.
The Class Anatomy
Eight methods, two registered buffers, two logging slots, four hyperparameters:
Component
Type
Purpose
Defined in §
beta
Python float
EMA smoothing coefficient. Paper: 0.99.
§18.2
warmup_steps
Python int
Steps with uniform 1/K weighting before adaptive logic kicks in. Paper: 100.
§18.4
min_weight
Python float
Floor on per-task weight. Paper: 0.05.
§18.3
n_tasks
Python int
K. RUL + health for this paper.
§17 (model anatomy)
ema_weights
Registered buffer (K,)
EMA-smoothed task weights. Persistent across steps and checkpoints.
§18.2
step_count
Registered buffer (long)
Step counter. Used by warmup gate. Persistent.
§18.4
_last_grad_norms
Optional Tensor
Logging snapshot of per-task gradient norms.
§18.5 (this section)
_last_raw_weights
Optional Tensor
Logging snapshot of pre-EMA raw weights.
§18.5
forward(rul_loss, health_loss, shared_params)
Method
K=2 convenience wrapper.
—
forward_k(losses, shared_params)
Method (workhorse)
Full pipeline: gate → grad norms → closed form → EMA → floor → renorm → combine.
this section
get_weights()
Method
Inspection: return current EMA weights as dict.
this section
get_gradient_stats()
Method
Inspection: return last gradient norms and raw weights as dict.
this section
__init__: Hyperparameters and Buffers
Construction stores four hyperparameters as plain Python attributes (they are not learnable) and registers two buffers via nn.Module.register_buffer:
ema_weights initialised to 1K/K — uniform 1/K. Paper Algorithm 1 line 2.
step_count initialised to 0 as a 0-dim long tensor.
The two logging slots _last_grad_norms and _last_raw_weights are NOT registered as buffers. They are debugging snapshots, not state we want to checkpoint. Saving them in state_dict would inflate checkpoints unnecessarily and re-introduce a fictitious dependency between training resumption and the most recent gradient measurement.
forward(): K=2 Convenience Wrapper
The K=2 convenience method is three lines. It packages Lrul and Lhealth into a list and delegates to the K-task workhorse:
The wrapper exists for backward compatibility with the paper's baseline trainer signature, which expects a (rul_loss, health_loss, shared_params=None, **kwargs) contract. The **kwargs catches stray trainer-provided kwargs (e.g. model for GradNorm baselines) so all loss classes can share the same call site.
forward_k(): The Workhorse
The general K-task method runs the full pipeline. Pseudocode:
Increment step_count.
Warmup gate: if t≤W OR shared_params is None, return 1/K uniform weights and skip the rest.
Per-task gradient norms: loop K times, each call to compute_task_grad_norm uses torch.autograd.grad with retain_graph=True.
Closed form (§17.3 + §18.1): vector form λi=(S−gi)/((K−1)S).
EMA update (§18.2): λ^←βλ^+(1−β)λ — saved back to the buffer with .detach().
Floor + renorm (§18.3): clamp(min=...) / .sum().
Combine: weighted sum of the K losses → scalar.
Return the scalar; caller calls .backward() and optimiser.step().
The detach is the trickiest line.self.ema_weights[:K] = ema_w.detach() on line 53 of the paper code. Without .detach(), the autograd graph through the EMA buffer accumulates ACROSS STEPS. After 1,000 steps, backward() walks all the way back to step 0 and exhausts GPU memory. This is the most-cited GABA implementation bug.
Logging Buffers and Inspection Helpers
Two methods expose internal state for monitoring without breaking encapsulation:
get_weights() returns the current EMA weights as a Dict[str, float]. For K=2 the keys are rul_weight and health_weight; for general K they are task_0_weight, etc.
get_gradient_stats() returns the last gradient norms, the gradient ratio (most useful single number), and the un-smoothed raw weights. Empty dict during warmup; populated keys during active steps.
Both helpers are read-only and call .detach().cpu() on every tensor before exposing it, so the trainer can log to W&B / TensorBoard without risk of breaking the autograd graph.
Interactive: Live GABA Dashboard
Four panels share the same x-axis (training step). The amber band is warmup. Drag the red ‘current step’ slider to see all four pipeline state variables at the same moment in training. The third panel (smoothed task weights) is the headline output; the first two show the inputs the closed form consumes; the fourth shows the resulting combined loss.
Loading GABA dashboard…
Try this. Set W = 0 and watch panel 3 immediately drop from 0.5 to a noisy trajectory. The first 100 steps now run on cold-start gradients — observe how panel 2 (gradient ratio) is wildly oscillating during this period, justifying §18.4's warmup gate. Restore W = 100 and watch panel 3 stay flat through warmup, then begin its smooth descent toward the steady-state weight after step 100.
Python: Full Class From Scratch (NumPy)
A pedagogical NumPy mirror of the paper class. Same hyperparameters, same buffers, same per-step pipeline, same bounded-weight guarantee — without autograd, so the structure is fully visible. A 5-step smoke test with warmup_steps=2 exercises both branches of the gate.
Full GABA pipeline in pure NumPy
🐍gaba_loss_numpy.py
Explanation(43)
Code(67)
1docstring
Module docstring. The class below mirrors the paper's grace/core/gaba.py:GABALoss line for line, but in pure NumPy so the structure is visible without PyTorch.
3import numpy as np
NumPy supplies the ndarray and all array math used here.
EXECUTION STATE
📚 numpy = Numerical computing library. Used for ndarray, np.full, np.maximum, np.dot, np.asarray.
6class GABALossNumPy:
Plain Python class (no nn.Module subclass needed since we don't have autograd). Same four hyperparameters, same two state variables, same two logging variables as paper.
7docstring
Records the relationship to the paper class. Drop-in NumPy version for understanding.
Instantiate with warmup_steps=2 (small) so the smoke test exits warmup quickly. Paper warmup is 100; we use 2 here to demo the active branch within 5 steps.
EXECUTION STATE
warmup_steps=2 = Test-only override. Real training uses 100.
50trace = [...]
Hand-built test sequence: 5 (step, L_rul, L_hp, g_rul, g_health) tuples. Steps 1-2 are warmup (g=None); steps 3-5 are active with realistic gradient magnitudes.
1"""GABA from scratch in pure NumPy - pedagogical mirror of paper class."""23import numpy as np
456classGABALossNumPy:7"""NumPy version of grace/core/gaba.py:GABALoss for clarity (no autograd).
89 Same hyperparameters, same buffers, same per-step pipeline, same
10 bounded-weight guarantee. Drop-in replacement for the paper class on
11 any platform that doesn't have PyTorch.
12 """1314def__init__(self, beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2):15 self.beta = beta
16 self.warmup_steps = warmup_steps
17 self.min_weight = min_weight
18 self.n_tasks = n_tasks
19# State (mirrors paper register_buffer)20 self.ema_weights = np.full(n_tasks,1.0/ n_tasks)21 self.step_count =022# Logging buffers (mirrors paper _last_grad_norms, _last_raw_weights)23 self.last_grad_norms =None24 self.last_raw_weights =None2526defstep(self, losses, grad_norms):27"""One full GABA step. losses and grad_norms are length-K lists."""28 K =len(losses)29 self.step_count +=130if self.step_count <= self.warmup_steps:31# Warmup: uniform weights, no gradient measurement32 weights = np.full(K,1.0/ K)33else:34 g = np.asarray(grad_norms, dtype=np.float64)35 self.last_grad_norms = g.copy()36 total_norm = g.sum()+1e-1237# Closed form (paper eq. 4)38 raw =(total_norm - g)/((K -1)* total_norm)39 self.last_raw_weights = raw.copy()40# EMA (paper eq. 5)41 self.ema_weights = self.beta * self.ema_weights +(1- self.beta)* raw
42# Floor + renorm (paper eq. 6)43 clamped = np.maximum(self.ema_weights, self.min_weight)44 weights = clamped / clamped.sum()45returnfloat(np.dot(weights, losses)), weights
464748# ---------- 5-step smoke test (warmup=2 so we exit early) ----------49gaba = GABALossNumPy(beta=0.99, warmup_steps=2, min_weight=0.05, n_tasks=2)5051trace =[52(1,5000.0,1.10,None,None),53(2,4900.0,1.08,None,None),54(3,4800.0,1.06,250.0,0.20),55(4,4700.0,1.04,300.0,0.22),56(5,4600.0,1.02,350.0,0.24),57]5859for step, L_rul, L_hp, g_r, g_h in trace:60 losses =[L_rul, L_hp]61 grads =[g_r, g_h]if g_r isnotNoneelseNone62 total, w = gaba.step(losses, grads)63print(f"step {step} | L_rul={L_rul:.1f} L_hp={L_hp:.2f} | "64f"w=({w[0]:.6f}, {w[1]:.6f}) | total={total:.2f}")6566print(f"\nfinal ema_weights = {gaba.ema_weights}")67print(f"final step_count = {gaba.step_count}")
PyTorch: The Paper's GABALoss Verbatim
The actual paper code from paper_ieee_tii/grace/core/gaba.py. Every line annotated. The compute_task_grad_norm helper at the top is also paper code (from grace/core/gradient_utils.py); the remaining methods are the full GABALoss class as it appears in the public release.
grace/core/gaba.py — verbatim, every line
🐍gaba_loss_paper.py
Explanation(66)
Code(78)
1docstring
Module docstring. The class below is the actual paper code from grace/core/gaba.py — line-for-line copy. The compute_task_grad_norm helper is also paper code from grace/core/gradient_utils.py.
3from __future__ import annotations
Python forward-references for type hints. Lets us write Optional[torch.Tensor] without importing the whole module if we only need the type at type-check time.
EXECUTION STATE
📚 from __future__ import annotations = Python feature flag: defers type-hint evaluation. Enables string-style type hints without quoting.
4import torch
Core PyTorch.
EXECUTION STATE
📚 torch = Tensor library with autograd. Used for tensors, register_buffer, autograd.grad, .clamp, .detach.
5import torch.nn as nn
Module primitives.
EXECUTION STATE
📚 torch.nn = PyTorch nn package. Provides nn.Module base class.
6from typing import Dict, List, Optional
Type-hint aliases. Dict / List / Optional are standard library generics from PEP 484.
EXECUTION STATE
📚 typing = Python standard library for type hints. Dict[str, float], List[Tensor], Optional[X] = X | None.
allow_unused=True = Tolerate parameters disconnected from this loss.
13total = torch.tensor(0.0, device=loss.device)
Accumulator for the squared-norm sum. Built on the same device as the loss.
14for g in grads:
Iterate per-parameter gradient tensors and accumulate squared L2 norms. Skip None entries from allow_unused.
15if g is not None:
Filter out the None entries.
16total = total + g.pow(2).sum()
Out-of-place add to keep autograd happy. Sum of squared elements per parameter.
17return total.sqrt()
Final square root. ||g||_2 = sqrt(sum_p ||g_p||_2^2).
20class GABALoss(nn.Module):
The full paper class. nn.Module subclass for buffer persistence and checkpoint compatibility. ALL state (ema_weights, step_count) lives in registered buffers.
EXECUTION STATE
📚 nn.Module = Base class for stateful PyTorch components. Tracks parameters, buffers, submodules.
Logging slot for the per-task gradient norms from the last active step. NOT registered as a buffer — it's a debugging aid, not state we want to checkpoint.
EXECUTION STATE
Optional[torch.Tensor] = Type hint: tensor or None. None during warmup; tensor (K,) during active steps.
Two-task convenience wrapper. Just packages the two losses into a list and calls forward_k. Backward-compatible with baselines that have a fixed (rul, health) signature.
Slice the EMA buffer to length K and move to the loss device. The slice supports K-flexible deployments where n_tasks could be larger than the K used in this call.
EXECUTION STATE
[:K] = Slice. For K=2 with n_tasks=2 this is a no-op; included for K-flexibility.
📚 .to(device) = Tensor method. Move to the given device. No-op if already there.
1"""Paper code: grace/core/gaba.py:GABALoss verbatim."""23from __future__ import annotations
4import torch
5import torch.nn as nn
6from typing import Dict, List, Optional
789defcompute_task_grad_norm(loss, shared_params, retain_graph=True):10"""L2 norm of grad(loss) on shared_params (paper grace/core/gradient_utils.py)."""11 grads = torch.autograd.grad(loss, shared_params, retain_graph=retain_graph,12 create_graph=False, allow_unused=True)13 total = torch.tensor(0.0, device=loss.device)14for g in grads:15if g isnotNone:16 total = total + g.pow(2).sum()17return total.sqrt()181920classGABALoss(nn.Module):21"""Gradient-Aware Balanced Adaptation loss for multi-task learning."""2223def__init__(self, beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2):24super().__init__()25 self.beta = beta
26 self.warmup_steps = warmup_steps
27 self.min_weight = min_weight
28 self.n_tasks = n_tasks
29 self.register_buffer("ema_weights", torch.ones(n_tasks)/ n_tasks)30 self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))31 self._last_grad_norms: Optional[torch.Tensor]=None32 self._last_raw_weights: Optional[torch.Tensor]=None3334defforward(self, rul_loss, health_loss, shared_params=None,**kwargs):35return self.forward_k([rul_loss, health_loss], shared_params)3637defforward_k(self, losses, shared_params=None):38 K =len(losses)39 device = losses[0].device
40 self.step_count +=141if shared_params isNoneor self.step_count.item()<= self.warmup_steps:42 weights = torch.ones(K, device=device)/ K
43else:44 grad_norms = torch.zeros(K, device=device)45for i, loss_i inenumerate(losses):46 grad_norms[i]= compute_task_grad_norm(loss_i, shared_params, retain_graph=True)47 self._last_grad_norms = grad_norms.detach().clone()48 total_norm = grad_norms.sum()+1e-1249 raw_weights =(total_norm - grad_norms)/((K -1)* total_norm)50 self._last_raw_weights = raw_weights.detach().clone()51 ema_w = self.ema_weights[:K].to(device)52 ema_w = self.beta * ema_w +(1.0- self.beta)* raw_weights
53 self.ema_weights[:K]= ema_w.detach()54 weights = ema_w.clamp(min=self.min_weight)55 weights = weights / weights.sum()56 total_loss = torch.tensor(0.0, device=device)57for w, l inzip(weights, losses):58 total_loss = total_loss + w * l
59return total_loss
6061defget_weights(self):62 w = self.ema_weights.detach().cpu()63if self.n_tasks ==2:64return{"rul_weight": w[0].item(),"health_weight": w[1].item()}65return{f"task_{i}_weight": w[i].item()for i inrange(self.n_tasks)}6667defget_gradient_stats(self):68 stats ={}69if self._last_grad_norms isnotNone:70 n = self._last_grad_norms.cpu()71 stats["grad_norm_rul"]= n[0].item()72 stats["grad_norm_health"]= n[1].item()73 stats["grad_ratio_rul_over_health"]= n[0].item()/(n[1].item()+1e-12)74if self._last_raw_weights isnotNone:75 r = self._last_raw_weights.cpu()76 stats["raw_weight_rul"]= r[0].item()77 stats["raw_weight_health"]= r[1].item()78return stats
Wiring Into A Real Trainer
The paper's training loop calls GABALoss like any other multi-task loss. Sketch:
Setup (once). Build the model, build shared = get_shared_params(model), build gaba = GABALoss(beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2), build the optimiser over model.parameters() (NOT over gaba.parameters() — GABA has no learnable parameters). Move all to device with .to(device).
Per step. Forward through the model ONCE; compute per-task losses; call total = gaba(rul_loss, health_loss, shared_params=shared); opt.zero_grad(); total.backward(); opt.step().
Logging (every N steps). Read gaba.get_weights() and gaba.get_gradient_stats() and forward both dicts to your tracking system.
Checkpoint save / load.state_dict captures both the model and gaba.state_dict() (which contains ema_weights and step_count). On resume, load_state_dict restores both.
Empirical scale. On the paper's actual 3.5M-parameter CNN-BiLSTM-Attention backbone, GABALoss.forward_k adds about 5–10 ms per training step (two extra autograd.grad calls plus a few tensor ops). Total wall-clock training time on FD002 is unchanged within noise — the compute overhead is dwarfed by the per-batch forward + backward.
The Pattern In Other Multi-Task Pipelines
The same four-mechanism architecture appears in many adaptive controllers, just with different names and different formulas at each stage:
System
Measure
Compute
Stabilise
Output
Predictive maintenance (this paper)
Per-task gradient norms ‖g_i‖
Inverse-proportional weights λ_i = g_j / Σg
EMA β=0.99, floor 0.05, warmup 100 steps
Combined multi-task loss
Adam optimiser (Kingma & Ba 2015)
Per-parameter gradient g_t
First / second moment m_t, v_t
EMA β₁=0.9, β₂=0.999, bias correction
Per-parameter step direction & scale
BatchNorm (Ioffe & Szegedy 2015)
Per-channel batch mean / variance
Normalise activations
EMA running stats; warmup uses batch stats
Normalised activations
Self-supervised target network (BYOL, He et al.)
Online network parameters θ
Target = EMA(online)
EMA β=0.99-0.9999
Target encoder for contrastive loss
RL Q-target (DQN, Mnih et al. 2015)
Online Q-network
Target = soft-copy of online
Polyak averaging τ=0.995
Bootstrapped target value
Federated learning (FedAvg + secure agg.)
Client gradient updates Δ_i
Inverse-norm weights w_i ∝ 1/‖Δ_i‖
Server-side outlier filtering
Aggregated global update
Model predictive control (industrial)
Plant output y_t
PID error → control u_t
Anti-windup integrator clamp
Actuator command
The recipe ‘measure → compute → stabilise → output’ is universal. GABA's contribution is not the recipe itself but the specific measurement (per-task gradient norm on shared parameters), the specific computation (inverse-proportional closed form), and the specific stabilisers (β=0.99 EMA, λ_min=0.05 floor, W=100 warmup) that empirically work for the 500×-imbalance regime characterised in §12.3.
Pitfalls In Wiring The Full Module
Pitfall 1: Adding gaba.parameters() to the optimiser. GABA has no nn.Parameter objects (only buffers). If you accidentally include them in the optimiser's parameter list, the optimiser works on an empty group — no harm but also no help. BUT if you write optimizer = AdamW(list(model.parameters()) + list(gaba.parameters())) and a future PyTorch version reclassifies a buffer as a parameter, the optimiser would silently start updating ema_weights as if it were a learnable weight. Always pass model.parameters() only.
Pitfall 2: Forgetting retain_graph=True in compute_task_grad_norm. First gradient norm computes; second one crashes with RuntimeError: Trying to backward through the graph a second time. The paper helper hard-codes retain_graph=True for exactly this reason.
Pitfall 3: Calling backward TWICE. Some users instinctively call rul_loss.backward(retain_graph=True) and health_loss.backward(retain_graph=True) themselves to populate .grad, then also call total.backward(). This DOUBLE-COUNTS the per-task gradients into p.grad. Use the functional autograd.grad (which the paper helper does) and then call backward ONLY on the combined total.
Pitfall 4: Wrong shared_params list. If you pass list(model.parameters()) instead of get_shared_params(model), the head parameters contaminate the gradient norms (§18.1 viz). The paper's 500× imbalance becomes a ~700× imbalance with different downstream λ∗ and slightly different convergence behaviour. Always filter to backbone-only.
Pitfall 5: Skipping shared_params on validation. During validation we do with torch.no_grad():; gradients are unavailable. If you still call gaba(rul_loss, health_loss) without passing shared_params, the gate falls through to the warmup branch and returns 1/K uniform weights — which is the right behaviour for validation logging. The paper code defaults shared_params=None precisely so this fallback is automatic.
The full integration test. If you instantiate GABALoss, run 250 training steps on a random 14→32→16 backbone with the paper's defaults, you should see (verified on the 0-seed run): steps 1–100 hold λrul∗=0.5; step 101 first reads gradient norms g≈(449,0.22); step 150 λrul∗≈0.30; step 250 λrul∗≈0.11 and still settling toward the floor-bound regime. If your numbers don't match, you've missed one of the four mechanisms.
Takeaway
GABALoss is 80 lines. Four hyperparameters, two buffers, two methods that do real work (forward, forward_k), two introspection helpers (get_weights, get_gradient_stats).
The forward_k method runs the full pipeline. Increment counter → gate → measure gradient norms → closed form → EMA → floor → renorm → combine. Mirrors paper Algorithm 1 line by line.
State lives in registered buffers. ema_weights and step_count survive .to(device) and checkpoint save/load. Logging snapshots live in unregistered slots so they don't inflate checkpoints.
Autograd hygiene is critical. retain_graph=True on every per-task grad call; create_graph=False to keep memory bounded; .detach() on the EMA write-back to prevent cross-step accumulation.
The K=2 wrapper exists for trainer compatibility. All baselines (Fixed, DWA, GradNorm, Uncertainty, PCGrad, CAGrad) share the (rul_loss, health_loss, shared_params=None, **kwargs) signature so the trainer can swap loss classes without changing the call site.
The pattern generalises. Adam moments, BatchNorm running stats, BYOL targets, Polyak Q-target averaging, Federated Averaging, PID anti-windup — all are instances of ‘measure → compute → stabilise → output’. GABA is just the gradient-balancing instance with the closed-form inverse rule.
Chapter 19 next. §19 reframes the entire pipeline as a closed-loop control system — GABA as a proportional controller, EMA as a first-order IIR low-pass filter, floor as anti-windup — and formally proves the bounded-weight guarantee that GradNorm cannot match (the property paper main.tex:387 calls ‘absent from loss-based approaches’).