Chapter 13
15 min read
Section 65 of 75

Training Monitoring and Debugging

Training Translation Model

Introduction

Training neural networks can fail in many ways. This section covers how to monitor training, diagnose problems, and debug common issues.


3.1 Metrics to Monitor

Essential Metrics

🐍python
1from typing import Dict, List
2import torch
3
4
5class TrainingMonitor:
6    """
7    Monitor training progress and detect issues.
8
9    Tracks:
10    - Loss trends
11    - Gradient statistics
12    - Learning rate
13    - Memory usage
14    - Training speed
15    """
16
17    def __init__(self):
18        self.history = {
19            'train_loss': [],
20            'val_loss': [],
21            'grad_norm': [],
22            'learning_rate': [],
23            'batch_time': [],
24        }
25        self.warnings = []
26
27    def update(
28        self,
29        train_loss: float,
30        val_loss: float = None,
31        grad_norm: float = None,
32        lr: float = None,
33        batch_time: float = None
34    ):
35        """Update metrics."""
36        self.history['train_loss'].append(train_loss)
37        if val_loss is not None:
38            self.history['val_loss'].append(val_loss)
39        if grad_norm is not None:
40            self.history['grad_norm'].append(grad_norm)
41        if lr is not None:
42            self.history['learning_rate'].append(lr)
43        if batch_time is not None:
44            self.history['batch_time'].append(batch_time)
45
46        # Run checks
47        self._check_for_issues()
48
49    def _check_for_issues(self):
50        """Check for common training issues."""
51        # Check for loss explosion
52        if len(self.history['train_loss']) >= 2:
53            recent = self.history['train_loss'][-1]
54            previous = self.history['train_loss'][-2]
55
56            if recent > previous * 2:
57                self.warnings.append(
58                    f"Loss spike detected: {previous:.4f} β†’ {recent:.4f}"
59                )
60
61        # Check for NaN
62        if self.history['train_loss'][-1] != self.history['train_loss'][-1]:
63            self.warnings.append("NaN loss detected!")
64
65        # Check for gradient issues
66        if self.history['grad_norm']:
67            grad = self.history['grad_norm'][-1]
68            if grad > 100:
69                self.warnings.append(f"Large gradient norm: {grad:.2f}")
70            elif grad < 1e-7:
71                self.warnings.append(f"Vanishing gradients: {grad:.2e}")
72
73    def get_summary(self) -> Dict:
74        """Get training summary."""
75        return {
76            'epochs': len(self.history['val_loss']),
77            'best_val_loss': min(self.history['val_loss']) if self.history['val_loss'] else None,
78            'final_train_loss': self.history['train_loss'][-1] if self.history['train_loss'] else None,
79            'avg_batch_time': sum(self.history['batch_time']) / len(self.history['batch_time']) if self.history['batch_time'] else None,
80            'warnings': self.warnings,
81        }
πŸ“text
1METRICS TO MONITOR:
2
31. TRAINING LOSS:
4   ──────────────
5   What: Average cross-entropy loss per token
6   Good: Steadily decreasing
7   Bad:  Spikes, plateaus, or NaN
8
9   Typical values (Multi30k):
10     Epoch 1:  ~6.0
11     Epoch 10: ~2.5
12     Epoch 30: ~2.0
13
14
152. VALIDATION LOSS:
16   ─────────────────
17   What: Loss on held-out data
18   Good: Tracks training loss (slightly higher)
19   Bad:  Diverges from training (overfitting)
20
21   Gap analysis:
22     train=2.0, val=2.1 β†’ Good
23     train=1.5, val=3.0 β†’ Overfitting!
24
25
263. PERPLEXITY:
27   ────────────
28   What: exp(loss)
29   Good: Decreasing, eventually ~10-20
30   Bad:  Very high (>100 after warmup)
31
32   Interpretation:
33     PPL=10: Model chooses from ~10 likely words
34     PPL=100: Model very uncertain
35
36
374. GRADIENT NORM:
38   ───────────────
39   What: L2 norm of all gradients
40   Good: Stable, typically 0.1-10
41   Bad:  Exploding (>100) or vanishing (<1e-6)
42
43   Code:
44     total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm)
45
46
475. LEARNING RATE:
48   ───────────────
49   What: Current learning rate
50   Good: Follows warmup schedule
51   Bad:  Too high (unstable) or too low (slow)
52
53
546. TOKEN ACCURACY:
55   ────────────────
56   What: % of tokens predicted correctly
57   Good: Increasing, eventually ~60-70%
58   Bad:  Stuck at low values
59
60
617. MEMORY USAGE:
62   ──────────────
63   What: GPU memory consumed
64   Good: Stable
65   Bad:  Growing (memory leak)
66
67   Code:
68     torch.cuda.memory_allocated() / 1e9  # GB

