Chapter 9
22 min read
Section 60 of 178

Debugging Neural Networks

Training Neural Networks

Learning Objectives

By the end of this section, you will be able to:

  1. Develop a systematic debugging mindset: Understand why neural networks fail silently and how to approach debugging methodically rather than randomly
  2. Perform essential sanity checks: Verify your model can overfit a single batch, check gradient flow, and validate loss starts at the expected value
  3. Diagnose common failure modes: Identify vanishing/exploding gradients, dead neurons, learning rate issues, and data problems from symptoms
  4. Use debugging tools in PyTorch: Apply hooks, gradient checking, activation visualization, and weight monitoring to find bugs
  5. Apply a systematic debugging checklist: Follow a structured approach to eliminate possibilities and isolate the root cause of training failures
Why This Matters: Neural networks fail silently. Unlike traditional software that crashes with stack traces, a buggy neural network simply produces bad predictions. The loss might decrease, accuracy might improve, but subtle bugs can prevent your model from reaching its full potential—or worse, cause it to learn the wrong things entirely. Mastering debugging is what separates practitioners who get models to work from those who don't.

The Big Picture

Why Neural Networks Are Hard to Debug

Debugging neural networks is fundamentally different from debugging traditional software for several reasons:

  1. No compiler errors: Your code can be syntactically correct yet mathematically wrong. A misplaced transpose or incorrect loss function won't raise an exception.
  2. Stochastic nature: Results vary between runs due to random initialization and data shuffling. Was that improvement real or just noise?
  3. Delayed feedback: You only discover problems after hours of training, when you finally look at the results.
  4. Coupled components: Architecture, data, loss function, optimizer, and hyperparameters all interact. A bug in one can mask problems in another.
  5. No ground truth: You often don't know what the "correct" loss or accuracy should be, making it hard to know if something is wrong.

The Cost of Bugs

In traditional software, bugs cause crashes. In deep learning, bugs cause silent failures:

Bug TypeTraditional SoftwareNeural Networks
Shape mismatchRuntime errorBroadcasting hides the bug
Wrong formulaWrong outputTraining proceeds, just poorly
Data issueValidation errorModel memorizes noise
Gradient bugN/ALoss stagnates mysteriously

The Most Dangerous Bugs

The scariest bugs are those that don't prevent training at all. Your loss decreases, accuracy increases, and everything looks fine—but the model has learned something subtly wrong. Always validate your model's behavior, not just its metrics.

The Debugging Mindset

Think Like a Scientist

Debugging neural networks requires the scientific method: form hypotheses, design experiments, observe results, and iterate. Random changes are the enemy.

  1. Observe the symptoms: What exactly is wrong? Loss not decreasing? Gradients exploding? Accuracy stuck at random chance?
  2. Form a hypothesis: Based on the symptoms, what could cause this? Learning rate too high? Data preprocessed incorrectly? Wrong loss function?
  3. Design a test: How can you confirm or reject this hypothesis? Change one thing and observe.
  4. Analyze results: Did the change help, hurt, or have no effect? What does that tell you?
  5. Iterate: Based on results, refine your hypothesis and repeat.

The Cardinal Rule

Change One Thing at a Time. If you change the learning rate, batch size, and architecture simultaneously, you won't know which change helped or hurt. Isolate variables rigorously.

Start Simple, Add Complexity

The most effective debugging strategy is simplification:

  • Smaller model: Can a tiny model learn anything? If not, the problem isn't model capacity.
  • Smaller dataset: Can you overfit 10 examples? If not, there's a fundamental bug.
  • Simpler task: Can your model solve an easier version of the problem?
  • Known baseline: Does a simple baseline (logistic regression, random forest) work? If not, the data might be the problem.

Sanity Checks Before Training

Before investing hours in training, run these quick sanity checks to catch obvious bugs:

Check 1: Initial Loss Value

Before any training, the loss should match the theoretical value for random predictions:

TaskExpected Initial LossFormula
Binary classification-ln(0.5) ≈ 0.693-ln(1/2)
10-class classification-ln(0.1) ≈ 2.303-ln(1/C) for C classes
100-class classification-ln(0.01) ≈ 4.605-ln(1/C)
Regression (normalized)~1.0Variance of targets

