Chapter 15
15 min read
Section 60 of 121

Optimizer & Scheduler (AdamW + Warmup + Plateau)

AMNL Training Pipeline

The Three-Layer Optimiser Stack

AMNL ships with a three-piece optimiser stack: an AdamW base with decoupled weight decay, a linear warmup over the first 10 epochs, and a ReduceLROnPlateau scheduler that cuts lr by 50% every time the validation RMSE has not improved for 30 epochs. Each piece does one job; together they keep training stable from epoch 0 to convergence.

One-line summary. Optimiser is AdamW(lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999)); warmup is linear from 0.1 → 1.0 over 10 epochs; plateau scheduler is ReduceLROnPlateau(mode='min', factor=0.5, patience=30, min_lr=5e-6). All from paper_ieee_tii/experiments/train_amnl_v7.py.

AdamW: Adam with Decoupled Weight Decay

Adam's update rule is θθηm^/v^\theta \leftarrow \theta - \eta \cdot \hat{m}/\sqrt{\hat{v}} where m^,v^\hat{m}, \hat{v} are bias-corrected EMAs of the first and second gradient moments. Plain L2 regularisation in Adam produces a regulariser strength that is divided by v^\sqrt{\hat{v}} - so parameters with large running gradients get effectively LESS regularisation. AdamW decouples the wd term:

θθηm^/v^ηwdθ\theta \leftarrow \theta - \eta \cdot \hat{m}/\sqrt{\hat{v}} - \eta \cdot w_d \cdot \theta

The second term is INDEPENDENT of v^\sqrt{\hat{v}}. Every parameter gets the same regularisation strength ηwd\eta \cdot w_d regardless of its gradient history.

Why AdamW for AMNL. AMNL's sample weighting (§14.1) inflates near-failure gradients up to 2×. With plain Adam + L2, those parameters would get proportionally LESS regularisation - a hidden bias. AdamW keeps regularisation uniform, which matters more in multi-task than in single-task because the per-task gradient ratio (§12) is large.

Warmup: First 10 Epochs

Linear warmup multiplies the base lr by:

