Chapter 20
28 min read
Section 63 of 65

Common Training Problems

Debugging and Improving Networks

Overview

Most failed training runs are not mysterious. They usually reduce to a short list of failure modes: the model is not learning at all, it is memorising the training set, gradients are numerically unstable, the data pipeline is wrong, or the system is spending most of its time waiting for memory rather than doing useful work.

The important habit is to debug in the right order. Start with the cheapest questions first: can the model overfit a tiny batch, do the labels line up with the inputs, do gradients have sane magnitudes, and is the validation split genuinely held out? Only after those basics are clean should you reach for bigger models, fancier optimizers, or architectural changes.

The core idea. Debugging a neural network is not guessing. It is a measurement problem. Every symptom should map to one of a few measurable quantities: loss curves, gradient norms, activation statistics, parameter updates, per-class metrics, and throughput or memory usage.

The Debugging Loop

A reliable workflow is more valuable than any single trick. In practice, you want the same loop every time a model misbehaves:

  1. Reproduce the failure. Fix the random seed, keep one code path, and save the exact config. A bug you cannot reproduce is almost impossible to isolate.
  2. Shrink the problem. Use a tiny dataset, a tiny model, and a short run. If the model cannot learn 32 examples, scaling up only hides the bug.
  3. Instrument the run. Log training loss, validation loss, gradient norm, learning rate, parameter norm, and a few example predictions every epoch.
  4. Change one thing at a time. If you alter the optimiser, batch size, and augmentation policy together, you learn nothing from the outcome.
  5. Promote fixes from toy to full scale. Once the model behaves on a tiny slice, move back to the real dataset and hardware.
The fastest sanity check in deep learning. Ask whether the model can overfit a single minibatch. If it cannot drive that batch loss near zero, the issue is usually not regularisation or dataset size. It is almost always a bug in the objective, architecture, gradients, or input pipeline.

Symptom Triage Table

SymptomMost likely causesFirst thing to inspect
Training loss is flatLearning rate too small, labels wrong, bug in loss, gradients blockedSingle-batch overfit test and gradient norms
Training loss falls, validation loss risesOverfitting, leakage in train metrics, weak regularisationValidation split, data augmentation, weight decay, dropout
Both training and validation are poorUnder-capacity model, optimiser mismatch, features too weakModel size, learning-rate schedule, feature or data quality
Loss suddenly becomes NaN or InfExploding activations, overflow in softmax or log, bad mixed precisionGradient norm, input scale, epsilon terms, AMP config
Accuracy looks high but deployment failsClass imbalance, leakage, bad metric choice, calibration issueConfusion matrix, per-class recall, precision, calibration
GPU is busy but throughput is poorSmall kernels, dataloader stalls, memory-bound opsProfiler trace, batch size, host-to-device copy, fused kernels
Track ratios, not just raw numbers. A parameter norm of 20 or a gradient norm of 0.8 means little in isolation. More useful is the update-to-weight ratio. If an optimizer step changes a parameter by 10610^{-6} relative to a parameter norm of 11, learning will be glacial. If it changes it by 10110^{-1}, training is likely unstable.

Before we dive into individual failure modes, train your eye. The same vertical axis (loss) and horizontal axis (epochs) can produce wildly different shapes — and each shape points at a different bug. Toggle between scenarios in the visualizer below: a healthy run, a learning rate that is too large, one that is too small, an overfitting model, an exploding-loss spike, and a stalled run that is not learning at all.

Loading loss-curve gallery…
What to look for. Healthy curves decay smoothly and the validation curve tracks training within a small gap. Oscillation means the optimizer is bouncing across the basin (lower the LR or clip gradients). A widening train/val gap is overfitting. A sudden upward spike usually means a single bad minibatch caused a NaN-class event. A flat horizontal line is the most ambiguous of all — it can be too small an LR, a frozen layer, or a mislabelled target.

Problem 1: Loss Does Not Go Down

