Chapter 10
20 min read
Section 53 of 75

Complete Training Loop

Training Pipeline

Introduction

This section brings together all components into a complete, production-ready training loop for translation: data loading, forward pass, loss computation, backpropagation, optimization, validation, and logging.


Training Components Overview

Required Components

πŸ“text
1β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
2β”‚                     TRAINING PIPELINE                            β”‚
3β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
4β”‚                                                                  β”‚
5β”‚  DATA                    MODEL                    OPTIMIZATION   β”‚
6β”‚  ────                    ─────                    ────────────   β”‚
7β”‚  DataLoader              Transformer             Adam Optimizer  β”‚
8β”‚  Collator                Encoder                 LR Scheduler    β”‚
9β”‚  Tokenizer               Decoder                 Grad Clipping   β”‚
10β”‚                                                                  β”‚
11β”‚  LOSS                    LOGGING                 CHECKPOINTING   β”‚
12β”‚  ────                    ───────                 ─────────────   β”‚
13β”‚  Label Smoothing         TensorBoard             Save/Load       β”‚
14β”‚  Padding Mask            Progress Bar            Best Model      β”‚
15β”‚  Metrics                 Metrics                 Resume          β”‚
16β”‚                                                                  β”‚
17β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Trainer Class

Complete Implementation