If your initial loss is significantly different, something is wrong with your loss function, data labels, or model output format.

🐍check_initial_loss.py
1# Check initial loss for 10-class classification
2model.eval()
3with torch.no_grad():
4    logits = model(sample_batch)
5    loss = F.cross_entropy(logits, labels)
6    print(f"Initial loss: {loss.item():.4f}")
7    print(f"Expected (random): {-np.log(1/10):.4f}")  # ~2.303
8
9# If initial loss is much lower, you might be:
10# - Accidentally using the labels in the forward pass
11# - Predicting the same class for everything (check logits)

Check 2: Can You Overfit a Single Batch?

The most important sanity check: your model should be able to perfectly memorize a small batch of data. If it can't, something is fundamentally broken.

Overfitting Sanity Check
🐍overfit_test.py
1The Most Important Test

If your model cannot memorize a tiny dataset, it cannot learn anything meaningful from a large one. This catches architecture bugs, loss function issues, and optimizer problems.

8High Learning Rate OK Here

For overfitting a single batch, a relatively high learning rate is fine. We want to see if learning is possible at all, not if it generalizes.

24Expect Near-Perfect Results

After 1000 steps on a single batch, loss should approach zero and accuracy should approach 100%. If not, you have a bug.

33 lines without explanation
1def overfit_single_batch(model, batch, loss_fn, num_steps=1000):
2    """Test if model can memorize a single batch."""
3    model.train()
4    inputs, targets = batch
5    inputs = inputs.to(device)
6    targets = targets.to(device)
7
8    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
9
10    for step in range(num_steps):
11        outputs = model(inputs)
12        loss = loss_fn(outputs, targets)
13
14        optimizer.zero_grad()
15        loss.backward()
16        optimizer.step()
17
18        if step % 100 == 0:
19            print(f"Step {step}: loss = {loss.item():.6f}")
20
21    # Final evaluation
22    model.eval()
23    with torch.no_grad():
24        outputs = model(inputs)
25        final_loss = loss_fn(outputs, targets)
26
27        # For classification, check accuracy
28        if outputs.dim() > 1 and outputs.size(1) > 1:
29            preds = outputs.argmax(dim=1)
30            accuracy = (preds == targets).float().mean()
31            print(f"Final accuracy: {accuracy.item():.2%}")
32
33        print(f"Final loss: {final_loss.item():.6f}")
34
35    # Should achieve near-zero loss and ~100% accuracy
36    return final_loss.item()

If This Test Fails

If you can't overfit a single batch, do NOT proceed to full training. The possible causes are:
  • Model outputs wrong shape (check dimensions carefully)
  • Loss function incompatible with output format
  • Data labels don't match expected format
  • Severe gradient issues (vanishing or exploding)
  • Bug in the forward pass

Check 3: Verify Gradient Flow

Ensure gradients actually flow through all parameters:

🐍check_gradients.py
1def check_gradient_flow(model):
2    """Verify all parameters receive gradients."""
3    print("Checking gradient flow...")
4
5    for name, param in model.named_parameters():
6        if param.grad is None:
7            print(f"❌ {name}: NO GRADIENT")
8        elif param.grad.abs().max() == 0:
9            print(f"⚠️  {name}: gradient is all zeros")
10        else:
11            grad_mean = param.grad.abs().mean().item()
12            grad_max = param.grad.abs().max().item()
13            print(f"✓ {name}: mean={grad_mean:.2e}, max={grad_max:.2e}")
14
15# After a backward pass:
16loss.backward()
17check_gradient_flow(model)

Check 4: Verify Data Loading

Data bugs are extremely common. Always visualize your data after all preprocessing:

🐍check_data.py
1# For images
2def visualize_batch(dataloader):
3    batch = next(iter(dataloader))
4    images, labels = batch
5
6    print(f"Batch shape: {images.shape}")
7    print(f"Labels: {labels[:10]}")
8    print(f"Image range: [{images.min():.2f}, {images.max():.2f}]")
9
10    # Denormalize if needed (assuming ImageNet normalization)
11    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
12    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
13    images_denorm = images * std + mean
14
15    # Plot first few images with labels
16    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
17    for i, ax in enumerate(axes.flat):
18        img = images_denorm[i].permute(1, 2, 0).clip(0, 1)
19        ax.imshow(img)
20        ax.set_title(f"Label: {labels[i].item()}")
21        ax.axis('off')
22    plt.show()
23
24# Common data bugs to check:
25# - Labels don't match images (shuffling bug)
26# - Images are blank or corrupted
27# - Normalization is wrong (values outside [-3, 3])
28# - Class imbalance is extreme