This is the most common failure mode. The usual temptation is to blame the model architecture, but architecture is rarely the first culprit. Start with the simplest possibilities:

  • The learning rate is too small. Updates are so tiny that the parameters barely move.
  • The learning rate is too large. The optimizer bounces around the basin and never settles.
  • The target or loss is wrong. For example, feeding probabilities into CrossEntropyLoss instead of logits, or mixing up one-hot labels and class indices.
  • Gradients are blocked. A bad detach(), an in-place operation, or a frozen layer keeps parameters from receiving signal.
  • The data are misaligned. Inputs and labels have been shuffled differently, or the preprocessing at train time differs from the labels' semantics.

The tiny-batch overfit test

Pick 16 or 32 examples. Train on them alone. If the model and loss are implemented correctly, a sufficiently expressive network should drive the loss down dramatically and often to near zero. Failure here means something upstream is broken.

Mathematically, if the parameter update is Δθt=ηθL(θt)\Delta \theta_t = -\eta \nabla_{\theta} L(\theta_t), then two things must both be true: the gradient θL\nabla_{\theta} L must be nonzero and the step size η\eta must be large enough to move parameters by a visible amount. A flat loss curve usually means one of those two quantities is effectively zero in practice.


The Tiny-Batch Overfit Test

The single most useful 30 lines of debugging code in deep learning. Pin the data, pin the model, train for a few hundred steps, and watch the loss. If it does not approach zero, you have a bug — not a capacity problem, not a regularisation problem, an actual bug.

🐍python
1import torch, torch.nn as nn
2
3torch.manual_seed(0)
4device = "cuda" if torch.cuda.is_available() else "cpu"
5
6# 32 random "examples" — replace with one mini-batch of your real data
7x = torch.randn(32, 784, device=device)
8y = torch.randint(0, 10, (32,), device=device)
9
10model = nn.Sequential(
11    nn.Linear(784, 256), nn.ReLU(),
12    nn.Linear(256, 10),
13).to(device)
14
15opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
16loss_fn = nn.CrossEntropyLoss()
17
18for step in range(300):
19    logits = model(x)                  # (32, 10)
20    loss = loss_fn(logits, y)          # logits, NOT softmax(logits)!
21    opt.zero_grad()
22    loss.backward()
23    opt.step()
24    if step % 50 == 0:
25        print(f"step {step:>3}  loss {loss.item():.4f}")
26
27# Expected: loss falls from ~2.30 to < 0.01 in <300 steps.
28# If it stalls above 1.0, something is wrong upstream — debug there first.
Interpreting the result. A 2-layer MLP with 256 hidden units has more than enough capacity to memorise 32 random examples. If it can't, look at: gradient flow (is requires_grad True everywhere?), the loss (CrossEntropyLoss wants raw logits), the device (does x live on the same device as the model?), and the learning rate (try 10210^{-2} through 10410^{-4}).

Worked Bug: Softmax-then-CrossEntropy

This is the single most common classification bug. PyTorch's nn.CrossEntropyLoss already applies a numerically stable log-softmax internally. Feeding it pre-softmaxed probabilities is a silent error: training still "works", but gradients shrink toward zero exactly when the model is confident — which is when you most need them to flow.

🐍python
1import torch, torch.nn as nn
2
3logits = torch.tensor([[2.0, 0.5, -1.0]], requires_grad=True)
4target = torch.tensor([0])
5loss_fn = nn.CrossEntropyLoss()
6
7# ❌ WRONG: softmax-then-CE
8probs = torch.softmax(logits, dim=-1)
9loss_wrong = loss_fn(probs, target)         # treats probabilities AS logits
10loss_wrong.backward()
11print("wrong  loss =", loss_wrong.item())   # ≈ 0.95
12print("wrong  grad =", logits.grad.clone())  # tiny, even though we are far from convergence
13logits.grad = None
14
15# ✅ CORRECT: raw logits straight into CrossEntropyLoss
16loss_right = loss_fn(logits, target)
17loss_right.backward()
18print("right  loss =", loss_right.item())   # ≈ 0.20
19print("right  grad =", logits.grad)          # well-scaled signed gradient

The two losses are not even comparable — the wrong version computes the cross-entropy of an already-normalised distribution against the label, then re-applies log-softmax to that. The wrong gradient magnitudes are typically 5–10× smaller than the correct ones, which is exactly why the symptom is "loss decreases, but very slowly" rather than an obvious crash.

