Chapter 12
18 min read
Section 48 of 121

Why MSE Gradients Dominate Cross-Entropy

The 500× Gradient Imbalance

Two Derivatives, Two Universes

§12.1 measured the imbalance. This section explains it. The gap is not an artefact of our model size or our optimiser - it falls straight out of the derivatives of MSE and CE.

Read the next four lines slowly. They are the entire chapter in compressed form.

LossDerivative wrt logit / predElement bound
MSE L = (1/B) Σ (ŷ - y)²∂L/∂ŷ_i = (2/B)(ŷ_i - y_i)unbounded - scales with |y|
CE L = -(1/B) Σ log p_y∂L/∂z_ik = (1/B)(p_ik - 1[k=y_i])≤ 1/B (always)
This is the entire mystery. One derivative is bounded; the other is not. With C-MAPSS RUL targets up to 125, the unbounded one is two-and-a-half orders of magnitude bigger before backprop even starts climbing the backbone.

MSE: Derivative Grows With Residual

Let ri=y^iyir_i = \hat{y}_i - y_i be the per-sample residual. Then

Lmse=1Bi=1Bri2        Lmsey^i=2Bri.\mathcal{L}_{\text{mse}} = \dfrac{1}{B} \sum_{i=1}^{B} r_i^2 \;\;\Longrightarrow\;\; \dfrac{\partial \mathcal{L}_{\text{mse}}}{\partial \hat{y}_i} = \dfrac{2}{B} r_i.

The gradient's magnitude is linear in the residual. With our capped RUL target (R_max = 125) and a freshly-initialised network predicting near zero, the per-sample residual is on the order of Eyrul62\mathbb{E}|y_{\text{rul}}| \approx 62 cycles. With B=32B = 32 the per-element gradient magnitude is (2/32)623.9(2/32) \cdot 62 \approx 3.9.

CE: Derivative Bounded By 1

Cross-entropy after softmax has a famously clean derivative:

Lce=1Bi=1Blogpi,yi        Lcezi,k=1B(pi,k1[k=yi]).\mathcal{L}_{\text{ce}} = -\dfrac{1}{B} \sum_{i=1}^{B} \log p_{i, y_i} \;\;\Longrightarrow\;\; \dfrac{\partial \mathcal{L}_{\text{ce}}}{\partial z_{i,k}} = \dfrac{1}{B} \bigl(p_{i,k} - \mathbb{1}[k = y_i]\bigr).

Because pi,k[0,1]p_{i,k} \in [0, 1] and 1[]{0,1}\mathbb{1}[\cdot] \in \{0, 1\}, every element of (ponehot)(p - \text{onehot}) is in [1,1][-1, 1]. Dividing by BB gives a per-element bound of 1/B1/B.

Symmetric upper bound, asymmetric reality. At init the softmax is uniform (1/K per class), so (ponehot)(p - \text{onehot}) sits at (11/K)/B(1-1/K)/B at the true class - already within a hair of the bound. At convergence the true class approaches probability 1 and the gradient on logits collapses to zero. The gradient is bounded by 1/B at init AND smaller than that everywhere else.

Interactive: Live Comparison

Drag the residual and probability sliders. Notice how the ratio crosses 100× when the residual goes above ~50 cycles (which it does at every freshly-seeded run on C-MAPSS).

Loading gradient comparator…
Try this. Set residual to 0. The MSE bar disappears. Now set residual to ±125. The MSE bar saturates the chart. The CE bar barely moves either way. The MSE gradient is the thing growing without limit; everything else about the imbalance is bookkeeping.

Python: NumPy Verification

Compute both gradients analytically (no autograd), then print the ratio per sample. The point of doing this from scratch is to make the bound visible: every per-element CE entry comes out at the same number, while the MSE entries spread up to the residual cap.

Analytic gradients from first principles
🐍mse_vs_ce_grad_numpy.py
1import numpy as np

NumPy is Python's numerical computing library. We use it for ndarray (fast vectorised arrays), broadcasting, and stable softmax. All math here runs as compiled C, not interpreted Python.

EXECUTION STATE
📚 numpy = Library: ndarray + linear algebra + random numbers + math. Foundation of every Python ML stack.
as np = Universal alias. np.exp, np.sum, np.max, np.linalg.norm, np.random.* are all read as one-token names.
8np.random.seed(0)

Reseed NumPy's default Mersenne-Twister PRNG so every run produces identical numbers. Without this, the residual values below would change between runs and the printed ratio would jitter.

