Introduction
This section covers how to use the trained model for translation. We'll implement efficient inference with beam search and create a complete translation pipeline.
Loading the Trained Model
Model Loading Utilities
๐python
1import torch
2import torch.nn as nn
3from typing import Dict, Any, Optional, List
4from pathlib import Path
5import json
6
7
8class TranslationModel:
9 """
10 Wrapper for loading and using trained translation model.
11
12 Handles:
13 - Loading model and tokenizer
14 - Preprocessing input
15 - Generating translations
16 - Postprocessing output
17
18 Args:
19 checkpoint_path: Path to model checkpoint
20 tokenizer_path: Path to tokenizer
21 device: Device to run on
22
23 Example:
24 >>> model = TranslationModel('best_model.pt', 'tokenizer.json')
25 >>> translation = model.translate("Der Hund lรคuft.")
26 """
27
28 def __init__(
29 self,
30 checkpoint_path: str,
31 tokenizer_path: str,
32 device: str = "cuda"
33 ):
34 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
35
36 # Load tokenizer
37 self.tokenizer = self._load_tokenizer(tokenizer_path)
38
39 # Load model
40 self.model, self.config = self._load_model(checkpoint_path)
41 self.model.eval()
42
43 print(f"Model loaded on {self.device}")
44 print(f"Vocabulary size: {len(self.tokenizer.token_to_id)}")
45
46 def _load_tokenizer(self, path: str):
47 """Load tokenizer from file."""
48 # Using our JointBPETokenizer
49 tokenizer = JointBPETokenizer.load(path)
50 return tokenizer
51
52 def _load_model(self, checkpoint_path: str):
53 """Load model from checkpoint."""
54 checkpoint = torch.load(checkpoint_path, map_location=self.device)
55
56 # Get config
57 if 'model_config' in checkpoint:
58 config = ModelConfig(**checkpoint['model_config'])
59 else:
60 # Default config
61 config = ModelConfig()
62
63 # Build model
64 model = build_model(config)
65 model.load_state_dict(checkpoint['model_state_dict'])
66 model = model.to(self.device)
67
68 return model, config
69
70 @torch.no_grad()
71 def translate(
72 self,
73 text: str,
74 beam_size: int = 5,
75 max_length: int = 128,
76 length_penalty: float = 1.0
77 ) -> str:
78 """
79 Translate a single sentence.
80
81 Args:
82 text: German text to translate
83 beam_size: Beam search width
84 max_length: Maximum output length
85 length_penalty: Length normalization factor
86
87 Returns:
88 English translation
89 """
90 # Preprocess
91 source_ids = self.tokenizer.encode(text, add_special_tokens=False)
92 source_tensor = torch.tensor([source_ids], device=self.device)
93
94 # Generate
95 output_ids = self._generate(
96 source_tensor,
97 beam_size=beam_size,
98 max_length=max_length,
99 length_penalty=length_penalty
100 )
101
102 # Decode
103 translation = self.tokenizer.decode(output_ids)
104
105 return translation
106
107 def _generate(
108 self,
109 source: torch.Tensor,
110 beam_size: int,
111 max_length: int,
112 length_penalty: float
113 ) -> List[int]:
114 """
115 Generate translation using beam search.
116 """
117 # Encode source
118 encoder_output = self.model.encode(source)
119
120 # Initialize beams
121 bos_id = self.tokenizer.bos_id
122 eos_id = self.tokenizer.eos_id
123
124 beams = [([], 0.0)] # (tokens, score)
125
126 for step in range(max_length):
127 all_candidates = []
128
129 for tokens, score in beams:
130 # Check if already finished
131 if tokens and tokens[-1] == eos_id:
132 all_candidates.append((tokens, score))
133 continue
134
135 # Prepare decoder input
136 decoder_input = [bos_id] + tokens
137 decoder_tensor = torch.tensor(
138 [decoder_input],
139 device=self.device
140 )
141
142 # Get logits for next token
143 logits = self.model.decode(
144 decoder_tensor,
145 encoder_output
146 )
147
148 # Get probabilities for last position
149 probs = torch.log_softmax(logits[0, -1], dim=-1)
150
151 # Get top-k candidates
152 top_probs, top_ids = probs.topk(beam_size)
153
154 for prob, token_id in zip(top_probs, top_ids):
155 new_tokens = tokens + [token_id.item()]
156 new_score = score + prob.item()
157
158 # Apply length penalty
159 length_norm = ((5 + len(new_tokens)) / 6) ** length_penalty
160 normalized_score = new_score / length_norm
161
162 all_candidates.append((new_tokens, normalized_score))
163
164 # Select top beams
165 all_candidates.sort(key=lambda x: x[1], reverse=True)
166 beams = all_candidates[:beam_size]
167
168 # Check if all beams finished
169 if all(b[0] and b[0][-1] == eos_id for b in beams):
170 break
171
172 # Return best beam (without EOS)
173 best_tokens = beams[0][0]
174 if best_tokens and best_tokens[-1] == eos_id:
175 best_tokens = best_tokens[:-1]
176
177 return best_tokens
178
179 @torch.no_grad()
180 def translate_batch(
181 self,
182 texts: List[str],
183 beam_size: int = 5,
184 max_length: int = 128
185 ) -> List[str]:
186 """
187 Translate multiple sentences.
188
189 Args:
190 texts: List of German sentences
191 beam_size: Beam search width
192 max_length: Maximum output length
193
194 Returns:
195 List of English translations
196 """
197 translations = []
198
199 for text in texts:
200 translation = self.translate(
201 text,
202 beam_size=beam_size,
203 max_length=max_length
204 )
205 translations.append(translation)
206
207 return translations
208
209
210def demonstrate_loading():
211 """
212 Demonstrate model loading.
213 """
214 print("Model Loading Demonstration")
215 print("=" * 60)
216
217 print("""
218 USAGE:
219 โโโโโโ
220
221 # Load model
222 model = TranslationModel(
223 checkpoint_path='checkpoints/best_model.pt',
224 tokenizer_path='data/tokenizer/tokenizer.json',
225 device='cuda'
226 )
227
228 # Single translation
229 german = "Der Hund lรคuft im Park."
230 english = model.translate(german)
231 print(f"DE: {german}")
232 print(f"EN: {english}")
233
234 # Batch translation
235 sentences = [
236 "Die Katze schlรคft auf dem Sofa.",
237 "Ein Mann liest ein Buch.",
238 "Kinder spielen im Garten.",
239 ]
240 translations = model.translate_batch(sentences)
241 """)
242
243
244demonstrate_loading()Optimized Inference
KV Caching for Speed
๐python
1class FastTranslationModel(TranslationModel):
2 """
3 Translation model with KV caching for faster inference.
4
5 Uses cached key-value pairs to avoid recomputing
6 attention for previous tokens.
7 """
8
9 @torch.no_grad()
10 def translate_with_cache(
11 self,
12 text: str,
13 beam_size: int = 1, # Greedy for simplicity with cache
14 max_length: int = 128
15 ) -> str:
16 """
17 Translate using KV caching.
18
19 More efficient for long sequences.
20 """
21 # Preprocess
22 source_ids = self.tokenizer.encode(text, add_special_tokens=False)
23 source_tensor = torch.tensor([source_ids], device=self.device)
24
25 # Encode source (once)
26 encoder_output = self.model.encode(source_tensor)
27
28 # Generate with caching
29 generated = [self.tokenizer.bos_id]
30 cache = None
31
32 for step in range(max_length):
33 # Only pass new token (and get cache)
34 if cache is None:
35 decoder_input = torch.tensor(
36 [generated],
37 device=self.device
38 )
39 else:
40 decoder_input = torch.tensor(
41 [[generated[-1]]],
42 device=self.device
43 )
44
45 # Forward with cache
46 logits, cache = self.model.decode_with_cache(
47 decoder_input,
48 encoder_output,
49 cache=cache
50 )
51
52 # Get next token (greedy)
53 next_token = logits[0, -1].argmax().item()
54
55 if next_token == self.tokenizer.eos_id:
56 break
57
58 generated.append(next_token)
59
60 # Decode (skip BOS)
61 translation = self.tokenizer.decode(generated[1:])
62
63 return translation
64
65
66def compare_inference_speed():
67 """
68 Compare inference methods.
69 """
70 print("Inference Speed Comparison")
71 print("=" * 60)
72
73 print("""
74 METHOD COMPARISON:
75 โโโโโโโโโโโโโโโโโโ
76
77 1. WITHOUT CACHING:
78 - Recomputes all previous tokens each step
79 - Time complexity: O(nยฒ) where n = output length
80 - Memory: Lower (no cache)
81
82 2. WITH KV CACHING:
83 - Caches key-value pairs from previous steps
84 - Time complexity: O(n)
85 - Memory: Higher (stores cache)
86
87 TYPICAL SPEEDUP:
88 โโโโโโโโโโโโโโโโ
89
90 Output length 10: ~2x faster with cache
91 Output length 20: ~5x faster
92 Output length 50: ~10x faster
93 Output length 100: ~20x faster
94
95 RECOMMENDATION:
96 โโโโโโโโโโโโโโโ
97
98 - Short outputs (<20 tokens): Either method
99 - Long outputs (>20 tokens): Use caching
100 - Batch processing: Consider parallel decoding
101 """)
102
103
104compare_inference_speed()Complete Inference Script
Ready-to-Run Script
๐python
1#!/usr/bin/env python3
2"""
3translate.py - Translate German text to English
4
5Usage:
6 python translate.py "Der Hund lรคuft im Park."
7 python translate.py --file input.txt --output translations.txt
8 python translate.py --interactive
9"""
10
11import argparse
12import torch
13from pathlib import Path
14import sys
15
16
17def parse_args():
18 parser = argparse.ArgumentParser(description="German-English Translation")
19
20 # Input modes
21 parser.add_argument("text", nargs="?", help="Text to translate")
22 parser.add_argument("--file", "-f", help="Input file (one sentence per line)")
23 parser.add_argument("--output", "-o", help="Output file")
24 parser.add_argument("--interactive", "-i", action="store_true",
25 help="Interactive mode")
26
27 # Model settings
28 parser.add_argument("--checkpoint", default="checkpoints/best_model.pt")
29 parser.add_argument("--tokenizer", default="data/tokenizer/tokenizer.json")
30
31 # Generation settings
32 parser.add_argument("--beam-size", type=int, default=5)
33 parser.add_argument("--max-length", type=int, default=128)
34 parser.add_argument("--length-penalty", type=float, default=1.0)
35
36 # Device
37 parser.add_argument("--device", default="cuda")
38
39 return parser.parse_args()
40
41
42def main():
43 args = parse_args()
44
45 # Load model
46 print("Loading model...")
47 model = TranslationModel(
48 checkpoint_path=args.checkpoint,
49 tokenizer_path=args.tokenizer,
50 device=args.device
51 )
52 print("Model loaded!\\n")
53
54 if args.interactive:
55 # Interactive mode
56 interactive_mode(model, args)
57
58 elif args.file:
59 # File mode
60 translate_file(model, args)
61
62 elif args.text:
63 # Single sentence
64 translation = model.translate(
65 args.text,
66 beam_size=args.beam_size,
67 max_length=args.max_length,
68 length_penalty=args.length_penalty
69 )
70 print(f"German: {args.text}")
71 print(f"English: {translation}")
72
73 else:
74 print("Please provide text to translate or use --interactive mode")
75 sys.exit(1)
76
77
78def interactive_mode(model, args):
79 """
80 Interactive translation loop.
81 """
82 print("=" * 50)
83 print("German-English Translation (Interactive Mode)")
84 print("Type 'quit' or 'exit' to stop")
85 print("=" * 50)
86 print()
87
88 while True:
89 try:
90 text = input("DE > ").strip()
91
92 if text.lower() in ['quit', 'exit', 'q']:
93 print("Goodbye!")
94 break
95
96 if not text:
97 continue
98
99 translation = model.translate(
100 text,
101 beam_size=args.beam_size,
102 max_length=args.max_length
103 )
104
105 print(f"EN > {translation}")
106 print()
107
108 except KeyboardInterrupt:
109 print("\\nGoodbye!")
110 break
111
112
113def translate_file(model, args):
114 """
115 Translate entire file.
116 """
117 input_path = Path(args.file)
118
119 if not input_path.exists():
120 print(f"File not found: {args.file}")
121 sys.exit(1)
122
123 # Read input
124 with open(input_path, 'r', encoding='utf-8') as f:
125 sentences = [line.strip() for line in f if line.strip()]
126
127 print(f"Translating {len(sentences)} sentences...")
128
129 # Translate
130 translations = []
131 for i, sentence in enumerate(sentences):
132 translation = model.translate(
133 sentence,
134 beam_size=args.beam_size,
135 max_length=args.max_length
136 )
137 translations.append(translation)
138
139 if (i + 1) % 100 == 0:
140 print(f" Processed {i + 1}/{len(sentences)}")
141
142 # Output
143 if args.output:
144 with open(args.output, 'w', encoding='utf-8') as f:
145 for trans in translations:
146 f.write(trans + '\\n')
147 print(f"Translations saved to {args.output}")
148 else:
149 for src, tgt in zip(sentences, translations):
150 print(f"DE: {src}")
151 print(f"EN: {tgt}")
152 print()
153
154
155if __name__ == "__main__":
156 main()Evaluation Script
Test Set Evaluation
๐python
1#!/usr/bin/env python3
2"""
3evaluate.py - Evaluate translation model on test set
4
5Usage:
6 python evaluate.py --checkpoint best_model.pt
7"""
8
9import argparse
10import torch
11from pathlib import Path
12import json
13
14
15def main():
16 parser = argparse.ArgumentParser()
17 parser.add_argument("--checkpoint", required=True)
18 parser.add_argument("--tokenizer", default="data/tokenizer/tokenizer.json")
19 parser.add_argument("--source", default="data/multi30k/test_2016_flickr.de")
20 parser.add_argument("--reference", default="data/multi30k/test_2016_flickr.en")
21 parser.add_argument("--output", default="evaluation_results.json")
22 parser.add_argument("--beam-size", type=int, default=5)
23 args = parser.parse_args()
24
25 # Load model
26 print("Loading model...")
27 model = TranslationModel(
28 checkpoint_path=args.checkpoint,
29 tokenizer_path=args.tokenizer
30 )
31
32 # Load test data
33 print("Loading test data...")
34 with open(args.source, 'r') as f:
35 sources = [line.strip() for line in f]
36 with open(args.reference, 'r') as f:
37 references = [line.strip() for line in f]
38
39 print(f"Test sentences: {len(sources)}")
40
41 # Translate
42 print("\\nTranslating...")
43 hypotheses = []
44
45 for i, source in enumerate(sources):
46 translation = model.translate(source, beam_size=args.beam_size)
47 hypotheses.append(translation)
48
49 if (i + 1) % 100 == 0:
50 print(f" {i + 1}/{len(sources)}")
51
52 # Compute metrics
53 print("\\nComputing metrics...")
54
55 # BLEU
56 bleu_scorer = BLEUScore()
57 for hyp, ref in zip(hypotheses, references):
58 bleu_scorer.add(hyp, [ref])
59 bleu_result = bleu_scorer.compute()
60
61 # ChrF
62 chrf_scorer = ChrFScore()
63 for hyp, ref in zip(hypotheses, references):
64 chrf_scorer.add(hyp, ref)
65 chrf_result = chrf_scorer.corpus_score()
66
67 # Print results
68 print("\\n" + "=" * 50)
69 print("EVALUATION RESULTS")
70 print("=" * 50)
71 print(f"BLEU: {bleu_result['bleu'] * 100:.2f}")
72 print(f"ChrF: {chrf_result['chrf'] * 100:.2f}")
73 print("=" * 50)
74
75 # Save results
76 results = {
77 'bleu': bleu_result['bleu'] * 100,
78 'chrf': chrf_result['chrf'] * 100,
79 'precisions': [p * 100 for p in bleu_result['precisions']],
80 'num_sentences': len(sources),
81 }
82
83 with open(args.output, 'w') as f:
84 json.dump(results, f, indent=2)
85
86 print(f"\\nResults saved to {args.output}")
87
88 # Show examples
89 print("\\nSample translations:")
90 print("-" * 50)
91
92 for i in range(min(5, len(sources))):
93 print(f"\\n[{i+1}]")
94 print(f" Source: {sources[i]}")
95 print(f" Reference: {references[i]}")
96 print(f" Hypothesis: {hypotheses[i]}")
97
98
99if __name__ == "__main__":
100 main()Summary
Inference Components
| Component | Purpose |
|---|---|
| TranslationModel | Load and use trained model |
| translate() | Single sentence translation |
| translate_batch() | Multiple sentences |
| KV caching | Faster inference |
Expected Results
After training on Multi30k:
- BLEU: 30-35 on test set
- ChrF: ~55-60
- Speed: ~10-20 sentences/second (with GPU)
Next Section Preview
In the next section, we'll create an Interactive Demo with examples and analysis.