3.2 Common Training Issues

Diagnosing Problems

πŸ“text
1ISSUE 1: LOSS NOT DECREASING
2────────────────────────────
3
4Symptoms:
5  - Loss stays flat or increases
6  - Validation loss also flat
7
8Causes & Solutions:
9  βœ“ Learning rate too high β†’ Reduce by 10x
10  βœ“ Learning rate too low β†’ Increase by 10x
11  βœ“ Bad initialization β†’ Check init_std
12  βœ“ Bug in data β†’ Inspect batches manually
13  βœ“ Gradient clipping too aggressive β†’ Increase max_norm
14  βœ“ Model too small β†’ Increase capacity
15
16Debugging:
17  # Check gradients are flowing
18  for name, param in model.named_parameters():
19      if param.grad is not None:
20          print(f"{name}: grad_mean={param.grad.mean():.6f}")
21
22
23ISSUE 2: NaN LOSS
24─────────────────
25
26Symptoms:
27  - Loss becomes NaN suddenly
28  - Model outputs NaN
29
30Causes & Solutions:
31  βœ“ Learning rate too high β†’ Reduce significantly
32  βœ“ No gradient clipping β†’ Add clipping (max_norm=1.0)
33  βœ“ Log(0) in loss β†’ Add small epsilon
34  βœ“ Division by zero β†’ Check normalization
35  βœ“ Exploding activations β†’ Add layer norm
36
37Debugging:
38  # Find which layer produces NaN
39  for name, module in model.named_modules():
40      def hook(m, inp, out):
41          if torch.isnan(out).any():
42              print(f"NaN in {name}")
43      module.register_forward_hook(hook)
44
45
46ISSUE 3: OVERFITTING
47────────────────────
48
49Symptoms:
50  - Train loss decreases
51  - Val loss increases or plateaus
52  - Large gap between train and val
53
54Causes & Solutions:
55  βœ“ Not enough regularization β†’ Increase dropout
56  βœ“ Too many epochs β†’ Use early stopping
57  βœ“ Model too large β†’ Reduce capacity
58  βœ“ Not enough data β†’ Add augmentation
59
60Debugging:
61  # Plot train vs val loss
62  # Look for divergence point
63
64
65ISSUE 4: SLOW CONVERGENCE
66─────────────────────────
67
68Symptoms:
69  - Loss decreases very slowly
70  - Need many epochs
71
72Causes & Solutions:
73  βœ“ Learning rate too low β†’ Increase
74  βœ“ Warmup too long β†’ Reduce warmup_steps
75  βœ“ Batch size too small β†’ Increase max_tokens
76  βœ“ Poor initialization β†’ Use better init
77
78Debugging:
79  # Check learning rate schedule
80  for step in range(1, 10000, 1000):
81      scheduler.step()
82      print(f"Step {step}: LR={optimizer.param_groups[0]['lr']:.2e}")
83
84
85ISSUE 5: UNSTABLE TRAINING
86──────────────────────────
87
88Symptoms:
89  - Loss oscillates wildly
90  - Spikes in loss
91
92Causes & Solutions:
93  βœ“ Learning rate too high β†’ Reduce
94  βœ“ Batch size too small β†’ Increase
95  βœ“ No warmup β†’ Add warmup phase
96  βœ“ Gradient clipping threshold wrong β†’ Adjust max_norm
97
98Debugging:
99  # Monitor gradient norms
100  total_norm = 0
101  for p in model.parameters():
102      if p.grad is not None:
103          total_norm += p.grad.data.norm(2).item() ** 2
104  total_norm = total_norm ** 0.5
105  print(f"Gradient norm: {total_norm:.4f}")

3.3 Debugging Tools

Practical Debugging Functions