wf(e)={0.1+0.9e/Ee<E1.0eE\text{wf}(e) = \begin{cases} 0.1 + 0.9 \cdot e/E & e < E \\ 1.0 & e \geq E \end{cases}

with E=10E = 10. So lr rises from 0.1×103=1040.1 \times 10^{-3} = 10^{-4} at epoch 0 to 10310^{-3} at epoch 10.

Why warmup helps Adam. Adam's v^\hat{v} needs ~50 mini-batches to stabilise. Before then it underestimates the true gradient variance, making m^/v^\hat{m}/\sqrt{\hat{v}} artificially large. Warmup shrinks the lr exactly when v^\sqrt{\hat{v}} is most unreliable.

ReduceLROnPlateau: Cuts on Stalls

After warmup, the lr is reactive: as long as validation RMSE keeps improving, lr stays at base_lr. Once 30 consecutive epochs pass without a new best RMSE, the scheduler halves the lr. This continues until lr hits the floor min_lr = 5×10⁻⁶ (about 200× below base).

HyperparameterValueWhy
mode'min'RMSE is a metric that should DECREASE - watch for new minima
factor0.550% reduction per cut. Default in PyTorch is 0.1 (more aggressive); paper picked 0.5 to be gentler
patience30wait 30 epochs without improvement before cutting. Long enough to absorb random batch-to-batch jitter
min_lr5e-6floor; below this the scheduler stops cutting and training continues at fixed small lr

Interactive: Walk a 200-Epoch Schedule

Drag the patience slider to see how aggressive cuts feel. Drag the number-of-plateaus slider to add or remove synthetic stall events. Watch lr drop in lock-step with each plateau marker.

Loading scheduler viz…
Try this. Set patience=10 and 4 plateaus - lr collapses to min_lr by epoch 100. Set patience=80 with 1 plateau - lr stays at base for ~50 epochs after the stall, a more conservative regime. The paper's patience=30 is the middle ground: aggressive enough to escape local minima, conservative enough not to cut prematurely.

Python: Schedule from Scratch

Pure NumPy reimplementation - replicates PyTorch's ReduceLROnPlateau(mode='min', factor=0.5, patience=30, min_lr=5e-6) step-for-step on a synthetic validation curve.

warmup_factor() + reduce_on_plateau() + build_schedule()
🐍lr_schedule_numpy.py
1import numpy as np

NumPy provides np.linspace for generating the synthetic validation curve and np.log10 for the decade-counting summary.

EXECUTION STATE
📚 numpy = Library: ndarray + math + random + statistics.
as np = Universal alias.
4def warmup_factor(epoch, warmup_epochs=10) -> float:

The exact paper warmup formula from <code>paper_ieee_tii/grace/training/callbacks.py</code> lines 103-104. Linear ramp from 0.1·base_lr at epoch 0 up to 1.0·base_lr at epoch warmup_epochs - 1, then 1.0 thereafter.

EXECUTION STATE
⬇ input: epoch = Current 0-indexed epoch number.
⬇ input: warmup_epochs = 10 = Number of epochs to ramp over. Paper default 10.
⬆ returns = Multiplier in [0.1, 1.0]. Caller multiplies base_lr by this.
12if epoch >= warmup_epochs:

Guard - past the warmup window, return the full multiplier 1.0.

13return 1.0

Constant after warmup. The lr stays at base_lr until ReduceLROnPlateau kicks in.

14return 0.1 + 0.9 * epoch / warmup_epochs

Linear interpolation between 0.1 (at epoch 0) and 1.0 (at epoch warmup_epochs-1).

EXECUTION STATE
operator: + / * = Scalar arithmetic. The 0.9 stretches the [0, 1] interpolation up to the [0.1, 1.0] interval.
→ at epoch 0 = 0.1 + 0.9 · 0/10 = 0.10
→ at epoch 5 = 0.1 + 0.9 · 5/10 = 0.55
→ at epoch 9 = 0.1 + 0.9 · 9/10 = 0.91
→ at epoch 10 = guard returns 1.0 (skip this branch)
⬆ result = Float in [0.1, 1.0).
17def reduce_on_plateau(val_metric, patience=30, factor=0.5) -> list[bool]:

Replicates PyTorch&apos;s <code>ReduceLROnPlateau(mode='min')</code> firing rule. Returns one boolean per epoch: True ⇒ this is when the optimiser CUT the lr.

EXECUTION STATE
⬇ input: val_metric = List of per-epoch validation values (RMSE for AMNL).
⬇ input: patience = 30 = Number of epochs to wait without improvement before cutting. Paper default.
⬇ input: factor = 0.5 = Multiplier applied to current lr when triggered. Paper default - 50% reduction.
⬆ returns = List[bool] of length len(val_metric).
24best = float("inf")

Initialise the running best to +infinity so the first val_metric value always counts as an improvement.

EXECUTION STATE
📚 float("inf") = IEEE-754 +∞. Comparable but always larger than any finite float.
25bad_epochs = 0

Counter for consecutive non-improving epochs. Resets to 0 on every improvement OR after firing.

26triggers: list[bool] = []

Output list; one bool per epoch.

27for v in val_metric:

Iterate the validation curve.

EXECUTION STATE
iter var: v = Per-epoch validation value (RMSE, lower is better).
LOOP TRACE · 6 iterations
v = 20.0 (epoch 0)
best = 20.0 (improvement)
bad = 0
trigger = False
v = 8.0 (epoch 60, after linear improvement)
best = 8.0 (improvement)
bad = 0
trigger = False
v = 8.0 (epochs 61-89, plateau)
best = 8.0 (no change)
bad = 1, 2, …, 30
trigger = False until bad reaches patience
v = 8.0 (epoch 90)
bad = 30 ⇒ trigger fires
trigger = True - cut lr
bad reset = → 0
v = 8.0 (epoch 119, second plateau)
trigger = True - second cut
...
remaining = more triggers as plateaus continue past patience
28if v < best:

Strictly-less, not ≤. PyTorch&apos;s default uses a small threshold but the paper version is exact-strict.

29best, bad_epochs = v, 0

Tuple unpacking - update both in one line. New best ⇒ reset the &lsquo;bad epochs&rsquo; counter.

EXECUTION STATE
→ tuple unpacking = Right-hand side builds (v, 0); left-hand side has 2 names ⇒ each gets one element.
30else:

No improvement.

31bad_epochs += 1

Increment the counter.

EXECUTION STATE
operator: += = Augmented assignment. Equivalent to bad_epochs = bad_epochs + 1.
33if bad_epochs >= patience:

If we&apos;ve waited patience epochs without improvement, FIRE.

34triggers.append(True)

Record the trigger.

EXECUTION STATE
📚 list.append(x) = In-place append. O(1) amortised.
35bad_epochs = 0

Reset the counter after firing - so we wait another full `patience` window before the next cut. Without this reset every subsequent epoch would also fire.

36else:

Not yet patience - keep waiting.

37triggers.append(False)

Record no-cut.

38return triggers

Hand back the trigger sequence.

EXECUTION STATE
⬆ returns = List[bool] of length len(val_metric). Each True marks a plateau cut.
41def build_schedule(base_lr=1e-3, min_lr=5e-6, factor=0.5, patience=30, n_epochs=200, val_metric=None) -> list[float]:

Compose warmup + ReduceLROnPlateau into one (n_epochs,) trajectory. Defaults match paper&apos;s train_amnl_v7.py.

EXECUTION STATE
⬇ input: base_lr = 1e-3 = Paper Adam initial lr.
⬇ input: min_lr = 5e-6 = Floor - the scheduler stops cutting here.
⬇ input: factor = 0.5 = Paper choice - 50% reduction per cut.
⬇ input: patience = 30 = Paper choice.
⬇ input: n_epochs = 200 = Maximum number of epochs.
⬇ input: val_metric = None = If None, defaults to a flat curve (always plateau).
49if val_metric is None:

Default-arg pattern - falsy default check.

50val_metric = [1.0] * n_epochs

Constant val_metric ⇒ no improvement ever ⇒ keep cutting until min_lr.

EXECUTION STATE
→ list multiplication = [x] * n produces a list of n copies of x. Dangerous for mutable x; safe for floats.
52triggers = reduce_on_plateau(val_metric, patience, factor)

Compute the trigger sequence first.

53lrs: list[float] = []

Output list.

54cur_lr = base_lr

Track the &lsquo;post-warmup&rsquo; lr. We multiply by factor at every plateau trigger; warmup overrides this in early epochs.

55for e in range(n_epochs):

Per-epoch loop.

EXECUTION STATE
📚 range(stop) = Lazy iterator [0, stop).
LOOP TRACE · 6 iterations
e = 0
phase = warmup ⇒ lr = 0.1 · base_lr = 1.0e-04
e = 9
phase = warmup ⇒ lr = 0.91 · base_lr = 9.1e-04
e = 10
phase = post-warmup ⇒ lr = base_lr = 1.0e-03
e = 90
phase = post + first plateau cut ⇒ cur_lr = 5.0e-04
lr = 5.0e-04
e = 120
phase = post + second cut ⇒ cur_lr = 2.5e-04
e = 199
phase = after multiple cuts ⇒ approaches min_lr
56wf = warmup_factor(e)

Get the warmup multiplier for this epoch.

57if e < 10:

Warmup branch.

58lr = base_lr * wf

Apply the warmup multiplier to base_lr - this OVERRIDES any plateau bookkeeping during warmup.

59elif triggers[e]:

Plateau branch - cut the lr.

60cur_lr = max(cur_lr * factor, min_lr)

Halve the running lr but never below min_lr.

EXECUTION STATE
📚 max(a, b) = Built-in - returns the larger of two scalars. Acts as a floor here.
⬇ arg 1: cur_lr * factor = Halved lr.
⬇ arg 2: min_lr = Floor.
61lr = cur_lr

Use the freshly halved lr.

62else:

No cut, no warmup ⇒ keep current lr.

63lr = cur_lr

Steady-state lr.

64lrs.append(lr)

Record the per-epoch lr.

65return lrs

Hand back the schedule.

69np.random.seed(0)

Repro.

EXECUTION STATE
📚 np.random.seed(s) = Set NumPy&apos;s legacy global PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
70val_rmse = np.linspace(20, 8, 60).tolist() + [8.0] * 140

Synthetic validation curve. Linearly improves from RMSE=20 → 8 over 60 epochs, then plateaus at 8.0 for the remaining 140 epochs - triggering ReduceLROnPlateau cuts every patience=30 epochs.

EXECUTION STATE
📚 np.linspace(start, stop, num) = num evenly-spaced values between start and stop, inclusive.
⬇ arg 1: start = 20 = RMSE at epoch 0.
⬇ arg 2: stop = 8 = RMSE at epoch 59.
⬇ arg 3: num = 60 = 60 evenly-spaced points covering epochs 0-59.
📚 .tolist() = Convert ndarray to Python list for concat with another list.
→ list concat = [a, b] + [c, d] = [a, b, c, d]. Used here to splice the plateau onto the end of the descent.
⬆ result: val_rmse = (200,) list - improving for 60 epochs, flat for 140.
71schedule = build_schedule(val_metric=val_rmse)

Run the scheduler with paper defaults.

73print(f"epoch | lr")

Header row.

74for e in [0, 5, 9, 10, 50, 89, 90, 119, 120, 150, 199]:

Print key milestones - warmup start/mid/end, post-warmup, plateau triggers.

EXECUTION STATE
iter var: e = Specific epochs to inspect.
LOOP TRACE · 11 iterations
e = 0
lr = 1.00e-04 (warmup start)
e = 5
lr = 5.50e-04 (warmup mid)
e = 9
lr = 9.10e-04 (warmup end)
e = 10
lr = 1.00e-03 (post-warmup)
e = 50
lr = 1.00e-03 (still improving)
e = 89
lr = 1.00e-03 (waiting for patience)
e = 90
lr = 5.00e-04 (first cut)
e = 119
lr = 5.00e-04 (still waiting)
e = 120
lr = 2.50e-04 (second cut)
e = 150
lr = 1.25e-04 (third cut)
e = 199
lr = 1.56e-05 (after several cuts)
75print(f"{e:>5d} | {schedule[e]:.2e}")

Format-string row. :.2e is scientific with 2 decimals.

EXECUTION STATE
→ :>5d = Integer, right-aligned, min width 5.
→ :.2e = Float in scientific notation, 2 decimals (e.g. 1.00e-03).
77print(f"final lr : {schedule[-1]:.2e}")

Final-epoch lr.

EXECUTION STATE
Output = final lr : 1.56e-05
78print(f"decades dropped : {np.log10(1e-3 / max(schedule[-1], 1e-12)):.2f}")

Number of orders-of-magnitude reduction from base_lr to final.

EXECUTION STATE
📚 np.log10(arr) = Element-wise base-10 log.
📚 max(a, b) = Built-in - divide-by-zero guard.
Output = decades dropped : 1.81
→ reading = Almost 2 orders of magnitude. Final lr near 1.5e-5; min_lr is 5e-6 so we&apos;re close to the floor.
33 lines without explanation
1import numpy as np
2
3
4def warmup_factor(epoch: int, warmup_epochs: int = 10) -> float:
5    """Linear lr warmup over the first warmup_epochs.
6
7    Source: paper_ieee_tii/grace/training/callbacks.py:103-104
8        return base_lr * (0.1 + 0.9 * epoch / warmup_epochs)
9
10    Returns the multiplier (0.1 → 1.0) applied to base_lr.
11    """
12    if epoch >= warmup_epochs:
13        return 1.0
14    return 0.1 + 0.9 * epoch / warmup_epochs
15
16
17def reduce_on_plateau(val_metric:    list[float],
18                       patience:      int   = 30,
19                       factor:        float = 0.5) -> list[bool]:
20    """Per-epoch &quot;cut lr now&quot; flags using PyTorch&apos;s ReduceLROnPlateau rule.
21
22    Triggers when the validation metric has not improved (i.e. dropped below
23    its running min) for 'patience' epochs in a row.
24    """
25    best       = float("inf")
26    bad_epochs = 0
27    triggers:  list[bool] = []
28    for v in val_metric:
29        if v < best:
30            best, bad_epochs = v, 0
31        else:
32            bad_epochs += 1
33
34        if bad_epochs >= patience:
35            triggers.append(True)
36            bad_epochs = 0                       # reset counter after firing
37        else:
38            triggers.append(False)
39    return triggers
40
41
42def build_schedule(base_lr:        float = 1e-3,
43                    min_lr:         float = 5e-6,
44                    factor:         float = 0.5,
45                    patience:       int   = 30,
46                    n_epochs:       int   = 200,
47                    val_metric:     list[float] | None = None) -> list[float]:
48    """Combine warmup + ReduceLROnPlateau into one (n_epochs,) lr trajectory."""
49    if val_metric is None:
50        val_metric = [1.0] * n_epochs            # never improves ⇒ keeps cutting
51
52    triggers = reduce_on_plateau(val_metric, patience, factor)
53    lrs:     list[float] = []
54    cur_lr   = base_lr
55    for e in range(n_epochs):
56        wf = warmup_factor(e)
57        if e < 10:                                # warmup phase
58            lr = base_lr * wf
59        elif triggers[e]:
60            cur_lr = max(cur_lr * factor, min_lr)
61            lr = cur_lr
62        else:
63            lr = cur_lr
64        lrs.append(lr)
65    return lrs
66
67
68# ---------- Worked example ----------
69np.random.seed(0)
70val_rmse  = np.linspace(20, 8, 60).tolist() + [8.0] * 140       # plateau after epoch 60
71schedule  = build_schedule(val_metric=val_rmse)
72
73print(f"epoch  | lr")
74for e in [0, 5, 9, 10, 50, 89, 90, 119, 120, 150, 199]:
75    print(f"{e:>5d}  | {schedule[e]:.2e}")
76print()
77print(f"final lr        : {schedule[-1]:.2e}")
78print(f"decades dropped : {np.log10(1e-3 / max(schedule[-1], 1e-12)):.2f}")

PyTorch: The Paper's Stack

Three factory functions plus a smoke test. Each function returns the exact paper-canonical config from paper_ieee_tii/experiments/train_amnl_v7.py.

build_optimizer() + build_scheduler() + warmup_lr()
🐍optim_stack_torch.py
1import torch

Top-level PyTorch.

EXECUTION STATE
📚 torch = Tensor library + autograd + nn modules + optim.
2import torch.optim as optim

Optimisers and learning-rate schedulers live in this submodule.

EXECUTION STATE
📚 torch.optim = Provides Adam, AdamW, SGD, RMSprop, plus optim.lr_scheduler.* schedulers.
3import torch.nn as nn

Module containers - we use nn.Linear for the smoke test.

6def build_optimizer(model, learning_rate=1e-3, weight_decay=1e-4) -> optim.Optimizer:

Paper-canonical AdamW factory from <code>paper_ieee_tii/experiments/train_amnl_v7.py</code> lines 480-486.

EXECUTION STATE
⬇ input: model = An nn.Module (any). The optimiser will read .parameters() off it.
⬇ input: learning_rate = 1e-3 = Paper default. Pre-warmup, post-plateau-cuts this is the steady-state value.
⬇ input: weight_decay = 1e-4 = Paper default. Decoupled L2 regularisation - keeps weights small without distorting Adam&apos;s second-moment estimate.
⬆ returns = An AdamW instance.
9return optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.999), eps=1e-8)

