Chapter 14
12 min read
Section 66 of 75

Inference Pipeline

Inference and Demo

Introduction

This section covers how to use the trained model for translation. We'll implement efficient inference with beam search and create a complete translation pipeline.


Loading the Trained Model

Model Loading Utilities

๐Ÿpython
1import torch
2import torch.nn as nn
3from typing import Dict, Any, Optional, List
4from pathlib import Path
5import json
6
7
8class TranslationModel:
9    """
10    Wrapper for loading and using trained translation model.
11
12    Handles:
13    - Loading model and tokenizer
14    - Preprocessing input
15    - Generating translations
16    - Postprocessing output
17
18    Args:
19        checkpoint_path: Path to model checkpoint
20        tokenizer_path: Path to tokenizer
21        device: Device to run on
22
23    Example:
24        >>> model = TranslationModel('best_model.pt', 'tokenizer.json')
25        >>> translation = model.translate("Der Hund lรคuft.")
26    """
27
28    def __init__(
29        self,
30        checkpoint_path: str,
31        tokenizer_path: str,
32        device: str = "cuda"
33    ):
34        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
35
36        # Load tokenizer
37        self.tokenizer = self._load_tokenizer(tokenizer_path)
38
39        # Load model
40        self.model, self.config = self._load_model(checkpoint_path)
41        self.model.eval()
42
43        print(f"Model loaded on {self.device}")
44        print(f"Vocabulary size: {len(self.tokenizer.token_to_id)}")
45
46    def _load_tokenizer(self, path: str):
47        """Load tokenizer from file."""
48        # Using our JointBPETokenizer
49        tokenizer = JointBPETokenizer.load(path)
50        return tokenizer
51
52    def _load_model(self, checkpoint_path: str):
53        """Load model from checkpoint."""
54        checkpoint = torch.load(checkpoint_path, map_location=self.device)
55
56        # Get config
57        if 'model_config' in checkpoint:
58            config = ModelConfig(**checkpoint['model_config'])
59        else:
60            # Default config
61            config = ModelConfig()
62
63        # Build model
64        model = build_model(config)
65        model.load_state_dict(checkpoint['model_state_dict'])
66        model = model.to(self.device)
67
68        return model, config
69
70    @torch.no_grad()
71    def translate(
72        self,
73        text: str,
74        beam_size: int = 5,
75        max_length: int = 128,
76        length_penalty: float = 1.0
77    ) -> str:
78        """
79        Translate a single sentence.
80
81        Args:
82            text: German text to translate
83            beam_size: Beam search width
84            max_length: Maximum output length
85            length_penalty: Length normalization factor
86
87        Returns:
88            English translation
89        """
90        # Preprocess
91        source_ids = self.tokenizer.encode(text, add_special_tokens=False)
92        source_tensor = torch.tensor([source_ids], device=self.device)
93
94        # Generate
95        output_ids = self._generate(
96            source_tensor,
97            beam_size=beam_size,
98            max_length=max_length,
99            length_penalty=length_penalty
100        )
101
102        # Decode
103        translation = self.tokenizer.decode(output_ids)
104
105        return translation
106
107    def _generate(
108        self,
109        source: torch.Tensor,
110        beam_size: int,
111        max_length: int,
112        length_penalty: float
113    ) -> List[int]:
114        """
115        Generate translation using beam search.
116        """
117        # Encode source
118        encoder_output = self.model.encode(source)
119
120        # Initialize beams
121        bos_id = self.tokenizer.bos_id
122        eos_id = self.tokenizer.eos_id
123
124        beams = [([], 0.0)]  # (tokens, score)
125
126        for step in range(max_length):
127            all_candidates = []
128
129            for tokens, score in beams:
130                # Check if already finished
131                if tokens and tokens[-1] == eos_id:
132                    all_candidates.append((tokens, score))
133                    continue
134
135                # Prepare decoder input
136                decoder_input = [bos_id] + tokens
137                decoder_tensor = torch.tensor(
138                    [decoder_input],
139                    device=self.device
140                )
141
142                # Get logits for next token
143                logits = self.model.decode(
144                    decoder_tensor,
145                    encoder_output
146                )
147
148                # Get probabilities for last position
149                probs = torch.log_softmax(logits[0, -1], dim=-1)
150
151                # Get top-k candidates
152                top_probs, top_ids = probs.topk(beam_size)
153
154                for prob, token_id in zip(top_probs, top_ids):
155                    new_tokens = tokens + [token_id.item()]
156                    new_score = score + prob.item()
157
158                    # Apply length penalty
159                    length_norm = ((5 + len(new_tokens)) / 6) ** length_penalty
160                    normalized_score = new_score / length_norm
161
162                    all_candidates.append((new_tokens, normalized_score))
163
164            # Select top beams
165            all_candidates.sort(key=lambda x: x[1], reverse=True)
166            beams = all_candidates[:beam_size]
167
168            # Check if all beams finished
169            if all(b[0] and b[0][-1] == eos_id for b in beams):
170                break
171
172        # Return best beam (without EOS)
173        best_tokens = beams[0][0]
174        if best_tokens and best_tokens[-1] == eos_id:
175            best_tokens = best_tokens[:-1]
176
177        return best_tokens
178
179    @torch.no_grad()
180    def translate_batch(
181        self,
182        texts: List[str],
183        beam_size: int = 5,
184        max_length: int = 128
185    ) -> List[str]:
186        """
187        Translate multiple sentences.
188
189        Args:
190            texts: List of German sentences
191            beam_size: Beam search width
192            max_length: Maximum output length
193
194        Returns:
195            List of English translations
196        """
197        translations = []
198
199        for text in texts:
200            translation = self.translate(
201                text,
202                beam_size=beam_size,
203                max_length=max_length
204            )
205            translations.append(translation)
206
207        return translations
208
209
210def demonstrate_loading():
211    """
212    Demonstrate model loading.
213    """
214    print("Model Loading Demonstration")
215    print("=" * 60)
216
217    print("""
218    USAGE:
219    โ”€โ”€โ”€โ”€โ”€โ”€
220
221    # Load model
222    model = TranslationModel(
223        checkpoint_path='checkpoints/best_model.pt',
224        tokenizer_path='data/tokenizer/tokenizer.json',
225        device='cuda'
226    )
227
228    # Single translation
229    german = "Der Hund lรคuft im Park."
230    english = model.translate(german)
231    print(f"DE: {german}")
232    print(f"EN: {english}")
233
234    # Batch translation
235    sentences = [
236        "Die Katze schlรคft auf dem Sofa.",
237        "Ein Mann liest ein Buch.",
238        "Kinder spielen im Garten.",
239    ]
240    translations = model.translate_batch(sentences)
241    """)
242
243
244demonstrate_loading()