Use the right loss for the right output head. CrossEntropyLoss ⇒ raw logits. NLLLosslog_softmax output. BCEWithLogitsLoss ⇒ raw logits for binary/multi-label. BCELoss ⇒ post-sigmoid probabilities (usually a code smell — prefer BCEWithLogitsLoss).

Problem 2: Training Improves, Validation Does Not

This is classical overfitting: the model is learning patterns that reduce training loss but do not generalise. The signature is a widening gap between training and validation curves after an initially healthy phase.

The fix is not always "add more dropout." First ask whether the validation process is trustworthy. A bad split, leakage, or changing preprocessing between train and validation can produce a fake generalisation gap.

If you observeAsk thisLikely remedy
Training accuracy ~99%, validation much lowerIs the model memorising?Weight decay, dropout, stronger augmentation, early stopping
Validation jumps around wildlyIs the validation set too small or distributionally odd?Larger validation split, stratified sampling, repeated evaluation
Validation beats trainingAre train-time augmentations much harder?Check augmentation severity and metric definitions

Two remedies matter more than the rest in practice:

  • Data augmentation. This increases the effective diversity of the training set without changing labels.
  • Early stopping. Stop when validation loss stops improving, rather than waiting until the model memorises noise.

Problem 3: Both Training and Validation Stay Bad

If both curves plateau at poor values, the model is underfitting. That can happen for two very different reasons:

  1. The optimisation process is weak. Bad learning rate, optimizer choice, or gradient flow prevents the network from reaching a good solution.
  2. The model or features are weak. The architecture does not have enough capacity, or the representation does not expose the useful signal.

A useful distinction is this: if training loss begins to fall but plateaus too high, you may need more capacity. If training loss hardly moves from the start, suspect the optimisation setup first.

Regularisation can also cause underfitting. Excessive dropout, too much weight decay, heavy label smoothing, or overly aggressive augmentation can all prevent the model from fitting even the core signal.

Problem 4: Exploding or Vanishing Gradients

Backpropagation multiplies many Jacobians together. If the typical singular value along that chain is larger than 1, gradients grow exponentially with depth; if it is smaller than 1, they shrink. That is the intuition behind exploding and vanishing gradients.

In a simplified recurrence, the gradient can scale like t=1TJt\left\lVert \prod_{t=1}^{T} J_t \right\rVert. If the average norm of JtJ_t is 1.11.1, then over 100 steps the scale is roughly 1.11001.1^{100}. If it is 0.90.9, the scale is roughly 0.91000.9^{100}. Both are disastrous.

What exploding gradients look like

  • Loss suddenly spikes after a stable phase.
  • Gradient norm jumps by orders of magnitude.
  • Weights become NaN after an optimizer step.

What vanishing gradients look like

  • Early layers learn almost nothing.
  • Gradient histograms collapse near zero.
  • Deep or recurrent models stall despite a correct loss setup.

The standard fixes are by now well understood:

FailureMost effective fixes
Exploding gradientsGradient clipping, lower learning rate, better normalisation, residual paths
Vanishing gradientsReLU-family activations, residual connections, normalisation, better initialization
Gradient clipping replaces a step with norm g>τ\lVert g \rVert > \tau by gτ/gg \cdot \tau / \lVert g \rVert. It does not fix a bad model, but it prevents one unstable minibatch from destroying the run.

Gradient-Health Monitor (Interactive)

Pick an activation function, an initialisation scheme, and a depth. The visualizer simulates the per-layer gradient norms and shows the canonical pathologies in real time: ReLU + zero init kills gradients instantly; sigmoid + uniform init vanishes them by layer 4; deeper networks need He or Xavier init plus a non-saturating nonlinearity.

Loading gradient-health monitor…
Why this picture matters. A single scalar gradient norm — the kind most loggers print — averages all layers together and can look healthy while the bottom layers are starved. Always plot the per-layer norm before concluding gradients are fine.

Implementing Gradient Clipping

