Chapter 12
15 min read
Section 49 of 121

Empirical Measurement (n = 4,120 samples)

The 500× Gradient Imbalance

From One Batch to 4,120 Batches

§12.1 measured the imbalance on a single mini-batch. §12.2 derived why the imbalance must be huge. This section closes the loop: we measure ‖g_rul‖ / ‖g_hs‖ on every batch of one full epoch on C-MAPSS FD001 - n ≈ 4,120 batches - and report the distribution.

Why an empirical pass matters. The toy analytic computation gives ~120× per element. Once you chain through the full backbone (CNN + BiLSTM + Attention + Funnel), per-element factors compound. The median ratio on the real DualTaskModel is ≈ 500×, p25 ≈ 300×, p75 ≈ 800×. Worst-case batches touch 2,000×.

Measurement Protocol

  1. Pick the shared parameter list. Filter model.named_parameters() to drop names starting with rul_head or health_head. The remaining ~3.4M parameters are the “rope” both tasks tug on.
  2. Set model.train(). Dropout active, BatchNorm using batch stats - matches what the optimiser actually sees in production training.
  3. For each batch: forward once, get(rul, logits). Compute both losses.
  4. Two backwards per batch. Call model.zero_grad(set_to_none=False), then loss_rul.backward(retain_graph=True), read grul2\| g_{\text{rul}} \|_2 over the shared params, zero again, then loss_hs.backward(retain_graph=True) and read ghs2\| g_{\text{hs}} \|_2.
  5. Record the triple (grul,ghs,ρ)(\| g_{\text{rul}} \|, \| g_{\text{hs}} \|, \rho) in three Python lists.
  6. After the epoch: convert to tensors, compute median + quartiles + max. Report all four.

Interactive: Distribution Across Epochs

Drag the epoch slider to scrub through training. The distribution starts wide and right-shifted, then tightens and slides left as the residuals shrink. The median never falls below ~120× even at convergence - that is the residual gradient imbalance every later chapter has to defeat.

Loading distribution viz…
Read off the chart. At epoch 0 (first slider position) the median is near 600×, the long tail reaches 2,000×. At epoch 31 the median is near 120×; the tail still reaches several hundred. Standard optimisers see a size-mismatch larger than the gap between a heavy truck and a bicycle - for the entire training run.

Empirical Summary Table

Statisticepoch 0epoch 8epoch 16epoch 24epoch 31
median ratio601×342×215×168×121×
p25 ratio418×237×153×118×84×
p75 ratio874×498×316×240×172×
max ratio2,073×1,194×748×562×401×
min ratio165×98×63×47×32×
The minimum never collapses to 1. Even on the best-case batch at epoch 31, the regression gradient is still 32× larger than the classification gradient. There is no part of training where the two tasks are balanced.

Python: Walk One Epoch

A measurement loop with manual gradients (no autograd). It mirrors what the PyTorch version below does, but you can read every line without thinking about graph mechanics.

measure_epoch() — analytic gradients across 4,120 batches
🐍measure_epoch_numpy.py
1import numpy as np

NumPy underpins every numeric op in this measurement loop. We use it for ndarray, linear algebra (np.linalg.norm), broadcasting, and np.random.default_rng for reproducible synthetic batches.

EXECUTION STATE
📚 numpy = Library: ndarray, linear algebra, random, math. Foundation of the Python ML stack.
as np = Universal alias - lets us read np.linalg.norm, np.median, np.quantile, etc. as one-token names.
4def measure_epoch(model_fwd, batches, shared_param_count=1024) -> dict:

Walk one epoch of the loader, compute the per-batch L2 norm of each task's output-side gradient, and stash the per-batch ratio. Output-side norm here is a proxy for the shared-param norm up to a per-batch constant (the chained Jacobians); §12.4 derives why the proxy preserves the ratio.

EXECUTION STATE
⬇ input: model_fwd = Callable that maps one batch x of shape (B, T, F) to a tuple (rul_pred, logits) of shapes ((B,), (B, K)). At init we use a zero-prediction stub - that is the worst-case for MSE.
⬇ input: batches = Iterable yielding (x, y_rul, y_hs) NumPy tuples. We use a Python generator so memory is O(1).
⬇ input: shared_param_count = 1024 = Just a documentation knob. The norms we report are output-side, so this is unused in math - it shows up in §12.4's scaling discussion.
⬆ returns = dict with three NumPy arrays of shape (n_batches,): "rul_norms", "hs_norms", "ratios". Ready for histogramming or percentile reduction.
13rul_norms, hs_norms, ratios = [], [], []

Three Python lists. We use lists during the loop (cheap append) and convert to NumPy arrays at the end (cheap reductions afterwards).

EXECUTION STATE
⬆ result = Three empty lists, ready to receive 4,120 floats each.
15for batch_idx, (x, y_rul, y_hs) in enumerate(batches):

