Chapter 18
12 min read
Section 75 of 121

Warmup (First 100 Steps)

The GABA Algorithm

The Cold-Start Problem

On a cold morning your car's engine doesn't go straight to peak efficiency. It runs richer for a minute or two, the catalytic converter waits, the emissions-control loop holds back. The on-board computer knows that running the full closed-loop fuel-injection algorithm against COLD sensors produces nonsense, so it gates the controller off until the engine reaches a sensible operating point.

GABA has the same cold-start problem. The closed form λi=gj/(gi+gj)\lambda^*_i = g_j / (g_i + g_j) assumes gig_i reflects the steady-state task structure of the loss landscape. At step 0 the model is at random initialisation: the gradients reflect random projections of random labels, not the 500× imbalance the paper characterises in §12.3. Feeding those transient gradients into the closed form — and into the EMA on top — produces garbage λ\lambda^* for the first few hundred steps.

Paper Algorithm 1 lines 4-6: for the first W=100W = 100 steps, set λi=1/K\lambda^*_i = 1/K (uniform weighting). Only after step W does the GABA pipeline (closed form → EMA → floor → renormalise) take over. Same controller, gated on the time index.

Why Step-0 Gradients Are Untrustworthy

Three reasons the first ~50 training steps look nothing like the steady-state regime characterised in §12.3:

  • Random init dominates. At step 0 the backbone is at Kaiming-init values; logits are random. Cross-entropy on random logits is lnK\approx \ln K regardless of the true labels. MSE on random RUL predictions is dominated by random-prediction variance, not by useful regression structure.
  • The 500× imbalance is a steady-state property. Paper main.tex:319 measures it on n=4,120n = 4{,}120 epoch-level samples FROM TRAINED MODELS — not at init. At step 0 the ratio can be anywhere from 1× to 100×, depending on which random seed was drawn.
  • Adam's own bias-correction kicks in over the first ~100 steps. Adam's first-moment estimator is biased toward zero for t1/(1β1)=10t \lesssim 1/(1-\beta_1) = 10 steps; the second-moment estimator for t1/(1β2)=1,000t \lesssim 1/(1-\beta_2) = 1{,}000 steps but with smaller variance. Combining adaptive loss weighting with adaptive optimiser estimates during THEIR transient phase amplifies oscillation.

Trying to use GABA from step 1 produces wild oscillations in λ\lambda^* as the EMA chases transient noise. The visualization below shows this directly: with W=0W = 0 (no warmup) the trajectory starts dropping from 0.5 immediately and overshoots before settling. With paper's W=100W = 100 the trajectory stays at 0.5 until the model has had a chance to find a stable operating point.

The Warmup Gate (Paper Algorithm 1)

Paper main.tex:362-374 specifies the gate:

λi={1/Kif tW(GABA pipeline)if t>W\lambda^*_i = \begin{cases} 1/K & \text{if } t \leq W \\ \text{(GABA pipeline)} & \text{if } t > W \end{cases}

Where W=100W = 100 is the warmup duration and tt is a 1-indexed step counter. Two important details:

  • The comparison is inclusive. tWt \leq W means step 100 is still in warmup; step 101 is the first active step. Paper code: self.step_count.item() <= self.warmup_steps.
  • The EMA buffer is NOT updated during warmup. The closed form is not computed; there is nothing to feed into the EMA. The buffer stays at its initial value λ^i(0)=1/K\hat{\lambda}^{(0)}_i = 1/K (paper Algorithm 1 line 2). When the gate flips at step 101, the EMA starts from a sensible 1/K value — which matches the warmup output, so there's no discontinuity in λ\lambda^* itself.

What Uniform 1/K Means In Practice

For K=2K = 2 (RUL + health), uniform weighting means λrul=λhealth=0.5\lambda^*_{\text{rul}} = \lambda^*_{\text{health}} = 0.5. The combined loss during warmup is exactly:

L=0.5Lrul+0.5Lhealth\mathcal{L} = 0.5 \cdot \mathcal{L}_{\text{rul}} + 0.5 \cdot \mathcal{L}_{\text{health}}

In other words, GABA falls back to the ‘Fixed Baseline’ method (paper §3.5) for the first 100 steps. This is intentional. Fixed Baseline is the simplest, safest, most-studied multi-task scheme: every published MTL paper has experience with it, every deep-learning library has well-understood behaviour for it, and Adam's bias correction has been engineered for losses that look like this.

Why warmup doesn't hurt the 500× imbalance regime. One could worry that 100 steps of equal weighting lets the RUL gradient dominate and train backbone features that are useless for health. Empirically the paper's ablations show no measurable harm: 100 steps is a tiny fraction of the full 500-epoch training horizon, and the active GABA controller has 400+ epochs to re-shape the backbone afterwards.

Why W = 100 Steps

The choice W=100W = 100 is paper canonical (main.tex:362). Three justifications:

ReasonQuantitative tieImplication
Match the EMA time constantτ = 1/(1−β) = 1/0.01 = 100 steps for β=0.99After warmup, the EMA has had one τ to absorb a meaningful signal
Cover Adam first-moment bias correction1/(1−β₁) = 10 steps for β₁=0.910x safety margin over Adam&apos;s first-moment transient
Long enough that gradient norms reflect data, not initEmpirically: ~50 steps suffice for gradient magnitudes to stabilise100 steps gives margin even for harder datasets
Short enough not to waste training100 / (500 epochs × 100 steps/epoch) ≈ 0.2% of trainingNegligible cost for the safety margin

The robustness ablation in paper §5.8 shows results are statistically indistinguishable for W[50,200]W \in [50, 200]; the canonical 100 sits comfortably in the middle.

What Happens At Step W + 1

At step W+1W + 1 the gate flips and three things happen in one update:

  • Closed form computes. For the paper-realistic 500× imbalance the raw λ\lambda is [0.002,0.998][0.002, 0.998].
  • EMA absorbs 1% of the new value. λ^=0.99[0.5,0.5]+0.01[0.002,0.998]=[0.49502,0.50498]\hat{\lambda} = 0.99 \cdot [0.5, 0.5] + 0.01 \cdot [0.002, 0.998] = [0.49502, 0.50498]. Tiny shift — the EMA has a long memory.
  • Floor is inactive. Both EMA values are above λmin=0.05\lambda_{\min} = 0.05; clamp is a no-op; renormalisation divides by 1.

So the FIRST active step changes λ\lambda^* from 0.50.5 to 0.495020.49502 — a 0.5% relative shift, invisible to the optimiser. Convergence to the steady-state 0.04762 takes another 3τ=300\sim 3\tau = 300 steps. Total: warmup + post-warmup settling = 400 steps to reach the floor-bound regime, out of typically 50,00050{,}000+ training steps.

The transition is mathematically smooth. Because the EMA buffer was initialised at 1/K and only absorbs (1−β) of the new measurement per step, switching the gate at step W+1 changes λ\lambda^* by less than (1β)1/Kλclosed(1-\beta) \cdot |1/K - \lambda^*_{\text{closed}}| in a single step. For β=0.99 and a 500×-imbalance target, that's a maximum jump of 0.5%\sim 0.5\%. Adam's second-moment estimator easily absorbs it.

Interactive: Slide The Warmup Boundary

Drag the W slider. The amber band marks the warmup region; inside it λrul\lambda^*_{\text{rul}} is pinned at 0.5. Outside the band the GABA pipeline runs. The dashed grey line is the same simulation with W=0W = 0 for comparison. Watch how W = 0 immediately starts dropping while W = 100 stays flat through the window.

Loading warmup-gate visualizer…
Try this. Set W = 0 and watch the blue and grey traces collapse onto each other — no warmup, immediate GABA. Set W = 300 and watch the trace stay at 0.5 for the entire visible window — warmup eats the whole simulation. Paper's W = 100 is the sweet spot: enough warmup to dodge the cold-start transient, short enough to leave the rest of training in active mode.

Python: Warmup Gate From Scratch

Implement the full GABA per-step update with the warmup gate in pure NumPy. Run for 500 steps and print the λrul\lambda_{\text{rul}} trajectory at milestone steps so the warmup plateau and the post-warmup settling are visible.

Paper Algorithm 1 with the warmup branch — full pipeline
🐍warmup_gate_from_scratch.py
1docstring

Module docstring. Implements the warmup branch from paper Algorithm 1 lines 4-6 verbatim — the IF gate that returns uniform weights during the first W steps.

3import numpy as np

NumPy supplies ndarray, np.full (uniform weights), np.array, np.maximum, np.random.randn for synthetic gradients.

EXECUTION STATE
📚 numpy = Numerical computing library. Used for ndarray and the floor + renorm aggregation.
5np.random.seed(0)

Deterministic PRNG so the milestones below are reproducible.

7# Paper canonical hyperparameters

Section header.

8W = 100

Warmup duration. Paper Algorithm 1 line 1: REQUIRE warmup W=100. Equals one EMA time constant τ = 1/(1−β) = 1/0.01 = 100 steps for β=0.99.

EXECUTION STATE
W = 100 = Integer. Number of training steps during which GABA returns uniform 1/K weights instead of running the closed-form pipeline.
→ why this number? = Matches the EMA time constant. After W steps the EMA buffer has had a full τ to absorb a meaningful signal. Smaller W: not enough time to escape transient. Larger W: wastes the first chunk of training on uniform weights.
9beta = 0.99

EMA smoothing coefficient (§18.2).

EXECUTION STATE
beta = 0.99 = Float. EMA coefficient. Time constant τ = 1/(1-β) = 100 steps.
10lam_min = 0.05

Floor for the post-EMA clamp (§18.3).

EXECUTION STATE
lam_min = 0.05 = Float. Minimum allowed per-task weight. Bounds output in approximately [λ_min, 1−λ_min] for K=2.
11K = 2

Number of tasks. K=2 for RUL + health.

EXECUTION STATE
K = 2 = Integer. RUL regression + health classification.
14def gaba_step(step_count, ema_w, g_rul, g_health)

One full GABA per-step update with the warmup gate. Returns the smoothed weights and the (possibly updated) EMA buffer.

EXECUTION STATE
⬇ input: step_count = Integer ≥ 1. Current training step (1-indexed). Used by the warmup gate.
⬇ input: ema_w = ndarray (K,). Persistent EMA-smoothed weight buffer from §18.2.
⬇ input: g_rul = Float. Current RUL gradient norm.
⬇ input: g_health = Float. Current health gradient norm.
⬆ returns = Tuple (weights, ema_w). weights is the lambda* used for THIS step; ema_w is the buffer for the NEXT step.
15docstring

Records that this is the full GABA per-step update including the warmup gate.

16if step_count <= W:

The warmup gate itself. Paper Algorithm 1 line 4. Note the comparison is INCLUSIVE — at step 100, we&apos;re still in warmup; at step 101 we&apos;re active.

EXECUTION STATE
step_count <= W = Boolean. True for steps 1..100, False for steps 101+.
→ off-by-one = step_count starts at 1 (first call increments from 0 to 1) and the gate is &lt;=. So warmup covers exactly W = 100 steps, not 99 or 101.
17return np.full(K, 1.0 / K), ema_w

Warmup branch: return uniform 1/K weights and DO NOT update the EMA. Paper Algorithm 1 line 5.

EXECUTION STATE
📚 np.full(shape, fill) = Build an ndarray of given shape filled with the given scalar. np.full(2, 0.5) = [0.5, 0.5].
⬇ arg 1: K = 2 = Output shape — a 1-D vector of length K.
⬇ arg 2: 1.0 / K = 0.5 = Fill value. For K=2 this is uniform 0.5.
→ why ema_w unchanged? = During warmup we do NOT compute the closed form, so there is nothing to feed into the EMA. The buffer stays at its initial value 1/K (or whatever it was on the previous warmup step).
→ why uniform, not zero? = 1/K keeps the simplex constraint sum=1 satisfied. The combined loss has its expected scale; the optimiser sees a stable objective. Zero weights would collapse the loss to nothing.
18# Closed form (eq. 4 specialised to K=2)

Inline comment marking the paper-equation reference for the next block.

19S = g_rul + g_health

Sum of gradient norms. K=2 normaliser.

EXECUTION STATE
S = Float. Sum of the two task gradient norms.
20raw = np.array([g_health / S, g_rul / S])

K=2 closed form (paper eq. 4). Larger gradient → smaller weight.

EXECUTION STATE
📚 np.array(list) = Build an ndarray from a Python list.
raw = ndarray (2,). Per-step inverse-proportional weights.
21# EMA (eq. 5)

Inline comment.

22ema_w = beta * ema_w + (1.0 - beta) * raw

EMA smoothing (§18.2 paper eq. 5). Convex combination of history and new measurement.

EXECUTION STATE
ema_w (after) = Updated EMA buffer. Will be returned and stored for next step.
23# Floor + renormalise (eq. 6)

Inline comment.

24clamped = np.maximum(ema_w, lam_min)

Anti-windup floor (§18.3). Element-wise max against lam_min.

EXECUTION STATE
📚 np.maximum(a, b) = Element-wise max — floors values at b.
clamped = ndarray (2,). Each element ≥ lam_min.
25return clamped / clamped.sum(), ema_w

Renormalise to the simplex and return both the lambda* AND the updated EMA buffer.

EXECUTION STATE
⬆ return = Tuple. Element 0: lambda* on the simplex. Element 1: updated ema_w buffer for the next step.
29def synth(t):

Synthetic per-step gradient generator. Mimics what a real backbone produces: at step 0 the gradients are different from the steady-state 5.0/0.01 because the model is at random init; over the first 50 steps the gradients drift toward the paper-realistic regime.

EXECUTION STATE
⬇ input: t = Step number (0-indexed).
⬆ returns = Tuple (g_rul, g_health) with realistic per-step magnitudes.
30blend = min(1.0, t / 50.0)

Linearly interpolate from 0 (step 0) to 1.0 (step 50+). Used to blend init-time gradients with steady-state gradients.

EXECUTION STATE
blend = Float in [0, 1]. 0 at step 0, ramps to 1 by step 50, stays at 1 thereafter.
→ why this matters for warmup = The first 50 steps&apos; gradient magnitudes don&apos;t match the long-run pattern. Computing GABA λ from THESE gradients gives a transient junk signal. That&apos;s the whole reason warmup exists.
31g_rul = max((2.0 + blend * 3.0) + 0.5 * np.random.randn(), 0.01)

RUL gradient norm: starts at ~2.0 (random init), drifts to ~5.0 by step 50, with std 0.5 per-batch noise.

EXECUTION STATE
📚 np.random.randn() = Sample one value from N(0, 1).
g_rul = Float ≥ 0.01. Synthetic RUL gradient norm at step t.
32g_health = max((0.5 - blend * 0.49) + 0.05 * np.random.randn(), 1e-6)

Health gradient norm: starts at ~0.5 (random init), drifts DOWN to ~0.01 by step 50.

EXECUTION STATE
g_health = Float ≥ 1e-6. Synthetic health gradient norm at step t.
→ ratio = At step 0: g_rul/g_health ≈ 2.0/0.5 = 4x. At step 50+: ≈ 5.0/0.01 = 500x. The 500x imbalance EMERGES during training, not at init.
33return g_rul, g_health

Return the two scalars.

37ema_w = np.array([0.5, 0.5])

Initial EMA buffer = uniform 1/K. Paper Algorithm 1 line 2.

EXECUTION STATE
ema_w (init) = ndarray (2,) = [0.5, 0.5]. Same value uniform weights produce — graceful continuity at the warmup boundary.
38trace = []

Empty list to record λ_rul at each step.

39for t in range(500):

Run 500 simulated training steps. step_count = t + 1 inside the loop so the gate uses 1-indexing.

LOOP TRACE · 6 iterations
t=0 (step_count=1, warmup)
weights = [0.5, 0.5]
trace[0] = 0.500000
t=50 (step_count=51, warmup)
weights = [0.5, 0.5]
trace[50] = 0.500000
t=99 (step_count=100, last warmup)
weights = [0.5, 0.5] — t=99 is the 100th call, still ≤ W
trace[99] = 0.500000
t=100 (step_count=101, FIRST active)
raw = ≈ [0.002, 0.998] (steady-state imbalance now)
ema_w = 0.99 · [0.5, 0.5] + 0.01 · [0.002, 0.998] ≈ [0.495, 0.505]
trace[100] = 0.495051 — barely budged from 0.5 yet
t=200 (active for 100 steps post-warmup)
trace[200] = 0.185016 — EMA halfway to steady state
t=499 (steady state)
trace[499] = 0.048225 — at the floor + renorm bound 0.04762
40step_count = t + 1

Convert 0-indexed loop variable to 1-indexed step counter so the gate uses standard 1-indexing.

41g_rul, g_health = synth(t)

Generate the per-step synthetic gradients.

42weights, ema_w = gaba_step(step_count, ema_w, g_rul, g_health)

Apply one full GABA step. The returned ema_w replaces the previous buffer.

EXECUTION STATE
→ tuple unpack = Python lets us assign a 2-tuple return to two variables in one line.
43trace.append(weights[0])

Record λ_rul (the first element of the weight vector) for plotting.

EXECUTION STATE
trace[t] = Float in [0, 1]. λ_rul at step t.
47print header

Pretty-print header.

EXECUTION STATE
Output = step | lam_rul | regime
48print separator

Visual separator line.

EXECUTION STATE
Output = ----------------------------------------
49for t in [0, 50, 99, 100, 101, 150, 200, 300, 499]:

Iterate selected milestones. Includes the boundary steps 99 (last warmup) and 100 (first active) so the transition is visible.

LOOP TRACE · 9 iterations
t=0
row = 0 | 0.500000 | warmup
t=50
row = 50 | 0.500000 | warmup
t=99
row = 99 | 0.500000 | warmup
t=100
row = 100 | 0.495051 | active
t=101
row = 101 | 0.490150 | active
t=150
row = 150 | 0.301712 | active
t=200
row = 200 | 0.185016 | active
t=300
row = 300 | 0.070340 | active
t=499
row = 499 | 0.048225 | active
50regime = ...

Tag each row with its regime label (warmup vs active) for the printout.

51print row

f-string with width specs; final printout shows the warmup plateau and the gradual EMA settling.

EXECUTION STATE
Final output =
 step |    lam_rul | regime
----------------------------------------
    0 |   0.500000 | warmup
   50 |   0.500000 | warmup
   99 |   0.500000 | warmup
  100 |   0.495051 | active
  101 |   0.490150 | active
  150 |   0.301712 | active
  200 |   0.185016 | active
  300 |   0.070340 | active
  499 |   0.048225 | active
→ reading = Steps 0-99: rock steady at 0.5. Step 100: gate flips, EMA absorbs first 1% of the new signal, λ moves to 0.495. Steps 100-499: EMA settles toward the floor-bound 0.04762 over ~3τ = 300 steps after the gate flip.
14 lines without explanation
1"""Warmup gate from scratch: paper Algorithm 1, lines 4-6."""
2
3import numpy as np
4
5np.random.seed(0)
6
7# Paper canonical hyperparameters
8W       = 100        # warmup steps
9beta    = 0.99       # EMA coefficient
10lam_min = 0.05       # floor
11K       = 2          # tasks
12
13
14def gaba_step(step_count, ema_w, g_rul, g_health):
15    """One full GABA step with warmup gate."""
16    if step_count <= W:
17        return np.full(K, 1.0 / K), ema_w   # warmup: uniform weights
18    # Closed form (eq. 4 specialised to K=2)
19    S = g_rul + g_health
20    raw = np.array([g_health / S, g_rul / S])
21    # EMA (eq. 5)
22    ema_w = beta * ema_w + (1.0 - beta) * raw
23    # Floor + renormalise (eq. 6)
24    clamped = np.maximum(ema_w, lam_min)
25    return clamped / clamped.sum(), ema_w
26
27
28# ---------- Synthetic gradients that drift toward the 500x regime ----------
29def synth(t):
30    blend = min(1.0, t / 50.0)
31    g_rul    = max((2.0 + blend * 3.0) + 0.5  * np.random.randn(), 0.01)
32    g_health = max((0.5 - blend * 0.49) + 0.05 * np.random.randn(), 1e-6)
33    return g_rul, g_health
34
35
36# ---------- Run 500 steps with the paper's warmup ----------
37ema_w = np.array([0.5, 0.5])
38trace = []
39for t in range(500):
40    step_count = t + 1
41    g_rul, g_health = synth(t)
42    weights, ema_w = gaba_step(step_count, ema_w, g_rul, g_health)
43    trace.append(weights[0])
44
45
46# ---------- Print milestones ----------
47print(f"{'step':>5} | {'lam_rul':>10} | regime")
48print("-" * 40)
49for t in [0, 50, 99, 100, 101, 150, 200, 300, 499]:
50    regime = "warmup" if (t + 1) <= W else "active"
51    print(f"{t:>5} | {trace[t]:>10.6f} | {regime}")

PyTorch: The Paper's Branching Code

The actual paper code lives at grace/core/gaba.py:107-108 and is two lines: if shared_params is None or self.step_count.item() <= self.warmup_steps: followed by weights = torch.ones(K, device=device) / K. We extract the gate (and its surrounding state-management machinery) into a standalone class for clarity.

Paper code: warmup gate + EMA + floor (faithful copy of grace/core/gaba.py)
🐍warmup_gate_torch.py
1docstring

Module docstring. The class below extracts the warmup-gate portion of grace/core/gaba.py for clarity, with the closed-form math factored out to focus on the gate itself.

3import torch

Core PyTorch.

EXECUTION STATE
📚 torch = Tensor library. Used for tensors, register_buffer, .item, .clamp, torch.ones.
4import torch.nn as nn

Module primitives.

EXECUTION STATE
📚 torch.nn = PyTorch nn package. nn.Module base class.
7class WarmupGabaLoss(nn.Module):

nn.Module subclass so step_count and ema_weights persist across batches and survive checkpoint save/load.

EXECUTION STATE
📚 nn.Module = Base class for stateful PyTorch components. Subclasses register buffers / parameters and override forward().
8docstring

Records that this is a faithful reproduction of grace/core/gaba.py:107-108.

18def __init__(self, beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2):

Constructor with all paper-canonical defaults.

EXECUTION STATE
⬇ beta = 0.99 = EMA coefficient (§18.2).
⬇ warmup_steps = 100 = Warmup duration W. Paper Algorithm 1 line 1.
⬇ min_weight = 0.05 = Floor (§18.3).
⬇ n_tasks = 2 = K. RUL + health.
20super().__init__()

Required first line of every nn.Module __init__.

21self.beta = beta

Store as plain instance attribute.

22self.warmup_steps = warmup_steps

Store the warmup duration. The gate compares step_count against this on every forward call.

23self.min_weight = min_weight

Floor. Used by clamp.

24self.n_tasks = n_tasks

Cached K.

25self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)