🐍python
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import DataLoader
5from typing import Optional, Dict, Any, Callable
6import time
7import os
8from pathlib import Path
9
10
11class TranslationTrainer:
12    """
13    Complete trainer for translation models.
14
15    Handles:
16    - Training loop with gradient clipping
17    - Validation with metrics
18    - Learning rate scheduling
19    - Checkpointing and resuming
20    - Logging (console and TensorBoard)
21
22    Args:
23        model: Transformer model
24        train_loader: Training data loader
25        val_loader: Validation data loader
26        optimizer: Optimizer instance
27        scheduler: LR scheduler
28        criterion: Loss function
29        config: Training configuration dict
30        device: Training device
31    """
32
33    def __init__(
34        self,
35        model: nn.Module,
36        train_loader: DataLoader,
37        val_loader: DataLoader,
38        optimizer: optim.Optimizer,
39        scheduler,
40        criterion: nn.Module,
41        config: Dict[str, Any],
42        device: torch.device
43    ):
44        self.model = model.to(device)
45        self.train_loader = train_loader
46        self.val_loader = val_loader
47        self.optimizer = optimizer
48        self.scheduler = scheduler
49        self.criterion = criterion
50        self.config = config
51        self.device = device
52
53        # Training state
54        self.global_step = 0
55        self.epoch = 0
56        self.best_val_loss = float('inf')
57
58        # Logging
59        self.log_interval = config.get('log_interval', 100)
60        self.save_dir = Path(config.get('save_dir', 'checkpoints'))
61        self.save_dir.mkdir(parents=True, exist_ok=True)
62
63        # Optional TensorBoard
64        self.writer = None
65        if config.get('use_tensorboard', False):
66            try:
67                from torch.utils.tensorboard import SummaryWriter
68                self.writer = SummaryWriter(self.save_dir / 'logs')
69            except ImportError:
70                print("TensorBoard not available")
71
72    def train_epoch(self) -> Dict[str, float]:
73        """
74        Train for one epoch.
75
76        Returns:
77            Dictionary of training metrics
78        """
79        self.model.train()
80
81        total_loss = 0
82        total_tokens = 0
83        total_correct = 0
84        num_batches = 0
85
86        epoch_start = time.time()
87        batch_start = time.time()
88
89        for batch_idx, batch in enumerate(self.train_loader):
90            # Move to device
91            source_ids = batch['source_ids'].to(self.device)
92            target_ids = batch['target_ids'].to(self.device)
93
94            # Forward pass
95            self.optimizer.zero_grad()
96
97            # Model expects target[:, :-1] as input
98            logits = self.model(source_ids, target_ids[:, :-1])
99
100            # Compute loss
101            loss_output = self.criterion(logits, target_ids)
102            loss = loss_output['loss']
103
104            # Backward pass
105            loss.backward()
106
107            # Gradient clipping
108            grad_norm = torch.nn.utils.clip_grad_norm_(
109                self.model.parameters(),
110                self.config.get('max_grad_norm', 1.0)
111            )
112
113            # Optimizer step
114            self.optimizer.step()
115            self.scheduler.step()
116
117            # Update stats
118            self.global_step += 1
119            total_loss += loss.item() * loss_output['num_tokens'].item()
120            total_tokens += loss_output['num_tokens'].item()
121            total_correct += (
122                loss_output['accuracy'].item() *
123                loss_output['num_tokens'].item()
124            )
125            num_batches += 1
126
127            # Logging
128            if self.global_step % self.log_interval == 0:
129                elapsed = time.time() - batch_start
130                tokens_per_sec = loss_output['num_tokens'].item() / elapsed
131
132                lr = self.optimizer.param_groups[0]['lr']
133
134                print(
135                    f"Step {self.global_step:6d} | "
136                    f"Loss {loss.item():.4f} | "
137                    f"LR {lr:.6f} | "
138                    f"Grad {grad_norm:.2f} | "
139                    f"Tok/s {tokens_per_sec:.0f}"
140                )
141
142                if self.writer:
143                    self.writer.add_scalar('train/loss', loss.item(), self.global_step)
144                    self.writer.add_scalar('train/lr', lr, self.global_step)
145                    self.writer.add_scalar('train/grad_norm', grad_norm, self.global_step)
146
147                batch_start = time.time()
148
149        # Epoch stats
150        avg_loss = total_loss / total_tokens
151        accuracy = total_correct / total_tokens
152        epoch_time = time.time() - epoch_start
153
154        return {
155            'loss': avg_loss,
156            'perplexity': torch.exp(torch.tensor(avg_loss)).item(),
157            'accuracy': accuracy,
158            'time': epoch_time,
159            'tokens_per_sec': total_tokens / epoch_time
160        }
161
162    @torch.no_grad()
163    def validate(self) -> Dict[str, float]:
164        """
165        Run validation.
166
167        Returns:
168            Dictionary of validation metrics
169        """
170        self.model.eval()
171
172        total_loss = 0
173        total_tokens = 0
174        total_correct = 0
175
176        for batch in self.val_loader:
177            source_ids = batch['source_ids'].to(self.device)
178            target_ids = batch['target_ids'].to(self.device)
179
180            logits = self.model(source_ids, target_ids[:, :-1])
181            loss_output = self.criterion(logits, target_ids)
182
183            total_loss += loss_output['loss'].item() * loss_output['num_tokens'].item()
184            total_tokens += loss_output['num_tokens'].item()
185            total_correct += (
186                loss_output['accuracy'].item() *
187                loss_output['num_tokens'].item()
188            )
189
190        avg_loss = total_loss / total_tokens
191        accuracy = total_correct / total_tokens
192
193        metrics = {
194            'loss': avg_loss,
195            'perplexity': torch.exp(torch.tensor(avg_loss)).item(),
196            'accuracy': accuracy
197        }
198
199        if self.writer:
200            self.writer.add_scalar('val/loss', avg_loss, self.global_step)
201            self.writer.add_scalar('val/perplexity', metrics['perplexity'], self.global_step)
202
203        return metrics
204
205    def save_checkpoint(self, path: Optional[str] = None, is_best: bool = False):
206        """Save training checkpoint."""
207        checkpoint = {
208            'epoch': self.epoch,
209            'global_step': self.global_step,
210            'model_state_dict': self.model.state_dict(),
211            'optimizer_state_dict': self.optimizer.state_dict(),
212            'scheduler_state_dict': self.scheduler.state_dict(),
213            'best_val_loss': self.best_val_loss,
214            'config': self.config
215        }
216
217        if path is None:
218            path = self.save_dir / f'checkpoint_epoch_{self.epoch}.pt'
219
220        torch.save(checkpoint, path)
221
222        if is_best:
223            best_path = self.save_dir / 'best_model.pt'
224            torch.save(checkpoint, best_path)
225            print(f"Saved best model to {best_path}")
226
227    def load_checkpoint(self, path: str):
228        """Load training checkpoint."""
229        checkpoint = torch.load(path, map_location=self.device)
230
231        self.model.load_state_dict(checkpoint['model_state_dict'])
232        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
233        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
234        self.epoch = checkpoint['epoch']
235        self.global_step = checkpoint['global_step']
236        self.best_val_loss = checkpoint['best_val_loss']
237
238        print(f"Loaded checkpoint from epoch {self.epoch}")
239
240    def train(self, num_epochs: int, resume_from: Optional[str] = None):
241        """
242        Full training loop.
243
244        Args:
245            num_epochs: Number of epochs to train
246            resume_from: Path to checkpoint to resume from
247        """
248        if resume_from:
249            self.load_checkpoint(resume_from)
250
251        print(f"Training for {num_epochs} epochs")
252        print(f"Training samples: {len(self.train_loader.dataset)}")
253        print(f"Validation samples: {len(self.val_loader.dataset)}")
254        print(f"Batch size: {self.train_loader.batch_size}")
255        print(f"Device: {self.device}")
256        print("=" * 60)
257
258        for epoch in range(self.epoch, num_epochs):
259            self.epoch = epoch
260
261            print(f"\nEpoch {epoch + 1}/{num_epochs}")
262            print("-" * 40)
263
264            # Train
265            train_metrics = self.train_epoch()
266            print(
267                f"Train - Loss: {train_metrics['loss']:.4f}, "
268                f"PPL: {train_metrics['perplexity']:.2f}, "
269                f"Acc: {train_metrics['accuracy']:.4f}"
270            )
271
272            # Validate
273            val_metrics = self.validate()
274            print(
275                f"Val   - Loss: {val_metrics['loss']:.4f}, "
276                f"PPL: {val_metrics['perplexity']:.2f}, "
277                f"Acc: {val_metrics['accuracy']:.4f}"
278            )
279
280            # Check if best
281            is_best = val_metrics['loss'] < self.best_val_loss
282            if is_best:
283                self.best_val_loss = val_metrics['loss']
284                print(f"New best validation loss: {self.best_val_loss:.4f}")
285
286            # Save checkpoint
287            self.save_checkpoint(is_best=is_best)
288
289        print("\nTraining complete!")
290        print(f"Best validation loss: {self.best_val_loss:.4f}")
291
292        if self.writer:
293            self.writer.close()