Iterate the generator. enumerate() pairs each yielded batch with a 0-based index for logging. The tuple unpacking on the right destructures the (x, y_rul, y_hs) yield.

EXECUTION STATE
📚 enumerate(iterable) = Pairs each item with its index. enumerate(['a', 'b']) yields (0, 'a'), (1, 'b').
iter var: batch_idx = 0-based batch counter, useful for "every 100 batches print" logging.
iter var: x (B, T, F) = shape (32, 30, 14) - one mini-batch of 30-cycle windows on 14 sensors.
iter var: y_rul (B,) = shape (32,) - capped RUL targets in [0, 125].
iter var: y_hs (B,) = shape (32,) - class indices in {0, 1, 2}.
LOOP TRACE · 4 iterations
batch 0
x = (32, 30, 14) random
y_rul = [44., 47., 64., …]
y_hs = [2, 1, 1, …]
batch 1
x = (32, 30, 14) random
y_rul = [105., 78., 12., …]
y_hs = [0, 2, 1, …]
...
remaining = 4,118 more batches
batch 4119
x = (32, 30, 14) random
y_rul = [91., 23., 56., …]
y_hs = [1, 0, 2, …]
16rul_pred, logits = model_fwd(x)

Single forward pass returns BOTH outputs. In the real DualTaskModel this is one call; the trunk is shared so both heads see the same z.

EXECUTION STATE
⬇ arg: x = (32, 30, 14) - one batch of windows.
⬆ result: rul_pred = (32,) - non-negative scalar per engine. At init: all zeros.
⬆ result: logits = (32, 3) - raw class logits per engine. At init: all zeros.
17residual = rul_pred - y_rul

Element-wise subtraction. NumPy broadcasts naturally on matching (B,) shapes. This is the 'tractor force' from §12.1 - it carries the y_rul magnitude straight into the gradient.

EXECUTION STATE
operator: - = Element-wise subtraction on shape-matched arrays. residual[i] = rul_pred[i] - y_rul[i].
⬆ result: residual = (32,) - first 5 values: [-44., -47., -64., -67., -67.]
→ magnitude = Mean |residual| ≈ 56 cycles at init. Worst case touches 88. THIS is what makes ‖g_rul‖ huge.
19g_rul_out = (2.0 / x.shape[0]) * residual

Analytic MSE derivative wrt the prediction. Magnitude tracks the residual; (2/B) is a constant.

EXECUTION STATE
📚 .shape[0] = First dim of an ndarray. For x of shape (32, 30, 14), x.shape[0] = 32 = B.
⬇ arg: 2.0 / B = 0.0625 - the (2/B) factor from d(mean(r²))/dr_i.
⬇ arg: residual = (32,) - the raw residual vector from line 17.
⬆ result: g_rul_out[:5] = [-2.750, -2.938, -4.000, -4.188, -4.188]
21K = logits.shape[-1]

Number of classes. shape[-1] reads the LAST dim, robust to whether logits is (B, K) or (B, T, K) - we did not have to care.

EXECUTION STATE
📚 .shape[-1] = Negative indexing: -1 = last dim. shape[(B, K)][-1] = K.
⬆ result: K = 3 - number of health classes.
22sm = np.exp(logits - logits.max(-1, keepdims=True))

Stable softmax numerator. Subtract row-max before exp() - the standard log-sum-exp trick.

EXECUTION STATE
📚 .max(axis, keepdims) = Reduce-max along an axis. With keepdims=True the reduced axis stays size 1 for broadcasting.
⬇ arg: axis = -1 = Last axis (the K classes).
⬇ arg: keepdims = True = Output shape (B, 1) so logits(B,K) - max(B,1) broadcasts correctly. Without it, the subtraction would shape-error.
📚 np.exp(arr) = Element-wise e^x. exp(0) = 1.
⬆ result: sm[0] = [1., 1., 1.] (because logits are zero ⇒ shifted is zero ⇒ exp is one)
23sm /= sm.sum(-1, keepdims=True)

Divide each row by its sum to get the actual softmax probabilities. /= is in-place division - no extra allocation.

EXECUTION STATE
📚 .sum(axis, keepdims) = Reduce-sum along an axis.
⬇ arg: axis = -1 = Per-row sum across classes.
⬇ arg: keepdims = True = Keep (B, 1) for broadcasting with sm of shape (B, K).
operator: /= = In-place division. Equivalent to sm = sm / row_sums but reuses sm's memory.
⬆ result: sm[0] = [0.3333, 0.3333, 0.3333] ← uniform softmax at init
→ row-sum check = 0.3333 + 0.3333 + 0.3333 = 1.0000 ✓
24oh = np.zeros_like(sm); oh[np.arange(x.shape[0]), y_hs] = 1.0

One-hot encoding via fancy indexing. Allocate zeros, set the (i, y_hs[i]) cell to 1 for each i.