Quick Check

Your model's initial loss for a 100-class classification task is 0.02. What does this suggest?


Common Failure Modes

Learn to recognize these common failure patterns and their symptoms:

1. Vanishing Gradients

Symptoms:

  • Loss barely changes or decreases extremely slowly
  • Early layers have much smaller gradients than later layers
  • Weights in early layers don't change during training
  • Deep networks fail while shallow ones succeed

Causes:

  • Sigmoid/tanh activations saturating
  • Very deep networks without skip connections
  • Poor weight initialization

Solutions:

  • Use ReLU or variants (LeakyReLU, GELU)
  • Add skip/residual connections
  • Use batch normalization
  • Proper initialization (He, Xavier)

2. Exploding Gradients

Symptoms:

  • Loss becomes NaN or Inf
  • Weights grow extremely large
  • Training becomes unstable with wild oscillations
  • Gradient norms increase exponentially over time

Causes:

  • Learning rate too high
  • Poor initialization
  • Deep RNNs without gradient clipping
  • Numerical instability in loss computation

Solutions:

  • Reduce learning rate
  • Apply gradient clipping
  • Use gradient norm monitoring
  • Better initialization
🐍gradient_clipping.py
1# Gradient clipping to prevent explosion
2loss.backward()
3torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
4optimizer.step()
5
6# Monitor gradient norms
7total_norm = 0
8for p in model.parameters():
9    if p.grad is not None:
10        param_norm = p.grad.data.norm(2)
11        total_norm += param_norm.item() ** 2
12total_norm = total_norm ** 0.5
13print(f"Gradient norm: {total_norm:.4f}")

3. Dead ReLU Neurons

Symptoms:

  • Many neurons output exactly zero for all inputs
  • Gradients for some weights are permanently zero
  • Model capacity seems lower than expected
  • Adding more neurons doesn't help

Causes:

  • Large negative biases after initialization
  • Learning rate too high, pushing neurons negative
  • All inputs to a neuron become negative

Solutions:

  • Use LeakyReLU, ELU, or GELU instead
  • Lower learning rate
  • Initialize biases to small positive values
🐍check_dead_neurons.py
1def check_dead_neurons(model, dataloader, threshold=0.01):
2    """Check for dead ReLU neurons."""
3    activation_counts = {}
4
5    def hook_fn(name):
6        def hook(module, input, output):
7            # Count how often each neuron fires (output > 0)
8            active = (output > 0).float().mean(dim=0)
9            if name not in activation_counts:
10                activation_counts[name] = []
11            activation_counts[name].append(active.cpu())
12        return hook
13
14    # Register hooks on ReLU layers
15    hooks = []
16    for name, module in model.named_modules():
17        if isinstance(module, nn.ReLU):
18            hooks.append(module.register_forward_hook(hook_fn(name)))
19
20    # Run through some data
21    model.eval()
22    with torch.no_grad():
23        for i, (x, _) in enumerate(dataloader):
24            if i >= 10:  # Sample 10 batches
25                break
26            model(x.to(device))
27
28    # Remove hooks
29    for h in hooks:
30        h.remove()
31
32    # Analyze
33    for name, counts in activation_counts.items():
34        avg_activity = torch.stack(counts).mean(dim=0)
35        dead_ratio = (avg_activity < threshold).float().mean()
36        print(f"{name}: {dead_ratio:.1%} neurons dead")

4. Learning Rate Problems

SymptomLikely CauseFix
Loss doesn't decrease at allLR too lowIncrease LR by 10x
Loss oscillates wildlyLR too highDecrease LR by 10x
Loss decreases then increasesLR too high for fine-tuningDecrease LR, add warmup
Loss plateaus earlyLR decay too aggressiveSlower decay schedule
Loss NaN after a few iterationsLR way too highDecrease LR significantly