Optimized Inference

KV Caching for Speed

๐Ÿpython
1class FastTranslationModel(TranslationModel):
2    """
3    Translation model with KV caching for faster inference.
4
5    Uses cached key-value pairs to avoid recomputing
6    attention for previous tokens.
7    """
8
9    @torch.no_grad()
10    def translate_with_cache(
11        self,
12        text: str,
13        beam_size: int = 1,  # Greedy for simplicity with cache
14        max_length: int = 128
15    ) -> str:
16        """
17        Translate using KV caching.
18
19        More efficient for long sequences.
20        """
21        # Preprocess
22        source_ids = self.tokenizer.encode(text, add_special_tokens=False)
23        source_tensor = torch.tensor([source_ids], device=self.device)
24
25        # Encode source (once)
26        encoder_output = self.model.encode(source_tensor)
27
28        # Generate with caching
29        generated = [self.tokenizer.bos_id]
30        cache = None
31
32        for step in range(max_length):
33            # Only pass new token (and get cache)
34            if cache is None:
35                decoder_input = torch.tensor(
36                    [generated],
37                    device=self.device
38                )
39            else:
40                decoder_input = torch.tensor(
41                    [[generated[-1]]],
42                    device=self.device
43                )
44
45            # Forward with cache
46            logits, cache = self.model.decode_with_cache(
47                decoder_input,
48                encoder_output,
49                cache=cache
50            )
51
52            # Get next token (greedy)
53            next_token = logits[0, -1].argmax().item()
54
55            if next_token == self.tokenizer.eos_id:
56                break
57
58            generated.append(next_token)
59
60        # Decode (skip BOS)
61        translation = self.tokenizer.decode(generated[1:])
62
63        return translation
64
65
66def compare_inference_speed():
67    """
68    Compare inference methods.
69    """
70    print("Inference Speed Comparison")
71    print("=" * 60)
72
73    print("""
74    METHOD COMPARISON:
75    โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
76
77    1. WITHOUT CACHING:
78       - Recomputes all previous tokens each step
79       - Time complexity: O(nยฒ) where n = output length
80       - Memory: Lower (no cache)
81
82    2. WITH KV CACHING:
83       - Caches key-value pairs from previous steps
84       - Time complexity: O(n)
85       - Memory: Higher (stores cache)
86
87    TYPICAL SPEEDUP:
88    โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
89
90    Output length 10:  ~2x faster with cache
91    Output length 20:  ~5x faster
92    Output length 50:  ~10x faster
93    Output length 100: ~20x faster
94
95    RECOMMENDATION:
96    โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
97
98    - Short outputs (<20 tokens): Either method
99    - Long outputs (>20 tokens): Use caching
100    - Batch processing: Consider parallel decoding
101    """)
102
103
104compare_inference_speed()