EMA buffer initialised at uniform 1/K. CRUCIAL: this matches the warmup output, so when the gate flips at step W+1 the EMA starts from a known sensible state.

EXECUTION STATE
📚 register_buffer(name, tensor) = Register a non-learnable persistent buffer on the module. Tracked in state_dict, moves with .to(device), survives checkpoint save/load.
→ init = 1/K not 0 = If we initialised at 0, the first active step would compute 0.99·0 + 0.01·raw = small junk. Initialising at 1/K gives a sensible warm start that matches the warmup output.
26self.register_buffer("step_count", torch.tensor(0, dtype=torch.long))

Step counter buffer. Used by the warmup gate. dtype=torch.long because step counts are integers and long never overflows.

EXECUTION STATE
→ why a buffer? = Survives checkpoint save/load. Resuming training at step 5,000 must NOT re-enter warmup because the checkpoint forgot the counter.
28def forward(self, raw_weights, shared_params=None):

Per-step entry point. Called once per training step.

EXECUTION STATE
⬇ input: raw_weights = Tensor (K,). The closed-form λ from §17.3 (computed elsewhere).
⬇ input: shared_params = Optional list of nn.Parameter. Default None means &lsquo;no params provided&rsquo;, which the gate treats the same as warmup.
⬆ returns = Tensor (K,). The lambda* used by the optimiser THIS step.
29K = self.n_tasks