AdamW = Adam with decoupled weight decay. The decoupling matters: standard L2 in Adam produces effective weight decay that depends on the per-parameter √v̂ - making the regularisation strength implicitly per-parameter. AdamW applies wd directly to the weights, independent of the second-moment estimate.

EXECUTION STATE
📚 optim.AdamW(params, lr, betas, eps, weight_decay, amsgrad) = Loshchilov &amp; Hutter, 2019. Decouples L2 from Adam&apos;s adaptive update. PyTorch ≥ 1.0.
⬇ arg 1: params = model.parameters() = Iterator over all learnable params.
⬇ arg 2: lr = learning_rate = 1e-3 = Initial step size. Adam&apos;s adaptive scaling means this is more like a maximum than an actual step magnitude.
⬇ arg 3: weight_decay = 1e-4 = Decoupled wd. Update rule: θ ← θ - lr · m̂/√v̂ - lr · wd · θ. Note the SECOND lr·wd·θ term - this is what makes it &lsquo;decoupled&rsquo; from m̂/√v̂.
⬇ arg 4: betas = (0.9, 0.999) = (β₁, β₂) - first and second-moment EMA decay rates. Canonical defaults.
⬇ arg 5: eps = 1e-8 = Numerical stabiliser added to √v̂. Without it the very first step (v=0) would divide by zero.
⬆ result = torch.optim.AdamW instance.
18def build_scheduler(optimizer) -> optim.lr_scheduler.ReduceLROnPlateau:

Paper-canonical scheduler factory from <code>paper_ieee_tii/experiments/train_amnl_v7.py</code> lines 489-496.

EXECUTION STATE
⬇ input: optimizer = An optim.Optimizer that the scheduler will modify (specifically its param_groups[i]['lr']).
⬆ returns = ReduceLROnPlateau instance.
23return optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=30, min_lr=5e-6)

Watches a validation metric; when it stops improving for `patience` epochs, multiplies the optimiser&apos;s lr by `factor`. Paper-canonical hyperparameters.

EXECUTION STATE
📚 optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps, verbose) = Watches a metric and reduces lr when it stops improving. Different from cosine / step schedulers - this one is REACTIVE to actual progress.
⬇ arg 1: optimizer = The optim.Optimizer to throttle.
⬇ arg 2: mode = 'min' = Watch a metric expected to DECREASE (validation RMSE). 'max' for accuracy-style metrics that should increase.
⬇ arg 3: factor = 0.5 = Halve lr each trigger. Paper choice. Default factor in PyTorch is 0.1 (more aggressive); paper picked 0.5 to be gentler.
⬇ arg 4: patience = 30 = Number of epochs without improvement before firing. Paper default for 200-epoch training.
⬇ arg 5: min_lr = 5e-6 = Floor - the scheduler stops cutting once lr reaches this value.
→ why min_lr? = Without it, after enough plateaus lr → 0 and training stalls. The floor lets training continue at a fixed (small) step.
⬆ result = ReduceLROnPlateau scheduler instance. Calls scheduler.step(val_metric) once per epoch.
33def warmup_lr(epoch, base_lr=1e-3, warmup_epochs=10) -> float:

Linear-warmup helper. Same formula as the NumPy block. Paper&apos;s LRWarmup callback uses this directly.

EXECUTION STATE
⬇ input: epoch = Current epoch.
⬇ input: base_lr = 1e-3 = Target lr after warmup.
⬇ input: warmup_epochs = 10 = Number of warmup epochs.
⬆ returns = lr value for this epoch.
36if epoch >= warmup_epochs:

Past the warmup window.

37return base_lr

Constant after warmup.

38return base_lr * (0.1 + 0.9 * epoch / warmup_epochs)

Linear interpolation - same formula as the NumPy block.

EXECUTION STATE
→ at epoch 0 = 1e-3 · 0.10 = 1.0e-04
→ at epoch 9 = 1e-3 · 0.91 = 9.1e-04
⬆ returns = Float in [1.0e-4, 9.1e-4) during warmup.
42torch.manual_seed(0)

Repro.

EXECUTION STATE
📚 torch.manual_seed(s) = Set the global PyTorch PRNG.
⬇ arg: s = 0 = Conventional canonical seed.
43model = nn.Linear(64, 1)

Tiny stand-in for DualTaskModel. The optimiser stack is model-agnostic; same scheduler logic applies regardless of architecture.

EXECUTION STATE
📚 nn.Linear(in_features, out_features, bias) = A single fully-connected layer. We just need SOMETHING with parameters for the smoke test.
⬇ args = in_features=64, out_features=1.
44optimizer = build_optimizer(model, learning_rate=1e-3)