Complete Inference Script

Ready-to-Run Script

๐Ÿpython
1#!/usr/bin/env python3
2"""
3translate.py - Translate German text to English
4
5Usage:
6    python translate.py "Der Hund lรคuft im Park."
7    python translate.py --file input.txt --output translations.txt
8    python translate.py --interactive
9"""
10
11import argparse
12import torch
13from pathlib import Path
14import sys
15
16
17def parse_args():
18    parser = argparse.ArgumentParser(description="German-English Translation")
19
20    # Input modes
21    parser.add_argument("text", nargs="?", help="Text to translate")
22    parser.add_argument("--file", "-f", help="Input file (one sentence per line)")
23    parser.add_argument("--output", "-o", help="Output file")
24    parser.add_argument("--interactive", "-i", action="store_true",
25                       help="Interactive mode")
26
27    # Model settings
28    parser.add_argument("--checkpoint", default="checkpoints/best_model.pt")
29    parser.add_argument("--tokenizer", default="data/tokenizer/tokenizer.json")
30
31    # Generation settings
32    parser.add_argument("--beam-size", type=int, default=5)
33    parser.add_argument("--max-length", type=int, default=128)
34    parser.add_argument("--length-penalty", type=float, default=1.0)
35
36    # Device
37    parser.add_argument("--device", default="cuda")
38
39    return parser.parse_args()
40
41
42def main():
43    args = parse_args()
44
45    # Load model
46    print("Loading model...")
47    model = TranslationModel(
48        checkpoint_path=args.checkpoint,
49        tokenizer_path=args.tokenizer,
50        device=args.device
51    )
52    print("Model loaded!\\n")
53
54    if args.interactive:
55        # Interactive mode
56        interactive_mode(model, args)
57
58    elif args.file:
59        # File mode
60        translate_file(model, args)
61
62    elif args.text:
63        # Single sentence
64        translation = model.translate(
65            args.text,
66            beam_size=args.beam_size,
67            max_length=args.max_length,
68            length_penalty=args.length_penalty
69        )
70        print(f"German:  {args.text}")
71        print(f"English: {translation}")
72
73    else:
74        print("Please provide text to translate or use --interactive mode")
75        sys.exit(1)
76
77
78def interactive_mode(model, args):
79    """
80    Interactive translation loop.
81    """
82    print("=" * 50)
83    print("German-English Translation (Interactive Mode)")
84    print("Type 'quit' or 'exit' to stop")
85    print("=" * 50)
86    print()
87
88    while True:
89        try:
90            text = input("DE > ").strip()
91
92            if text.lower() in ['quit', 'exit', 'q']:
93                print("Goodbye!")
94                break
95
96            if not text:
97                continue
98
99            translation = model.translate(
100                text,
101                beam_size=args.beam_size,
102                max_length=args.max_length
103            )
104
105            print(f"EN > {translation}")
106            print()
107
108        except KeyboardInterrupt:
109            print("\\nGoodbye!")
110            break
111
112
113def translate_file(model, args):
114    """
115    Translate entire file.
116    """
117    input_path = Path(args.file)
118
119    if not input_path.exists():
120        print(f"File not found: {args.file}")
121        sys.exit(1)
122
123    # Read input
124    with open(input_path, 'r', encoding='utf-8') as f:
125        sentences = [line.strip() for line in f if line.strip()]
126
127    print(f"Translating {len(sentences)} sentences...")
128
129    # Translate
130    translations = []
131    for i, sentence in enumerate(sentences):
132        translation = model.translate(
133            sentence,
134            beam_size=args.beam_size,
135            max_length=args.max_length
136        )
137        translations.append(translation)
138
139        if (i + 1) % 100 == 0:
140            print(f"  Processed {i + 1}/{len(sentences)}")
141
142    # Output
143    if args.output:
144        with open(args.output, 'w', encoding='utf-8') as f:
145            for trans in translations:
146                f.write(trans + '\\n')
147        print(f"Translations saved to {args.output}")
148    else:
149        for src, tgt in zip(sentences, translations):
150            print(f"DE: {src}")
151            print(f"EN: {tgt}")
152            print()
153
154
155if __name__ == "__main__":
156    main()