Cache K locally for brevity.

30device = raw_weights.device

Read the input&apos;s device so we build new tensors on the same hardware (CPU / GPU / MPS).

EXECUTION STATE
📚 .device = Tensor attribute. The hardware the tensor lives on.
31self.step_count += 1

Increment the counter. PyTorch supports in-place add on 0-dim tensors, including long-dtype.

EXECUTION STATE
→ BEFORE the gate = Increment FIRST so the gate compares the post-increment value. step_count starts at 0; first call sets it to 1; gate compares 1 ≤ 100 → warmup.
33# WARMUP GATE — paper Algorithm 1 lines 4-6

Section header marking the gate.

34if shared_params is None or self.step_count.item() <= self.warmup_steps:

The gate. TWO triggers for warmup behaviour: (a) caller didn&apos;t pass shared_params (the gradient-norm computation is unavailable, so we have no signal to compute λ from), OR (b) we&apos;re still in the W=100 window.

EXECUTION STATE
shared_params is None = True when the caller is in &lsquo;dry-run&rsquo; mode — e.g. evaluating without backprop. We can&apos;t compute the closed form, so we fall back to uniform.
self.step_count.item() = Convert 0-dim tensor to Python int so we can compare with the Python int self.warmup_steps.
📚 .item() = Tensor method. Pull out a 0-dim tensor as a Python scalar. Detaches from autograd.
<= self.warmup_steps = Inclusive: at step 100 the condition is True (last warmup step); at step 101 it&apos;s False (first active step).
35return torch.ones(K, device=device) / K

