Chapter 13
20 min read
Section 64 of 75

Complete Training Script

Training Translation Model

Introduction

This section provides the complete training script that brings together all components to train the German-English translation model on Multi30k.


2.1 Training Loop Implementation

The Trainer Class

🐍python
1import torch
2import torch.nn as nn
3from torch.utils.data import DataLoader
4from typing import Dict, Any, Optional
5from pathlib import Path
6import time
7import json
8from tqdm import tqdm
9
10
11class Trainer:
12    """
13    Complete trainer for the translation model.
14
15    Handles:
16    - Training loop with gradient accumulation
17    - Validation
18    - Checkpointing
19    - Logging
20    - Mixed precision training
21
22    Args:
23        model: Transformer model
24        optimizer: Optimizer
25        scheduler: LR scheduler
26        criterion: Loss function
27        config: TrainingConfig
28        scaler: GradScaler for mixed precision
29    """
30
31    def __init__(
32        self,
33        model: nn.Module,
34        optimizer: torch.optim.Optimizer,
35        scheduler,
36        criterion: nn.Module,
37        config,
38        scaler: Optional[torch.cuda.amp.GradScaler] = None
39    ):
40        self.model = model
41        self.optimizer = optimizer
42        self.scheduler = scheduler
43        self.criterion = criterion
44        self.config = config
45        self.scaler = scaler
46
47        self.device = torch.device(config.device)
48        self.global_step = 0
49        self.best_val_loss = float('inf')
50
51        # Create directories
52        Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)
53        Path(config.log_dir).mkdir(parents=True, exist_ok=True)
54
55        # Metrics tracking
56        self.train_losses = []
57        self.val_losses = []
58        self.learning_rates = []
59
60    def train_epoch(
61        self,
62        train_loader: DataLoader,
63        epoch: int
64    ) -> Dict[str, float]:
65        """
66        Train for one epoch.
67
68        Args:
69            train_loader: Training data loader
70            epoch: Current epoch number
71
72        Returns:
73            Dictionary of epoch metrics
74        """
75        self.model.train()
76        total_loss = 0.0
77        total_tokens = 0
78        num_batches = 0
79
80        # Progress bar
81        pbar = tqdm(
82            train_loader,
83            desc=f"Epoch {epoch}",
84            leave=True
85        )
86
87        # Gradient accumulation
88        accumulation_counter = 0
89        self.optimizer.zero_grad()
90
91        for batch in pbar:
92            # Move to device
93            source = batch['source'].to(self.device)
94            source_mask = batch['source_mask'].to(self.device)
95            target_input = batch['target_input'].to(self.device)
96            target_output = batch['target_output'].to(self.device)
97
98            # Forward pass
99            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
100                logits = self.model(
101                    source,
102                    target_input,
103                    src_mask=source_mask
104                )
105
106                # Compute loss
107                loss = self.criterion(
108                    logits.reshape(-1, logits.size(-1)),
109                    target_output.reshape(-1)
110                )
111
112                # Scale for gradient accumulation
113                loss = loss / self.config.accumulation_steps
114
115            # Backward pass
116            if self.scaler is not None:
117                self.scaler.scale(loss).backward()
118            else:
119                loss.backward()
120
121            accumulation_counter += 1
122
123            # Optimizer step
124            if accumulation_counter >= self.config.accumulation_steps:
125                # Gradient clipping
126                if self.scaler is not None:
127                    self.scaler.unscale_(self.optimizer)
128
129                torch.nn.utils.clip_grad_norm_(
130                    self.model.parameters(),
131                    self.config.gradient_clip
132                )
133
134                # Update weights
135                if self.scaler is not None:
136                    self.scaler.step(self.optimizer)
137                    self.scaler.update()
138                else:
139                    self.optimizer.step()
140
141                self.scheduler.step()
142                self.optimizer.zero_grad()
143
144                self.global_step += 1
145                accumulation_counter = 0
146
147            # Track metrics
148            batch_loss = loss.item() * self.config.accumulation_steps
149            num_tokens = (target_output != self.config.pad_id).sum().item()
150            total_loss += batch_loss * num_tokens
151            total_tokens += num_tokens
152            num_batches += 1
153
154            # Update progress bar
155            current_lr = self.optimizer.param_groups[0]['lr']
156            pbar.set_postfix({
157                'loss': f'{batch_loss:.4f}',
158                'lr': f'{current_lr:.2e}',
159            })
160
161            # Logging
162            if self.global_step % self.config.log_every == 0:
163                self.learning_rates.append(current_lr)
164
165        # Epoch metrics
166        avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
167
168        return {
169            'train_loss': avg_loss,
170            'perplexity': torch.exp(torch.tensor(avg_loss)).item(),
171            'num_batches': num_batches,
172            'global_step': self.global_step,
173        }
174
175    @torch.no_grad()
176    def validate(
177        self,
178        val_loader: DataLoader
179    ) -> Dict[str, float]:
180        """
181        Validate the model.
182
183        Args:
184            val_loader: Validation data loader
185
186        Returns:
187            Dictionary of validation metrics
188        """
189        self.model.eval()
190        total_loss = 0.0
191        total_tokens = 0
192        total_correct = 0
193
194        for batch in val_loader:
195            # Move to device
196            source = batch['source'].to(self.device)
197            source_mask = batch['source_mask'].to(self.device)
198            target_input = batch['target_input'].to(self.device)
199            target_output = batch['target_output'].to(self.device)
200
201            # Forward pass
202            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
203                logits = self.model(
204                    source,
205                    target_input,
206                    src_mask=source_mask
207                )
208
209                loss = self.criterion(
210                    logits.reshape(-1, logits.size(-1)),
211                    target_output.reshape(-1)
212                )
213
214            # Track metrics
215            num_tokens = (target_output != self.config.pad_id).sum().item()
216            total_loss += loss.item() * num_tokens
217            total_tokens += num_tokens
218
219            # Accuracy
220            predictions = logits.argmax(dim=-1)
221            mask = (target_output != self.config.pad_id)
222            correct = ((predictions == target_output) & mask).sum().item()
223            total_correct += correct
224
225        # Compute metrics
226        avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
227        accuracy = total_correct / total_tokens if total_tokens > 0 else 0
228
229        return {
230            'val_loss': avg_loss,
231            'val_perplexity': torch.exp(torch.tensor(avg_loss)).item(),
232            'val_accuracy': accuracy,
233        }
234
235    def save_checkpoint(
236        self,
237        epoch: int,
238        metrics: Dict[str, float],
239        is_best: bool = False
240    ):
241        """
242        Save training checkpoint.
243        """
244        checkpoint = {
245            'epoch': epoch,
246            'global_step': self.global_step,
247            'model_state_dict': self.model.state_dict(),
248            'optimizer_state_dict': self.optimizer.state_dict(),
249            'scheduler_state_dict': self.scheduler.state_dict(),
250            'metrics': metrics,
251            'best_val_loss': self.best_val_loss,
252        }
253
254        if self.scaler is not None:
255            checkpoint['scaler_state_dict'] = self.scaler.state_dict()
256
257        # Save regular checkpoint
258        checkpoint_path = Path(self.config.checkpoint_dir) / f"checkpoint_epoch{epoch}.pt"
259        torch.save(checkpoint, checkpoint_path)
260
261        # Save best model
262        if is_best:
263            best_path = Path(self.config.checkpoint_dir) / "best_model.pt"
264            torch.save(checkpoint, best_path)
265            print(f"  New best model saved! val_loss: {metrics['val_loss']:.4f}")
266
267    def load_checkpoint(self, checkpoint_path: str) -> int:
268        """
269        Load checkpoint and return starting epoch.
270        """
271        checkpoint = torch.load(checkpoint_path, map_location=self.device)
272
273        self.model.load_state_dict(checkpoint['model_state_dict'])
274        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
275        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
276
277        if self.scaler is not None and 'scaler_state_dict' in checkpoint:
278            self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
279
280        self.global_step = checkpoint['global_step']
281        self.best_val_loss = checkpoint['best_val_loss']
282
283        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
284
285        return checkpoint['epoch'] + 1
286
287    def save_metrics(self):
288        """Save training metrics to file."""
289        metrics = {
290            'train_losses': self.train_losses,
291            'val_losses': self.val_losses,
292            'learning_rates': self.learning_rates,
293        }
294
295        metrics_path = Path(self.config.log_dir) / "training_metrics.json"
296        with open(metrics_path, 'w') as f:
297            json.dump(metrics, f, indent=2)
298
299
300def create_trainer(components: Dict[str, Any]) -> Trainer:
301    """
302    Create trainer from setup components.
303    """
304    return Trainer(
305        model=components['model'],
306        optimizer=components['optimizer'],
307        scheduler=components['scheduler'],
308        criterion=components['criterion'],
309        config=components['config'].training,
310        scaler=components['scaler'],
311    )

