Learning Objectives
By the end of this section, you will be able to:
- Develop a systematic debugging mindset: Understand why neural networks fail silently and how to approach debugging methodically rather than randomly
- Perform essential sanity checks: Verify your model can overfit a single batch, check gradient flow, and validate loss starts at the expected value
- Diagnose common failure modes: Identify vanishing/exploding gradients, dead neurons, learning rate issues, and data problems from symptoms
- Use debugging tools in PyTorch: Apply hooks, gradient checking, activation visualization, and weight monitoring to find bugs
- Apply a systematic debugging checklist: Follow a structured approach to eliminate possibilities and isolate the root cause of training failures
Why This Matters: Neural networks fail silently. Unlike traditional software that crashes with stack traces, a buggy neural network simply produces bad predictions. The loss might decrease, accuracy might improve, but subtle bugs can prevent your model from reaching its full potential—or worse, cause it to learn the wrong things entirely. Mastering debugging is what separates practitioners who get models to work from those who don't.
The Big Picture
Why Neural Networks Are Hard to Debug
Debugging neural networks is fundamentally different from debugging traditional software for several reasons:
- No compiler errors: Your code can be syntactically correct yet mathematically wrong. A misplaced transpose or incorrect loss function won't raise an exception.
- Stochastic nature: Results vary between runs due to random initialization and data shuffling. Was that improvement real or just noise?
- Delayed feedback: You only discover problems after hours of training, when you finally look at the results.
- Coupled components: Architecture, data, loss function, optimizer, and hyperparameters all interact. A bug in one can mask problems in another.
- No ground truth: You often don't know what the "correct" loss or accuracy should be, making it hard to know if something is wrong.
The Cost of Bugs
In traditional software, bugs cause crashes. In deep learning, bugs cause silent failures:
| Bug Type | Traditional Software | Neural Networks |
|---|---|---|
| Shape mismatch | Runtime error | Broadcasting hides the bug |
| Wrong formula | Wrong output | Training proceeds, just poorly |
| Data issue | Validation error | Model memorizes noise |
| Gradient bug | N/A | Loss stagnates mysteriously |
The Most Dangerous Bugs
The Debugging Mindset
Think Like a Scientist
Debugging neural networks requires the scientific method: form hypotheses, design experiments, observe results, and iterate. Random changes are the enemy.
- Observe the symptoms: What exactly is wrong? Loss not decreasing? Gradients exploding? Accuracy stuck at random chance?
- Form a hypothesis: Based on the symptoms, what could cause this? Learning rate too high? Data preprocessed incorrectly? Wrong loss function?
- Design a test: How can you confirm or reject this hypothesis? Change one thing and observe.
- Analyze results: Did the change help, hurt, or have no effect? What does that tell you?
- Iterate: Based on results, refine your hypothesis and repeat.
The Cardinal Rule
Change One Thing at a Time. If you change the learning rate, batch size, and architecture simultaneously, you won't know which change helped or hurt. Isolate variables rigorously.
Start Simple, Add Complexity
The most effective debugging strategy is simplification:
- Smaller model: Can a tiny model learn anything? If not, the problem isn't model capacity.
- Smaller dataset: Can you overfit 10 examples? If not, there's a fundamental bug.
- Simpler task: Can your model solve an easier version of the problem?
- Known baseline: Does a simple baseline (logistic regression, random forest) work? If not, the data might be the problem.
Sanity Checks Before Training
Before investing hours in training, run these quick sanity checks to catch obvious bugs:
Check 1: Initial Loss Value
Before any training, the loss should match the theoretical value for random predictions:
| Task | Expected Initial Loss | Formula |
|---|---|---|
| Binary classification | -ln(0.5) ≈ 0.693 | -ln(1/2) |
| 10-class classification | -ln(0.1) ≈ 2.303 | -ln(1/C) for C classes |
| 100-class classification | -ln(0.01) ≈ 4.605 | -ln(1/C) |
| Regression (normalized) | ~1.0 | Variance of targets |
If your initial loss is significantly different, something is wrong with your loss function, data labels, or model output format.
1# Check initial loss for 10-class classification
2model.eval()
3with torch.no_grad():
4 logits = model(sample_batch)
5 loss = F.cross_entropy(logits, labels)
6 print(f"Initial loss: {loss.item():.4f}")
7 print(f"Expected (random): {-np.log(1/10):.4f}") # ~2.303
8
9# If initial loss is much lower, you might be:
10# - Accidentally using the labels in the forward pass
11# - Predicting the same class for everything (check logits)Check 2: Can You Overfit a Single Batch?
The most important sanity check: your model should be able to perfectly memorize a small batch of data. If it can't, something is fundamentally broken.
If This Test Fails
- Model outputs wrong shape (check dimensions carefully)
- Loss function incompatible with output format
- Data labels don't match expected format
- Severe gradient issues (vanishing or exploding)
- Bug in the forward pass
Check 3: Verify Gradient Flow
Ensure gradients actually flow through all parameters:
1def check_gradient_flow(model):
2 """Verify all parameters receive gradients."""
3 print("Checking gradient flow...")
4
5 for name, param in model.named_parameters():
6 if param.grad is None:
7 print(f"❌ {name}: NO GRADIENT")
8 elif param.grad.abs().max() == 0:
9 print(f"⚠️ {name}: gradient is all zeros")
10 else:
11 grad_mean = param.grad.abs().mean().item()
12 grad_max = param.grad.abs().max().item()
13 print(f"✓ {name}: mean={grad_mean:.2e}, max={grad_max:.2e}")
14
15# After a backward pass:
16loss.backward()
17check_gradient_flow(model)Check 4: Verify Data Loading
Data bugs are extremely common. Always visualize your data after all preprocessing:
1# For images
2def visualize_batch(dataloader):
3 batch = next(iter(dataloader))
4 images, labels = batch
5
6 print(f"Batch shape: {images.shape}")
7 print(f"Labels: {labels[:10]}")
8 print(f"Image range: [{images.min():.2f}, {images.max():.2f}]")
9
10 # Denormalize if needed (assuming ImageNet normalization)
11 mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
12 std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
13 images_denorm = images * std + mean
14
15 # Plot first few images with labels
16 fig, axes = plt.subplots(2, 4, figsize=(12, 6))
17 for i, ax in enumerate(axes.flat):
18 img = images_denorm[i].permute(1, 2, 0).clip(0, 1)
19 ax.imshow(img)
20 ax.set_title(f"Label: {labels[i].item()}")
21 ax.axis('off')
22 plt.show()
23
24# Common data bugs to check:
25# - Labels don't match images (shuffling bug)
26# - Images are blank or corrupted
27# - Normalization is wrong (values outside [-3, 3])
28# - Class imbalance is extremeQuick Check
Your model's initial loss for a 100-class classification task is 0.02. What does this suggest?
Common Failure Modes
Learn to recognize these common failure patterns and their symptoms:
1. Vanishing Gradients
Symptoms:
- Loss barely changes or decreases extremely slowly
- Early layers have much smaller gradients than later layers
- Weights in early layers don't change during training
- Deep networks fail while shallow ones succeed
Causes:
- Sigmoid/tanh activations saturating
- Very deep networks without skip connections
- Poor weight initialization
Solutions:
- Use ReLU or variants (LeakyReLU, GELU)
- Add skip/residual connections
- Use batch normalization
- Proper initialization (He, Xavier)
2. Exploding Gradients
Symptoms:
- Loss becomes NaN or Inf
- Weights grow extremely large
- Training becomes unstable with wild oscillations
- Gradient norms increase exponentially over time
Causes:
- Learning rate too high
- Poor initialization
- Deep RNNs without gradient clipping
- Numerical instability in loss computation
Solutions:
- Reduce learning rate
- Apply gradient clipping
- Use gradient norm monitoring
- Better initialization
1# Gradient clipping to prevent explosion
2loss.backward()
3torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
4optimizer.step()
5
6# Monitor gradient norms
7total_norm = 0
8for p in model.parameters():
9 if p.grad is not None:
10 param_norm = p.grad.data.norm(2)
11 total_norm += param_norm.item() ** 2
12total_norm = total_norm ** 0.5
13print(f"Gradient norm: {total_norm:.4f}")3. Dead ReLU Neurons
Symptoms:
- Many neurons output exactly zero for all inputs
- Gradients for some weights are permanently zero
- Model capacity seems lower than expected
- Adding more neurons doesn't help
Causes:
- Large negative biases after initialization
- Learning rate too high, pushing neurons negative
- All inputs to a neuron become negative
Solutions:
- Use LeakyReLU, ELU, or GELU instead
- Lower learning rate
- Initialize biases to small positive values
1def check_dead_neurons(model, dataloader, threshold=0.01):
2 """Check for dead ReLU neurons."""
3 activation_counts = {}
4
5 def hook_fn(name):
6 def hook(module, input, output):
7 # Count how often each neuron fires (output > 0)
8 active = (output > 0).float().mean(dim=0)
9 if name not in activation_counts:
10 activation_counts[name] = []
11 activation_counts[name].append(active.cpu())
12 return hook
13
14 # Register hooks on ReLU layers
15 hooks = []
16 for name, module in model.named_modules():
17 if isinstance(module, nn.ReLU):
18 hooks.append(module.register_forward_hook(hook_fn(name)))
19
20 # Run through some data
21 model.eval()
22 with torch.no_grad():
23 for i, (x, _) in enumerate(dataloader):
24 if i >= 10: # Sample 10 batches
25 break
26 model(x.to(device))
27
28 # Remove hooks
29 for h in hooks:
30 h.remove()
31
32 # Analyze
33 for name, counts in activation_counts.items():
34 avg_activity = torch.stack(counts).mean(dim=0)
35 dead_ratio = (avg_activity < threshold).float().mean()
36 print(f"{name}: {dead_ratio:.1%} neurons dead")4. Learning Rate Problems
| Symptom | Likely Cause | Fix |
|---|---|---|
| Loss doesn't decrease at all | LR too low | Increase LR by 10x |
| Loss oscillates wildly | LR too high | Decrease LR by 10x |
| Loss decreases then increases | LR too high for fine-tuning | Decrease LR, add warmup |
| Loss plateaus early | LR decay too aggressive | Slower decay schedule |
| Loss NaN after a few iterations | LR way too high | Decrease LR significantly |
Interactive: Gradient Health Monitor
Explore how gradient health varies across a neural network. This visualization shows gradient magnitudes through different layers, helping you identify vanishing or exploding gradient problems:
Use the controls to:
- Adjust network depth to see how gradients change in deeper networks
- Toggle between different activation functions to see their effect on gradient flow
- Experiment with different initialization schemes
- Observe the gradient magnitude at each layer
Quick Check
When using sigmoid activation in a 10-layer network, gradients in early layers are 1000x smaller than in later layers. What is this called?
Debugging the Loss Curve
The loss curve is your primary diagnostic tool. Learn to read it like a doctor reads vital signs:
Healthy Loss Curve Characteristics
- Initial drop: Loss should decrease noticeably in the first few epochs
- Smooth decrease: Some noise is normal, but overall trend is down
- Eventual plateau: Loss stabilizes as model converges
- Small train-val gap: Validation loss tracks training loss
Pathological Loss Curves
Loss Doesn't Decrease
- Check learning rate (try 10x higher and 10x lower)
- Verify gradients are non-zero
- Confirm loss function matches task
- Check for data preprocessing bugs
Loss Oscillates Wildly
- Learning rate too high
- Batch size too small
- Label noise in data
- Try gradient clipping
Loss Increases After Initial Decrease
- Learning rate too high
- No learning rate decay
- Training on corrupted data
- Catastrophic forgetting (if fine-tuning)
Validation Loss Diverges from Training Loss
- Overfitting: add regularization, get more data
- Training/validation distribution mismatch
- Data leakage in validation set
Loss Becomes NaN
- Numerical instability: use log-sum-exp tricks
- Learning rate explosion
- Division by zero or log of zero
- Invalid inputs (NaN in data)
1# Debugging NaN losses
2def check_for_nans(name):
3 def hook(module, input, output):
4 if torch.isnan(output).any():
5 print(f"NaN detected in {name}")
6 print(f" Input range: [{input[0].min():.4f}, {input[0].max():.4f}]")
7 print(f" Input has NaN: {torch.isnan(input[0]).any()}")
8 return hook
9
10# Register on all layers
11for name, module in model.named_modules():
12 module.register_forward_hook(check_for_nans(name))
13
14# Common NaN causes:
15# 1. log(0) - add epsilon: log(x + 1e-10)
16# 2. exp(large) - use log-sum-exp: torch.logsumexp()
17# 3. Division by zero - add epsilon to denominator
18# 4. sqrt of negative - use relu: sqrt(relu(x) + 1e-10)Interactive: Loss Curve Debugger
This interactive visualization simulates different training scenarios. Observe how various problems manifest in the loss curve and experiment with fixes:
Try different scenarios:
- Normal training: See what a healthy loss curve looks like
- Learning rate too high: Observe oscillation and potential divergence
- Learning rate too low: See the painfully slow progress
- Overfitting: Watch validation loss diverge from training
- Gradient explosion: See the sudden NaN/explosion
Weight and Activation Analysis
Monitoring Weight Statistics
Track how weights evolve during training. Healthy training shows stable weight distributions:
What to Look For
| Observation | Possible Problem | Action |
|---|---|---|
| Weight std → 0 | Weights collapsing | Check for gradient issues |
| Weight std → ∞ | Weights exploding | Add regularization, lower LR |
| Weights not changing | No gradients, LR too low | Check gradient flow |
| Gradient norm → 0 | Vanishing gradients | Fix architecture, initialization |
| Gradient norm → ∞ | Exploding gradients | Clip gradients, lower LR |
| Some layers frozen | Selective gradient issues | Check layer connectivity |
Activation Analysis
Activations (layer outputs) reveal what your model is actually computing. Healthy activations have these properties:
- Non-zero: At least some neurons fire for most inputs
- Bounded: Values don't grow unboundedly through layers
- Diverse: Different neurons respond to different inputs
1class ActivationAnalyzer:
2 """Analyze activations to diagnose network health."""
3
4 def __init__(self, model):
5 self.activations = {}
6 self.hooks = []
7
8 # Register hooks on all layers
9 for name, module in model.named_modules():
10 if len(list(module.children())) == 0: # Leaf modules only
11 hook = module.register_forward_hook(self._create_hook(name))
12 self.hooks.append(hook)
13
14 def _create_hook(self, name):
15 def hook(module, input, output):
16 if isinstance(output, torch.Tensor):
17 self.activations[name] = {
18 'mean': output.mean().item(),
19 'std': output.std().item(),
20 'zero_fraction': (output == 0).float().mean().item(),
21 'negative_fraction': (output < 0).float().mean().item(),
22 'shape': output.shape,
23 }
24 return hook
25
26 def analyze(self, x):
27 """Run input through model and analyze activations."""
28 with torch.no_grad():
29 _ = self.model(x)
30 return self.activations
31
32 def print_summary(self):
33 """Print activation summary."""
34 for name, stats in self.activations.items():
35 print(f"{name}:")
36 print(f" Shape: {stats['shape']}")
37 print(f" Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}")
38 print(f" Zero: {stats['zero_fraction']:.1%}, Negative: {stats['negative_fraction']:.1%}")
39 print()Gradient Debugging Techniques
Numerical Gradient Checking
When you suspect your gradients are wrong, verify them numerically using finite differences. This is slow but gives ground truth:
Gradient Checking Pitfalls
- Gradient checking is extremely slow (one forward pass per parameter)
- Don't use dropout or batch norm during checking (non-deterministic)
- Use double precision (float64) for more accurate comparisons
- Relative error > 1e-4 usually indicates a bug
Using Hooks for Gradient Analysis
1# Register hooks to analyze gradients during backprop
2gradient_info = {}
3
4def save_gradients(name):
5 def hook(grad):
6 gradient_info[name] = {
7 'norm': grad.norm().item(),
8 'mean': grad.mean().item(),
9 'std': grad.std().item(),
10 'has_nan': torch.isnan(grad).any().item(),
11 'has_inf': torch.isinf(grad).any().item(),
12 }
13 return grad # Return unchanged
14 return hook
15
16# Register on parameters
17for name, param in model.named_parameters():
18 param.register_hook(save_gradients(name))
19
20# After backward:
21loss.backward()
22for name, info in gradient_info.items():
23 if info['has_nan'] or info['has_inf']:
24 print(f"❌ {name}: NaN={info['has_nan']}, Inf={info['has_inf']}")
25 else:
26 print(f"✓ {name}: norm={info['norm']:.2e}")Common Bugs and Fixes
A catalog of bugs that have tripped up countless practitioners:
1. Forgetting model.train() / model.eval()
1# BUG: Running training in eval mode
2model.eval() # Oops, forgot to switch back
3for batch in train_loader:
4 loss = ...
5 loss.backward() # Dropout is disabled! BatchNorm uses wrong stats!
6
7# FIX: Always set mode explicitly
8def train_epoch(model, ...):
9 model.train() # ← Always do this first
10 ...
11
12def evaluate(model, ...):
13 model.eval() # ← Always do this first
14 with torch.no_grad():
15 ...2. Not Zeroing Gradients
1# BUG: Gradients accumulate across batches
2for batch in dataloader:
3 loss = model(batch)
4 loss.backward()
5 optimizer.step() # Using accumulated gradients!
6
7# FIX: Zero gradients before backward
8for batch in dataloader:
9 optimizer.zero_grad() # ← Add this
10 loss = model(batch)
11 loss.backward()
12 optimizer.step()3. Inplace Operations Breaking Autograd
1# BUG: Inplace operations can break gradient computation
2x = model.layer1(input)
3x += bias # Inplace! Might break autograd
4
5# Also problematic:
6x[:, 0] = 0 # Inplace indexing assignment
7x.add_(1) # Inplace add
8
9# FIX: Use out-of-place operations
10x = model.layer1(input)
11x = x + bias # Create new tensor
12
13# Or explicitly clone first:
14x = x.clone()
15x[:, 0] = 0 # Now safe4. Data Leakage
1# BUG: Normalizing before splitting
2data = load_data()
3data = (data - data.mean()) / data.std() # Uses ALL data
4train, val = split(data) # Validation stats leaked into normalization!
5
6# FIX: Compute statistics on training set only
7train, val = split(data)
8train_mean, train_std = train.mean(), train.std()
9train = (train - train_mean) / train_std
10val = (val - train_mean) / train_std # Use TRAINING stats5. Wrong Loss Function
1# BUG: Using wrong loss for the task
2# For multi-class classification with logits:
3loss = nn.BCELoss()(model(x), y) # Wrong! BCE is for binary
4
5# For binary classification with logits:
6loss = nn.CrossEntropyLoss()(model(x), y) # Wrong! CE expects class indices
7
8# FIX: Match loss to output format
9# Multi-class with logits → CrossEntropyLoss (expects raw logits, integer labels)
10# Multi-class with probabilities → NLLLoss after log_softmax
11# Binary with logits → BCEWithLogitsLoss
12# Binary with probabilities → BCELoss
13# Regression → MSELoss or L1Loss6. Dimension Mismatch Hidden by Broadcasting
1# BUG: Shapes look compatible but semantics are wrong
2labels = torch.tensor([0, 1, 2]) # Shape: (3,)
3predictions = model(x) # Shape: (3, 10) - logits for 10 classes
4
5# This doesn't crash but computes wrong loss!
6loss = (predictions - labels) ** 2 # Broadcasting: (3, 10) - (3,) → (3, 10)
7
8# FIX: Always verify shapes explicitly
9print(f"Predictions: {predictions.shape}") # Should be (batch, classes)
10print(f"Labels: {labels.shape}") # Should be (batch,)
11loss = nn.CrossEntropyLoss()(predictions, labels)7. Learning Rate Issues After Loading
1# BUG: Creating new optimizer after loading, losing LR schedule state
2model.load_state_dict(torch.load('model.pt'))
3optimizer = optim.Adam(model.parameters(), lr=0.001) # Resets LR!
4
5# FIX: Load optimizer state too
6checkpoint = torch.load('checkpoint.pt')
7model.load_state_dict(checkpoint['model_state_dict'])
8optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
9scheduler.load_state_dict(checkpoint['scheduler_state_dict'])Quick Check
Your training loss decreases but validation loss stays flat from epoch 1. What's the most likely cause?
The Debugging Checklist
When your model isn't training properly, work through this checklist systematically:
Phase 1: Data Verification
- Visualize raw data: Can you see what you expect? Are images right-side up? Are labels correct?
- Check data statistics: Mean, std, min, max. Are they reasonable? Any NaN/Inf?
- Verify labels: Check class distribution. Are labels in the right format (indices vs one-hot)?
- Check data loading: Are batches shuffled? Is data augmentation working correctly?
Phase 2: Model Sanity Checks
- Output shape: Does model output match expected shape for the task?
- Initial loss: Is it close to theoretical random baseline?
- Overfit single batch: Can model memorize 1-10 examples?
- Parameter count: Does the model have enough capacity?
Phase 3: Gradient Checks
- Gradient flow: Do all parameters have non-zero gradients?
- Gradient magnitude: Are gradients reasonable (not vanishing or exploding)?
- Gradient direction: Does numerical check match analytical gradients?
Phase 4: Training Dynamics
- Learning rate: Try 10x higher and 10x lower. Which works best?
- Loss curve: Is it decreasing? Oscillating? Plateauing?
- Weight updates: Are weights actually changing? In a reasonable range?
- Activation distributions: Are neurons dying? Saturating?
Phase 5: Comparison and Baselines
- Simple baseline: Does logistic regression or random forest work on this data?
- Known implementation: Does a reference implementation achieve expected results?
- Simpler model: Can a smaller/simpler model learn anything?
- Easier task: Can the model solve a simplified version of the problem?
Document Everything
Related Topics
- Chapter 8 Section 6: Gradient Flow Analysis - Theory behind vanishing/exploding gradients, skip connections, and gradient clipping
- Section 4: Weight Initialization - Proper initialization prevents many gradient issues from the start
- Section 5: Regularization - Early stopping, dropout, and other techniques to prevent overfitting
Summary
Debugging neural networks is challenging but systematic approaches make it manageable. Here's what we covered:
| Topic | Key Points |
|---|---|
| Debugging Mindset | Scientific method: hypothesize → test → analyze. Change one thing at a time. |
| Sanity Checks | Initial loss should match theory. Must be able to overfit single batch. |
| Vanishing Gradients | Early layers stop learning. Fix: ReLU, skip connections, proper init. |
| Exploding Gradients | Loss becomes NaN/Inf. Fix: gradient clipping, lower LR, proper init. |
| Dead Neurons | ReLU outputs always zero. Fix: LeakyReLU, lower LR, positive bias init. |
| Loss Curve Analysis | Primary diagnostic tool. Learn the patterns of common problems. |
| Weight Monitoring | Track mean, std, gradient norms over time to spot issues early. |
| Gradient Checking | Numerical verification of analytical gradients for debugging custom layers. |
| Common Bugs | train/eval mode, zero_grad, inplace ops, data leakage, wrong loss. |
The Golden Rules of Neural Network Debugging
- Start simple: Get a tiny model working on a tiny dataset first
- Verify everything: Never assume something works—test it
- Change one thing: Isolate variables to identify causes
- Monitor constantly: Track losses, gradients, weights, activations
- Compare to baselines: Know what "working" looks like
Exercises
Conceptual Questions
- Your 50-layer CNN trains perfectly on MNIST but fails on CIFAR-10 with the same architecture. The loss stays flat. What are three possible causes and how would you diagnose each?
- Explain why
optimizer.zero_grad()must be called beforeloss.backward()butoptimizer.step()can be called multiple times without zeroing. What use case does gradient accumulation serve? - Why might numerical gradient checking pass for a layer but the layer still has a bug in practice? Give two scenarios.
- A colleague says "my model isn't learning because the loss is stuck at 2.3 for a 10-class problem." What would you tell them?
Solution Hints
- CIFAR-10 is harder: consider vanishing gradients in deep networks, learning rate, and whether simple features suffice.
- Gradients accumulate by design. Multiple steps without zeroing enables gradient accumulation for larger effective batch sizes.
- Gradient check passes: 1) Bug only manifests with specific input patterns, 2) Bug is in training dynamics, not forward/backward math.
- 2.3 is the expected random loss (-ln(1/10)). The model hasn't learned anything yet—this is the starting point, not a plateau.
Coding Exercises
- Build a debugging dashboard: Create a class that tracks and visualizes training loss, validation loss, gradient norms, weight norms, and learning rate over time. Update it after each batch.
- Implement dead neuron detection: Write a function that identifies layers where more than 50% of neurons never fire (always output zero) across a validation set.
- Create a gradient anomaly detector: Implement a system using hooks that automatically pauses training if gradient norms exceed a threshold or become NaN, and logs the problematic layer.
- Build an overfitting test suite: Create a function that automatically runs the "overfit single batch" test, varying batch sizes and learning rates, and reports whether the model can memorize small amounts of data.
Coding Exercise Hints
- Exercise 1: Use matplotlib with live updates (plt.ion()). Store history in lists and plot after each epoch.
- Exercise 2: Use forward hooks to capture activations. Count zero outputs across many forward passes.
- Exercise 3: Use backward hooks on parameters. Check grad.norm() against thresholds.
- Exercise 4: Loop over batch_size in [1, 2, 4, 8] and lr in [0.1, 0.01, 0.001]. For each, run 1000 steps and check final loss.
Congratulations! You've completed Chapter 9 on Training Neural Networks. You now understand the complete training process: from the training loop itself, through optimizers and learning rate schedules, to weight initialization, regularization, hyperparameter tuning, and debugging. In the next chapter, we'll apply this knowledge to build and train Convolutional Neural Networks for image processing.