Warmup branch: return uniform 1/K weights and bypass the rest of the pipeline. CRITICAL: the EMA buffer is NOT updated during warmup, so it stays at its initial 1/K value.

EXECUTION STATE
📚 torch.ones(*size, device) = Build a tensor of all-ones with the given shape and device.
⬇ arg 1: K = Output shape — 1-D vector of length K.
⬇ arg 2: device=device = Match input device. Without this, the returned tensor lives on CPU even if raw_weights is on GPU; the optimiser update would crash.
/ K = Scalar division. torch.ones(2) / 2 = [0.5, 0.5].
37# ACTIVE branch — closed form is already in raw_weights; do EMA + floor.

Comment. The closed form (§17.3) was computed by the caller and passed in as raw_weights; this module handles only EMA + floor.

38ema_w = self.beta * self.ema_weights + (1.0 - self.beta) * raw_weights

EMA update (§18.2 paper eq. 5). At step 101 (first active), this combines the still-1/K ema_weights with the realistic raw_weights.

EXECUTION STATE
→ step 101 first compute = 0.99 · [0.5, 0.5] + 0.01 · [0.002, 0.998] = [0.49502, 0.50498]. Tiny shift from 0.5 toward the new target — exactly the slow EMA absorption we expect.
39self.ema_weights = ema_w.detach()