Interactive: Gradient Health Monitor

Explore how gradient health varies across a neural network. This visualization shows gradient magnitudes through different layers, helping you identify vanishing or exploding gradient problems:

Use the controls to:

  • Adjust network depth to see how gradients change in deeper networks
  • Toggle between different activation functions to see their effect on gradient flow
  • Experiment with different initialization schemes
  • Observe the gradient magnitude at each layer

Quick Check

When using sigmoid activation in a 10-layer network, gradients in early layers are 1000x smaller than in later layers. What is this called?


Debugging the Loss Curve

The loss curve is your primary diagnostic tool. Learn to read it like a doctor reads vital signs:

Healthy Loss Curve Characteristics

  • Initial drop: Loss should decrease noticeably in the first few epochs
  • Smooth decrease: Some noise is normal, but overall trend is down
  • Eventual plateau: Loss stabilizes as model converges
  • Small train-val gap: Validation loss tracks training loss

Pathological Loss Curves

Loss Doesn't Decrease

  • Check learning rate (try 10x higher and 10x lower)
  • Verify gradients are non-zero
  • Confirm loss function matches task
  • Check for data preprocessing bugs

Loss Oscillates Wildly

  • Learning rate too high
  • Batch size too small
  • Label noise in data
  • Try gradient clipping

Loss Increases After Initial Decrease

  • Learning rate too high
  • No learning rate decay
  • Training on corrupted data
  • Catastrophic forgetting (if fine-tuning)

Validation Loss Diverges from Training Loss

  • Overfitting: add regularization, get more data
  • Training/validation distribution mismatch
  • Data leakage in validation set

Loss Becomes NaN

  • Numerical instability: use log-sum-exp tricks
  • Learning rate explosion
  • Division by zero or log of zero
  • Invalid inputs (NaN in data)
🐍nan_debugging.py
1# Debugging NaN losses
2def check_for_nans(name):
3    def hook(module, input, output):
4        if torch.isnan(output).any():
5            print(f"NaN detected in {name}")
6            print(f"  Input range: [{input[0].min():.4f}, {input[0].max():.4f}]")
7            print(f"  Input has NaN: {torch.isnan(input[0]).any()}")
8    return hook
9
10# Register on all layers
11for name, module in model.named_modules():
12    module.register_forward_hook(check_for_nans(name))
13
14# Common NaN causes:
15# 1. log(0) - add epsilon: log(x + 1e-10)
16# 2. exp(large) - use log-sum-exp: torch.logsumexp()
17# 3. Division by zero - add epsilon to denominator
18# 4. sqrt of negative - use relu: sqrt(relu(x) + 1e-10)

Interactive: Loss Curve Debugger

This interactive visualization simulates different training scenarios. Observe how various problems manifest in the loss curve and experiment with fixes:

Try different scenarios:

  • Normal training: See what a healthy loss curve looks like
  • Learning rate too high: Observe oscillation and potential divergence
  • Learning rate too low: See the painfully slow progress
  • Overfitting: Watch validation loss diverge from training
  • Gradient explosion: See the sudden NaN/explosion

Weight and Activation Analysis

Monitoring Weight Statistics

Track how weights evolve during training. Healthy training shows stable weight distributions:

Weight Statistics Monitor
🐍weight_monitor.py
7Track History

Store statistics for each parameter over time. This lets you visualize trends and spot when things go wrong.

12Key Statistics

Mean, std, min, max give you the weight distribution. Grad_norm tells you how much the weights will change.

36Log Scale for Gradients

Gradient norms can vary by orders of magnitude. Log scale makes patterns visible.