2.2 Main Training Script

Complete train.py

🐍python
1#!/usr/bin/env python3
2"""
3train.py - Train German-English translation model on Multi30k
4
5Usage:
6    python train.py --config configs/base.json
7    python train.py --model-size small --epochs 10
8"""
9
10import argparse
11import torch
12import random
13import numpy as np
14from pathlib import Path
15from datetime import datetime
16
17
18def set_seed(seed: int):
19    """Set all random seeds."""
20    random.seed(seed)
21    np.random.seed(seed)
22    torch.manual_seed(seed)
23    if torch.cuda.is_available():
24        torch.cuda.manual_seed_all(seed)
25
26
27def parse_args():
28    """Parse command line arguments."""
29    parser = argparse.ArgumentParser(description="Train translation model")
30
31    # Configuration
32    parser.add_argument("--config", type=str, help="Path to config file")
33    parser.add_argument("--model-size", type=str, default="base",
34                       choices=["tiny", "small", "base", "large"])
35
36    # Overrides
37    parser.add_argument("--epochs", type=int, help="Number of epochs")
38    parser.add_argument("--lr", type=float, help="Learning rate")
39    parser.add_argument("--batch-tokens", type=int, help="Max tokens per batch")
40
41    # Paths
42    parser.add_argument("--data-dir", type=str, default="data/multi30k")
43    parser.add_argument("--output-dir", type=str, default="outputs")
44
45    # Resume
46    parser.add_argument("--resume", type=str, help="Path to checkpoint to resume")
47
48    # Device
49    parser.add_argument("--device", type=str, default="cuda")
50    parser.add_argument("--no-amp", action="store_true", help="Disable mixed precision")
51
52    return parser.parse_args()
53
54
55def main():
56    args = parse_args()
57
58    # Setup experiment name
59    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
60    exp_name = f"{args.model_size}_{timestamp}"
61
62    output_dir = Path(args.output_dir) / exp_name
63    output_dir.mkdir(parents=True, exist_ok=True)
64
65    print(f"Experiment: {exp_name}")
66    print(f"Output directory: {output_dir}")
67
68    # Configuration
69    model_config = get_model_config(args.model_size)
70
71    training_config = TrainingConfig(
72        num_epochs=args.epochs or 30,
73        learning_rate=args.lr or 1e-4,
74        max_tokens=args.batch_tokens or 4096,
75        checkpoint_dir=str(output_dir / "checkpoints"),
76        log_dir=str(output_dir / "logs"),
77        device=args.device if torch.cuda.is_available() else "cpu",
78        mixed_precision=not args.no_amp and torch.cuda.is_available(),
79    )
80
81    # Set seed
82    seed = 42
83    set_seed(seed)
84
85    print(f"\nDevice: {training_config.device}")
86    print(f"Mixed precision: {training_config.mixed_precision}")
87
88    # Setup data
89    print("\nLoading data...")
90    data_config = DataConfig(
91        data_dir=args.data_dir,
92        tokenizer_path=Path(args.data_dir) / "tokenizer" / "tokenizer.json",
93        max_tokens=training_config.max_tokens,
94    )
95
96    data_module = Multi30kDataModule(data_config)
97    data_module.setup()
98
99    # Update model config with tokenizer info
100    model_config.vocab_size = data_module.vocab_size
101    model_config.pad_id = data_module.pad_id
102
103    # Build model
104    print("\nBuilding model...")
105    model = build_model(model_config)
106    model = model.to(training_config.device)
107
108    num_params = count_parameters(model)
109    print(f"Model parameters: {num_params:,}")
110
111    # Setup training components
112    print("\nSetting up training...")
113    optimizer = setup_optimizer(model, training_config)
114
115    train_loader = data_module.train_dataloader()
116    total_steps = len(train_loader) * training_config.num_epochs
117
118    scheduler = setup_scheduler(optimizer, training_config, total_steps)
119
120    criterion = setup_criterion(
121        training_config,
122        model_config.vocab_size,
123        model_config.pad_id
124    )
125
126    scaler = None
127    if training_config.mixed_precision:
128        scaler = torch.cuda.amp.GradScaler()
129
130    # Create trainer
131    trainer = Trainer(
132        model=model,
133        optimizer=optimizer,
134        scheduler=scheduler,
135        criterion=criterion,
136        config=training_config,
137        scaler=scaler,
138    )
139
140    # Resume if specified
141    start_epoch = 0
142    if args.resume:
143        start_epoch = trainer.load_checkpoint(args.resume)
144
145    # Save config
146    config_path = output_dir / "config.json"
147    with open(config_path, 'w') as f:
148        json.dump({
149            'model': model_config.to_dict(),
150            'training': training_config.to_dict(),
151            'seed': seed,
152            'num_params': num_params,
153        }, f, indent=2)
154
155    # Training loop
156    print("\n" + "=" * 60)
157    print("Starting training...")
158    print("=" * 60)
159
160    val_loader = data_module.val_dataloader()
161
162    for epoch in range(start_epoch, training_config.num_epochs):
163        print(f"\n--- Epoch {epoch + 1}/{training_config.num_epochs} ---")
164
165        # Train
166        train_metrics = trainer.train_epoch(train_loader, epoch + 1)
167
168        # Validate
169        val_metrics = trainer.validate(val_loader)
170
171        # Track metrics
172        trainer.train_losses.append(train_metrics['train_loss'])
173        trainer.val_losses.append(val_metrics['val_loss'])
174
175        # Print metrics
176        print(f"\nEpoch {epoch + 1} Results:")
177        print(f"  Train loss: {train_metrics['train_loss']:.4f}")
178        print(f"  Train PPL:  {train_metrics['perplexity']:.2f}")
179        print(f"  Val loss:   {val_metrics['val_loss']:.4f}")
180        print(f"  Val PPL:    {val_metrics['val_perplexity']:.2f}")
181        print(f"  Val Acc:    {val_metrics['val_accuracy']:.4f}")
182
183        # Check for best model
184        is_best = val_metrics['val_loss'] < trainer.best_val_loss
185        if is_best:
186            trainer.best_val_loss = val_metrics['val_loss']
187
188        # Save checkpoint
189        metrics = {**train_metrics, **val_metrics}
190        trainer.save_checkpoint(epoch + 1, metrics, is_best)
191
192    # Save final metrics
193    trainer.save_metrics()
194
195    print("\n" + "=" * 60)
196    print("Training complete!")
197    print(f"Best validation loss: {trainer.best_val_loss:.4f}")
198    print(f"Checkpoints saved to: {training_config.checkpoint_dir}")
199    print("=" * 60)
200
201
202if __name__ == "__main__":
203    main()