Save back to the buffer with .detach() to prevent autograd-history accumulation (§18.2 pitfall).

EXECUTION STATE
📚 .detach() = Strip autograd history. Critical here to keep memory constant across steps.
40clamped = ema_w.clamp(min=self.min_weight)

Anti-windup floor (§18.3). Element-wise clamp at min_weight.

EXECUTION STATE
📚 .clamp(min=v) = Tensor method. Element-wise floor: max(x, v) per element.
41return clamped / clamped.sum()

Renormalise to the simplex.

EXECUTION STATE
⬆ return = Tensor (K,). The lambda* used by the optimiser this step. Bounded in [≈ λ_min, ≈ 1−λ_min].
45torch.manual_seed(0)

Reproducible smoke test.

46gaba = WarmupGabaLoss(beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2)

Instantiate with paper defaults.

EXECUTION STATE
gaba.ema_weights = [0.5, 0.5] (initial buffer).
gaba.step_count = tensor(0) (initial buffer).
47raw_lambda = torch.tensor([0.002, 0.998])

Constant raw input across all 150 steps so the convergence is purely about the gate flipping and the EMA settling, not about input changes.

EXECUTION STATE
raw_lambda = Tensor (2,). Realistic 500x-imbalance closed-form weight from §17.3.
48shared = [torch.zeros(1)]