53 lines without explanation
1class WeightMonitor:
2    """Monitor weight statistics during training."""
3
4    def __init__(self, model):
5        self.model = model
6        self.history = {name: [] for name, _ in model.named_parameters()}
7
8    def record(self):
9        """Record current weight statistics."""
10        for name, param in self.model.named_parameters():
11            stats = {
12                'mean': param.data.mean().item(),
13                'std': param.data.std().item(),
14                'min': param.data.min().item(),
15                'max': param.data.max().item(),
16                'grad_norm': param.grad.norm().item() if param.grad is not None else 0,
17            }
18            self.history[name].append(stats)
19
20    def plot(self, param_name):
21        """Plot weight statistics over time."""
22        history = self.history[param_name]
23
24        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
25
26        # Mean
27        axes[0, 0].plot([h['mean'] for h in history])
28        axes[0, 0].set_title('Weight Mean')
29
30        # Std
31        axes[0, 1].plot([h['std'] for h in history])
32        axes[0, 1].set_title('Weight Std')
33
34        # Range
35        axes[1, 0].fill_between(
36            range(len(history)),
37            [h['min'] for h in history],
38            [h['max'] for h in history],
39            alpha=0.3
40        )
41        axes[1, 0].set_title('Weight Range')
42
43        # Gradient norm
44        axes[1, 1].semilogy([h['grad_norm'] for h in history])
45        axes[1, 1].set_title('Gradient Norm (log scale)')
46
47        plt.tight_layout()
48        return fig
49
50# Usage
51monitor = WeightMonitor(model)
52for epoch in range(num_epochs):
53    train_one_epoch(...)
54    monitor.record()  # Record after each epoch
55
56monitor.plot('layer1.weight')

What to Look For

ObservationPossible ProblemAction
Weight std → 0Weights collapsingCheck for gradient issues
Weight std → ∞Weights explodingAdd regularization, lower LR
Weights not changingNo gradients, LR too lowCheck gradient flow
Gradient norm → 0Vanishing gradientsFix architecture, initialization
Gradient norm → ∞Exploding gradientsClip gradients, lower LR
Some layers frozenSelective gradient issuesCheck layer connectivity

Activation Analysis

Activations (layer outputs) reveal what your model is actually computing. Healthy activations have these properties:

  • Non-zero: At least some neurons fire for most inputs
  • Bounded: Values don't grow unboundedly through layers
  • Diverse: Different neurons respond to different inputs
🐍activation_analysis.py
1class ActivationAnalyzer:
2    """Analyze activations to diagnose network health."""
3
4    def __init__(self, model):
5        self.activations = {}
6        self.hooks = []
7
8        # Register hooks on all layers
9        for name, module in model.named_modules():
10            if len(list(module.children())) == 0:  # Leaf modules only
11                hook = module.register_forward_hook(self._create_hook(name))
12                self.hooks.append(hook)
13
14    def _create_hook(self, name):
15        def hook(module, input, output):
16            if isinstance(output, torch.Tensor):
17                self.activations[name] = {
18                    'mean': output.mean().item(),
19                    'std': output.std().item(),
20                    'zero_fraction': (output == 0).float().mean().item(),
21                    'negative_fraction': (output < 0).float().mean().item(),
22                    'shape': output.shape,
23                }
24        return hook
25
26    def analyze(self, x):
27        """Run input through model and analyze activations."""
28        with torch.no_grad():
29            _ = self.model(x)
30        return self.activations
31
32    def print_summary(self):
33        """Print activation summary."""
34        for name, stats in self.activations.items():
35            print(f"{name}:")
36            print(f"  Shape: {stats['shape']}")
37            print(f"  Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}")
38            print(f"  Zero: {stats['zero_fraction']:.1%}, Negative: {stats['negative_fraction']:.1%}")
39            print()

Gradient Debugging Techniques

Numerical Gradient Checking

When you suspect your gradients are wrong, verify them numerically using finite differences. This is slow but gives ground truth:

LθL(θ+ϵ)L(θϵ)2ϵ\frac{\partial L}{\partial \theta} \approx \frac{L(\theta + \epsilon) - L(\theta - \epsilon)}{2\epsilon}
Numerical Gradient Verification
🐍gradient_check.py
1When to Use This

Use gradient checking when implementing custom layers, loss functions, or when you suspect autograd bugs. It's too slow for regular training.

5Analytical Gradients

These are the gradients computed by autograd (what you normally use). We want to verify they're correct.

20Finite Differences

For each parameter, perturb it slightly and measure how the loss changes. This is the numerical gradient.

31Relative Error

Compare analytical vs numerical. Relative error should be < 1e-5 for correct implementations.