Training Configuration

Config Class

🐍python
1from dataclasses import dataclass, field
2from typing import Optional
3
4
5@dataclass
6class TrainingConfig:
7    """
8    Complete training configuration.
9    """
10    # Model
11    src_vocab_size: int = 32000
12    tgt_vocab_size: int = 32000
13    d_model: int = 512
14    num_heads: int = 8
15    num_layers: int = 6
16    d_ff: int = 2048
17    dropout: float = 0.1
18
19    # Data
20    data_dir: str = "data/multi30k"
21    max_source_len: int = 128
22    max_target_len: int = 128
23    batch_size: int = 32
24    num_workers: int = 4
25
26    # Training
27    num_epochs: int = 30
28    warmup_steps: int = 4000
29    label_smoothing: float = 0.1
30    max_grad_norm: float = 1.0
31
32    # Optimizer
33    learning_rate: float = 1.0
34    adam_beta1: float = 0.9
35    adam_beta2: float = 0.98
36    adam_eps: float = 1e-9
37    weight_decay: float = 0.0
38
39    # Logging
40    log_interval: int = 100
41    save_dir: str = "checkpoints"
42    use_tensorboard: bool = True
43
44    # Device
45    device: str = "cuda"
46
47    def to_dict(self) -> dict:
48        """Convert to dictionary."""
49        return {k: getattr(self, k) for k in self.__dataclass_fields__}

Training Script

Main Script Structure

🐍python
1def main():
2    """
3    Main training script.
4    """
5    import argparse
6
7    parser = argparse.ArgumentParser(description='Train translation model')
8    parser.add_argument('--config', type=str, default='base')
9    parser.add_argument('--data-dir', type=str, default='data/multi30k')
10    parser.add_argument('--save-dir', type=str, default='checkpoints')
11    parser.add_argument('--resume', type=str, default=None)
12    parser.add_argument('--epochs', type=int, default=30)
13    args = parser.parse_args()
14
15    # Load config
16    config = TrainingConfig()
17    config.data_dir = args.data_dir
18    config.save_dir = args.save_dir
19    config.num_epochs = args.epochs
20
21    # Set device
22    device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
23    print(f"Using device: {device}")
24
25    # Training steps would go here:
26    # 1. Load tokenizer
27    # 2. Create data module
28    # 3. Create model
29    # 4. Create optimizer and scheduler
30    # 5. Create criterion
31    # 6. Create trainer
32    # 7. Train
33
34    print("Training script structure shown above.")
35    print("Run with: python train.py --config base --epochs 30")