2.3 Training Progress Monitoring

Live Metrics Visualization

📝text
1EXPECTED OUTPUT DURING TRAINING:
2─────────────────────────────────
3
4Experiment: base_20240115_143022
5Output directory: outputs/base_20240115_143022
6
7Device: cuda
8Mixed precision: True
9
10Loading data...
11Loaded 29000 sentence pairs
12Loaded 1014 sentence pairs
13Loaded 1000 sentence pairs
14
15Building model...
16Model parameters: 65,432,576
17
18Setting up training...
19
20============================================================
21Starting training...
22============================================================
23
24--- Epoch 1/30 ---
25Epoch 1: 100%|████████████████| 453/453 [02:15<00:00, loss=6.2341, lr=2.34e-06]
26
27Epoch 1 Results:
28  Train loss: 6.2341
29  Train PPL:  509.23
30  Val loss:   5.1234
31  Val PPL:    167.89
32  Val Acc:    0.2145
33  New best model saved! val_loss: 5.1234
34
35--- Epoch 2/30 ---
36Epoch 2: 100%|████████████████| 453/453 [02:12<00:00, loss=4.8765, lr=4.68e-06]
37
38Epoch 2 Results:
39  Train loss: 4.8765
40  Train PPL:  131.12
41  Val loss:   4.2341
42  Val PPL:    68.93
43  Val Acc:    0.3456
44  New best model saved! val_loss: 4.2341
45
46...
47
48--- Epoch 30/30 ---
49Epoch 30: 100%|████████████████| 453/453 [02:10<00:00, loss=2.1234, lr=1.23e-05]
50
51Epoch 30 Results:
52  Train loss: 2.1234
53  Train PPL:  8.36
54  Val loss:   2.3456
55  Val PPL:    10.44
56  Val Acc:    0.6789
57
58============================================================
59Training complete!
60Best validation loss: 2.2891
61Checkpoints saved to: outputs/base_20240115_143022/checkpoints
62============================================================