EXECUTION STATE
📚 np.random.seed(seed) = Sets the global PRNG state. After seed(0) the next call to np.random.randint(...) returns a fixed deterministic stream.
⬇ arg: seed = 0 = Any non-negative int works. 0 is conventional for "canonical" runs.
9B = 32

Batch size. Both loss reductions divide by B, so this is the constant that scales every gradient below.

EXECUTION STATE
B = 32 = Number of (window, target) pairs in the batch. Real C-MAPSS training uses B=64; we pick 32 here for cleaner round numbers.
10K = 3

Number of health-state classes. K controls the bound 1/K on the softmax output at init.

EXECUTION STATE
K = 3 = {healthy, degrading, critical}. At init each class has probability 1/3 ≈ 0.333.
14y_rul = np.random.randint(0, 126, B).astype(np.float32)

Generate random RUL targets in [0, 125]. The cap (R_max=125) comes from §7.2. Cast to float32 because the gradient downstream is also float32 - mixing dtypes is a silent bug source.

EXECUTION STATE
📚 np.random.randint(low, high, size) = Returns ints uniformly drawn from [low, high). Note the exclusive upper bound - that is why we pass 126 to get values up to 125.
⬇ arg: low = 0 = Inclusive lower bound. Capped RUL never goes below 0.
⬇ arg: high = 126 = Exclusive upper bound. Use 126 to get inclusive 125.
⬇ arg: size = B = 32 = Output shape - one target per batch element.
📚 .astype(np.float32) = Dtype cast. randint returns int64 by default; the rest of the gradient math runs in float32.
⬆ result: y_rul = [ 44., 47., 64., 67., 67., 9., 83., 21., 36., 87., 70., 88., 88., 12., 58., 65., 39., 87., 46., 88., 81., 37., 25., 77., 72., 9., 20., 80., 69., 79., 47., 64.]
→ typical magnitude = E[|y_rul|] ≈ 62. This is the constant that the MSE gradient carries through.
15y_pred = np.zeros(B, dtype=np.float32)

Pretend the model is at INIT - it predicts 0 for every engine. This is the worst case for the MSE gradient because the residual is exactly the target magnitude. (After warmup the residual shrinks; near convergence it is tiny.)

