Introduction
This section provides the complete training script that brings together all components to train the German-English translation model on Multi30k.
2.1 Training Loop Implementation
The Trainer Class
🐍python
1import torch
2import torch.nn as nn
3from torch.utils.data import DataLoader
4from typing import Dict, Any, Optional
5from pathlib import Path
6import time
7import json
8from tqdm import tqdm
9
10
11class Trainer:
12 """
13 Complete trainer for the translation model.
14
15 Handles:
16 - Training loop with gradient accumulation
17 - Validation
18 - Checkpointing
19 - Logging
20 - Mixed precision training
21
22 Args:
23 model: Transformer model
24 optimizer: Optimizer
25 scheduler: LR scheduler
26 criterion: Loss function
27 config: TrainingConfig
28 scaler: GradScaler for mixed precision
29 """
30
31 def __init__(
32 self,
33 model: nn.Module,
34 optimizer: torch.optim.Optimizer,
35 scheduler,
36 criterion: nn.Module,
37 config,
38 scaler: Optional[torch.cuda.amp.GradScaler] = None
39 ):
40 self.model = model
41 self.optimizer = optimizer
42 self.scheduler = scheduler
43 self.criterion = criterion
44 self.config = config
45 self.scaler = scaler
46
47 self.device = torch.device(config.device)
48 self.global_step = 0
49 self.best_val_loss = float('inf')
50
51 # Create directories
52 Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)
53 Path(config.log_dir).mkdir(parents=True, exist_ok=True)
54
55 # Metrics tracking
56 self.train_losses = []
57 self.val_losses = []
58 self.learning_rates = []
59
60 def train_epoch(
61 self,
62 train_loader: DataLoader,
63 epoch: int
64 ) -> Dict[str, float]:
65 """
66 Train for one epoch.
67
68 Args:
69 train_loader: Training data loader
70 epoch: Current epoch number
71
72 Returns:
73 Dictionary of epoch metrics
74 """
75 self.model.train()
76 total_loss = 0.0
77 total_tokens = 0
78 num_batches = 0
79
80 # Progress bar
81 pbar = tqdm(
82 train_loader,
83 desc=f"Epoch {epoch}",
84 leave=True
85 )
86
87 # Gradient accumulation
88 accumulation_counter = 0
89 self.optimizer.zero_grad()
90
91 for batch in pbar:
92 # Move to device
93 source = batch['source'].to(self.device)
94 source_mask = batch['source_mask'].to(self.device)
95 target_input = batch['target_input'].to(self.device)
96 target_output = batch['target_output'].to(self.device)
97
98 # Forward pass
99 with torch.cuda.amp.autocast(enabled=self.scaler is not None):
100 logits = self.model(
101 source,
102 target_input,
103 src_mask=source_mask
104 )
105
106 # Compute loss
107 loss = self.criterion(
108 logits.reshape(-1, logits.size(-1)),
109 target_output.reshape(-1)
110 )
111
112 # Scale for gradient accumulation
113 loss = loss / self.config.accumulation_steps
114
115 # Backward pass
116 if self.scaler is not None:
117 self.scaler.scale(loss).backward()
118 else:
119 loss.backward()
120
121 accumulation_counter += 1
122
123 # Optimizer step
124 if accumulation_counter >= self.config.accumulation_steps:
125 # Gradient clipping
126 if self.scaler is not None:
127 self.scaler.unscale_(self.optimizer)
128
129 torch.nn.utils.clip_grad_norm_(
130 self.model.parameters(),
131 self.config.gradient_clip
132 )
133
134 # Update weights
135 if self.scaler is not None:
136 self.scaler.step(self.optimizer)
137 self.scaler.update()
138 else:
139 self.optimizer.step()
140
141 self.scheduler.step()
142 self.optimizer.zero_grad()
143
144 self.global_step += 1
145 accumulation_counter = 0
146
147 # Track metrics
148 batch_loss = loss.item() * self.config.accumulation_steps
149 num_tokens = (target_output != self.config.pad_id).sum().item()
150 total_loss += batch_loss * num_tokens
151 total_tokens += num_tokens
152 num_batches += 1
153
154 # Update progress bar
155 current_lr = self.optimizer.param_groups[0]['lr']
156 pbar.set_postfix({
157 'loss': f'{batch_loss:.4f}',
158 'lr': f'{current_lr:.2e}',
159 })
160
161 # Logging
162 if self.global_step % self.config.log_every == 0:
163 self.learning_rates.append(current_lr)
164
165 # Epoch metrics
166 avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
167
168 return {
169 'train_loss': avg_loss,
170 'perplexity': torch.exp(torch.tensor(avg_loss)).item(),
171 'num_batches': num_batches,
172 'global_step': self.global_step,
173 }
174
175 @torch.no_grad()
176 def validate(
177 self,
178 val_loader: DataLoader
179 ) -> Dict[str, float]:
180 """
181 Validate the model.
182
183 Args:
184 val_loader: Validation data loader
185
186 Returns:
187 Dictionary of validation metrics
188 """
189 self.model.eval()
190 total_loss = 0.0
191 total_tokens = 0
192 total_correct = 0
193
194 for batch in val_loader:
195 # Move to device
196 source = batch['source'].to(self.device)
197 source_mask = batch['source_mask'].to(self.device)
198 target_input = batch['target_input'].to(self.device)
199 target_output = batch['target_output'].to(self.device)
200
201 # Forward pass
202 with torch.cuda.amp.autocast(enabled=self.scaler is not None):
203 logits = self.model(
204 source,
205 target_input,
206 src_mask=source_mask
207 )
208
209 loss = self.criterion(
210 logits.reshape(-1, logits.size(-1)),
211 target_output.reshape(-1)
212 )
213
214 # Track metrics
215 num_tokens = (target_output != self.config.pad_id).sum().item()
216 total_loss += loss.item() * num_tokens
217 total_tokens += num_tokens
218
219 # Accuracy
220 predictions = logits.argmax(dim=-1)
221 mask = (target_output != self.config.pad_id)
222 correct = ((predictions == target_output) & mask).sum().item()
223 total_correct += correct
224
225 # Compute metrics
226 avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
227 accuracy = total_correct / total_tokens if total_tokens > 0 else 0
228
229 return {
230 'val_loss': avg_loss,
231 'val_perplexity': torch.exp(torch.tensor(avg_loss)).item(),
232 'val_accuracy': accuracy,
233 }
234
235 def save_checkpoint(
236 self,
237 epoch: int,
238 metrics: Dict[str, float],
239 is_best: bool = False
240 ):
241 """
242 Save training checkpoint.
243 """
244 checkpoint = {
245 'epoch': epoch,
246 'global_step': self.global_step,
247 'model_state_dict': self.model.state_dict(),
248 'optimizer_state_dict': self.optimizer.state_dict(),
249 'scheduler_state_dict': self.scheduler.state_dict(),
250 'metrics': metrics,
251 'best_val_loss': self.best_val_loss,
252 }
253
254 if self.scaler is not None:
255 checkpoint['scaler_state_dict'] = self.scaler.state_dict()
256
257 # Save regular checkpoint
258 checkpoint_path = Path(self.config.checkpoint_dir) / f"checkpoint_epoch{epoch}.pt"
259 torch.save(checkpoint, checkpoint_path)
260
261 # Save best model
262 if is_best:
263 best_path = Path(self.config.checkpoint_dir) / "best_model.pt"
264 torch.save(checkpoint, best_path)
265 print(f" New best model saved! val_loss: {metrics['val_loss']:.4f}")
266
267 def load_checkpoint(self, checkpoint_path: str) -> int:
268 """
269 Load checkpoint and return starting epoch.
270 """
271 checkpoint = torch.load(checkpoint_path, map_location=self.device)
272
273 self.model.load_state_dict(checkpoint['model_state_dict'])
274 self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
275 self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
276
277 if self.scaler is not None and 'scaler_state_dict' in checkpoint:
278 self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
279
280 self.global_step = checkpoint['global_step']
281 self.best_val_loss = checkpoint['best_val_loss']
282
283 print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
284
285 return checkpoint['epoch'] + 1
286
287 def save_metrics(self):
288 """Save training metrics to file."""
289 metrics = {
290 'train_losses': self.train_losses,
291 'val_losses': self.val_losses,
292 'learning_rates': self.learning_rates,
293 }
294
295 metrics_path = Path(self.config.log_dir) / "training_metrics.json"
296 with open(metrics_path, 'w') as f:
297 json.dump(metrics, f, indent=2)
298
299
300def create_trainer(components: Dict[str, Any]) -> Trainer:
301 """
302 Create trainer from setup components.
303 """
304 return Trainer(
305 model=components['model'],
306 optimizer=components['optimizer'],
307 scheduler=components['scheduler'],
308 criterion=components['criterion'],
309 config=components['config'].training,
310 scaler=components['scaler'],
311 )2.2 Main Training Script
Complete train.py
🐍python
1#!/usr/bin/env python3
2"""
3train.py - Train German-English translation model on Multi30k
4
5Usage:
6 python train.py --config configs/base.json
7 python train.py --model-size small --epochs 10
8"""
9
10import argparse
11import torch
12import random
13import numpy as np
14from pathlib import Path
15from datetime import datetime
16
17
18def set_seed(seed: int):
19 """Set all random seeds."""
20 random.seed(seed)
21 np.random.seed(seed)
22 torch.manual_seed(seed)
23 if torch.cuda.is_available():
24 torch.cuda.manual_seed_all(seed)
25
26
27def parse_args():
28 """Parse command line arguments."""
29 parser = argparse.ArgumentParser(description="Train translation model")
30
31 # Configuration
32 parser.add_argument("--config", type=str, help="Path to config file")
33 parser.add_argument("--model-size", type=str, default="base",
34 choices=["tiny", "small", "base", "large"])
35
36 # Overrides
37 parser.add_argument("--epochs", type=int, help="Number of epochs")
38 parser.add_argument("--lr", type=float, help="Learning rate")
39 parser.add_argument("--batch-tokens", type=int, help="Max tokens per batch")
40
41 # Paths
42 parser.add_argument("--data-dir", type=str, default="data/multi30k")
43 parser.add_argument("--output-dir", type=str, default="outputs")
44
45 # Resume
46 parser.add_argument("--resume", type=str, help="Path to checkpoint to resume")
47
48 # Device
49 parser.add_argument("--device", type=str, default="cuda")
50 parser.add_argument("--no-amp", action="store_true", help="Disable mixed precision")
51
52 return parser.parse_args()
53
54
55def main():
56 args = parse_args()
57
58 # Setup experiment name
59 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
60 exp_name = f"{args.model_size}_{timestamp}"
61
62 output_dir = Path(args.output_dir) / exp_name
63 output_dir.mkdir(parents=True, exist_ok=True)
64
65 print(f"Experiment: {exp_name}")
66 print(f"Output directory: {output_dir}")
67
68 # Configuration
69 model_config = get_model_config(args.model_size)
70
71 training_config = TrainingConfig(
72 num_epochs=args.epochs or 30,
73 learning_rate=args.lr or 1e-4,
74 max_tokens=args.batch_tokens or 4096,
75 checkpoint_dir=str(output_dir / "checkpoints"),
76 log_dir=str(output_dir / "logs"),
77 device=args.device if torch.cuda.is_available() else "cpu",
78 mixed_precision=not args.no_amp and torch.cuda.is_available(),
79 )
80
81 # Set seed
82 seed = 42
83 set_seed(seed)
84
85 print(f"\nDevice: {training_config.device}")
86 print(f"Mixed precision: {training_config.mixed_precision}")
87
88 # Setup data
89 print("\nLoading data...")
90 data_config = DataConfig(
91 data_dir=args.data_dir,
92 tokenizer_path=Path(args.data_dir) / "tokenizer" / "tokenizer.json",
93 max_tokens=training_config.max_tokens,
94 )
95
96 data_module = Multi30kDataModule(data_config)
97 data_module.setup()
98
99 # Update model config with tokenizer info
100 model_config.vocab_size = data_module.vocab_size
101 model_config.pad_id = data_module.pad_id
102
103 # Build model
104 print("\nBuilding model...")
105 model = build_model(model_config)
106 model = model.to(training_config.device)
107
108 num_params = count_parameters(model)
109 print(f"Model parameters: {num_params:,}")
110
111 # Setup training components
112 print("\nSetting up training...")
113 optimizer = setup_optimizer(model, training_config)
114
115 train_loader = data_module.train_dataloader()
116 total_steps = len(train_loader) * training_config.num_epochs
117
118 scheduler = setup_scheduler(optimizer, training_config, total_steps)
119
120 criterion = setup_criterion(
121 training_config,
122 model_config.vocab_size,
123 model_config.pad_id
124 )
125
126 scaler = None
127 if training_config.mixed_precision:
128 scaler = torch.cuda.amp.GradScaler()
129
130 # Create trainer
131 trainer = Trainer(
132 model=model,
133 optimizer=optimizer,
134 scheduler=scheduler,
135 criterion=criterion,
136 config=training_config,
137 scaler=scaler,
138 )
139
140 # Resume if specified
141 start_epoch = 0
142 if args.resume:
143 start_epoch = trainer.load_checkpoint(args.resume)
144
145 # Save config
146 config_path = output_dir / "config.json"
147 with open(config_path, 'w') as f:
148 json.dump({
149 'model': model_config.to_dict(),
150 'training': training_config.to_dict(),
151 'seed': seed,
152 'num_params': num_params,
153 }, f, indent=2)
154
155 # Training loop
156 print("\n" + "=" * 60)
157 print("Starting training...")
158 print("=" * 60)
159
160 val_loader = data_module.val_dataloader()
161
162 for epoch in range(start_epoch, training_config.num_epochs):
163 print(f"\n--- Epoch {epoch + 1}/{training_config.num_epochs} ---")
164
165 # Train
166 train_metrics = trainer.train_epoch(train_loader, epoch + 1)
167
168 # Validate
169 val_metrics = trainer.validate(val_loader)
170
171 # Track metrics
172 trainer.train_losses.append(train_metrics['train_loss'])
173 trainer.val_losses.append(val_metrics['val_loss'])
174
175 # Print metrics
176 print(f"\nEpoch {epoch + 1} Results:")
177 print(f" Train loss: {train_metrics['train_loss']:.4f}")
178 print(f" Train PPL: {train_metrics['perplexity']:.2f}")
179 print(f" Val loss: {val_metrics['val_loss']:.4f}")
180 print(f" Val PPL: {val_metrics['val_perplexity']:.2f}")
181 print(f" Val Acc: {val_metrics['val_accuracy']:.4f}")
182
183 # Check for best model
184 is_best = val_metrics['val_loss'] < trainer.best_val_loss
185 if is_best:
186 trainer.best_val_loss = val_metrics['val_loss']
187
188 # Save checkpoint
189 metrics = {**train_metrics, **val_metrics}
190 trainer.save_checkpoint(epoch + 1, metrics, is_best)
191
192 # Save final metrics
193 trainer.save_metrics()
194
195 print("\n" + "=" * 60)
196 print("Training complete!")
197 print(f"Best validation loss: {trainer.best_val_loss:.4f}")
198 print(f"Checkpoints saved to: {training_config.checkpoint_dir}")
199 print("=" * 60)
200
201
202if __name__ == "__main__":
203 main()2.3 Training Progress Monitoring
Live Metrics Visualization
📝text
1EXPECTED OUTPUT DURING TRAINING:
2─────────────────────────────────
3
4Experiment: base_20240115_143022
5Output directory: outputs/base_20240115_143022
6
7Device: cuda
8Mixed precision: True
9
10Loading data...
11Loaded 29000 sentence pairs
12Loaded 1014 sentence pairs
13Loaded 1000 sentence pairs
14
15Building model...
16Model parameters: 65,432,576
17
18Setting up training...
19
20============================================================
21Starting training...
22============================================================
23
24--- Epoch 1/30 ---
25Epoch 1: 100%|████████████████| 453/453 [02:15<00:00, loss=6.2341, lr=2.34e-06]
26
27Epoch 1 Results:
28 Train loss: 6.2341
29 Train PPL: 509.23
30 Val loss: 5.1234
31 Val PPL: 167.89
32 Val Acc: 0.2145
33 New best model saved! val_loss: 5.1234
34
35--- Epoch 2/30 ---
36Epoch 2: 100%|████████████████| 453/453 [02:12<00:00, loss=4.8765, lr=4.68e-06]
37
38Epoch 2 Results:
39 Train loss: 4.8765
40 Train PPL: 131.12
41 Val loss: 4.2341
42 Val PPL: 68.93
43 Val Acc: 0.3456
44 New best model saved! val_loss: 4.2341
45
46...
47
48--- Epoch 30/30 ---
49Epoch 30: 100%|████████████████| 453/453 [02:10<00:00, loss=2.1234, lr=1.23e-05]
50
51Epoch 30 Results:
52 Train loss: 2.1234
53 Train PPL: 8.36
54 Val loss: 2.3456
55 Val PPL: 10.44
56 Val Acc: 0.6789
57
58============================================================
59Training complete!
60Best validation loss: 2.2891
61Checkpoints saved to: outputs/base_20240115_143022/checkpoints
62============================================================2.4 Early Stopping
Implementing Early Stopping
🐍python
1class EarlyStopping:
2 """
3 Early stopping to prevent overfitting.
4
5 Monitors validation metric and stops training
6 if no improvement for `patience` epochs.
7
8 Args:
9 patience: Number of epochs to wait
10 min_delta: Minimum change to qualify as improvement
11 mode: 'min' or 'max' (minimize or maximize metric)
12 """
13
14 def __init__(
15 self,
16 patience: int = 5,
17 min_delta: float = 0.0,
18 mode: str = 'min'
19 ):
20 self.patience = patience
21 self.min_delta = min_delta
22 self.mode = mode
23
24 self.counter = 0
25 self.best_score = None
26 self.should_stop = False
27
28 def __call__(self, metric: float) -> bool:
29 """
30 Check if training should stop.
31
32 Args:
33 metric: Current metric value
34
35 Returns:
36 True if should stop, False otherwise
37 """
38 if self.best_score is None:
39 self.best_score = metric
40 return False
41
42 if self.mode == 'min':
43 improved = metric < self.best_score - self.min_delta
44 else:
45 improved = metric > self.best_score + self.min_delta
46
47 if improved:
48 self.best_score = metric
49 self.counter = 0
50 else:
51 self.counter += 1
52
53 if self.counter >= self.patience:
54 self.should_stop = True
55
56 return self.should_stop
57
58
59# Add to training loop
60
61early_stopping = EarlyStopping(
62 patience=5,
63 min_delta=0.01,
64 mode='min'
65)
66
67for epoch in range(num_epochs):
68 # Train and validate
69 train_metrics = trainer.train_epoch(train_loader, epoch)
70 val_metrics = trainer.validate(val_loader)
71
72 # Check early stopping
73 if early_stopping(val_metrics['val_loss']):
74 print(f"Early stopping at epoch {epoch + 1}")
75 print(f"Best validation loss: {early_stopping.best_score:.4f}")
76 break
77
78 # Save if best
79 if val_metrics['val_loss'] == early_stopping.best_score:
80 trainer.save_checkpoint(epoch, val_metrics, is_best=True)2.5 Training Tips and Best Practices
Recommendations
📝text
1LEARNING RATE:
2──────────────
3- Start with 1e-4 for Adam/AdamW
4- Use warmup (4000 steps for base model)
5- Monitor for loss spikes (reduce LR if unstable)
6
7BATCH SIZE:
8───────────
9- Use max_tokens instead of fixed batch_size
10- 4096 tokens works well for base model
11- Larger batch → more stable gradients
12- Use gradient accumulation if GPU memory limited
13
14REGULARIZATION:
15───────────────
16- Label smoothing 0.1 (standard)
17- Dropout 0.1 (increase for smaller datasets)
18- Weight decay 0.01
19
20GRADIENT CLIPPING:
21──────────────────
22- Always clip (max_norm=1.0)
23- Prevents gradient explosion
24- Monitor gradient norms
25
26MIXED PRECISION:
27────────────────
28- Use fp16 on GPU (2x speedup)
29- Use GradScaler to prevent underflow
30- Some operations stay in fp32 automatically
31
32CHECKPOINTING:
33──────────────
34- Save every epoch (at minimum)
35- Keep best N models
36- Include optimizer state for resuming
37- Save metrics for analysis
38
39
40DEBUGGING CHECKLIST:
41────────────────────
42
43☐ Loss decreasing?
44 → If not: check LR, data, model architecture
45
46☐ Gradients flowing?
47 → Monitor gradient norms
48
49☐ Memory stable?
50 → Watch for memory leaks with long sequences
51
52☐ Training speed reasonable?
53 → Profile if slow
54
55☐ Validation improving?
56 → If train good but val bad: overfitting
57
58
59EXPECTED TIMELINE (Multi30k base model):
60────────────────────────────────────────
61
62Epoch 1: Val loss ~5.0, PPL ~150
63Epoch 5: Val loss ~3.0, PPL ~20
64Epoch 10: Val loss ~2.5, PPL ~12
65Epoch 20: Val loss ~2.3, PPL ~10
66Epoch 30: Val loss ~2.2, PPL ~9
67
68After convergence:
69 BLEU on test: ~30-35Summary
Training Components
| Component | Purpose |
|---|---|
| Trainer class | Main training logic |
| train_epoch() | Single epoch training |
| validate() | Evaluation without gradients |
| save_checkpoint() | Model persistence |
| EarlyStopping | Prevent overfitting |
Key Settings for Multi30k
🐍python
1TrainingConfig(
2 num_epochs=30,
3 max_tokens=4096,
4 learning_rate=1e-4,
5 warmup_steps=4000,
6 gradient_clip=1.0,
7 label_smoothing=0.1,
8 mixed_precision=True,
9)Exercises
Implementation
- Add TensorBoard logging for metrics visualization.
- Implement learning rate finder to determine optimal LR.
- Add distributed training support with PyTorch DDP.
Experimentation
- Train with different learning rates and compare convergence.
- Compare early stopping patience values.
Next Section Preview: In the next section, we'll cover Training Monitoring and Debugging—how to diagnose and fix common training issues.