43 lines without explanation
1def gradient_check(model, loss_fn, x, y, epsilon=1e-5):
2    """Verify analytical gradients match numerical approximation."""
3
4    # Compute analytical gradients
5    model.zero_grad()
6    output = model(x)
7    loss = loss_fn(output, y)
8    loss.backward()
9
10    errors = []
11
12    for name, param in model.named_parameters():
13        if param.grad is None:
14            continue
15
16        analytical = param.grad.clone()
17        numerical = torch.zeros_like(param)
18
19        # Compute numerical gradient for each parameter
20        for i in range(param.numel()):
21            param_flat = param.view(-1)
22
23            # Forward pass with +epsilon
24            orig_val = param_flat[i].item()
25            param_flat[i] = orig_val + epsilon
26            loss_plus = loss_fn(model(x), y)
27
28            # Forward pass with -epsilon
29            param_flat[i] = orig_val - epsilon
30            loss_minus = loss_fn(model(x), y)
31
32            # Restore and compute numerical gradient
33            param_flat[i] = orig_val
34            numerical.view(-1)[i] = (loss_plus - loss_minus) / (2 * epsilon)
35
36        # Compare
37        diff = (analytical - numerical).abs()
38        relative_error = diff / (analytical.abs() + numerical.abs() + 1e-10)
39        max_error = relative_error.max().item()
40        errors.append((name, max_error))
41
42        if max_error > 1e-4:
43            print(f"❌ {name}: relative error = {max_error:.2e}")
44        else:
45            print(f"✓ {name}: relative error = {max_error:.2e}")
46
47    return errors

Gradient Checking Pitfalls

  • Gradient checking is extremely slow (one forward pass per parameter)
  • Don't use dropout or batch norm during checking (non-deterministic)
  • Use double precision (float64) for more accurate comparisons
  • Relative error > 1e-4 usually indicates a bug

Using Hooks for Gradient Analysis

🐍gradient_hooks.py
1# Register hooks to analyze gradients during backprop
2gradient_info = {}
3
4def save_gradients(name):
5    def hook(grad):
6        gradient_info[name] = {
7            'norm': grad.norm().item(),
8            'mean': grad.mean().item(),
9            'std': grad.std().item(),
10            'has_nan': torch.isnan(grad).any().item(),
11            'has_inf': torch.isinf(grad).any().item(),
12        }
13        return grad  # Return unchanged
14    return hook
15
16# Register on parameters
17for name, param in model.named_parameters():
18    param.register_hook(save_gradients(name))
19
20# After backward:
21loss.backward()
22for name, info in gradient_info.items():
23    if info['has_nan'] or info['has_inf']:
24        print(f"❌ {name}: NaN={info['has_nan']}, Inf={info['has_inf']}")
25    else:
26        print(f"✓ {name}: norm={info['norm']:.2e}")

Common Bugs and Fixes

A catalog of bugs that have tripped up countless practitioners:

1. Forgetting model.train() / model.eval()

🐍train_eval_bug.py
1# BUG: Running training in eval mode
2model.eval()  # Oops, forgot to switch back
3for batch in train_loader:
4    loss = ...
5    loss.backward()  # Dropout is disabled! BatchNorm uses wrong stats!
6
7# FIX: Always set mode explicitly
8def train_epoch(model, ...):
9    model.train()  # ← Always do this first
10    ...
11
12def evaluate(model, ...):
13    model.eval()   # ← Always do this first
14    with torch.no_grad():
15        ...

2. Not Zeroing Gradients

🐍zero_grad_bug.py
1# BUG: Gradients accumulate across batches
2for batch in dataloader:
3    loss = model(batch)
4    loss.backward()
5    optimizer.step()  # Using accumulated gradients!
6
7# FIX: Zero gradients before backward
8for batch in dataloader:
9    optimizer.zero_grad()  # ← Add this
10    loss = model(batch)
11    loss.backward()
12    optimizer.step()

3. Inplace Operations Breaking Autograd

🐍inplace_bug.py
1# BUG: Inplace operations can break gradient computation
2x = model.layer1(input)
3x += bias  # Inplace! Might break autograd
4
5# Also problematic:
6x[:, 0] = 0  # Inplace indexing assignment
7x.add_(1)    # Inplace add
8
9# FIX: Use out-of-place operations
10x = model.layer1(input)
11x = x + bias  # Create new tensor
12
13# Or explicitly clone first:
14x = x.clone()
15x[:, 0] = 0  # Now safe