🐍python
1def gradient_check(model: torch.nn.Module) -> Dict[str, float]:
2    """
3    Check gradient statistics for each layer.
4
5    Returns statistics about gradient flow.
6    """
7    stats = {}
8
9    for name, param in model.named_parameters():
10        if param.grad is not None:
11            grad = param.grad.data
12            stats[name] = {
13                'mean': grad.mean().item(),
14                'std': grad.std().item(),
15                'max': grad.max().item(),
16                'min': grad.min().item(),
17                'has_nan': torch.isnan(grad).any().item(),
18            }
19
20    return stats
21
22
23def print_gradient_stats(model: torch.nn.Module, top_k: int = 10):
24    """
25    Print gradient statistics for debugging.
26    """
27    print("\nGradient Statistics:")
28    print("-" * 70)
29    print(f"{'Layer':<40} {'Mean':>10} {'Std':>10} {'Max':>10}")
30    print("-" * 70)
31
32    stats = gradient_check(model)
33
34    # Sort by absolute mean
35    sorted_stats = sorted(
36        stats.items(),
37        key=lambda x: abs(x[1]['mean']),
38        reverse=True
39    )
40
41    for name, s in sorted_stats[:top_k]:
42        name_short = name[-38:] if len(name) > 40 else name
43        print(f"{name_short:<40} {s['mean']:>10.2e} {s['std']:>10.2e} {s['max']:>10.2e}")
44
45
46def check_for_dead_neurons(model: torch.nn.Module, threshold: float = 1e-6):
47    """
48    Check for neurons that never activate.
49    """
50    dead_count = 0
51    total_count = 0
52
53    for name, param in model.named_parameters():
54        if 'weight' in name and len(param.shape) >= 2:
55            # Check each output neuron
56            neuron_norms = param.data.norm(dim=-1)
57            dead = (neuron_norms < threshold).sum().item()
58            dead_count += dead
59            total_count += neuron_norms.numel()
60
61    if total_count > 0:
62        print(f"Dead neurons: {dead_count}/{total_count} ({100*dead_count/total_count:.2f}%)")
63
64
65def memory_report():
66    """
67    Print GPU memory usage.
68    """
69    if torch.cuda.is_available():
70        allocated = torch.cuda.memory_allocated() / 1e9
71        reserved = torch.cuda.memory_reserved() / 1e9
72        max_allocated = torch.cuda.max_memory_allocated() / 1e9
73
74        print(f"\nGPU Memory:")
75        print(f"  Allocated: {allocated:.2f} GB")
76        print(f"  Reserved:  {reserved:.2f} GB")
77        print(f"  Max alloc: {max_allocated:.2f} GB")
78    else:
79        print("No GPU available")
80
81
82def debug_batch(batch: Dict[str, torch.Tensor]):
83    """
84    Debug a single batch.
85    """
86    print("\nBatch Debug Info:")
87    print("-" * 40)
88
89    for key, tensor in batch.items():
90        if isinstance(tensor, torch.Tensor):
91            print(f"{key}:")
92            print(f"  Shape: {tensor.shape}")
93            print(f"  Dtype: {tensor.dtype}")
94            print(f"  Device: {tensor.device}")
95            print(f"  Min/Max: {tensor.min().item():.4f} / {tensor.max().item():.4f}")
96            if tensor.dtype in [torch.float32, torch.float16]:
97                print(f"  Has NaN: {torch.isnan(tensor).any().item()}")

3.4 Training Visualization

Plotting Training Progress

🐍python
1def plot_training_progress_ascii(
2    train_losses: List[float],
3    val_losses: List[float],
4    height: int = 15,
5    width: int = 60
6):
7    """
8    ASCII plot of training progress.
9    """
10    print("\nTraining Progress")
11    print("=" * (width + 10))
12
13    if not train_losses:
14        print("No data to plot")
15        return
16
17    # Combine for scale
18    all_losses = train_losses + val_losses
19    min_loss = min(all_losses)
20    max_loss = max(all_losses)
21    loss_range = max_loss - min_loss if max_loss > min_loss else 1
22
23    # Create plot grid
24    grid = [[' ' for _ in range(width)] for _ in range(height)]
25
26    # Plot train losses
27    for i, loss in enumerate(train_losses):
28        x = int(i / len(train_losses) * (width - 1))
29        y = int((1 - (loss - min_loss) / loss_range) * (height - 1))
30        y = max(0, min(height - 1, y))
31        grid[y][x] = '●'  # Train point
32
33    # Plot val losses
34    for i, loss in enumerate(val_losses):
35        x = int(i / len(val_losses) * (width - 1))
36        y = int((1 - (loss - min_loss) / loss_range) * (height - 1))
37        y = max(0, min(height - 1, y))
38        if grid[y][x] == '●':
39            grid[y][x] = 'β—†'  # Both
40        else:
41            grid[y][x] = 'β—‹'  # Val point
42
43    # Print plot
44    for i, row in enumerate(grid):
45        loss_label = max_loss - (i / (height - 1)) * loss_range
46        print(f"{loss_label:6.2f} |{''.join(row)}|")
47
48    print(f"       +{'-' * width}+")
49    print(f"        0{' ' * (width - 10)}epochs")
50    print(f"\n● Train  β—‹ Val")
51
52
53def demonstrate_visualization():
54    """
55    Demonstrate training visualization.
56    """
57    import random
58
59    # Generate sample data
60    train_losses = [6.0]
61    val_losses = [5.5]
62
63    for i in range(29):
64        train_losses.append(train_losses[-1] * 0.92 + random.random() * 0.1)
65        val_losses.append(val_losses[-1] * 0.93 + random.random() * 0.15)
66
67    plot_training_progress_ascii(train_losses, val_losses)
68
69
70demonstrate_visualization()