Paper AdamW.

45scheduler = build_scheduler(optimizer)

Paper ReduceLROnPlateau.

47print(f"{'epoch':>5s} | {'lr (post-step)':>15s} | {'phase':>10s}")

Header.

EXECUTION STATE
→ :>5s = String, right-aligned, min width 5.
→ :>15s = String, right-aligned, min width 15.
→ :>10s = String, right-aligned, min width 10.
48val_rmse = [20 - 0.2 * e for e in range(60)] + [8.0] * 140

Synthetic validation curve. List comprehension generates the descending segment; the [8.0] * 140 multiplies a 1-element list to get the plateau.

EXECUTION STATE
→ list comprehension = [expr for var in iter] - lazy syntax for [transformed values].
→ list multiplication = [8.0] * 140 = list of 140 copies of 8.0.
⬆ result: val_rmse = 200-element list. Decreases linearly for 60 epochs then stays flat.
50for epoch in range(200):

Per-epoch training loop.

EXECUTION STATE
📚 range(stop) = Lazy iterator over [0, stop).
52if epoch < 10:

Warmup window.

53for pg in optimizer.param_groups:

An optimiser holds one or more &lsquo;parameter groups&rsquo; - dicts with lr, weight_decay, etc. Modifying pg['lr'] in-place changes the lr the optimiser uses on the next step. We don&apos;t use .set_lr() because PyTorch doesn&apos;t expose one.