4. Data Leakage

🐍data_leakage_bug.py
1# BUG: Normalizing before splitting
2data = load_data()
3data = (data - data.mean()) / data.std()  # Uses ALL data
4train, val = split(data)  # Validation stats leaked into normalization!
5
6# FIX: Compute statistics on training set only
7train, val = split(data)
8train_mean, train_std = train.mean(), train.std()
9train = (train - train_mean) / train_std
10val = (val - train_mean) / train_std  # Use TRAINING stats

5. Wrong Loss Function

🐍loss_bug.py
1# BUG: Using wrong loss for the task
2# For multi-class classification with logits:
3loss = nn.BCELoss()(model(x), y)  # Wrong! BCE is for binary
4
5# For binary classification with logits:
6loss = nn.CrossEntropyLoss()(model(x), y)  # Wrong! CE expects class indices
7
8# FIX: Match loss to output format
9# Multi-class with logits → CrossEntropyLoss (expects raw logits, integer labels)
10# Multi-class with probabilities → NLLLoss after log_softmax
11# Binary with logits → BCEWithLogitsLoss
12# Binary with probabilities → BCELoss
13# Regression → MSELoss or L1Loss

6. Dimension Mismatch Hidden by Broadcasting

🐍broadcasting_bug.py
1# BUG: Shapes look compatible but semantics are wrong
2labels = torch.tensor([0, 1, 2])  # Shape: (3,)
3predictions = model(x)  # Shape: (3, 10) - logits for 10 classes
4
5# This doesn't crash but computes wrong loss!
6loss = (predictions - labels) ** 2  # Broadcasting: (3, 10) - (3,) → (3, 10)
7
8# FIX: Always verify shapes explicitly
9print(f"Predictions: {predictions.shape}")  # Should be (batch, classes)
10print(f"Labels: {labels.shape}")            # Should be (batch,)
11loss = nn.CrossEntropyLoss()(predictions, labels)

7. Learning Rate Issues After Loading

🐍lr_loading_bug.py
1# BUG: Creating new optimizer after loading, losing LR schedule state
2model.load_state_dict(torch.load('model.pt'))
3optimizer = optim.Adam(model.parameters(), lr=0.001)  # Resets LR!
4
5# FIX: Load optimizer state too
6checkpoint = torch.load('checkpoint.pt')
7model.load_state_dict(checkpoint['model_state_dict'])
8optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
9scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

Quick Check

Your training loss decreases but validation loss stays flat from epoch 1. What's the most likely cause?


The Debugging Checklist

When your model isn't training properly, work through this checklist systematically:

Phase 1: Data Verification

  1. Visualize raw data: Can you see what you expect? Are images right-side up? Are labels correct?
  2. Check data statistics: Mean, std, min, max. Are they reasonable? Any NaN/Inf?
  3. Verify labels: Check class distribution. Are labels in the right format (indices vs one-hot)?
  4. Check data loading: Are batches shuffled? Is data augmentation working correctly?

Phase 2: Model Sanity Checks

  1. Output shape: Does model output match expected shape for the task?
  2. Initial loss: Is it close to theoretical random baseline?
  3. Overfit single batch: Can model memorize 1-10 examples?
  4. Parameter count: Does the model have enough capacity?

Phase 3: Gradient Checks

  1. Gradient flow: Do all parameters have non-zero gradients?
  2. Gradient magnitude: Are gradients reasonable (not vanishing or exploding)?
  3. Gradient direction: Does numerical check match analytical gradients?

Phase 4: Training Dynamics

  1. Learning rate: Try 10x higher and 10x lower. Which works best?
  2. Loss curve: Is it decreasing? Oscillating? Plateauing?
  3. Weight updates: Are weights actually changing? In a reasonable range?
  4. Activation distributions: Are neurons dying? Saturating?

Phase 5: Comparison and Baselines

  1. Simple baseline: Does logistic regression or random forest work on this data?
  2. Known implementation: Does a reference implementation achieve expected results?
  3. Simpler model: Can a smaller/simpler model learn anything?
  4. Easier task: Can the model solve a simplified version of the problem?

Document Everything