3.5 Diagnostic Checklist

Pre-Training Checklist

πŸ“text
1PRE-TRAINING CHECKLIST:
2
3DATA:
4☐ Tokenizer loaded correctly
5☐ Data shapes match expectations
6☐ No NaN/Inf in data
7☐ Special tokens correct (PAD, BOS, EOS)
8☐ Data loader produces batches
9
10MODEL:
11☐ Model on correct device (GPU)
12☐ Model parameters initialized
13☐ Forward pass works without error
14☐ Output shape is [batch, seq, vocab]
15
16LOSS:
17☐ Loss computes without error
18☐ Loss is finite (not NaN/Inf)
19☐ Padding properly ignored
20
21OPTIMIZER:
22☐ All parameters registered
23☐ Learning rate is reasonable
24☐ Weight decay applied correctly
25
26TRAINING:
27☐ Backward pass completes
28☐ Gradients are non-zero
29☐ Gradient clipping active
30☐ Loss decreases on first batch (sanity check)
31
32
33QUICK SANITY TEST:
34──────────────────
35
36# Overfit one batch
37model.train()
38batch = next(iter(train_loader))
39
40for _ in range(100):
41    loss = train_step(batch)
42    print(f"Loss: {loss:.4f}")
43
44# Loss should decrease to ~0 if model can learn
45
46
47POST-EPOCH CHECKLIST:
48
49AFTER EACH EPOCH, CHECK:
50────────────────────────
51
52☐ Training loss decreased
53☐ Validation loss reasonable
54☐ No NaN in losses
55☐ Learning rate schedule correct
56☐ Checkpoint saved
57☐ Memory usage stable
58
59IF TRAIN LOSS NOT DECREASING:
60☐ Check learning rate (too high/low?)
61☐ Check gradients (vanishing/exploding?)
62☐ Check data (correct format?)
63☐ Try overfitting one batch
64
65IF VAL LOSS DIVERGING FROM TRAIN:
66☐ Increase dropout
67☐ Add weight decay
68☐ Reduce model size
69☐ Early stopping
70
71IF TRAINING TOO SLOW:
72☐ Use mixed precision
73☐ Increase batch size
74☐ Profile for bottlenecks
75☐ Check data loading

Summary

Monitoring Priorities

PriorityMetricAction if Bad
1Loss (NaN)Stop immediately, debug
2Train loss trendAdjust LR, check data
3Val loss gapAdd regularization
4Gradient normsClip gradients
5MemoryProfile, reduce batch

Quick Debug Commands

🐍python
1# Check for NaN
2torch.isnan(loss).any()
3
4# Gradient norm
5torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
6
7# Memory usage
8torch.cuda.memory_allocated() / 1e9
9
10# Learning rate
11optimizer.param_groups[0]['lr']

Chapter Summary

In this chapter, we covered the complete training pipeline:

  • Configuration: Model and training configs
  • Training Script: Complete trainer implementation
  • Monitoring: Metrics, debugging, visualization

Expected results after training:

  • Val loss: ~2.2-2.5
  • Val PPL: ~9-12
  • Token accuracy: ~65-70%
  • BLEU (next chapter): ~30-35

Exercises

Debugging

  • Intentionally break training (e.g., high LR) and practice debugging.
  • Implement automatic NaN detection with training halt.
  • Create a training dashboard with live metrics.

Analysis

  • Profile training to find the slowest components.
  • Compare training curves with different hyperparameters.

Next Chapter Preview: In the next chapter, we'll cover Model Inference and Demoβ€”generating translations with the trained model and creating an interactive demo.

Loading comments...