EXECUTION STATE
📚 np.zeros(shape, dtype) = Allocate an array of zeros. shape can be a scalar (1-D) or a tuple (N-D).
⬇ arg: shape = B = 32 = 1-D vector with 32 entries.
⬇ arg: dtype = np.float32 = Match y_rul's dtype to avoid silent upcast.
⬆ result: y_pred = [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
16residual = y_pred - y_rul

Element-wise subtraction. NumPy broadcasts naturally: same shape (B,) on both sides. This is the "tractor force" in the rope metaphor of §12.1 - it carries the y_rul magnitude straight into the gradient.

EXECUTION STATE
operator: - = Element-wise subtraction. residual[i] = y_pred[i] - y_rul[i].
⬆ result: residual = [-44., -47., -64., -67., -67., -9., -83., -21., -36., -87., -70., -88., -88., -12., -58., -65., -39., -87., -46., -88., -81., -37., -25., -77., -72., -9., -20., -80., -69., -79., -47., -64.]
→ magnitude = Mean |residual| ≈ 56. Worst case at init touches 88. THIS is what makes ‖g_mse‖ large.
17loss_mse = np.mean(residual ** 2)

Mean-squared error. Square every residual, take the mean over the batch. Scalar output.

EXECUTION STATE
📚 np.mean(arr) = Arithmetic mean. With no axis, reduces to a single scalar.
operator: ** 2 = Element-wise squaring. residual[i] ** 2.
⬆ result: loss_mse = 3,748.6 (scalar - huge because residuals are huge)
→ why huge? = loss_mse ≈ E[(y_rul)²] ≈ E[y_rul]² + Var(y_rul) ≈ 62² + 1300 ≈ 3,800 at init. Compare with loss_ce ≈ 1.10 below.
20g_mse_pred = (2.0 / B) * residual

Analytic derivative of MSE. d(mean(r²))/dr_i = (2/B) · r_i. The factor (2/B) is constant - the per-element magnitude tracks |residual|.

EXECUTION STATE
📚 derivation = loss = (1/B) Σ r_i² ⇒ ∂loss/∂r_i = (2/B) r_i. With B=32: scale factor = 0.0625.
⬇ arg: 2.0 / B = 0.0625 - the constant outside.
⬇ arg: residual = (B,) the residual vector from line 16.
⬆ result: g_mse_pred[:5] = [-2.750, -2.938, -4.000, -4.188, -4.188]
→ typical |g_mse| = ≈ 0.0625 · 56 ≈ 3.5. Per element. Already ~100x bigger than what CE will produce.
24y_hs = np.random.randint(0, K, B)

Random class indices in {0, 1, 2}. Shape (32,). Same RNG stream as y_rul (drawn after).

EXECUTION STATE
📚 np.random.randint(0, K, B) = K=3 classes, B=32 samples. Output ints in {0, 1, 2}.
⬇ arg: low = 0 = Inclusive.
⬇ arg: high = K = 3 = Exclusive ⇒ values in {0, 1, 2}.
⬇ arg: size = B = 32 = One label per sample.
⬆ result: y_hs = [2, 1, 1, 2, 0, 1, 1, 0, 0, 1, 2, 1, 2, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 0, 1, 0, 0, 1, 0, 2, 0]
25logits = np.zeros((B, K), dtype=np.float32)

Pretend logits at INIT are all zero. This is the WORST case for CE's "wake-up" gradient - the softmax is exactly uniform and (p - y_onehot) is at its bounded maximum magnitude.

EXECUTION STATE
📚 np.zeros(shape, dtype) = Allocate zeros - here a 2-D matrix.
⬇ arg: shape = (B, K) = (32, 3) = Tuple - 2-D output, 32 samples × 3 logits each.
⬇ arg: dtype = np.float32 = Match downstream math.
⬆ result: logits (first 3 rows) =
      hd   dg   cr
0   0.00 0.00 0.00
1   0.00 0.00 0.00
2   0.00 0.00 0.00
...
26shift = logits - logits.max(-1, keepdims=True)

Log-sum-exp stabilisation step. Subtract row-max before exp() to prevent overflow. Mathematically softmax(x - c) = softmax(x), so shifting is invariant.

EXECUTION STATE
📚 .max(axis, keepdims) = Reduces over an axis to find the max. With keepdims=True the reduced axis stays as size 1 so broadcasting works in the subtraction.
⬇ arg: axis = -1 = Last axis (the K=3 class axis). axis=0 would reduce across samples - wrong.
⬇ arg: keepdims = True = Output shape (B, 1) instead of (B,). Needed so logits(32,3) - max(32,1) broadcasts correctly. Without keepdims you'd get a shape error.
⬆ result: shift =
All zeros (because logits is all zeros - max minus max = 0). Shape (32, 3).
27log_p = shift - np.log(np.exp(shift).sum(-1, keepdims=True))

log-softmax via the stable identity: log p_k = z_k - log Σ exp(z_j). For the all-zero shift the denominator is log(K) = log 3 ≈ 1.0986.

EXECUTION STATE
📚 np.exp(arr) = Element-wise e^x. exp(0) = 1.
📚 .sum(axis, keepdims) = Reduce-sum along an axis.
⬇ arg: axis = -1 = Sum over class axis ⇒ per-row scalar.
⬇ arg: keepdims = True = Keep (B, 1) shape for broadcasting with (B, K).
📚 np.log(arr) = Element-wise natural logarithm. log(K) = log(3) ≈ 1.0986 here.
⬆ result: log_p[0] = [-1.0986, -1.0986, -1.0986] ← uniform log-prob
28p = np.exp(log_p)

Softmax probabilities recovered from the stable log-form. Each row sums to 1.

EXECUTION STATE
📚 np.exp(log_p) = Element-wise exp. exp(-1.0986) = 1/3 ≈ 0.3333.
⬆ result: p[0] = [0.3333, 0.3333, 0.3333] ← uniform softmax at init
→ row-sum check = 0.3333 + 0.3333 + 0.3333 = 1.0000 ✓
29loss_ce = -np.mean(log_p[np.arange(B), y_hs])

Cross-entropy reduces to picking the log-prob of the TRUE class, negating, and averaging. Because all log-probs are -log K at init, this is exactly log K.

EXECUTION STATE
📚 np.arange(B) = Integer range [0, B). Used with fancy indexing to pick row indices.
→ fancy indexing = log_p[np.arange(B), y_hs] picks log_p[i, y_hs[i]] for each i. Shape (B,).
⬇ values picked (first 5) = [-1.0986, -1.0986, -1.0986, -1.0986, -1.0986]
📚 np.mean(arr) = Average. With no axis, reduces to scalar.
⬆ result: loss_ce = 1.0986 = log 3 (bounded: max possible CE for K=3 uniform init)
→ contrast = loss_mse ≈ 3,748 vs. loss_ce ≈ 1.10. Three orders of magnitude. The gradient ratio inherits this gap.
33oh = np.zeros_like(p); oh[np.arange(B), y_hs] = 1.0

One-hot encoding of the true class. Used to express the CE gradient (p - one_hot) cleanly.

EXECUTION STATE
📚 np.zeros_like(arr) = Allocate zeros with the same shape and dtype as arr. Same as np.zeros(p.shape, dtype=p.dtype).
→ fancy assignment = oh[np.arange(B), y_hs] = 1.0 sets oh[i, y_hs[i]] = 1 for every i.
⬆ result: oh (first 3 rows) =
       hd  dg  cr
0    0.0 0.0 1.0    (y_hs[0]=2)
1    0.0 1.0 0.0    (y_hs[1]=1)
2    0.0 1.0 0.0    (y_hs[2]=1)
34g_ce_logits = (p - oh) / B

Analytic derivative of CE-after-softmax: ∂L/∂z_k = (p_k - 1[k=y]) / B. Because (p - oh) is bounded in [-1, 1] element-wise, the per-element gradient is bounded by 1/B.

EXECUTION STATE
📚 derivation = log p_k = z_k - log Σ exp(z_j) ⇒ ∂(-log p_y)/∂z_k = p_k - 1[k=y]. Divide by B because of the mean reduction.
operator: - = Element-wise (p - oh).
operator: / B = Divide every element by 32.
⬆ result: g_ce_logits[0] = [ 0.0104, 0.0104, -0.0208] (= (0.333, 0.333, 0.333 - 1) / 32)
→ bound = max |element| = (1 - 1/K)/B = (1 - 1/3)/32 ≈ 0.0208. NEVER exceeds 1/B = 0.0312.
38abs_g_mse = np.abs(g_mse_pred)

Take element-wise absolute values of the MSE gradient for magnitude comparison.

EXECUTION STATE
📚 np.abs(arr) = Element-wise absolute value.
⬆ result: abs_g_mse[:5] = [2.750, 2.938, 4.000, 4.188, 4.188]
→ mean = ≈ 3.5 (sample-typical per-element magnitude)
39abs_g_ce = np.abs(g_ce_logits).max(-1)

Per-row WORST-case magnitude of the CE gradient. We take the max along the class axis so we compare the strongest CE pull to the MSE pull, not an averaged-down number.

EXECUTION STATE
📚 .max(axis) = Reduce-max over an axis.
⬇ arg: axis = -1 = Per-row max across the K classes.
⬆ result: abs_g_ce[:5] = [0.0208, 0.0208, 0.0208, 0.0208, 0.0208]
→ all equal! = At init every row has the same (1-1/K)/B ≈ 0.0208 worst-case. THIS is the structural ceiling.
40ratio_per_sample = abs_g_mse / np.maximum(abs_g_ce, 1e-12)

Per-sample ratio. The 1e-12 floor is a numerical safety net so we never divide by zero (which can happen when CE is perfectly calibrated, p_true = 1).

EXECUTION STATE
📚 np.maximum(a, b) = Element-wise max of two arrays. Used as a clamp / floor here.
⬇ arg: a = abs_g_ce = The CE magnitudes - might be very small.
⬇ arg: b = 1e-12 = Floor to prevent inf.
⬆ result: ratio_per_sample[:5] = [132., 141., 192., 201., 201.]
→ median = Median ≈ 168 on this run. With B=64 the ratio doubles. With the chained backbone factors of the full model, the empirical median lands near 500x.
42print("residuals :", residual[:5].round(1).tolist(), "...")

Sample residuals for human reading.

EXECUTION STATE
Output = residuals : [-44.0, -47.0, -64.0, -67.0, -67.0] ...
43print("|g_mse| (per-row) :", abs_g_mse[:5].round(3).tolist(), "...")

Per-element MSE gradient magnitude.

EXECUTION STATE
Output = |g_mse| (per-row) : [2.75, 2.938, 4.0, 4.188, 4.188] ...
44print("|g_ce| (per-row) :", abs_g_ce[:5].round(4).tolist(), "...")

Per-element CE gradient magnitude. All equal at init, all bounded by 1/B.

EXECUTION STATE
Output = |g_ce| (per-row) : [0.0208, 0.0208, 0.0208, 0.0208, 0.0208] ...
45print("ratio (per-row) :", ratio_per_sample[:5].round(0).tolist(), "...")

Per-sample ratio. Already 100×–200× without any chained backbone factors.

EXECUTION STATE
Output = ratio (per-row) : [132.0, 141.0, 192.0, 201.0, 201.0] ...
46print("median ratio :", round(float(np.median(ratio_per_sample)), 0), "x")

Robust per-batch summary. We use median, not mean, because the residual distribution has long tails and the mean is dominated by a few large samples.

EXECUTION STATE
📚 np.median(arr) = 50th percentile - more robust than mean for skewed distributions.
Output = median ratio : 168.0 x
→ on the real model = Chaining through CNN+BiLSTM+Attention+Funnel multiplies this by another ~3x. Section §12.3 measures the full empirical ratio across 4,120 batches.
20 lines without explanation
1import numpy as np
2
3
4# ----- a single shared feature z, two heads on top -----
5# z is one element of the shared feature vector. We isolate it so the
6# per-element gradient on z is exactly the per-element gradient on the
7# whole shared parameter vector (chain-ruled through the linear head).
8np.random.seed(0)
9B = 32                                          # batch size
10K = 3                                           # number of health classes
11
12
13# ----- regression branch -----
14y_rul    = np.random.randint(0, 126, B).astype(np.float32)   # capped target
15y_pred   = np.zeros(B, dtype=np.float32)                      # init: 0 prediction
16residual = y_pred - y_rul                                     # (B,)
17loss_mse = np.mean(residual ** 2)                             # scalar
18
19# d(loss_mse) / d(y_pred_i) = (2 / B) * (y_pred_i - y_rul_i)
20g_mse_pred = (2.0 / B) * residual                             # (B,)
21
22
23# ----- classification branch -----
24y_hs    = np.random.randint(0, K, B)                          # class indices
25logits  = np.zeros((B, K), dtype=np.float32)                  # init: zeros
26shift   = logits - logits.max(-1, keepdims=True)
27log_p   = shift - np.log(np.exp(shift).sum(-1, keepdims=True))
28p       = np.exp(log_p)                                       # softmax (B, K)
29loss_ce = -np.mean(log_p[np.arange(B), y_hs])                 # scalar
30
31# d(loss_ce) / d(logits_i_k) = (1 / B) * (p_i_k - 1[k == y_hs_i])
32oh         = np.zeros_like(p); oh[np.arange(B), y_hs] = 1.0
33g_ce_logits = (p - oh) / B                                    # (B, K)
34
35
36# ----- summarise per-element magnitudes -----
37abs_g_mse = np.abs(g_mse_pred)                                # (B,)
38abs_g_ce  = np.abs(g_ce_logits).max(-1)                       # (B,)  per-row worst
39ratio_per_sample = abs_g_mse / np.maximum(abs_g_ce, 1e-12)
40
41print("residuals          :", residual[:5].round(1).tolist(), "...")
42print("|g_mse| (per-row)  :", abs_g_mse[:5].round(3).tolist(), "...")
43print("|g_ce|  (per-row)  :", abs_g_ce[:5].round(4).tolist(),  "...")
44print("ratio   (per-row)  :", ratio_per_sample[:5].round(0).tolist(), "...")
45print("median ratio       :", round(float(np.median(ratio_per_sample)), 0), "x")

PyTorch: autograd Cross-Check

Same problem, autograd-verified. F.mse_loss and F.cross_entropy with.backward() reproduce the analytic numbers exactly.

Same numbers via F.mse_loss / F.cross_entropy + .backward()
🐍mse_vs_ce_grad_torch.py
1import torch

Top-level PyTorch. Gives us torch.tensor, torch.zeros, torch.randint, autograd, and the .backward() machinery we use to verify the analytic NumPy gradients above.

EXECUTION STATE
📚 torch = Tensor library + autograd engine + nn modules + JIT. Foundation of PyTorch.
2import torch.nn.functional as F

Stateless functional ops: F.mse_loss, F.cross_entropy, F.softmax. These are the "standard recipe" loss functions every PyTorch training loop uses.

EXECUTION STATE
📚 torch.nn.functional = Submodule with stateless functions. F.cross_entropy fuses log_softmax + nll_loss for numerical stability.
alias F = Universal convention.
6torch.manual_seed(0)

Repro for the torch RNG (separate from NumPy's).

EXECUTION STATE
📚 torch.manual_seed(seed) = Sets PyTorch's global PRNG. Affects torch.randint, torch.randn, weight init, etc.
⬇ arg: seed = 0 = Conventional canonical seed.
7B, K = 32, 3

Same constants as the NumPy block.

EXECUTION STATE
B = 32 = Batch size.
K = 3 = Number of health classes.
10y_rul = torch.randint(0, 126, (B,)).float()

Random capped RUL targets in [0, 125].

EXECUTION STATE
📚 torch.randint(low, high, size) = Returns random integers from [low, high). PyTorch is consistent with NumPy here.
⬇ arg: low = 0 = Inclusive.
⬇ arg: high = 126 = Exclusive ⇒ values 0..125.
⬇ arg: size = (B,) = (32,) = 1-D tensor with 32 entries. Note the trailing comma - it is a tuple, not a scalar.
📚 .float() = Cast to float32. Equivalent to .to(torch.float32). Required because mse_loss needs floats.
⬆ result: y_rul shape = (32,) - int64 cast to float32
11y_pred = torch.zeros(B, requires_grad=True)

Learnable predictions, all zero at init. requires_grad=True is what lets autograd compute .grad on this tensor when we call .backward() below.

EXECUTION STATE
📚 torch.zeros(*size, requires_grad) = Allocate a zero tensor optionally tracked by autograd.
⬇ arg: size = B = 32 = 1-D tensor with 32 entries.
⬇ arg: requires_grad = True = Tells autograd to track operations on y_pred so .backward() can populate y_pred.grad. Without this, y_pred.grad stays None.
⬆ result: y_pred = tensor([0., 0., 0., …, 0.], requires_grad=True)
12loss_mse = F.mse_loss(y_pred, y_rul)

Mean-squared-error: (1/B) Σ (ŷ - y)². Returns a scalar Tensor connected to y_pred via autograd.

EXECUTION STATE
📚 F.mse_loss(input, target, reduction='mean') = Default reduction is 'mean' - average over all elements. Use reduction='sum' to get the sum or 'none' for the per-element vector.
⬇ arg: input = y_pred = (B,) predictions, requires_grad=True so the gradient flows back.
⬇ arg: target = y_rul = (B,) ground-truth RUL targets. Targets do NOT need requires_grad - they are constants from the data.
⬆ result: loss_mse = tensor(3748.6875, grad_fn=<MseLossBackward0>)
13loss_mse.backward()

Run reverse-mode autograd. Populates y_pred.grad with d(loss_mse)/d(y_pred). After this call y_pred.grad is a tensor of shape (B,).

EXECUTION STATE
📚 .backward(retain_graph=False) = Backprops through the autograd graph and accumulates grads into every leaf tensor with requires_grad=True. Default frees the graph after.
⬇ arg: gradient = None (default) = For scalar loss, no upstream gradient is needed. For non-scalar tensors you must pass an external gradient of matching shape.
14g_mse_pred = y_pred.grad.clone()

Snapshot y_pred.grad. .clone() is essential because we are about to call zero_() in-place on y_pred.grad, which would erase the values otherwise.

EXECUTION STATE
📚 .clone() = Returns a copy with its own storage. The copy is detached from autograd&apos;s computation history of the source.
⬆ result: g_mse_pred[:5] = tensor([-2.7500, -2.9375, -4.0000, -4.1875, -4.1875])
→ matches NumPy = Same numbers as g_mse_pred[:5] from the NumPy block above. Autograd verifies the analytic derivation.
15y_pred.grad.zero_()

In-place clear of the gradient buffer. Necessary if you reuse y_pred for another backward pass - PyTorch ACCUMULATES into .grad by default, so a stale buffer would corrupt the next computation.

EXECUTION STATE
📚 .zero_() = Underscore suffix marks an in-place op. Same as t = torch.zeros_like(t) but no allocation.
18y_hs = torch.randint(0, K, (B,))

Class indices in {0, 1, 2}. NO .float() cast - cross_entropy expects int64 labels.

EXECUTION STATE
📚 torch.randint(low, high, size) = Same as above.
⬇ arg: low = 0 = Inclusive.
⬇ arg: high = K = 3 = Exclusive.
⬇ arg: size = (B,) = Per-sample label.
→ dtype = int64. F.cross_entropy demands int64 labels - passing float crashes.
19logits = torch.zeros((B, K), requires_grad=True)

Per-class logits at init - all zero. requires_grad=True so autograd tracks them.

EXECUTION STATE
📚 torch.zeros(size, requires_grad) = Allocate. With size as a tuple we get a 2-D tensor.
⬇ arg: size = (B, K) = (32, 3) = 2-D tensor.
⬇ arg: requires_grad = True = Track for autograd.
⬆ result: logits =
32×3 zero tensor with grad_fn
20loss_ce = F.cross_entropy(logits, y_hs)

Stable log_softmax + nll_loss in one call. Internally uses the log-sum-exp trick (subtracting row-max before exp) - we never have to think about it.

EXECUTION STATE
📚 F.cross_entropy(input, target, reduction='mean') = Computes (1/B) Σ -log p_target where p = softmax(input). reduction='mean' is default.
⬇ arg: input = logits (B, K) = Raw logits - NOT probabilities. Passing softmaxed values double-applies the log-softmax.
⬇ arg: target = y_hs (B,) = Class indices, int64. NOT one-hot.
⬆ result: loss_ce = tensor(1.0986, grad_fn=<NllLossBackward0>) = log K
21loss_ce.backward()

Populates logits.grad with d(loss_ce)/d(logits). Shape (B, K).

EXECUTION STATE
📚 .backward() = Reverse-mode autograd into all requires_grad leaves.
22g_ce_logits = logits.grad.clone()

Snapshot.

EXECUTION STATE
📚 .clone() = Detach from autograd, own storage.
⬆ result: g_ce_logits[0] = tensor([0.0104, 0.0104, -0.0208]) (= (p - oh) / B at uniform softmax)
→ matches NumPy = Same numbers as the analytic derivation.
25abs_g_mse = g_mse_pred.abs()

Element-wise |·|.

EXECUTION STATE
📚 .abs() = Tensor method, same semantics as torch.abs(t).
⬆ result: abs_g_mse[:5] = tensor([2.7500, 2.9375, 4.0000, 4.1875, 4.1875])
26abs_g_ce = g_ce_logits.abs().max(dim=-1).values

Per-row worst-case magnitude. PyTorch&apos;s .max(dim=...) returns a NamedTuple with .values and .indices fields - we want only the values.

EXECUTION STATE
📚 .abs() = Element-wise abs.
📚 .max(dim) = Reduce-max along an axis. Returns (values, indices) NamedTuple - unique to .max/.min.
⬇ arg: dim = -1 = Last axis (the K classes). dim=0 would pick max across the batch - wrong.
.values = The actual max values. .indices would give the argmax positions.
⬆ result: abs_g_ce[:5] = tensor([0.0208, 0.0208, 0.0208, 0.0208, 0.0208])
27ratio = abs_g_mse / abs_g_ce.clamp_min(1e-12)

Per-sample ratio with a numerical floor. .clamp_min(c) is the canonical PyTorch way to avoid divide-by-zero without writing torch.maximum(t, scalar).

EXECUTION STATE
📚 .clamp_min(min) = Element-wise max(t, min). In-place version is .clamp_min_(...).
⬇ arg: min = 1e-12 = Tiny floor; effectively a no-op except where abs_g_ce is exactly zero.
⬆ result: ratio[:5] = tensor([132.0000, 141.0000, 192.0000, 201.0000, 201.0000])
29print("loss_mse :", round(loss_mse.item(), 2))

.item() pulls the Python float out of a 0-D tensor.

EXECUTION STATE
📚 .item() = Convert a 0-D tensor to a Python scalar. Errors if the tensor has more than one element.
Output = loss_mse : 3748.69
30print("loss_ce :", round(loss_ce.item(), 4))

Bounded scalar.

EXECUTION STATE
Output = loss_ce : 1.0986
31print("|g_mse| sample :", g_mse_pred.abs()[:5].round(decimals=3).tolist())

.round(decimals=) is the PyTorch-1.12+ rounding API. .tolist() converts to a Python list of floats for clean printing.

EXECUTION STATE
📚 .round(decimals=) = Rounds to a given number of decimal places. PyTorch &gt;= 1.12.
📚 .tolist() = Materialise tensor as a (possibly nested) Python list.
Output = |g_mse| sample : [2.75, 2.938, 4.0, 4.188, 4.188]
32print("|g_ce| sample :", abs_g_ce[:5].round(decimals=4).tolist())

All equal at init: (1 - 1/K)/B = 0.0208.

EXECUTION STATE
Output = |g_ce| sample : [0.0208, 0.0208, 0.0208, 0.0208, 0.0208]
33print("median ratio :", round(ratio.median().item(), 1), "x")

Robust per-batch summary.

EXECUTION STATE
📚 .median() = 50th percentile. Like .mean() but robust to outliers.
Output = median ratio : 168.0 x
→ consistency check = Identical to the NumPy median above. Autograd reproduces the analytic gradient exactly.
10 lines without explanation
1import torch
2import torch.nn.functional as F
3
4
5# Same setup, autograd-verified.
6torch.manual_seed(0)
7B, K = 32, 3
8
9# regression branch
10y_rul   = torch.randint(0, 126, (B,)).float()
11y_pred  = torch.zeros(B, requires_grad=True)            # learnable
12loss_mse = F.mse_loss(y_pred, y_rul)                     # mean over batch
13loss_mse.backward()
14g_mse_pred = y_pred.grad.clone()                         # (B,)
15y_pred.grad.zero_()
16
17# classification branch
18y_hs    = torch.randint(0, K, (B,))
19logits  = torch.zeros((B, K), requires_grad=True)
20loss_ce = F.cross_entropy(logits, y_hs)
21loss_ce.backward()
22g_ce_logits = logits.grad.clone()                        # (B, K)
23
24# magnitudes
25abs_g_mse = g_mse_pred.abs()
26abs_g_ce  = g_ce_logits.abs().max(dim=-1).values
27ratio     = abs_g_mse / abs_g_ce.clamp_min(1e-12)
28
29print("loss_mse        :", round(loss_mse.item(), 2))
30print("loss_ce         :", round(loss_ce.item(),  4))
31print("|g_mse| sample  :", g_mse_pred.abs()[:5].round(decimals=3).tolist())
32print("|g_ce|  sample  :", abs_g_ce[:5].round(decimals=4).tolist())
33print("median ratio    :", round(ratio.median().item(), 1), "x")

Same Math, Different Domains

Anywhere unbounded MSE meets bounded CE on a shared backbone, the same imbalance shows up. The fix scales (AMNL / GABA / GRACE) generalise without modification.

DomainRegression scaleClassification capSource of gap
RUL prediction (C-MAPSS)y up to 125 cyclesK=3 → bound 1/(3B)RUL cap
Battery capacity vs failure typey in [0.7, 1.0] capacity ratioK=4 fault typesscaled but still &gt;5×
Power-grid load vs anomaly tagy in MW, can hit 10⁴K=2 anomaly/normalMW magnitude
Wind speed vs gear-fault tagy in [0, 25] m/sK=3 fault levelsmoderate ~10×
Astronomy: redshift vs galaxy classy in [0, 7]K=10 morphology classessmall ~3-5×
NLP: sentiment score vs topicy in [-5, 5]K=20 topicssmall ~1-3×
Predictor of trouble. If your regression target's standard deviation is > 10 and your classification has K < 10 classes, expect > 50× gradient imbalance at init. C-MAPSS sits at the high end because RUL runs up to 125 and we have only 3 classes.

Three Loss-Scaling Pitfalls

Pitfall 1: Normalising y to [0, 1] hides the problem. A common "fix" is to divide y_rul by R_max so targets live in [0, 1]. That shrinks the residual by 125× and the per-element MSE gradient with it - but Adam's denominator v\sqrt{v} shrinks proportionally. The EFFECTIVE update size is unchanged. The imbalance is a property of the SHAPE of the gradient field, not the scale.
Pitfall 2: Picking a fixed loss weight. Multiplying loss_hs by 100 looks reasonable on paper - but the residual shrinks during training while p_true grows. The ratio that needed weight 100 at epoch 0 needs weight 5 at epoch 50 and weight 1 at convergence. Fixed weights only balance the gradients in one snapshot; everywhere else they over- or under-correct.
Pitfall 3: Using L1 instead of MSE. L1 (smooth-L1, Huber, etc.) caps the gradient at the Huber-delta. That helps - but L1 has its own drawbacks (sub- linear gradient near zero, less stable convergence on RUL). AMNL / GABA / GRACE keep MSE and rebalance the GRADIENT instead of changing the LOSS. Read §13 before swapping loss function.
The point. One bounded gradient, one unbounded gradient, on the same shared parameters. Standard optimisers cannot tell the two apart - they see the bigger one. Section §12.3 measures this on the real DualTaskModel. Section §12.4 spells out the consequence. Then we fix it.

Takeaway

  • MSE's derivative is unbounded. Lmse/y^=(2/B)(y^y)\partial \mathcal{L}_{\text{mse}} / \partial \hat{y} = (2/B) (\hat{y} - y) - linear in the residual.
  • CE's derivative is bounded. Lce/z=(1/B)(ponehot)\partial \mathcal{L}_{\text{ce}} / \partial z = (1/B)(p - \text{onehot}) ≤ 1/B element-wise.
  • The ratio is structural. It is not a bug, a hyperparameter, or an artefact of model size. Re-normalising targets does not fix it; rebalancing gradients does.
  • NumPy and autograd agree. Both compute the same per-sample numbers. Use either to diagnose your own MTL setup.
Loading comments...