Keep a debugging log. Write down what you tried, what you observed, and what you concluded. Future you will thank present you.

Related Topics

  • Chapter 8 Section 6: Gradient Flow Analysis - Theory behind vanishing/exploding gradients, skip connections, and gradient clipping
  • Section 4: Weight Initialization - Proper initialization prevents many gradient issues from the start
  • Section 5: Regularization - Early stopping, dropout, and other techniques to prevent overfitting

Summary

Debugging neural networks is challenging but systematic approaches make it manageable. Here's what we covered:

TopicKey Points
Debugging MindsetScientific method: hypothesize → test → analyze. Change one thing at a time.
Sanity ChecksInitial loss should match theory. Must be able to overfit single batch.
Vanishing GradientsEarly layers stop learning. Fix: ReLU, skip connections, proper init.
Exploding GradientsLoss becomes NaN/Inf. Fix: gradient clipping, lower LR, proper init.
Dead NeuronsReLU outputs always zero. Fix: LeakyReLU, lower LR, positive bias init.
Loss Curve AnalysisPrimary diagnostic tool. Learn the patterns of common problems.
Weight MonitoringTrack mean, std, gradient norms over time to spot issues early.
Gradient CheckingNumerical verification of analytical gradients for debugging custom layers.
Common Bugstrain/eval mode, zero_grad, inplace ops, data leakage, wrong loss.

The Golden Rules of Neural Network Debugging

  1. Start simple: Get a tiny model working on a tiny dataset first
  2. Verify everything: Never assume something works—test it
  3. Change one thing: Isolate variables to identify causes
  4. Monitor constantly: Track losses, gradients, weights, activations
  5. Compare to baselines: Know what "working" looks like

Exercises

Conceptual Questions

  1. Your 50-layer CNN trains perfectly on MNIST but fails on CIFAR-10 with the same architecture. The loss stays flat. What are three possible causes and how would you diagnose each?
  2. Explain why optimizer.zero_grad() must be called before loss.backward() but optimizer.step() can be called multiple times without zeroing. What use case does gradient accumulation serve?
  3. Why might numerical gradient checking pass for a layer but the layer still has a bug in practice? Give two scenarios.
  4. A colleague says "my model isn't learning because the loss is stuck at 2.3 for a 10-class problem." What would you tell them?

Solution Hints

  1. CIFAR-10 is harder: consider vanishing gradients in deep networks, learning rate, and whether simple features suffice.
  2. Gradients accumulate by design. Multiple steps without zeroing enables gradient accumulation for larger effective batch sizes.
  3. Gradient check passes: 1) Bug only manifests with specific input patterns, 2) Bug is in training dynamics, not forward/backward math.
  4. 2.3 is the expected random loss (-ln(1/10)). The model hasn't learned anything yet—this is the starting point, not a plateau.

Coding Exercises

  1. Build a debugging dashboard: Create a class that tracks and visualizes training loss, validation loss, gradient norms, weight norms, and learning rate over time. Update it after each batch.
  2. Implement dead neuron detection: Write a function that identifies layers where more than 50% of neurons never fire (always output zero) across a validation set.
  3. Create a gradient anomaly detector: Implement a system using hooks that automatically pauses training if gradient norms exceed a threshold or become NaN, and logs the problematic layer.
  4. Build an overfitting test suite: Create a function that automatically runs the "overfit single batch" test, varying batch sizes and learning rates, and reports whether the model can memorize small amounts of data.

Coding Exercise Hints

  • Exercise 1: Use matplotlib with live updates (plt.ion()). Store history in lists and plot after each epoch.
  • Exercise 2: Use forward hooks to capture activations. Count zero outputs across many forward passes.
  • Exercise 3: Use backward hooks on parameters. Check grad.norm() against thresholds.
  • Exercise 4: Loop over batch_size in [1, 2, 4, 8] and lr in [0.1, 0.01, 0.001]. For each, run 1000 steps and check final loss.

Congratulations! You've completed Chapter 9 on Training Neural Networks. You now understand the complete training process: from the training loop itself, through optimizers and learning rate schedules, to weight initialization, regularization, hyperparameter tuning, and debugging. In the next chapter, we'll apply this knowledge to build and train Convolutional Neural Networks for image processing.