Evaluation Script

Test Set Evaluation

๐Ÿpython
1#!/usr/bin/env python3
2"""
3evaluate.py - Evaluate translation model on test set
4
5Usage:
6    python evaluate.py --checkpoint best_model.pt
7"""
8
9import argparse
10import torch
11from pathlib import Path
12import json
13
14
15def main():
16    parser = argparse.ArgumentParser()
17    parser.add_argument("--checkpoint", required=True)
18    parser.add_argument("--tokenizer", default="data/tokenizer/tokenizer.json")
19    parser.add_argument("--source", default="data/multi30k/test_2016_flickr.de")
20    parser.add_argument("--reference", default="data/multi30k/test_2016_flickr.en")
21    parser.add_argument("--output", default="evaluation_results.json")
22    parser.add_argument("--beam-size", type=int, default=5)
23    args = parser.parse_args()
24
25    # Load model
26    print("Loading model...")
27    model = TranslationModel(
28        checkpoint_path=args.checkpoint,
29        tokenizer_path=args.tokenizer
30    )
31
32    # Load test data
33    print("Loading test data...")
34    with open(args.source, 'r') as f:
35        sources = [line.strip() for line in f]
36    with open(args.reference, 'r') as f:
37        references = [line.strip() for line in f]
38
39    print(f"Test sentences: {len(sources)}")
40
41    # Translate
42    print("\\nTranslating...")
43    hypotheses = []
44
45    for i, source in enumerate(sources):
46        translation = model.translate(source, beam_size=args.beam_size)
47        hypotheses.append(translation)
48
49        if (i + 1) % 100 == 0:
50            print(f"  {i + 1}/{len(sources)}")
51
52    # Compute metrics
53    print("\\nComputing metrics...")
54
55    # BLEU
56    bleu_scorer = BLEUScore()
57    for hyp, ref in zip(hypotheses, references):
58        bleu_scorer.add(hyp, [ref])
59    bleu_result = bleu_scorer.compute()
60
61    # ChrF
62    chrf_scorer = ChrFScore()
63    for hyp, ref in zip(hypotheses, references):
64        chrf_scorer.add(hyp, ref)
65    chrf_result = chrf_scorer.corpus_score()
66
67    # Print results
68    print("\\n" + "=" * 50)
69    print("EVALUATION RESULTS")
70    print("=" * 50)
71    print(f"BLEU:  {bleu_result['bleu'] * 100:.2f}")
72    print(f"ChrF:  {chrf_result['chrf'] * 100:.2f}")
73    print("=" * 50)
74
75    # Save results
76    results = {
77        'bleu': bleu_result['bleu'] * 100,
78        'chrf': chrf_result['chrf'] * 100,
79        'precisions': [p * 100 for p in bleu_result['precisions']],
80        'num_sentences': len(sources),
81    }
82
83    with open(args.output, 'w') as f:
84        json.dump(results, f, indent=2)
85
86    print(f"\\nResults saved to {args.output}")
87
88    # Show examples
89    print("\\nSample translations:")
90    print("-" * 50)
91
92    for i in range(min(5, len(sources))):
93        print(f"\\n[{i+1}]")
94        print(f"  Source: {sources[i]}")
95        print(f"  Reference: {references[i]}")
96        print(f"  Hypothesis: {hypotheses[i]}")
97
98
99if __name__ == "__main__":
100    main()

Summary

Inference Components

ComponentPurpose
TranslationModelLoad and use trained model
translate()Single sentence translation
translate_batch()Multiple sentences
KV cachingFaster inference

Expected Results

After training on Multi30k:

  • BLEU: 30-35 on test set
  • ChrF: ~55-60
  • Speed: ~10-20 sentences/second (with GPU)

Next Section Preview

In the next section, we'll create an Interactive Demo with examples and analysis.

Loading comments...