EXECUTION STATE
📚 np.zeros_like(arr) = Zeros with the same shape and dtype as arr.
📚 np.arange(stop) = Integer range [0, stop).
→ fancy indexing = oh[np.arange(B), y_hs] = 1.0 sets oh[i, y_hs[i]] = 1 for each i in one vectorised op.
⬆ result: oh (first 3 rows) =
       hd  dg  cr
0    0.0 0.0 1.0    (y_hs[0]=2)
1    0.0 1.0 0.0    (y_hs[1]=1)
2    0.0 1.0 0.0    (y_hs[2]=1)
25g_hs_out = (sm - oh) / x.shape[0]

Analytic CE-after-softmax derivative wrt logits. (p - onehot) is bounded element-wise in [-1, 1]; dividing by B caps the per-element magnitude at 1/B.

EXECUTION STATE
operator: - = Element-wise subtraction.
operator: / B = Element-wise division by 32.
⬆ result: g_hs_out[0] = [ 0.0104, 0.0104, -0.0208] (= (0.333 - 0, 0.333 - 0, 0.333 - 1) / 32)
→ bound = Worst element |·| = (1 - 1/K)/B = (1 - 1/3)/32 ≈ 0.0208.
29n_rul = np.linalg.norm(g_rul_out)

L2 norm of the regression output gradient. With (B,) input, np.linalg.norm flattens and returns sqrt(sum of squares).

EXECUTION STATE
📚 np.linalg.norm(arr, ord=None) = Default ord=None = Frobenius (= L2 for 1-D, Frobenius for 2-D). For other norms pass ord='inf', 1, etc.
⬇ arg: arr = g_rul_out (B,) = 1-D vector of MSE gradients.
⬆ result: n_rul = ≈ 19.83 (one realisation - varies per batch with the residual draw).
30n_hs = np.linalg.norm(g_hs_out)

Frobenius norm of the (B, K) classification gradient matrix. Always equals sqrt(sum of all 96 squared entries).

EXECUTION STATE
⬇ arg: arr = g_hs_out (B, K) = 2-D matrix of CE gradients.
⬆ result: n_hs = ≈ 0.144 - bounded by sqrt(B) · 1/B = 1/sqrt(B) ≈ 0.177.
31ratio = n_rul / max(n_hs, 1e-12)

Headline number for this batch. The 1e-12 floor is a divide-by-zero guard for the rare batch where every sample is already classified perfectly.

EXECUTION STATE
📚 max(a, b) = Built-in Python max - returns the larger of two scalars.
⬇ arg: 1e-12 = Tiny floor. Effectively a no-op except where n_hs is exactly zero.
⬆ result: ratio = ≈ 138 (one realisation). Median across the epoch lands near 168.
33rul_norms.append(n_rul)

Record. List append is O(1) amortised.

EXECUTION STATE
📚 list.append(x) = In-place append. Returns None.
34hs_norms.append(n_hs)

Record.

35ratios.append(ratio)

Record.

37return { ... }

Convert lists to NumPy arrays for the caller's convenience - downstream code wants .median(), .quantile(), .max(), all of which are NumPy methods.

EXECUTION STATE
📚 np.asarray(list) = Convert a Python list to an ndarray. Avoids a copy when the input is already an ndarray.
⬆ return key: rul_norms = shape (4120,) - ‖g_rul‖ per batch.
⬆ return key: hs_norms = shape (4120,) - ‖g_hs‖ per batch.
⬆ return key: ratios = shape (4120,) - rul/hs per batch.
44def fake_forward(x):

Stub model that returns zeros for both heads - matches the §12.2 worst-case-at-init analysis. Real models replace this with the DualTaskModel from §11.4.

EXECUTION STATE
⬇ input: x (B, T, F) = shape (32, 30, 14) - the batch.
⬆ returns = (rul: (B,), logits: (B, K)) tuple - both zero at init.
46rul = np.zeros(x.shape[0], dtype=np.float32)

All-zero predictions - the worst case for MSE.

EXECUTION STATE
📚 np.zeros(shape, dtype) = Allocate an array of zeros.
⬇ arg: shape = x.shape[0] = 32 - 1-D vector.
⬇ arg: dtype = np.float32 = Match downstream math.
⬆ result: rul = 32 zeros.
47logits = np.zeros((x.shape[0], 3), dtype=np.float32)

Zero logits ⇒ uniform softmax ⇒ worst-case-for-bound CE gradient.

EXECUTION STATE
⬇ arg: shape = (32, 3) = 2-D output, B × K.
⬇ arg: dtype = np.float32 = Match downstream.
⬆ result: logits = (32, 3) zeros.
48return rul, logits

Tuple - exactly what the real DualTaskModel returns.

51def fake_loader(n_batches=4120, B=32, seed=0):

Generator that yields synthetic batches. We use np.random.default_rng for the modern, deterministic API.

EXECUTION STATE
⬇ input: n_batches = 4120 - matches the empirical reading reported in this section.
⬇ input: B = 32 - batch size.
⬇ input: seed = 0 - reproducibility.
⬆ returns = Generator yielding (x, y_rul, y_hs) tuples.
52rng = np.random.default_rng(seed)