Two flavours of gradient clipping are common. Norm clipping (the canonical formula above) preserves gradient direction and only rescales magnitude — almost always the right default. Value clipping simply clamps each coordinate of the gradient to [τ,τ][-\tau, \tau] and can distort direction when only a few coordinates exceed the threshold.

🐍python
1import torch, torch.nn as nn
2
3model = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 10))
4opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
5loss_fn = nn.CrossEntropyLoss()
6
7x = torch.randn(8, 64)
8y = torch.randint(0, 10, (8,))
9
10logits = model(x)
11loss = loss_fn(logits, y)
12opt.zero_grad()
13loss.backward()
14
15# ✅ Norm clipping — preserves gradient direction
16grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
17print(f"pre-clip norm: {grad_norm.item():.4f}")
18
19# Optionally log the post-clip norm too (useful for spotting regime shifts)
20post = sum(p.grad.detach().norm() ** 2 for p in model.parameters() if p.grad is not None)
21print(f"post-clip norm: {post.sqrt().item():.4f}")
22
23opt.step()
Pick the threshold from data. Run a few hundred unclipped steps and plot the gradient-norm distribution. Pick τ\tau at roughly the 95th percentile — clipping should be rare, not constant. If you find yourself clipping every step, the real bug is upstream (LR too high, bad init, or a loss-scale problem in mixed precision).

Problem 5: NaNs, Infs, and Numerical Instability

NaNs are not random. They are arithmetic failures. Somewhere in the graph you divided by zero, overflowed an exponential, took the log of zero, normalised an all-zero variance, or let mixed-precision arithmetic leave the safe numerical range.

The fastest way to debug NaNs is to find the first invalid tensor, not the last place the loss becomes NaN. Once NaNs are in the graph, they spread everywhere.

  • Add epsilon terms to denominators and logs, such as log(x+ε)\log(x + \varepsilon).
  • Use numerically stable formulations like log-sum-exp and fused cross-entropy kernels.
  • Check input scale. A preprocessing bug that multiplies pixel values by 255 twice can destabilise the whole run.
  • If using AMP, verify loss scaling or prefer BF16 when hardware allows.
Softmax and cross-entropy are a classic source of overflow. Never compute log(softmax(x))\log(\mathrm{softmax}(x)) by hand from raw exponentials if the framework provides a fused stable version.

Dying ReLUs: Activation Flow

ReLU's gradient is zero for any negative pre-activation. If a neuron drifts negative for every example in your training set, its weights will never update again — that neuron is dead. This is a silent failure: loss still decreases because the surviving neurons compensate, but you are paying for capacity you do not have.

Loading activation-flow visualizer…