EXECUTION STATE
📚 optimizer.param_groups = List of dicts, one per parameter group. Each dict has 'params', 'lr', 'weight_decay', 'betas', etc.
→ why iterate? = Most models have one param group, but some pipelines have separate groups for &lsquo;backbone with low lr&rsquo; and &lsquo;heads with high lr&rsquo;. Iterating handles both cases.
54pg["lr"] = warmup_lr(epoch)

Set the lr for this group. Subsequent .step() calls use this value.

58if epoch >= 10:

Past warmup ⇒ let the plateau scheduler manage the lr.

59scheduler.step(val_rmse[epoch])

ReduceLROnPlateau watches the metric we feed it. <code>.step(metric)</code> records the metric and (if no improvement for `patience` epochs) cuts the optimiser&apos;s lr in place.

EXECUTION STATE
📚 scheduler.step(metrics) = PyTorch ReduceLROnPlateau-specific signature - takes a metric. OTHER schedulers (CosineAnnealingLR, StepLR) take no argument.
⬇ arg: metrics = val_rmse[epoch] = Validation metric for this epoch.
→ effect = Updates internal state. If patience exceeded, multiplies pg['lr'] by factor=0.5 in-place.
61cur_lr = optimizer.param_groups[0]["lr"]

Read the current lr for logging. Index [0] picks the first (and usually only) group.

62if epoch in (0, 5, 9, 10, 50, 89, 90, 120, 150, 199):

Print only at key milestones.

EXECUTION STATE
→ membership test = x in tuple is O(n) but fast for small n. Set membership would be O(1) but unnecessary here.
63phase = "warmup" if epoch < 10 else ("plateau" if cur_lr < 1e-3 else "steady")

Nested ternary - classifies the current epoch into one of three named phases.

EXECUTION STATE
→ ternary = Python: a if cond else b - inline if-expression. Equivalent to (cond ? a : b) in C-style languages.
64print(f"{epoch:>5d} | {cur_lr:>15.2e} | {phase:>10s}")

Format-string row.

EXECUTION STATE
→ :>5d = Integer, right-aligned, width 5.
→ :>15.2e = Float in scientific, width 15, 2 decimals.
Output (one realisation) = epoch | lr (post-step) | phase 0 | 1.00e-04 | warmup 5 | 5.50e-04 | warmup 9 | 9.10e-04 | warmup 10 | 1.00e-03 | steady 50 | 1.00e-03 | steady 89 | 1.00e-03 | steady 90 | 5.00e-04 | plateau 120 | 2.50e-04 | plateau 150 | 1.25e-04 | plateau 199 | 1.56e-05 | plateau
→ reading = Lr starts at 1e-4 (10% of base), ramps to 1e-3 over 10 epochs, holds at 1e-3 through 89, then halves at every plateau trigger. Final lr is 1.5e-5 - close to the 5e-6 floor but not quite there.
37 lines without explanation
1import torch
2import torch.optim as optim
3import torch.nn as nn
4
5# Source: paper_ieee_tii/experiments/train_amnl_v7.py:480-496
6def build_optimizer(model: nn.Module,
7                     learning_rate: float = 1e-3,
8                     weight_decay:  float = 1e-4) -> optim.Optimizer:
9    """Build the paper-canonical AdamW optimiser."""
10    return optim.AdamW(
11        model.parameters(),
12        lr=learning_rate,
13        weight_decay=weight_decay,
14        betas=(0.9, 0.999),
15        eps=1e-8,
16    )
17
18
19def build_scheduler(optimizer: optim.Optimizer) -> optim.lr_scheduler.ReduceLROnPlateau:
20    """Build the paper-canonical ReduceLROnPlateau scheduler.
21
22    factor=0.5, patience=30, min_lr=5e-6.
23    """
24    return optim.lr_scheduler.ReduceLROnPlateau(
25        optimizer,
26        mode="min",
27        factor=0.5,
28        patience=30,
29        min_lr=5e-6,
30    )
31
32
33def warmup_lr(epoch: int, base_lr: float = 1e-3,
34                warmup_epochs: int = 10) -> float:
35    """Apply linear warmup. Source: paper_ieee_tii/grace/training/callbacks.py:103-104."""
36    if epoch >= warmup_epochs:
37        return base_lr
38    return base_lr * (0.1 + 0.9 * epoch / warmup_epochs)
39
40
41# ---------- Smoke test ----------
42torch.manual_seed(0)
43model     = nn.Linear(64, 1)                                     # tiny stand-in
44optimizer = build_optimizer(model, learning_rate=1e-3)
45scheduler = build_scheduler(optimizer)
46
47print(f"{'epoch':>5s} | {'lr (post-step)':>15s} | {'phase':>10s}")
48val_rmse = [20 - 0.2 * e for e in range(60)] + [8.0] * 140        # plateau
49
50for epoch in range(200):
51    # 1. warmup overrides during early epochs
52    if epoch < 10:
53        for pg in optimizer.param_groups:
54            pg["lr"] = warmup_lr(epoch)
55
56    # 2. (training would go here)
57    # 3. plateau scheduler step (uses validation metric)
58    if epoch >= 10:
59        scheduler.step(val_rmse[epoch])
60
61    cur_lr = optimizer.param_groups[0]["lr"]
62    if epoch in (0, 5, 9, 10, 50, 89, 90, 120, 150, 199):
63        phase = "warmup" if epoch < 10 else ("plateau" if cur_lr < 1e-3 else "steady")
64        print(f"{epoch:>5d} | {cur_lr:>15.2e} | {phase:>10s}")