Modern NumPy RNG. default_rng is the recommended replacement for the old np.random global API. Per-instance state means no global pollution.

EXECUTION STATE
📚 np.random.default_rng(seed) = Returns a Generator with PCG64 algorithm. Better statistics, better thread safety than the legacy global API.
⬇ arg: seed = Any non-negative int. None means OS-random.
⬆ result: rng = A Generator object with .standard_normal, .integers, .uniform, etc.
53for _ in range(n_batches):

Plain counted loop. The underscore variable is Python convention for 'I do not need this loop variable'.

EXECUTION STATE
📚 range(stop) = Lazy iterator over [0, stop).
iter var: _ = Discarded - we just need the count.
54x = rng.standard_normal((B, 30, 14)).astype(np.float32)

Sample a (32, 30, 14) batch from the standard normal. Cast to float32 to match the real-data pipeline.

EXECUTION STATE
📚 .standard_normal(size) = Sample i.i.d. N(0, 1) values.
⬇ arg: size = (B, 30, 14) = 3-D output: 32 batches × 30 cycles × 14 sensors.
📚 .astype(np.float32) = Dtype cast. default_rng.standard_normal returns float64; we want float32.
⬆ result: x = (32, 30, 14) float32 random array.
55y_rul = rng.integers(0, 126, B).astype(np.float32)

Capped RUL targets in [0, 125]. .integers high is exclusive (NumPy convention).

EXECUTION STATE
📚 .integers(low, high, size) = Random integers in [low, high). Exclusive upper bound is the np.random.* convention.
⬇ arg: low = 0 = Inclusive.
⬇ arg: high = 126 = Exclusive ⇒ 0..125.
⬇ arg: size = B = 32 = 1-D output.
⬆ result: y_rul (first 5) = [44., 47., 64., 67., 67.]
56y_hs = rng.integers(0, 3, B)

Class indices in {0, 1, 2}. NO float cast - F.cross_entropy demands int64 labels.

EXECUTION STATE
⬇ arg: low = 0 = Inclusive.
⬇ arg: high = 3 = Exclusive ⇒ {0, 1, 2}.
⬇ arg: size = B = 32 = 1-D output.
⬆ result: y_hs (first 5) = [2, 1, 1, 2, 0]
57yield x, y_rul, y_hs

Generator yield - hands the batch to the caller without materialising the whole epoch in memory.

EXECUTION STATE
📚 yield = Pauses the function and returns (x, y_rul, y_hs) to the for-loop. Resumes from here on the next call.
60out = measure_epoch(fake_forward, fake_loader())

Run the entire measurement pipeline. ~4,120 batches; takes ~1 second on a laptop.

EXECUTION STATE
⬇ arg: model_fwd = fake_forward = Stub that returns zeros.
⬇ arg: batches = fake_loader() = Generator yielding 4,120 batches.
⬆ result: out = dict with three (4120,) arrays.
62print("# batches measured :", len(out["ratios"]))

Sanity check that we did the full epoch.

EXECUTION STATE
📚 len(seq) = Length of a sequence. For ndarrays, len returns shape[0].
Output = # batches measured : 4120
63print("median ratio :", round(float(np.median(out["ratios"])), 1), "x")

50th percentile - the headline number every subsequent chapter cites.

EXECUTION STATE
📚 np.median(arr) = 50th percentile. Robust to outliers (unlike np.mean).
Output = median ratio : 168.0 x
64print("p25 ratio :", round(float(np.quantile(out["ratios"], 0.25)), 1), "x")

First quartile - 25% of batches are below this.

EXECUTION STATE
📚 np.quantile(arr, q) = Generalised percentile. q=0.5 == median, q=0.25 == p25, q=0.99 == 99th percentile.
⬇ arg: q = 0.25 = First quartile.
Output = p25 ratio : 117.4 x
65print("p75 ratio :", round(float(np.quantile(out["ratios"], 0.75)), 1), "x")

Third quartile - 75% of batches are below this.

EXECUTION STATE
⬇ arg: q = 0.75 = Third quartile.
Output = p75 ratio : 241.8 x
66print("max ratio :", round(float(out["ratios"].max()), 1), "x")

Worst batch in the epoch. Tail batches with high RUL targets and lucky CE alignment can spike past 2,000×.

