Introduction
Training neural networks can fail in many ways. This section covers how to monitor training, diagnose problems, and debug common issues.
3.1 Metrics to Monitor
Essential Metrics
πpython
1from typing import Dict, List
2import torch
3
4
5class TrainingMonitor:
6 """
7 Monitor training progress and detect issues.
8
9 Tracks:
10 - Loss trends
11 - Gradient statistics
12 - Learning rate
13 - Memory usage
14 - Training speed
15 """
16
17 def __init__(self):
18 self.history = {
19 'train_loss': [],
20 'val_loss': [],
21 'grad_norm': [],
22 'learning_rate': [],
23 'batch_time': [],
24 }
25 self.warnings = []
26
27 def update(
28 self,
29 train_loss: float,
30 val_loss: float = None,
31 grad_norm: float = None,
32 lr: float = None,
33 batch_time: float = None
34 ):
35 """Update metrics."""
36 self.history['train_loss'].append(train_loss)
37 if val_loss is not None:
38 self.history['val_loss'].append(val_loss)
39 if grad_norm is not None:
40 self.history['grad_norm'].append(grad_norm)
41 if lr is not None:
42 self.history['learning_rate'].append(lr)
43 if batch_time is not None:
44 self.history['batch_time'].append(batch_time)
45
46 # Run checks
47 self._check_for_issues()
48
49 def _check_for_issues(self):
50 """Check for common training issues."""
51 # Check for loss explosion
52 if len(self.history['train_loss']) >= 2:
53 recent = self.history['train_loss'][-1]
54 previous = self.history['train_loss'][-2]
55
56 if recent > previous * 2:
57 self.warnings.append(
58 f"Loss spike detected: {previous:.4f} β {recent:.4f}"
59 )
60
61 # Check for NaN
62 if self.history['train_loss'][-1] != self.history['train_loss'][-1]:
63 self.warnings.append("NaN loss detected!")
64
65 # Check for gradient issues
66 if self.history['grad_norm']:
67 grad = self.history['grad_norm'][-1]
68 if grad > 100:
69 self.warnings.append(f"Large gradient norm: {grad:.2f}")
70 elif grad < 1e-7:
71 self.warnings.append(f"Vanishing gradients: {grad:.2e}")
72
73 def get_summary(self) -> Dict:
74 """Get training summary."""
75 return {
76 'epochs': len(self.history['val_loss']),
77 'best_val_loss': min(self.history['val_loss']) if self.history['val_loss'] else None,
78 'final_train_loss': self.history['train_loss'][-1] if self.history['train_loss'] else None,
79 'avg_batch_time': sum(self.history['batch_time']) / len(self.history['batch_time']) if self.history['batch_time'] else None,
80 'warnings': self.warnings,
81 }πtext
1METRICS TO MONITOR:
2
31. TRAINING LOSS:
4 ββββββββββββββ
5 What: Average cross-entropy loss per token
6 Good: Steadily decreasing
7 Bad: Spikes, plateaus, or NaN
8
9 Typical values (Multi30k):
10 Epoch 1: ~6.0
11 Epoch 10: ~2.5
12 Epoch 30: ~2.0
13
14
152. VALIDATION LOSS:
16 βββββββββββββββββ
17 What: Loss on held-out data
18 Good: Tracks training loss (slightly higher)
19 Bad: Diverges from training (overfitting)
20
21 Gap analysis:
22 train=2.0, val=2.1 β Good
23 train=1.5, val=3.0 β Overfitting!
24
25
263. PERPLEXITY:
27 ββββββββββββ
28 What: exp(loss)
29 Good: Decreasing, eventually ~10-20
30 Bad: Very high (>100 after warmup)
31
32 Interpretation:
33 PPL=10: Model chooses from ~10 likely words
34 PPL=100: Model very uncertain
35
36
374. GRADIENT NORM:
38 βββββββββββββββ
39 What: L2 norm of all gradients
40 Good: Stable, typically 0.1-10
41 Bad: Exploding (>100) or vanishing (<1e-6)
42
43 Code:
44 total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm)
45
46
475. LEARNING RATE:
48 βββββββββββββββ
49 What: Current learning rate
50 Good: Follows warmup schedule
51 Bad: Too high (unstable) or too low (slow)
52
53
546. TOKEN ACCURACY:
55 ββββββββββββββββ
56 What: % of tokens predicted correctly
57 Good: Increasing, eventually ~60-70%
58 Bad: Stuck at low values
59
60
617. MEMORY USAGE:
62 ββββββββββββββ
63 What: GPU memory consumed
64 Good: Stable
65 Bad: Growing (memory leak)
66
67 Code:
68 torch.cuda.memory_allocated() / 1e9 # GB3.2 Common Training Issues
Diagnosing Problems
πtext
1ISSUE 1: LOSS NOT DECREASING
2ββββββββββββββββββββββββββββ
3
4Symptoms:
5 - Loss stays flat or increases
6 - Validation loss also flat
7
8Causes & Solutions:
9 β Learning rate too high β Reduce by 10x
10 β Learning rate too low β Increase by 10x
11 β Bad initialization β Check init_std
12 β Bug in data β Inspect batches manually
13 β Gradient clipping too aggressive β Increase max_norm
14 β Model too small β Increase capacity
15
16Debugging:
17 # Check gradients are flowing
18 for name, param in model.named_parameters():
19 if param.grad is not None:
20 print(f"{name}: grad_mean={param.grad.mean():.6f}")
21
22
23ISSUE 2: NaN LOSS
24βββββββββββββββββ
25
26Symptoms:
27 - Loss becomes NaN suddenly
28 - Model outputs NaN
29
30Causes & Solutions:
31 β Learning rate too high β Reduce significantly
32 β No gradient clipping β Add clipping (max_norm=1.0)
33 β Log(0) in loss β Add small epsilon
34 β Division by zero β Check normalization
35 β Exploding activations β Add layer norm
36
37Debugging:
38 # Find which layer produces NaN
39 for name, module in model.named_modules():
40 def hook(m, inp, out):
41 if torch.isnan(out).any():
42 print(f"NaN in {name}")
43 module.register_forward_hook(hook)
44
45
46ISSUE 3: OVERFITTING
47ββββββββββββββββββββ
48
49Symptoms:
50 - Train loss decreases
51 - Val loss increases or plateaus
52 - Large gap between train and val
53
54Causes & Solutions:
55 β Not enough regularization β Increase dropout
56 β Too many epochs β Use early stopping
57 β Model too large β Reduce capacity
58 β Not enough data β Add augmentation
59
60Debugging:
61 # Plot train vs val loss
62 # Look for divergence point
63
64
65ISSUE 4: SLOW CONVERGENCE
66βββββββββββββββββββββββββ
67
68Symptoms:
69 - Loss decreases very slowly
70 - Need many epochs
71
72Causes & Solutions:
73 β Learning rate too low β Increase
74 β Warmup too long β Reduce warmup_steps
75 β Batch size too small β Increase max_tokens
76 β Poor initialization β Use better init
77
78Debugging:
79 # Check learning rate schedule
80 for step in range(1, 10000, 1000):
81 scheduler.step()
82 print(f"Step {step}: LR={optimizer.param_groups[0]['lr']:.2e}")
83
84
85ISSUE 5: UNSTABLE TRAINING
86ββββββββββββββββββββββββββ
87
88Symptoms:
89 - Loss oscillates wildly
90 - Spikes in loss
91
92Causes & Solutions:
93 β Learning rate too high β Reduce
94 β Batch size too small β Increase
95 β No warmup β Add warmup phase
96 β Gradient clipping threshold wrong β Adjust max_norm
97
98Debugging:
99 # Monitor gradient norms
100 total_norm = 0
101 for p in model.parameters():
102 if p.grad is not None:
103 total_norm += p.grad.data.norm(2).item() ** 2
104 total_norm = total_norm ** 0.5
105 print(f"Gradient norm: {total_norm:.4f}")3.3 Debugging Tools
Practical Debugging Functions
πpython
1def gradient_check(model: torch.nn.Module) -> Dict[str, float]:
2 """
3 Check gradient statistics for each layer.
4
5 Returns statistics about gradient flow.
6 """
7 stats = {}
8
9 for name, param in model.named_parameters():
10 if param.grad is not None:
11 grad = param.grad.data
12 stats[name] = {
13 'mean': grad.mean().item(),
14 'std': grad.std().item(),
15 'max': grad.max().item(),
16 'min': grad.min().item(),
17 'has_nan': torch.isnan(grad).any().item(),
18 }
19
20 return stats
21
22
23def print_gradient_stats(model: torch.nn.Module, top_k: int = 10):
24 """
25 Print gradient statistics for debugging.
26 """
27 print("\nGradient Statistics:")
28 print("-" * 70)
29 print(f"{'Layer':<40} {'Mean':>10} {'Std':>10} {'Max':>10}")
30 print("-" * 70)
31
32 stats = gradient_check(model)
33
34 # Sort by absolute mean
35 sorted_stats = sorted(
36 stats.items(),
37 key=lambda x: abs(x[1]['mean']),
38 reverse=True
39 )
40
41 for name, s in sorted_stats[:top_k]:
42 name_short = name[-38:] if len(name) > 40 else name
43 print(f"{name_short:<40} {s['mean']:>10.2e} {s['std']:>10.2e} {s['max']:>10.2e}")
44
45
46def check_for_dead_neurons(model: torch.nn.Module, threshold: float = 1e-6):
47 """
48 Check for neurons that never activate.
49 """
50 dead_count = 0
51 total_count = 0
52
53 for name, param in model.named_parameters():
54 if 'weight' in name and len(param.shape) >= 2:
55 # Check each output neuron
56 neuron_norms = param.data.norm(dim=-1)
57 dead = (neuron_norms < threshold).sum().item()
58 dead_count += dead
59 total_count += neuron_norms.numel()
60
61 if total_count > 0:
62 print(f"Dead neurons: {dead_count}/{total_count} ({100*dead_count/total_count:.2f}%)")
63
64
65def memory_report():
66 """
67 Print GPU memory usage.
68 """
69 if torch.cuda.is_available():
70 allocated = torch.cuda.memory_allocated() / 1e9
71 reserved = torch.cuda.memory_reserved() / 1e9
72 max_allocated = torch.cuda.max_memory_allocated() / 1e9
73
74 print(f"\nGPU Memory:")
75 print(f" Allocated: {allocated:.2f} GB")
76 print(f" Reserved: {reserved:.2f} GB")
77 print(f" Max alloc: {max_allocated:.2f} GB")
78 else:
79 print("No GPU available")
80
81
82def debug_batch(batch: Dict[str, torch.Tensor]):
83 """
84 Debug a single batch.
85 """
86 print("\nBatch Debug Info:")
87 print("-" * 40)
88
89 for key, tensor in batch.items():
90 if isinstance(tensor, torch.Tensor):
91 print(f"{key}:")
92 print(f" Shape: {tensor.shape}")
93 print(f" Dtype: {tensor.dtype}")
94 print(f" Device: {tensor.device}")
95 print(f" Min/Max: {tensor.min().item():.4f} / {tensor.max().item():.4f}")
96 if tensor.dtype in [torch.float32, torch.float16]:
97 print(f" Has NaN: {torch.isnan(tensor).any().item()}")3.4 Training Visualization
Plotting Training Progress
πpython
1def plot_training_progress_ascii(
2 train_losses: List[float],
3 val_losses: List[float],
4 height: int = 15,
5 width: int = 60
6):
7 """
8 ASCII plot of training progress.
9 """
10 print("\nTraining Progress")
11 print("=" * (width + 10))
12
13 if not train_losses:
14 print("No data to plot")
15 return
16
17 # Combine for scale
18 all_losses = train_losses + val_losses
19 min_loss = min(all_losses)
20 max_loss = max(all_losses)
21 loss_range = max_loss - min_loss if max_loss > min_loss else 1
22
23 # Create plot grid
24 grid = [[' ' for _ in range(width)] for _ in range(height)]
25
26 # Plot train losses
27 for i, loss in enumerate(train_losses):
28 x = int(i / len(train_losses) * (width - 1))
29 y = int((1 - (loss - min_loss) / loss_range) * (height - 1))
30 y = max(0, min(height - 1, y))
31 grid[y][x] = 'β' # Train point
32
33 # Plot val losses
34 for i, loss in enumerate(val_losses):
35 x = int(i / len(val_losses) * (width - 1))
36 y = int((1 - (loss - min_loss) / loss_range) * (height - 1))
37 y = max(0, min(height - 1, y))
38 if grid[y][x] == 'β':
39 grid[y][x] = 'β' # Both
40 else:
41 grid[y][x] = 'β' # Val point
42
43 # Print plot
44 for i, row in enumerate(grid):
45 loss_label = max_loss - (i / (height - 1)) * loss_range
46 print(f"{loss_label:6.2f} |{''.join(row)}|")
47
48 print(f" +{'-' * width}+")
49 print(f" 0{' ' * (width - 10)}epochs")
50 print(f"\nβ Train β Val")
51
52
53def demonstrate_visualization():
54 """
55 Demonstrate training visualization.
56 """
57 import random
58
59 # Generate sample data
60 train_losses = [6.0]
61 val_losses = [5.5]
62
63 for i in range(29):
64 train_losses.append(train_losses[-1] * 0.92 + random.random() * 0.1)
65 val_losses.append(val_losses[-1] * 0.93 + random.random() * 0.15)
66
67 plot_training_progress_ascii(train_losses, val_losses)
68
69
70demonstrate_visualization()3.5 Diagnostic Checklist
Pre-Training Checklist
πtext
1PRE-TRAINING CHECKLIST:
2
3DATA:
4β Tokenizer loaded correctly
5β Data shapes match expectations
6β No NaN/Inf in data
7β Special tokens correct (PAD, BOS, EOS)
8β Data loader produces batches
9
10MODEL:
11β Model on correct device (GPU)
12β Model parameters initialized
13β Forward pass works without error
14β Output shape is [batch, seq, vocab]
15
16LOSS:
17β Loss computes without error
18β Loss is finite (not NaN/Inf)
19β Padding properly ignored
20
21OPTIMIZER:
22β All parameters registered
23β Learning rate is reasonable
24β Weight decay applied correctly
25
26TRAINING:
27β Backward pass completes
28β Gradients are non-zero
29β Gradient clipping active
30β Loss decreases on first batch (sanity check)
31
32
33QUICK SANITY TEST:
34ββββββββββββββββββ
35
36# Overfit one batch
37model.train()
38batch = next(iter(train_loader))
39
40for _ in range(100):
41 loss = train_step(batch)
42 print(f"Loss: {loss:.4f}")
43
44# Loss should decrease to ~0 if model can learn
45
46
47POST-EPOCH CHECKLIST:
48
49AFTER EACH EPOCH, CHECK:
50ββββββββββββββββββββββββ
51
52β Training loss decreased
53β Validation loss reasonable
54β No NaN in losses
55β Learning rate schedule correct
56β Checkpoint saved
57β Memory usage stable
58
59IF TRAIN LOSS NOT DECREASING:
60β Check learning rate (too high/low?)
61β Check gradients (vanishing/exploding?)
62β Check data (correct format?)
63β Try overfitting one batch
64
65IF VAL LOSS DIVERGING FROM TRAIN:
66β Increase dropout
67β Add weight decay
68β Reduce model size
69β Early stopping
70
71IF TRAINING TOO SLOW:
72β Use mixed precision
73β Increase batch size
74β Profile for bottlenecks
75β Check data loadingSummary
Monitoring Priorities
| Priority | Metric | Action if Bad |
|---|---|---|
| 1 | Loss (NaN) | Stop immediately, debug |
| 2 | Train loss trend | Adjust LR, check data |
| 3 | Val loss gap | Add regularization |
| 4 | Gradient norms | Clip gradients |
| 5 | Memory | Profile, reduce batch |
Quick Debug Commands
πpython
1# Check for NaN
2torch.isnan(loss).any()
3
4# Gradient norm
5torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
6
7# Memory usage
8torch.cuda.memory_allocated() / 1e9
9
10# Learning rate
11optimizer.param_groups[0]['lr']Chapter Summary
In this chapter, we covered the complete training pipeline:
- Configuration: Model and training configs
- Training Script: Complete trainer implementation
- Monitoring: Metrics, debugging, visualization
Expected results after training:
- Val loss: ~2.2-2.5
- Val PPL: ~9-12
- Token accuracy: ~65-70%
- BLEU (next chapter): ~30-35
Exercises
Debugging
- Intentionally break training (e.g., high LR) and practice debugging.
- Implement automatic NaN detection with training halt.
- Create a training dashboard with live metrics.
Analysis
- Profile training to find the slowest components.
- Compare training curves with different hyperparameters.
Next Chapter Preview: In the next chapter, we'll cover Model Inference and Demoβgenerating translations with the trained model and creating an interactive demo.