Placeholder list for shared_params. Just non-None so the gate&apos;s &lsquo;shared_params is None&rsquo; branch doesn&apos;t fire. The actual values aren&apos;t used by this stripped-down module.

EXECUTION STATE
shared = List with one zero-tensor. Just keeps the gate from defaulting to warmup on the missing-params branch.
50print header

Pretty-print header.

EXECUTION STATE
Output = step | lam_rul | lam_health | regime
51print separator

Visual separator line.

52for step in range(1, 151):

150 simulated steps. Covers warmup (1..100) and the start of active mode (101..150).

LOOP TRACE · 9 iterations
step = 1 (warmup)
out = [0.5, 0.5] — uniform, gate active
step = 50 (warmup)
out = [0.5, 0.5]
step = 99 (warmup)
out = [0.5, 0.5]
step = 100 (last warmup)
out = [0.5, 0.5] — gate fires &lsquo;True&rsquo;: 100 ≤ 100
step = 101 (FIRST active)
ema_w (raw) = 0.99·[0.5,0.5] + 0.01·[0.002,0.998] = [0.49502, 0.50498]
clamp = [0.49502, 0.50498] (no floor needed)
out = [0.49502, 0.50498] — first active step shows the EMA inertia
step = 105
out = [0.47557, 0.52443] — 5 steps of slow drift
step = 110
out = [0.45203, 0.54797]
step = 130
out = [0.36789, 0.63211]
step = 150 (50 steps post-warmup)
out = [0.30188, 0.69812] — half-way through one EMA time constant
53out = gaba(raw_lambda, shared_params=shared)

One forward call. Mutates step_count (and ema_weights when active).

54if step in (1, 50, 99, 100, 101, 105, 110, 130, 150):

Only print at milestone steps so the log stays compact.

55compute regime label

Tag the row as warmup or active based on the current step relative to the gate.

56print row

f-string with width specs; output shows the warmup plateau and the slow EMA absorption after the gate flip.

EXECUTION STATE
Final output =
step |    lam_rul | lam_health | regime
--------------------------------------------------
   1 |   0.500000 |   0.500000 | warmup
  50 |   0.500000 |   0.500000 | warmup
  99 |   0.500000 |   0.500000 | warmup
 100 |   0.500000 |   0.500000 | warmup
 101 |   0.495020 |   0.504980 | active
 105 |   0.475570 |   0.524430 | active
 110 |   0.452029 |   0.547971 | active
 130 |   0.367892 |   0.632108 | active
 150 |   0.301880 |   0.698120 | active
→ reading = Steps 1-100: rock steady at [0.5, 0.5]. Step 101 is the first active step; the EMA absorbs only 1% of the new input per step. After 50 active steps (step 150), λ_rul has dropped from 0.5 to 0.30 — about half-way to the steady-state target 0.04762.
20 lines without explanation
1"""Paper code: warmup gate as in grace/core/gaba.py:107-108."""
2
3import torch
4import torch.nn as nn
5
6
7class WarmupGabaLoss(nn.Module):
8    """Minimal GABA module showing the warmup gate from paper code.
9
10    Mirrors lines 107-108 of grace/core/gaba.py:
11
12        if shared_params is None or self.step_count.item() <= self.warmup_steps:
13            weights = torch.ones(K, device=device) / K
14        else:
15            ...full GABA pipeline...
16    """
17
18    def __init__(self, beta: float = 0.99, warmup_steps: int = 100,
19                 min_weight: float = 0.05, n_tasks: int = 2) -> None:
20        super().__init__()
21        self.beta         = beta
22        self.warmup_steps = warmup_steps
23        self.min_weight   = min_weight
24        self.n_tasks      = n_tasks
25        self.register_buffer("ema_weights", torch.ones(n_tasks) / n_tasks)
26        self.register_buffer("step_count",  torch.tensor(0, dtype=torch.long))
27
28    def forward(self, raw_weights: torch.Tensor, shared_params=None) -> torch.Tensor:
29        K = self.n_tasks
30        device = raw_weights.device
31        self.step_count += 1
32
33        # WARMUP GATE — paper Algorithm 1 lines 4-6
34        if shared_params is None or self.step_count.item() <= self.warmup_steps:
35            return torch.ones(K, device=device) / K
36
37        # ACTIVE branch — closed form is already in raw_weights; do EMA + floor.
38        ema_w = self.beta * self.ema_weights + (1.0 - self.beta) * raw_weights
39        self.ema_weights = ema_w.detach()
40        clamped = ema_w.clamp(min=self.min_weight)
41        return clamped / clamped.sum()
42
43
44# ---------- Smoke test: walk steps 1..150 with constant raw input ----------
45torch.manual_seed(0)
46gaba = WarmupGabaLoss(beta=0.99, warmup_steps=100, min_weight=0.05, n_tasks=2)
47raw_lambda = torch.tensor([0.002, 0.998])
48shared = [torch.zeros(1)]      # placeholder so the gate doesn't fall through
49
50print(f"{'step':>4} | {'lam_rul':>10} | {'lam_health':>10} | regime")
51print("-" * 50)
52for step in range(1, 151):
53    out = gaba(raw_lambda, shared_params=shared)
54    if step in (1, 50, 99, 100, 101, 105, 110, 130, 150):
55        regime = "warmup" if step <= gaba.warmup_steps else "active"
56        print(f"{step:>4} | {out[0].item():>10.6f} | {out[1].item():>10.6f} | {regime}")