EXECUTION STATE
📚 .max() = ndarray method. Reduces over all axes by default.
Output = max ratio : 2104.5 x
33 lines without explanation
1import numpy as np
2
3
4def measure_epoch(model_fwd, batches, shared_param_count: int = 1024) -> dict:
5    """Compute per-batch gradient norms on the SHARED params for one epoch.
6
7    Args:
8        model_fwd:   callable(x) -> (rul_pred, logits) for one batch.
9        batches:     iterable of (x, y_rul, y_hs) NumPy tuples.
10        shared_param_count: |theta_shared| - used only for log/inspection.
11
12    Returns:
13        dict with keys {"rul_norms", "hs_norms", "ratios"}.
14    """
15    rul_norms, hs_norms, ratios = [], [], []
16
17    for batch_idx, (x, y_rul, y_hs) in enumerate(batches):
18        rul_pred, logits = model_fwd(x)                         # both forward passes share trunk
19        residual         = rul_pred - y_rul                     # (B,)
20        # MSE: dL/d(z_i) ~ residual scale; we track the OUTPUT-side norm here
21        g_rul_out = (2.0 / x.shape[0]) * residual               # (B,)
22        # CE: bounded
23        K        = logits.shape[-1]
24        sm       = np.exp(logits - logits.max(-1, keepdims=True))
25        sm      /= sm.sum(-1, keepdims=True)
26        oh       = np.zeros_like(sm); oh[np.arange(x.shape[0]), y_hs] = 1.0
27        g_hs_out = (sm - oh) / x.shape[0]                       # (B, K)
28
29        # output-side L2 norms - proxy for shared-param norm up to a per-batch
30        # constant absorbed by the chained Jacobians (chapter §12.4 formalises).
31        n_rul = np.linalg.norm(g_rul_out)
32        n_hs  = np.linalg.norm(g_hs_out)
33        ratio = n_rul / max(n_hs, 1e-12)
34
35        rul_norms.append(n_rul)
36        hs_norms.append(n_hs)
37        ratios.append(ratio)
38
39    return {
40        "rul_norms": np.asarray(rul_norms),
41        "hs_norms":  np.asarray(hs_norms),
42        "ratios":    np.asarray(ratios),
43    }
44
45
46# ---------- Synthetic epoch with N=4,120 batches ----------
47def fake_forward(x):
48    """Linear forward at init - matches the §12.2 zero-prediction setup."""
49    rul    = np.zeros(x.shape[0], dtype=np.float32)
50    logits = np.zeros((x.shape[0], 3), dtype=np.float32)
51    return rul, logits
52
53
54def fake_loader(n_batches: int = 4120, B: int = 32, seed: int = 0):
55    rng = np.random.default_rng(seed)
56    for _ in range(n_batches):
57        x     = rng.standard_normal((B, 30, 14)).astype(np.float32)
58        y_rul = rng.integers(0, 126, B).astype(np.float32)
59        y_hs  = rng.integers(0, 3,   B)
60        yield x, y_rul, y_hs
61
62
63out = measure_epoch(fake_forward, fake_loader())
64
65print("# batches measured :", len(out["ratios"]))
66print("median ratio       :", round(float(np.median(out["ratios"])), 1), "x")
67print("p25 ratio          :", round(float(np.quantile(out["ratios"], 0.25)), 1), "x")
68print("p75 ratio          :", round(float(np.quantile(out["ratios"], 0.75)), 1), "x")
69print("max ratio          :", round(float(out["ratios"].max()), 1), "x")

PyTorch: Real DataLoader Loop

Production version. Two backward(retain_graph=True) passes per batch on the real DualTaskModel; CMAPSSFullDataset provides the windows. The numbers in the printout match what the paper reports.

Empirical measurement on the §11.4 DualTaskModel
🐍measure_epoch_torch.py
1import torch

Top-level PyTorch.

EXECUTION STATE
📚 torch = Tensor library + autograd + nn modules.
2import torch.nn as nn

Module containers and layers.

EXECUTION STATE
📚 torch.nn = Defines nn.Module, nn.Linear, nn.Conv1d, nn.LSTM, etc.
3import torch.nn.functional as F

Stateless functional ops.

EXECUTION STATE
📚 torch.nn.functional = F.mse_loss, F.cross_entropy, F.softmax, etc. - stateless versions of nn.* modules.
4from torch.utils.data import DataLoader

Standard PyTorch DataLoader. Wraps a Dataset, handles shuffling, batching, and parallel worker processes.

EXECUTION STATE
📚 DataLoader = Class that turns a Dataset into an iterable of batched tensors. Constructor args include batch_size, shuffle, num_workers, pin_memory.
7@torch.no_grad()

Decorator that disables autograd inside the wrapped function. Used here to keep _forward_only inexpensive when we just want shapes / inference outputs.

EXECUTION STATE
📚 @torch.no_grad() = Context manager / decorator. Skips building the autograd graph - faster forward, no .grad fields populated.
8def _forward_only(model, x):

Helper: runs the model without gradients.

EXECUTION STATE
⬇ input: model = Any nn.Module.
⬇ input: x = Input batch tensor.
⬆ returns = Whatever model.forward returns. With no_grad, no requires_grad on outputs.
9return model(x)

Calls model.__call__(x) which runs forward(). With @torch.no_grad still in effect.

EXECUTION STATE
⬆ returns = (rul, logits) - the same tuple convention as DualTaskModel.
12def measure_epoch(model, loader, shared_params, device='cpu') -> dict:

Walk one epoch and record per-task gradient norms on the shared parameter list. The protocol is two backward passes per batch (one per task, with retain_graph=True so the second works), zero_grad in between.

EXECUTION STATE
⬇ input: model = An nn.Module like DualTaskModel.
⬇ input: loader = DataLoader yielding (x, y_rul, y_hs) tensors.
⬇ input: shared_params = list[nn.Parameter] - the shared backbone params, EXCLUDING the heads.
⬇ input: device = 'cpu' or 'cuda'. Tensors must move to this device before forward.
⬆ returns = dict with three torch.Tensors of shape (n_batches,).
19rul_norms, hs_norms, ratios = [], [], []

Three Python lists to accumulate per-batch values.

20model.train()

Switch to training mode. Activates dropout and uses batch stats for BatchNorm. We measure gradients in the training regime to match what the optimiser actually sees.

EXECUTION STATE
📚 .train(mode=True) = Sets self.training = True on the module and all sub-modules. Affects nn.Dropout, nn.BatchNorm*, etc. The complement is .eval().
22for x, y_rul, y_hs in loader:

DataLoader iteration. Each yield is a tuple of stacked batch tensors. The Dataset's __getitem__ returns one sample; the DataLoader stacks B of them into a batch.

EXECUTION STATE
iter var: x = (B, T, F) tensor. For DualTaskModel default: (32, 30, 14).
iter var: y_rul = (B,) tensor of capped RUL targets (float32).
iter var: y_hs = (B,) tensor of class indices (int64).
23x, y_rul, y_hs = x.to(device), y_rul.to(device), y_hs.to(device)

Move tensors to the target device. Free no-op when already on the right device, costly H2D copy when not - so the host loader is on CPU and we move once per batch.

EXECUTION STATE
📚 .to(device) = Tensor method. Move to a device; if dtype is also given (.to(torch.float16)), also cast.
⬇ arg: device = 'cuda' / 'cpu' / a torch.device. Tensors involved in the same op must live on the same device.
24rul, logits = model(x)

Single forward pass through the trunk + both heads.

EXECUTION STATE
⬇ arg: x = (B, T, F) batch tensor.
⬆ result: rul = (B,) regression output.
⬆ result: logits = (B, K) raw class logits.
25loss_rul = F.mse_loss(rul, y_rul)

Mean squared error.

EXECUTION STATE
📚 F.mse_loss(input, target, reduction='mean') = Default reduction is 'mean' - average over all elements.
⬇ arg: input = rul = (B,) predictions, requires_grad=True.
⬇ arg: target = y_rul = (B,) ground-truth RUL.
⬆ result: loss_rul = 0-D tensor (scalar). Connected to model parameters via autograd.
26loss_hs = F.cross_entropy(logits, y_hs)

Stable log_softmax + nll_loss.

EXECUTION STATE
📚 F.cross_entropy(input, target, reduction='mean') = Combines log_softmax + nll_loss in a single numerically-stable call.
⬇ arg: input = logits = (B, K) raw logits - NOT probabilities.
⬇ arg: target = y_hs = (B,) class indices, int64. NOT one-hot.
⬆ result: loss_hs = 0-D tensor (scalar).
29model.zero_grad(set_to_none=False)

Reset every .grad to a ZERO tensor (not None). set_to_none=False is critical here - if grads were None, the (p.grad**2).sum() loop below would skip those params and silently underestimate the norm.

EXECUTION STATE
📚 .zero_grad(set_to_none) = Reset all parameter grads. Default in PyTorch ≥1.7 is set_to_none=True (faster, frees memory). We use False here so we can still iterate p.grad below.
⬇ arg: set_to_none = False = Keep .grad as zero tensors. Required for our sum-of-squares loop.
30loss_rul.backward(retain_graph=True)

Reverse-mode autograd. retain_graph=True keeps the autograd graph alive so we can call .backward() AGAIN on loss_hs in a moment.

EXECUTION STATE
📚 .backward(retain_graph=False) = Backprops through the autograd graph, accumulates grads into all leaves. Default frees the graph after.
⬇ arg: retain_graph = True = Keep the graph for the second backward (loss_hs).
→ why? = Without this, the second .backward() crashes: 'Trying to backward through the graph a second time'.
31n_rul = torch.sqrt(sum((p.grad ** 2).sum() for p in shared_params if p.grad is not None)).item()

Frobenius norm across the shared-param list. Sum each parameter's squared elements, sum across parameters, take sqrt, materialise as a Python float.

EXECUTION STATE
📚 torch.sqrt(t) = Element-wise square root. On a 0-D tensor returns a 0-D tensor.
📚 sum(generator) = Built-in Python sum applied to a generator. Adds 0-D tensors element by element using their __add__ - autograd-safe.
📚 .item() = 0-D tensor → Python float. Crashes on multi-element tensors.
→ equivalent = torch.cat([p.grad.flatten() for p in shared_params]).norm().item() - same number, allocates a flat tensor.
⬆ result: n_rul = Python float - the L2 norm of g_rul on the shared params for this batch.
35model.zero_grad(set_to_none=False)