2.4 Early Stopping

Implementing Early Stopping

🐍python
1class EarlyStopping:
2    """
3    Early stopping to prevent overfitting.
4
5    Monitors validation metric and stops training
6    if no improvement for `patience` epochs.
7
8    Args:
9        patience: Number of epochs to wait
10        min_delta: Minimum change to qualify as improvement
11        mode: 'min' or 'max' (minimize or maximize metric)
12    """
13
14    def __init__(
15        self,
16        patience: int = 5,
17        min_delta: float = 0.0,
18        mode: str = 'min'
19    ):
20        self.patience = patience
21        self.min_delta = min_delta
22        self.mode = mode
23
24        self.counter = 0
25        self.best_score = None
26        self.should_stop = False
27
28    def __call__(self, metric: float) -> bool:
29        """
30        Check if training should stop.
31
32        Args:
33            metric: Current metric value
34
35        Returns:
36            True if should stop, False otherwise
37        """
38        if self.best_score is None:
39            self.best_score = metric
40            return False
41
42        if self.mode == 'min':
43            improved = metric < self.best_score - self.min_delta
44        else:
45            improved = metric > self.best_score + self.min_delta
46
47        if improved:
48            self.best_score = metric
49            self.counter = 0
50        else:
51            self.counter += 1
52
53            if self.counter >= self.patience:
54                self.should_stop = True
55
56        return self.should_stop
57
58
59# Add to training loop
60
61early_stopping = EarlyStopping(
62    patience=5,
63    min_delta=0.01,
64    mode='min'
65)
66
67for epoch in range(num_epochs):
68    # Train and validate
69    train_metrics = trainer.train_epoch(train_loader, epoch)
70    val_metrics = trainer.validate(val_loader)
71
72    # Check early stopping
73    if early_stopping(val_metrics['val_loss']):
74        print(f"Early stopping at epoch {epoch + 1}")
75        print(f"Best validation loss: {early_stopping.best_score:.4f}")
76        break
77
78    # Save if best
79    if val_metrics['val_loss'] == early_stopping.best_score:
80        trainer.save_checkpoint(epoch, val_metrics, is_best=True)

