Introduction
This section brings together all components into a complete, production-ready training loop for translation: data loading, forward pass, loss computation, backpropagation, optimization, validation, and logging.
Training Components Overview
Required Components
πtext
1βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2β TRAINING PIPELINE β
3βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
4β β
5β DATA MODEL OPTIMIZATION β
6β ββββ βββββ ββββββββββββ β
7β DataLoader Transformer Adam Optimizer β
8β Collator Encoder LR Scheduler β
9β Tokenizer Decoder Grad Clipping β
10β β
11β LOSS LOGGING CHECKPOINTING β
12β ββββ βββββββ βββββββββββββ β
13β Label Smoothing TensorBoard Save/Load β
14β Padding Mask Progress Bar Best Model β
15β Metrics Metrics Resume β
16β β
17βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββTrainer Class
Complete Implementation
πpython
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import DataLoader
5from typing import Optional, Dict, Any, Callable
6import time
7import os
8from pathlib import Path
9
10
11class TranslationTrainer:
12 """
13 Complete trainer for translation models.
14
15 Handles:
16 - Training loop with gradient clipping
17 - Validation with metrics
18 - Learning rate scheduling
19 - Checkpointing and resuming
20 - Logging (console and TensorBoard)
21
22 Args:
23 model: Transformer model
24 train_loader: Training data loader
25 val_loader: Validation data loader
26 optimizer: Optimizer instance
27 scheduler: LR scheduler
28 criterion: Loss function
29 config: Training configuration dict
30 device: Training device
31 """
32
33 def __init__(
34 self,
35 model: nn.Module,
36 train_loader: DataLoader,
37 val_loader: DataLoader,
38 optimizer: optim.Optimizer,
39 scheduler,
40 criterion: nn.Module,
41 config: Dict[str, Any],
42 device: torch.device
43 ):
44 self.model = model.to(device)
45 self.train_loader = train_loader
46 self.val_loader = val_loader
47 self.optimizer = optimizer
48 self.scheduler = scheduler
49 self.criterion = criterion
50 self.config = config
51 self.device = device
52
53 # Training state
54 self.global_step = 0
55 self.epoch = 0
56 self.best_val_loss = float('inf')
57
58 # Logging
59 self.log_interval = config.get('log_interval', 100)
60 self.save_dir = Path(config.get('save_dir', 'checkpoints'))
61 self.save_dir.mkdir(parents=True, exist_ok=True)
62
63 # Optional TensorBoard
64 self.writer = None
65 if config.get('use_tensorboard', False):
66 try:
67 from torch.utils.tensorboard import SummaryWriter
68 self.writer = SummaryWriter(self.save_dir / 'logs')
69 except ImportError:
70 print("TensorBoard not available")
71
72 def train_epoch(self) -> Dict[str, float]:
73 """
74 Train for one epoch.
75
76 Returns:
77 Dictionary of training metrics
78 """
79 self.model.train()
80
81 total_loss = 0
82 total_tokens = 0
83 total_correct = 0
84 num_batches = 0
85
86 epoch_start = time.time()
87 batch_start = time.time()
88
89 for batch_idx, batch in enumerate(self.train_loader):
90 # Move to device
91 source_ids = batch['source_ids'].to(self.device)
92 target_ids = batch['target_ids'].to(self.device)
93
94 # Forward pass
95 self.optimizer.zero_grad()
96
97 # Model expects target[:, :-1] as input
98 logits = self.model(source_ids, target_ids[:, :-1])
99
100 # Compute loss
101 loss_output = self.criterion(logits, target_ids)
102 loss = loss_output['loss']
103
104 # Backward pass
105 loss.backward()
106
107 # Gradient clipping
108 grad_norm = torch.nn.utils.clip_grad_norm_(
109 self.model.parameters(),
110 self.config.get('max_grad_norm', 1.0)
111 )
112
113 # Optimizer step
114 self.optimizer.step()
115 self.scheduler.step()
116
117 # Update stats
118 self.global_step += 1
119 total_loss += loss.item() * loss_output['num_tokens'].item()
120 total_tokens += loss_output['num_tokens'].item()
121 total_correct += (
122 loss_output['accuracy'].item() *
123 loss_output['num_tokens'].item()
124 )
125 num_batches += 1
126
127 # Logging
128 if self.global_step % self.log_interval == 0:
129 elapsed = time.time() - batch_start
130 tokens_per_sec = loss_output['num_tokens'].item() / elapsed
131
132 lr = self.optimizer.param_groups[0]['lr']
133
134 print(
135 f"Step {self.global_step:6d} | "
136 f"Loss {loss.item():.4f} | "
137 f"LR {lr:.6f} | "
138 f"Grad {grad_norm:.2f} | "
139 f"Tok/s {tokens_per_sec:.0f}"
140 )
141
142 if self.writer:
143 self.writer.add_scalar('train/loss', loss.item(), self.global_step)
144 self.writer.add_scalar('train/lr', lr, self.global_step)
145 self.writer.add_scalar('train/grad_norm', grad_norm, self.global_step)
146
147 batch_start = time.time()
148
149 # Epoch stats
150 avg_loss = total_loss / total_tokens
151 accuracy = total_correct / total_tokens
152 epoch_time = time.time() - epoch_start
153
154 return {
155 'loss': avg_loss,
156 'perplexity': torch.exp(torch.tensor(avg_loss)).item(),
157 'accuracy': accuracy,
158 'time': epoch_time,
159 'tokens_per_sec': total_tokens / epoch_time
160 }
161
162 @torch.no_grad()
163 def validate(self) -> Dict[str, float]:
164 """
165 Run validation.
166
167 Returns:
168 Dictionary of validation metrics
169 """
170 self.model.eval()
171
172 total_loss = 0
173 total_tokens = 0
174 total_correct = 0
175
176 for batch in self.val_loader:
177 source_ids = batch['source_ids'].to(self.device)
178 target_ids = batch['target_ids'].to(self.device)
179
180 logits = self.model(source_ids, target_ids[:, :-1])
181 loss_output = self.criterion(logits, target_ids)
182
183 total_loss += loss_output['loss'].item() * loss_output['num_tokens'].item()
184 total_tokens += loss_output['num_tokens'].item()
185 total_correct += (
186 loss_output['accuracy'].item() *
187 loss_output['num_tokens'].item()
188 )
189
190 avg_loss = total_loss / total_tokens
191 accuracy = total_correct / total_tokens
192
193 metrics = {
194 'loss': avg_loss,
195 'perplexity': torch.exp(torch.tensor(avg_loss)).item(),
196 'accuracy': accuracy
197 }
198
199 if self.writer:
200 self.writer.add_scalar('val/loss', avg_loss, self.global_step)
201 self.writer.add_scalar('val/perplexity', metrics['perplexity'], self.global_step)
202
203 return metrics
204
205 def save_checkpoint(self, path: Optional[str] = None, is_best: bool = False):
206 """Save training checkpoint."""
207 checkpoint = {
208 'epoch': self.epoch,
209 'global_step': self.global_step,
210 'model_state_dict': self.model.state_dict(),
211 'optimizer_state_dict': self.optimizer.state_dict(),
212 'scheduler_state_dict': self.scheduler.state_dict(),
213 'best_val_loss': self.best_val_loss,
214 'config': self.config
215 }
216
217 if path is None:
218 path = self.save_dir / f'checkpoint_epoch_{self.epoch}.pt'
219
220 torch.save(checkpoint, path)
221
222 if is_best:
223 best_path = self.save_dir / 'best_model.pt'
224 torch.save(checkpoint, best_path)
225 print(f"Saved best model to {best_path}")
226
227 def load_checkpoint(self, path: str):
228 """Load training checkpoint."""
229 checkpoint = torch.load(path, map_location=self.device)
230
231 self.model.load_state_dict(checkpoint['model_state_dict'])
232 self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
233 self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
234 self.epoch = checkpoint['epoch']
235 self.global_step = checkpoint['global_step']
236 self.best_val_loss = checkpoint['best_val_loss']
237
238 print(f"Loaded checkpoint from epoch {self.epoch}")
239
240 def train(self, num_epochs: int, resume_from: Optional[str] = None):
241 """
242 Full training loop.
243
244 Args:
245 num_epochs: Number of epochs to train
246 resume_from: Path to checkpoint to resume from
247 """
248 if resume_from:
249 self.load_checkpoint(resume_from)
250
251 print(f"Training for {num_epochs} epochs")
252 print(f"Training samples: {len(self.train_loader.dataset)}")
253 print(f"Validation samples: {len(self.val_loader.dataset)}")
254 print(f"Batch size: {self.train_loader.batch_size}")
255 print(f"Device: {self.device}")
256 print("=" * 60)
257
258 for epoch in range(self.epoch, num_epochs):
259 self.epoch = epoch
260
261 print(f"\nEpoch {epoch + 1}/{num_epochs}")
262 print("-" * 40)
263
264 # Train
265 train_metrics = self.train_epoch()
266 print(
267 f"Train - Loss: {train_metrics['loss']:.4f}, "
268 f"PPL: {train_metrics['perplexity']:.2f}, "
269 f"Acc: {train_metrics['accuracy']:.4f}"
270 )
271
272 # Validate
273 val_metrics = self.validate()
274 print(
275 f"Val - Loss: {val_metrics['loss']:.4f}, "
276 f"PPL: {val_metrics['perplexity']:.2f}, "
277 f"Acc: {val_metrics['accuracy']:.4f}"
278 )
279
280 # Check if best
281 is_best = val_metrics['loss'] < self.best_val_loss
282 if is_best:
283 self.best_val_loss = val_metrics['loss']
284 print(f"New best validation loss: {self.best_val_loss:.4f}")
285
286 # Save checkpoint
287 self.save_checkpoint(is_best=is_best)
288
289 print("\nTraining complete!")
290 print(f"Best validation loss: {self.best_val_loss:.4f}")
291
292 if self.writer:
293 self.writer.close()Training Configuration
Config Class
πpython
1from dataclasses import dataclass, field
2from typing import Optional
3
4
5@dataclass
6class TrainingConfig:
7 """
8 Complete training configuration.
9 """
10 # Model
11 src_vocab_size: int = 32000
12 tgt_vocab_size: int = 32000
13 d_model: int = 512
14 num_heads: int = 8
15 num_layers: int = 6
16 d_ff: int = 2048
17 dropout: float = 0.1
18
19 # Data
20 data_dir: str = "data/multi30k"
21 max_source_len: int = 128
22 max_target_len: int = 128
23 batch_size: int = 32
24 num_workers: int = 4
25
26 # Training
27 num_epochs: int = 30
28 warmup_steps: int = 4000
29 label_smoothing: float = 0.1
30 max_grad_norm: float = 1.0
31
32 # Optimizer
33 learning_rate: float = 1.0
34 adam_beta1: float = 0.9
35 adam_beta2: float = 0.98
36 adam_eps: float = 1e-9
37 weight_decay: float = 0.0
38
39 # Logging
40 log_interval: int = 100
41 save_dir: str = "checkpoints"
42 use_tensorboard: bool = True
43
44 # Device
45 device: str = "cuda"
46
47 def to_dict(self) -> dict:
48 """Convert to dictionary."""
49 return {k: getattr(self, k) for k in self.__dataclass_fields__}Training Script
Main Script Structure
πpython
1def main():
2 """
3 Main training script.
4 """
5 import argparse
6
7 parser = argparse.ArgumentParser(description='Train translation model')
8 parser.add_argument('--config', type=str, default='base')
9 parser.add_argument('--data-dir', type=str, default='data/multi30k')
10 parser.add_argument('--save-dir', type=str, default='checkpoints')
11 parser.add_argument('--resume', type=str, default=None)
12 parser.add_argument('--epochs', type=int, default=30)
13 args = parser.parse_args()
14
15 # Load config
16 config = TrainingConfig()
17 config.data_dir = args.data_dir
18 config.save_dir = args.save_dir
19 config.num_epochs = args.epochs
20
21 # Set device
22 device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
23 print(f"Using device: {device}")
24
25 # Training steps would go here:
26 # 1. Load tokenizer
27 # 2. Create data module
28 # 3. Create model
29 # 4. Create optimizer and scheduler
30 # 5. Create criterion
31 # 6. Create trainer
32 # 7. Train
33
34 print("Training script structure shown above.")
35 print("Run with: python train.py --config base --epochs 30")Monitoring Training
Key Metrics
πtext
1KEY METRICS TO WATCH:
2βββββββββββββββββββββ
3
41. LOSS
5 - Should decrease steadily
6 - Validation loss should follow training loss
7 - Large gap = overfitting
8
92. PERPLEXITY
10 - exp(cross-entropy loss)
11 - Interpretable as "average branching factor"
12 - Good translation: PPL < 10
13
143. LEARNING RATE
15 - Should follow expected schedule
16 - Peak around warmup_steps
17
184. GRADIENT NORM
19 - Should be stable (not exploding)
20 - Frequent clipping = LR too high
21
225. TOKEN ACCURACY
23 - Percentage of correct tokens
24 - Should increase over training
25
266. TOKENS PER SECOND
27 - Training speed metric
28 - Should be consistent
29
30
31WARNING SIGNS:
32ββββββββββββββ
33
34β Loss not decreasing
35 β LR too low or too high
36
37β Loss exploding (NaN)
38 β LR too high, gradients exploding
39 β Reduce LR, increase warmup
40
41β Val loss increasing while train decreases
42 β Overfitting
43 β Increase dropout, use more data
44
45β Gradient norm very high
46 β Increase gradient clipping
47 β Reduce learning rate
48
49β Tokens/sec very low
50 β Check data loading (num_workers)
51 β Check batch size
52
53
54TENSORBOARD COMMANDS:
55ββββββββββββββββββββ
56
57# Launch TensorBoard
58tensorboard --logdir=checkpoints/logs
59
60# View in browser
61http://localhost:6006Common Training Issues
Troubleshooting
Problem: Loss doesn't decrease
Possible causes:
- Learning rate too low or too high
- Model too small for task
- Data preprocessing error
Solutions:
- Try different learning rates (1e-5 to 1e-3)
- Check data pipeline, inspect batches
- Verify tokenization is correct
Problem: NaN loss
Possible causes:
- Learning rate too high
- Gradient explosion
- Numerical instability
Solutions:
- Reduce learning rate
- Increase warmup steps
- Add gradient clipping (max_norm=1.0)
- Use mixed precision carefully
Problem: Overfitting (train << val loss)
Possible causes:
- Model too large for data
- Not enough regularization
- Training too long
Solutions:
- Increase dropout
- Use label smoothing
- Add weight decay
- Early stopping
Problem: Training very slow
Possible causes:
- Data loading bottleneck
- Small batch size
- Not using GPU
Solutions:
- Increase num_workers
- Use pin_memory=True
- Dynamic batching
- Mixed precision training
Summary
Training Pipeline Components
| Component | Purpose |
|---|---|
| DataLoader | Load and batch data |
| Model | Forward pass |
| Criterion | Compute loss |
| Optimizer | Update weights |
| Scheduler | Adjust learning rate |
| Gradient clipping | Stabilize training |
| Checkpointing | Save/resume training |
Training Loop Steps
- Load batch
- Move to device
- Forward pass
- Compute loss
- Backward pass
- Clip gradients
- Optimizer step
- Scheduler step
- Log metrics
- Validate periodically
- Save checkpoints
Exercises
Implementation
- Add early stopping based on validation loss.
- Implement mixed precision training (FP16).
- Add distributed training support.
Analysis
- Profile training to find bottlenecks.
- Compare different optimizers (Adam vs AdamW vs SGD).
In the final section, we'll cover Checkpointing and Model Selectionβbest practices for saving models, selecting the best checkpoint, and deploying trained models.