Same Stack, Other Domains

The (AdamW + warmup + plateau) recipe transfers wherever you train a deep model with a reactive validation metric. Tune the patience to your epoch budget; everything else stays.

DomainlrwdwarmuppatienceNotes
RUL prediction (this book)1e-31e-41030paper default
BERT-base fine-tune (NLP)5e-50.01500 stepslinear decaywarmup steps not epochs
Vision Transformer (ImageNet)1e-40.055000 stepscosineuses cosine decay instead of plateau
Stable Diffusion fine-tune1e-51e-2100 stepsconstantwd higher to prevent collapse
GAN training2e-400n/awd=0 for generators - typically
Reinforcement learning (PPO)3e-400n/ano decay - throughput matters more
One rule of thumb. Set warmup to the number of epochs/steps it takes for Adam's v^\hat{v} to stabilise on your data - typically the equivalent of one full epoch worth of batches. AMNL's 10 epochs ≈ one full pass through C-MAPSS with batch size 32.

Three Optimiser Pitfalls

Pitfall 1: Using Adam instead of AdamW. Plain Adam with weight_decay>0 silently couples the L2 term with v^\sqrt{\hat{v}} ⇒ near-failure parameters with large gradients get effectively LESS regularisation. Always use AdamW for any AMNL-style sample weighting.
Pitfall 2: Calling scheduler.step() during warmup. ReduceLROnPlateau records the val metric on every call. If you call scheduler.step(val) during the warmup window (where lr is being externally set by the warmup callback), the scheduler's internal ‘best’ gets corrupted and the first post-warmup cut fires too early. Paper trainer skips scheduler.step() until epoch ≥ warmup_epochs.
Pitfall 3: Forgetting the min_lr floor. Without min_lr, after enough plateaus lr → 0 and training stalls completely. Paper sets min_lr=5e-6 (about 200× below base) so even the worst-case multi-plateau training never grinds to a complete halt.
The point. AdamW + warmup + plateau is a three-piece stack that handles the early-instability, steady-state, and late-stall regimes of training with minimal hyperparameter sensitivity. §15.3 adds the two remaining tricks: gradient clipping and weight EMA.

Takeaway

  • AdamW(lr=1e-3, wd=1e-4). Decoupled weight decay - critical with sample-weighted losses.
  • Warmup 0.1 → 1.0 over 10 epochs. Lets Adam's v^\hat{v} stabilise before letting it drive the step size.
  • ReduceLROnPlateau(factor=0.5, patience=30, min_lr=5e-6). Reactive scheduler. 50% cut after 30 stall epochs; floor at 5×10⁻⁶.
  • ~1-2 decades of lr drop over 200 epochs on typical AMNL FD002 training.
  • Skip scheduler.step() during warmup - paper trainer guards this explicitly.
Loading comments...