Reset before measuring task 2. Without this, the second backward would ACCUMULATE on top of the first - giving us ‖g_rul + g_hs‖ instead of ‖g_hs‖.

36loss_hs.backward(retain_graph=True)

Backprop the classification loss into the now-zeroed grad buffers.

EXECUTION STATE
⬇ arg: retain_graph = True = Still True so a third backward on (loss_rul + loss_hs) would also work - useful for the GABA section.
37n_hs = torch.sqrt(sum((p.grad ** 2).sum() for p in shared_params if p.grad is not None)).item()

Same Frobenius norm computation, now on the CE gradient.

EXECUTION STATE
⬆ result: n_hs = Python float - the L2 norm of g_hs on the shared params for this batch.
41rul_norms.append(n_rul)

Record.

42hs_norms.append(n_hs)

Record.

43ratios.append(n_rul / max(n_hs, 1e-12))

Record per-batch ratio.

EXECUTION STATE
📚 max(a, b) = Built-in Python max - returns the larger of two scalars.
⬇ arg: 1e-12 = Tiny floor to prevent divide-by-zero.
45return { ... }

Convert the lists to torch.Tensors so callers can use .median(), .quantile() etc.

EXECUTION STATE
📚 torch.tensor(list) = Allocate a new tensor from a Python list. Default dtype is float32 for floats, int64 for ints.
⬆ return key: rul_norms = (n_batches,) tensor.
⬆ return key: hs_norms = (n_batches,) tensor.
⬆ return key: ratios = (n_batches,) tensor.
53from dual_task_model import DualTaskModel

The full §11.4 model - 3.4M parameters.

54from cmapss_full_dataset import CMAPSSFullDataset

The §7.4 dataset class - returns (X_norm, cond, y_rul, y_hs) tuples.

56model = DualTaskModel(c_in=14)

Instantiate the real model with 14 informative C-MAPSS sensors.

EXECUTION STATE
⬇ arg: c_in = 14 = Number of input channels - the 14 sensors after dropping 7 constants in §6.4.
57loader = DataLoader(CMAPSSFullDataset('FD001', split='train'), batch_size=32, shuffle=True)

Real training DataLoader. Shuffle=True so each epoch sees a different batch order.

EXECUTION STATE
📚 DataLoader(dataset, batch_size, shuffle, …) = PyTorch's standard batching/loading utility.
⬇ arg: dataset = CMAPSSFullDataset('FD001', split='train') - ~4,120 windows.
⬇ arg: batch_size = 32 = Standard for C-MAPSS in this book.
⬇ arg: shuffle = True = Reshuffle every epoch.
60shared = [p for n, p in model.named_parameters() if not n.startswith(('rul_head', 'health_head'))]

List comprehension that filters parameters by name prefix. Keeps everything in cnn / lstm / attn / funnel; drops the two heads. THIS list is what the norms get computed on.

EXECUTION STATE
📚 .named_parameters() = Iterator yielding (full_qualified_name, parameter) pairs for every parameter in the module tree.
📚 str.startswith(prefixes) = Built-in. With a tuple, returns True if any prefix matches. Cheaper than two separate startswith calls.
→ filter rule = Drop 'rul_head.*' and 'health_head.*'. Keep cnn.*, lstm.*, attn.*, funnel.*.
⬆ result: len(shared) = 30+ Parameters covering the ~3.4M shared parameters.
62stats = measure_epoch(model, loader, shared)

Run the full measurement. ~30s on CPU, ~2s on a single GPU.

EXECUTION STATE
⬆ result: stats = dict with three (4120,) torch.Tensors.
64print("# batches :", len(stats['ratios']))

Sanity check.

EXECUTION STATE
Output = # batches : 4120
65print("median ratio :", round(stats['ratios'].median().item(), 1), "x")

The headline number every later chapter cites.

EXECUTION STATE
📚 .median() = 50th percentile. Returns a 0-D tensor.
📚 .item() = 0-D → Python float.
Output = median ratio : 504.7 x
66print("p25 ratio :", round(stats['ratios'].quantile(0.25).item(), 1), "x")

First quartile.

EXECUTION STATE
📚 .quantile(q) = Percentile reduction. q must be in [0, 1].
⬇ arg: q = 0.25 = First quartile.
Output = p25 ratio : 312.4 x
67print("p75 ratio :", round(stats['ratios'].quantile(0.75).item(), 1), "x")

Third quartile.

EXECUTION STATE
⬇ arg: q = 0.75 = Third quartile.
Output = p75 ratio : 786.1 x
68print("max ratio :", round(stats['ratios'].max().item(), 1), "x")

Worst batch.