Warmup Patterns In Other Pipelines

FieldWarmup mechanismTypical durationWhy it&apos;s needed
Predictive maintenance (this paper)GABA gate (uniform 1/K for first W steps)100 stepsAvoid cold-start transient in gradient norms
Optimisation: linear LR warmupLinear ramp from 0 to peak LR500-2,000 steps (BERT, GPT)Adam&apos;s second-moment estimator is biased; small LR avoids divergence
Computer vision: BatchNorm running statsFirst few epochs use batch stats only1-5 epochsRunning-mean estimates need data to be meaningful
Reinforcement learning: replay buffer fillNo gradient updates until buffer holds N samples10K-50K stepsOff-policy methods need a non-trivial buffer to sample from
Generative models: classifier-free guidanceFirst few steps use unconditional generation10% of denoising stepsAvoids over-conditioning on weak class signal early in sampling
Training schedules: gradual unfreezingTrain head first, then unfreeze backbone layers progressively1 epoch per layerPrevents catastrophic forgetting in fine-tuning
Continual learning: rehearsal warmupTrain on old + new data with fixed mixing for first epoch1-3 epochsLets the new task settle before adaptive task-balancing kicks in

The pattern — gate the adaptive controller off until the system has reached a sensible operating point — recurs across nearly every adaptive learning pipeline. GABA inherits the pattern; the implementation is just a step counter and an inclusive comparison.

Pitfalls In Warmup Gating

Pitfall 1: Off-by-one on the step counter. If you increment AFTER the gate check (or use a 0-indexed counter) you get either 99 warmup steps or 101 warmup steps instead of paper's 100. With 99 you miss the EMA-time-constant alignment; with 101 you have an extra step. Paper code increments BEFORE the check and uses \leq.
Pitfall 2: Updating the EMA during warmup. Some implementations compute the closed form during warmup ‘just to keep the EMA tracking’ and only override λ\lambda^* at the output. This corrupts the EMA buffer with transient junk. Paper code only enters the EMA branch AFTER the gate flips; the buffer stays at 1/K through the entire warmup window.
Pitfall 3: Resetting the step counter on checkpoint resume. If you restart training from a saved checkpoint at step 5,000 and forget to restore step_count, GABA re-enters warmup for another 100 steps with uniform weights. The model spends 100 steps un-doing the adaptive weighting it learned before. Paper code stores step_count as a registered buffer so state_dict() captures it.
Pitfall 4: Setting W too small (e.g. W = 10). Skips the cold-start period and feeds raw init-time gradients through the closed form. The first 50 active steps see wildly oscillating λ\lambda^* as the gradient ratio swings from 4× to 500× as the model warms up. Adam's second-moment estimator amplifies the oscillation. Setting W = 100 avoids all of this for ~0.2% of training cost.
Why warmup is so cheap. 100 steps out of a typical 500-epoch / 50,000-step training run is ~0.2% of compute. The cost of NOT warming up is a fragile first ~500 steps that can derail an entire run. The cost-benefit ratio is > 1,000× in favour of warmup. That is why paper main.tex:362 lists W = 100 as a ‘robust default’ alongside β = 0.99 and λ_min = 0.05.

Takeaway

  • Warmup gates the adaptive controller off for the first W = 100 steps. Paper Algorithm 1 lines 4-6: if tWt \leq W return 1/K1/K uniform weights; otherwise run the GABA pipeline.
  • W = 100 matches the EMA time constant. One τ = 1/(1−β) = 100 steps. This alignment isn't coincidence: warmup gives the EMA exactly one time constant of clean uniform initialisation before adaptive logic kicks in.
  • The EMA buffer is not updated during warmup. It stays at the initial 1/K1/K, which matches the warmup output, so no discontinuity in λ\lambda^* at step W+1.
  • The transition is smooth. First active step shifts λ\lambda^* by 0.5%\sim 0.5\% for the realistic 500× imbalance — well below Adam's noise floor.
  • step_count must be a buffer. So checkpoint resume doesn't restart the warmup counter mid-training.
  • The pattern is universal. Linear LR warmup, BatchNorm running stats, replay-buffer fill, gradual unfreezing — every adaptive controller benefits from a startup gate.
Loading comments...