The fix is one of: LeakyReLU or GELU (small negative slope, so dead neurons can recover); He initialisation (matched to ReLU's expected zero-fraction); a smaller learning rate at the start of training; or BatchNorm / LayerNorm upstream of the ReLU (which centres pre-activations near zero).


Weight-Distribution Monitor

A healthy training run produces a recognisable weight signature: roughly Gaussian, mean near zero, standard deviation that grows slowly and then plateaus. Pathological signatures are equally recognisable: a collapsed distribution (most weights ≈ 0 — under- training), an exploded distribution (heavy tails — needs clipping or weight decay), or a bimodal one (often a sign that two halves of the network are decoupled).

Loading weight-distribution monitor…

Problem 6: Data and Evaluation Bugs

Many "model" problems are really data problems. A neural network will happily optimise the wrong task if the pipeline asks it to.

Common examples include:

  • Label leakage. The target leaks into an input feature or filename convention, inflating apparent validation performance.
  • Train and validation mismatch. Different normalisation, tokenisation, resize rules, or missing-value handling across splits.
  • Class imbalance hidden by aggregate accuracy. A model that predicts the majority class can still report high accuracy.
  • Wrong metric. Accuracy is often insufficient for retrieval, ranking, detection, imbalanced classification, and calibrated decision-making.

Error analysis beats hunches

Once the model trains at all, the next tool is not another optimizer. It is an error table. Break mistakes down by class, subgroup, sequence length, brightness, prompt length, accent, or whatever natural axes your domain exposes. A confusion matrix and a small notebook of misclassified examples often explain more than another day of hyperparameter search.

Always look at raw examples. If ten failed predictions all share the same artifact, the dataset is telling you what to fix. Metrics summarize the failure; examples reveal its mechanism.

Problem 7: The Model Works but Is Too Slow

Slowness is also a training problem. A model that trains correctly but wastes 90% of available hardware turns iteration into a bottleneck, and slow iteration means slower research. There are two common regimes:

  1. Input-pipeline bound. The GPU waits for the CPU, dataloader, tokenizer, or storage system.
  2. Kernel or memory bound. The GPU is active, but the arithmetic intensity is low and the run is limited by memory traffic.

This chapter's later sections cover both sides in detail. Section 2 gives you visual tools for understanding activations, gradients, and attention. Section 3 explains the roofline model, memory hierarchy, fusion, mixed precision, checkpointing, and KV-cache tradeoffs. The important debugging lesson here is just the ordering: first verify the model is correct, then optimise the system.

A bad profiler pattern to recognise. If GPU utilisation oscillates between bursts of work and idle gaps, the bottleneck is often the dataloader or host-to-device transfer. If utilisation is high but tokens or samples per second are still poor, suspect memory-bound kernels, tiny batch sizes, or unnecessary tensor copies.

Reproducibility and Determinism

"Fix the random seed" is necessary but not sufficient. PyTorch on the GPU has multiple sources of non-determinism that ignore your seed unless explicitly disabled: cuDNN's autotuner picks different kernels per run, atomic reductions in scatter / gather kernels permute floating-point sums, and DataLoader workers shuffle non-deterministically when their own seeds aren't pinned.

🐍python
1import os, random
2import numpy as np
3import torch
4
5def set_full_determinism(seed: int = 0) -> None:
6    """One call to silence every common source of nondeterminism."""
7    os.environ["PYTHONHASHSEED"] = str(seed)
8    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"   # required by torch
9    random.seed(seed)
10    np.random.seed(seed)
11    torch.manual_seed(seed)
12    torch.cuda.manual_seed_all(seed)
13    torch.backends.cudnn.deterministic = True
14    torch.backends.cudnn.benchmark = False              # disables autotune
15    torch.use_deterministic_algorithms(True, warn_only=True)
16
17# Pin DataLoader worker seeds too:
18def seed_worker(worker_id: int) -> None:
19    worker_seed = torch.initial_seed() % 2**32
20    np.random.seed(worker_seed)
21    random.seed(worker_seed)
22
23g = torch.Generator(); g.manual_seed(0)
24# loader = DataLoader(..., worker_init_fn=seed_worker, generator=g)
25set_full_determinism(0)
The price of determinism. Disabling cudnn.benchmark can cost 5–20% throughput, and deterministic algorithms are sometimes slower or use more memory. Use the recipe above when reproducing a bug; turn it off when you ship.

Picking a Learning Rate: The LR Finder

Asking "is my learning rate too small or too large?" is the wrong framing. The right framing is: sweep the learning rate across several orders of magnitude on a single batch and plot loss-vs-LR. The curve has a characteristic shape — a flat region at small LR, a steep descent in the useful range, and an explosion when the LR becomes too large. Pick a value about an order of magnitude smaller than the explosion point. This is the Cyclical Learning Rates trick (Smith, 2017) and it has replaced grid search for LR in most modern recipes.

🐍python
1import torch, torch.nn as nn
2import math
3
4def find_lr(model, loader, loss_fn, lr_min=1e-7, lr_max=1.0, num=100):
5    opt = torch.optim.SGD(model.parameters(), lr=lr_min)
6    mult = (lr_max / lr_min) ** (1 / num)
7    losses, lrs = [], []
8    best = float("inf")
9    for i, (x, y) in enumerate(loader):
10        if i >= num: break
11        for g in opt.param_groups: g["lr"] = lr_min * (mult ** i)
12        loss = loss_fn(model(x), y)
13        opt.zero_grad(); loss.backward(); opt.step()
14        if loss.item() > 4 * best: break          # diverged — stop
15        best = min(best, loss.item())
16        losses.append(loss.item()); lrs.append(opt.param_groups[0]["lr"])
17    return lrs, losses
18
19# Plot lrs (log x) vs losses. Pick LR at the steepest point of the descent —
20# typically ~1 order of magnitude below where the curve turns up.

A Practical Checklist

When a run goes wrong, walk this list from top to bottom:

  1. Can the model overfit one minibatch?
  2. Are the labels, loss function, and output activations compatible?
  3. Are gradient norms finite and nontrivial?
  4. Do activations saturate, collapse to zero, or explode layer by layer?
  5. Is the validation split clean and processed identically to training?
  6. Do per-class metrics reveal a hidden failure that aggregate accuracy hides?
  7. Is the bottleneck model quality, data quality, or system throughput?
A useful mental rule. Do not debug deep learning from the final scalar alone. One number called "loss" collapses the entire run. You need at least a handful of observables: loss, accuracy or task metric, gradient norm, activation distribution, and a few raw predictions.

Quick Check

Q1. Your training loss starts at ln102.30\ln 10 \approx 2.30 for a 10-class problem and stays there for 100 epochs. The validation loss is identical. What three things would you check, in order?
Answer: (1) Can the model overfit a single minibatch of 32 examples? If not, suspect a bug in the loss / output head / gradient flow. (2) Are the labels actually paired with the right inputs (a DataLoader collate or shuffle bug)? (3) Is the learning rate non-trivially small (the LR finder should show a steep descent somewhere — if it doesn't, gradients aren't flowing).
Q2. A teammate reports: "loss spikes from 0.3 to NaN at step 4,200 every time we restart from this checkpoint." What is the single fastest experiment to localise the bug?
Answer: Save the inputs and the model state from step 4,199, then replay step 4,200 with torch.autograd.set_detect_anomaly(True). The traceback names the first op whose forward or backward produced a NaN — almost always a log/sqrt/div with a zero argument or an mixed-precision loss-scale event.
Q3. Your validation accuracy is 92% but the model ranks classes badly when actually deployed. The training set is 88% of one majority class. What metric did you optimise, and what should you have looked at?
Answer: Plain accuracy is dominated by the majority class — a constant predictor would score 88%. You needed a confusion matrix, per-class recall / precision, and (for ranking) area-under-ROC or precision-at-k. The fix is rarely a new model; it is a better metric and class-balanced sampling or focal loss.

Summary

Most training failures fall into a manageable taxonomy. If loss is flat, suspect the learning setup or a bug in the objective. If training improves but validation does not, suspect overfitting or a bad evaluation protocol. If both are poor, suspect underfitting or weak optimisation. If NaNs appear, look for the first unstable arithmetic operation. If the model is correct but iteration is slow, profile the pipeline and the memory behaviour of the kernels.

The next two sections extend this workflow. Section 2 shows how to look inside a network with activations, saliency maps, loss landscapes, and attention visualisations. Section 3 shows how to make the same network run efficiently by reasoning about arithmetic intensity, memory hierarchy, fusion, precision, and caching.


References

  • Karpathy, A. (2019). "A Recipe for Training Neural Networks". karpathy.github.io/2019/04/25/recipe/.
  • Smith, L. N. (2017). "Cyclical Learning Rates for Training Neural Networks". WACV.
  • Smith, L. N. (2018). "A disciplined approach to neural network hyper-parameters". arXiv:1803.09820.
  • Pascanu, R., Mikolov, T., & Bengio, Y. (2013). "On the difficulty of training recurrent neural networks" (gradient clipping). ICML.
  • Glorot, X. & Bengio, Y. (2010). "Understanding the difficulty of training deep feedforward neural networks" (Xavier init). AISTATS.
  • He, K., Zhang, X., Ren, S., & Sun, J. (2015). "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" (He init). ICCV.
  • Maas, A., Hannun, A., & Ng, A. (2013). "Rectifier Nonlinearities Improve Neural Network Acoustic Models" (LeakyReLU). ICML Workshop.
  • PyTorch documentation. "Reproducibility". pytorch.org/docs/stable/notes/randomness.html.
Loading comments...