Monitoring Training

Key Metrics

πŸ“text
1KEY METRICS TO WATCH:
2─────────────────────
3
41. LOSS
5   - Should decrease steadily
6   - Validation loss should follow training loss
7   - Large gap = overfitting
8
92. PERPLEXITY
10   - exp(cross-entropy loss)
11   - Interpretable as "average branching factor"
12   - Good translation: PPL < 10
13
143. LEARNING RATE
15   - Should follow expected schedule
16   - Peak around warmup_steps
17
184. GRADIENT NORM
19   - Should be stable (not exploding)
20   - Frequent clipping = LR too high
21
225. TOKEN ACCURACY
23   - Percentage of correct tokens
24   - Should increase over training
25
266. TOKENS PER SECOND
27   - Training speed metric
28   - Should be consistent
29
30
31WARNING SIGNS:
32──────────────
33
34❌ Loss not decreasing
35   β†’ LR too low or too high
36
37❌ Loss exploding (NaN)
38   β†’ LR too high, gradients exploding
39   β†’ Reduce LR, increase warmup
40
41❌ Val loss increasing while train decreases
42   β†’ Overfitting
43   β†’ Increase dropout, use more data
44
45❌ Gradient norm very high
46   β†’ Increase gradient clipping
47   β†’ Reduce learning rate
48
49❌ Tokens/sec very low
50   β†’ Check data loading (num_workers)
51   β†’ Check batch size
52
53
54TENSORBOARD COMMANDS:
55────────────────────
56
57# Launch TensorBoard
58tensorboard --logdir=checkpoints/logs
59
60# View in browser
61http://localhost:6006

Common Training Issues

Troubleshooting

Problem: Loss doesn't decrease

Possible causes:

  • Learning rate too low or too high
  • Model too small for task
  • Data preprocessing error

Solutions:

  • Try different learning rates (1e-5 to 1e-3)
  • Check data pipeline, inspect batches
  • Verify tokenization is correct

Problem: NaN loss

Possible causes:

  • Learning rate too high
  • Gradient explosion
  • Numerical instability

Solutions:

  • Reduce learning rate
  • Increase warmup steps
  • Add gradient clipping (max_norm=1.0)
  • Use mixed precision carefully

Problem: Overfitting (train << val loss)

Possible causes:

  • Model too large for data
  • Not enough regularization
  • Training too long

Solutions:

  • Increase dropout
  • Use label smoothing
  • Add weight decay
  • Early stopping

Problem: Training very slow

Possible causes:

  • Data loading bottleneck
  • Small batch size
  • Not using GPU

Solutions:

  • Increase num_workers
  • Use pin_memory=True
  • Dynamic batching
  • Mixed precision training

Summary

Training Pipeline Components

ComponentPurpose
DataLoaderLoad and batch data
ModelForward pass
CriterionCompute loss
OptimizerUpdate weights
SchedulerAdjust learning rate
Gradient clippingStabilize training
CheckpointingSave/resume training

Training Loop Steps

  1. Load batch
  2. Move to device
  3. Forward pass
  4. Compute loss
  5. Backward pass
  6. Clip gradients
  7. Optimizer step
  8. Scheduler step
  9. Log metrics
  10. Validate periodically
  11. Save checkpoints

Exercises

Implementation

  • Add early stopping based on validation loss.
  • Implement mixed precision training (FP16).
  • Add distributed training support.

Analysis

  • Profile training to find bottlenecks.
  • Compare different optimizers (Adam vs AdamW vs SGD).

In the final section, we'll cover Checkpointing and Model Selectionβ€”best practices for saving models, selecting the best checkpoint, and deploying trained models.

Loading comments...