2.5 Training Tips and Best Practices

Recommendations

📝text
1LEARNING RATE:
2──────────────
3- Start with 1e-4 for Adam/AdamW
4- Use warmup (4000 steps for base model)
5- Monitor for loss spikes (reduce LR if unstable)
6
7BATCH SIZE:
8───────────
9- Use max_tokens instead of fixed batch_size
10- 4096 tokens works well for base model
11- Larger batch → more stable gradients
12- Use gradient accumulation if GPU memory limited
13
14REGULARIZATION:
15───────────────
16- Label smoothing 0.1 (standard)
17- Dropout 0.1 (increase for smaller datasets)
18- Weight decay 0.01
19
20GRADIENT CLIPPING:
21──────────────────
22- Always clip (max_norm=1.0)
23- Prevents gradient explosion
24- Monitor gradient norms
25
26MIXED PRECISION:
27────────────────
28- Use fp16 on GPU (2x speedup)
29- Use GradScaler to prevent underflow
30- Some operations stay in fp32 automatically
31
32CHECKPOINTING:
33──────────────
34- Save every epoch (at minimum)
35- Keep best N models
36- Include optimizer state for resuming
37- Save metrics for analysis
38
39
40DEBUGGING CHECKLIST:
41────────────────────
42
43☐ Loss decreasing?
44  → If not: check LR, data, model architecture
45
46☐ Gradients flowing?
47  → Monitor gradient norms
48
49☐ Memory stable?
50  → Watch for memory leaks with long sequences
51
52☐ Training speed reasonable?
53  → Profile if slow
54
55☐ Validation improving?
56  → If train good but val bad: overfitting
57
58
59EXPECTED TIMELINE (Multi30k base model):
60────────────────────────────────────────
61
62Epoch 1:   Val loss ~5.0, PPL ~150
63Epoch 5:   Val loss ~3.0, PPL ~20
64Epoch 10:  Val loss ~2.5, PPL ~12
65Epoch 20:  Val loss ~2.3, PPL ~10
66Epoch 30:  Val loss ~2.2, PPL ~9
67
68After convergence:
69  BLEU on test: ~30-35

Summary

Training Components

ComponentPurpose
Trainer classMain training logic
train_epoch()Single epoch training
validate()Evaluation without gradients
save_checkpoint()Model persistence
EarlyStoppingPrevent overfitting

Key Settings for Multi30k

🐍python
1TrainingConfig(
2    num_epochs=30,
3    max_tokens=4096,
4    learning_rate=1e-4,
5    warmup_steps=4000,
6    gradient_clip=1.0,
7    label_smoothing=0.1,
8    mixed_precision=True,
9)

Exercises

Implementation

  • Add TensorBoard logging for metrics visualization.
  • Implement learning rate finder to determine optimal LR.
  • Add distributed training support with PyTorch DDP.

Experimentation

  • Train with different learning rates and compare convergence.
  • Compare early stopping patience values.

Next Section Preview: In the next section, we'll cover Training Monitoring and Debugging—how to diagnose and fix common training issues.

Loading comments...