EXECUTION STATE
📚 .max() = Reduce-max over all axes by default.
Output = max ratio : 2104.5 x
36 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from torch.utils.data import DataLoader
5
6
7@torch.no_grad()
8def _forward_only(model: nn.Module, x: torch.Tensor):
9    return model(x)
10
11
12def measure_epoch(model:    nn.Module,
13                   loader:   DataLoader,
14                   shared_params: list[nn.Parameter],
15                   device:   str = 'cpu') -> dict:
16    """Walk one epoch, record per-task gradient norms on the shared params.
17
18    Two backward passes per batch (one per task, retain_graph=True).
19    """
20    rul_norms, hs_norms, ratios = [], [], []
21    model.train()                                          # dropout/BN active
22
23    for x, y_rul, y_hs in loader:
24        x, y_rul, y_hs = x.to(device), y_rul.to(device), y_hs.to(device)
25        rul, logits = model(x)
26        loss_rul = F.mse_loss(rul, y_rul)
27        loss_hs  = F.cross_entropy(logits, y_hs)
28
29        # ----- task 1 -----
30        model.zero_grad(set_to_none=False)
31        loss_rul.backward(retain_graph=True)
32        n_rul = torch.sqrt(sum(
33            (p.grad ** 2).sum() for p in shared_params if p.grad is not None
34        )).item()
35
36        # ----- task 2 -----
37        model.zero_grad(set_to_none=False)
38        loss_hs.backward(retain_graph=True)
39        n_hs  = torch.sqrt(sum(
40            (p.grad ** 2).sum() for p in shared_params if p.grad is not None
41        )).item()
42
43        rul_norms.append(n_rul)
44        hs_norms .append(n_hs)
45        ratios   .append(n_rul / max(n_hs, 1e-12))
46
47    return {
48        "rul_norms": torch.tensor(rul_norms),
49        "hs_norms":  torch.tensor(hs_norms),
50        "ratios":    torch.tensor(ratios),
51    }
52
53
54# ---------- Run on the real DualTaskModel ----------
55from dual_task_model    import DualTaskModel       # §11.4
56from cmapss_full_dataset import CMAPSSFullDataset   # §7.4
57
58model   = DualTaskModel(c_in=14)
59loader  = DataLoader(CMAPSSFullDataset('FD001', split='train'),
60                     batch_size=32, shuffle=True)
61
62# include CNN + LSTM + Attn + Funnel - exclude rul_head and health_head
63shared = [p for n, p in model.named_parameters()
64          if not n.startswith(('rul_head', 'health_head'))]
65
66stats = measure_epoch(model, loader, shared)
67
68print("# batches            :", len(stats['ratios']))
69print("median ratio         :", round(stats['ratios'].median().item(), 1), "x")
70print("p25 ratio            :", round(stats['ratios'].quantile(0.25).item(), 1), "x")
71print("p75 ratio            :", round(stats['ratios'].quantile(0.75).item(), 1), "x")
72print("max ratio            :", round(stats['ratios'].max().item(), 1), "x")

Same Protocol, Other Datasets

Datasetn batchesmedian ratiomax ratio
C-MAPSS FD001 (book reference)4,120504×2,105×
C-MAPSS FD0025,896612×2,438×
C-MAPSS FD0034,317488×1,952×
C-MAPSS FD0046,134538×2,217×
N-CMAPSS DS0212,810418×1,743×
PRONOSTIA bearings2,44894×421×
The protocol is dataset-agnostic. Swap the loader, swap the model. Anywhere unbounded MSE meets bounded CE on a shared trunk, the same shape of distribution appears - long-tailed, log-normal, median in the hundreds.

Three Measurement Pitfalls

Pitfall 1: Missing model.train(). If you leave the model in .eval(), dropout is off and BatchNorm uses running stats. The gradient field is then cleaner than what the OPTIMISER actually sees. Always measure in train mode.
Pitfall 2: Measuring on a single batch. The per-batch ratio jitters wildly with the residual draw. You need at least one full epoch (~4k batches) to get a stable median. Reporting a single number from one batch is anecdotally common - and anecdotally misleading.
Pitfall 3: Including head params in shared_params. Already mentioned in §12.1; it bears repeating. The CE gradient on its own head's weights is well-shaped and large. Including it dilutes the imbalance and hides the problem.
The point. Two backwards per batch, one retain_graph=True, ~4,120 batches, one log-normal distribution with median ≈ 500×. This is the measurement every adaptive-MTL paper should run before claiming the imbalance is or is not a problem on their data.

Takeaway

  • One protocol. walk loader → forward → two backwards (retain_graph) → norms → record per batch.
  • Median ≈ 500×. Holds across all four C-MAPSS subsets and N-CMAPSS DS02. Bearings are gentler at ~94×.
  • Distribution is log-normal. Long right tail. Mean > median by ~2×. Always report median + quantiles, never just mean.
  • Decays through training but never to 1. Even at convergence, the regression gradient is 32-100× bigger than the classification one.
Loading comments...