Introduction
Now we'll implement Byte-Pair Encoding from scratch in Python. This educational implementation helps you understand exactly how BPE works before using production libraries like SentencePiece.
We'll build each component step by step, with tests at each stage.
3.1 Data Structures
Word Representation
Each word is represented as a tuple of symbols (initially characters):
1from typing import Dict, List, Tuple, Set
2from collections import Counter, defaultdict
3
4
5# Word: tuple of symbols
6# "low" → ("l", "o", "w", "</w>")
7# After merging: ("lo", "w", "</w>") → ("low", "</w>")
8
9WordRepr = Tuple[str, ...] # Type alias
10
11
12def get_word_repr(word: str) -> WordRepr:
13 """Convert word to tuple of characters with end-of-word marker."""
14 return tuple(list(word) + ["</w>"])
15
16
17# Test
18print(get_word_repr("low")) # ('l', 'o', 'w', '</w>')
19print(get_word_repr("hello")) # ('h', 'e', 'l', 'l', 'o', '</w>')Vocabulary Structure
1# Vocabulary: maps word representation to frequency
2# {"low": 5, "lower": 2} becomes:
3# {("l", "o", "w", "</w>"): 5, ("l", "o", "w", "e", "r", "</w>"): 2}
4
5VocabDict = Dict[WordRepr, int]
6
7
8def build_vocab(word_freqs: Dict[str, int]) -> VocabDict:
9 """Convert word frequencies to vocabulary with character-level representations."""
10 vocab = {}
11 for word, freq in word_freqs.items():
12 word_repr = get_word_repr(word)
13 vocab[word_repr] = freq
14 return vocab
15
16
17# Test
18word_freqs = {"low": 5, "lower": 2, "newest": 6, "widest": 3}
19vocab = build_vocab(word_freqs)
20
21print("Vocabulary:")
22for word_repr, freq in vocab.items():
23 print(f" {word_repr}: {freq}")Output:
1Vocabulary:
2 ('l', 'o', 'w', '</w>'): 5
3 ('l', 'o', 'w', 'e', 'r', '</w>'): 2
4 ('n', 'e', 'w', 'e', 's', 't', '</w>'): 6
5 ('w', 'i', 'd', 'e', 's', 't', '</w>'): 33.2 Counting Pair Frequencies
The Core Function
1def get_pair_frequencies(vocab: VocabDict) -> Counter:
2 """
3 Count frequencies of all adjacent symbol pairs in vocabulary.
4
5 Args:
6 vocab: Dictionary mapping word tuples to frequencies
7
8 Returns:
9 Counter of (symbol1, symbol2) pairs with their total frequencies
10 """
11 pairs = Counter()
12
13 for word_repr, freq in vocab.items():
14 # Get all adjacent pairs in this word
15 for i in range(len(word_repr) - 1):
16 pair = (word_repr[i], word_repr[i + 1])
17 pairs[pair] += freq
18
19 return pairs
20
21
22# Test
23vocab = build_vocab({"low": 5, "lower": 2, "newest": 6, "widest": 3})
24pairs = get_pair_frequencies(vocab)
25
26print("Pair frequencies:")
27for pair, freq in pairs.most_common(10):
28 print(f" {pair}: {freq}")Output:
1Pair frequencies:
2 ('e', 's'): 9
3 ('s', 't'): 9
4 ('t', '</w>'): 9
5 ('l', 'o'): 7
6 ('o', 'w'): 7
7 ('n', 'e'): 6
8 ('e', 'w'): 6
9 ('w', 'e'): 6
10 ('w', 'i'): 3
11 ('i', 'd'): 33.3 Merging Pairs
Update Vocabulary with Merged Pair
1def merge_pair(
2 vocab: VocabDict,
3 pair: Tuple[str, str]
4) -> VocabDict:
5 """
6 Merge all occurrences of a pair in the vocabulary.
7
8 Args:
9 vocab: Current vocabulary
10 pair: (symbol1, symbol2) to merge
11
12 Returns:
13 New vocabulary with merged pair
14 """
15 new_vocab = {}
16
17 # Create the merged symbol
18 merged = pair[0] + pair[1]
19
20 for word_repr, freq in vocab.items():
21 # Convert tuple to list for modification
22 new_word = list(word_repr)
23
24 # Find and merge all occurrences of the pair
25 i = 0
26 while i < len(new_word) - 1:
27 if new_word[i] == pair[0] and new_word[i + 1] == pair[1]:
28 # Replace pair with merged symbol
29 new_word[i] = merged
30 del new_word[i + 1]
31 else:
32 i += 1
33
34 # Convert back to tuple
35 new_vocab[tuple(new_word)] = freq
36
37 return new_vocab
38
39
40# Test
41vocab = build_vocab({"low": 5, "lower": 2})
42print("Before merge:")
43for w, f in vocab.items():
44 print(f" {w}: {f}")
45
46vocab = merge_pair(vocab, ("l", "o"))
47print("\nAfter merging ('l', 'o'):")
48for w, f in vocab.items():
49 print(f" {w}: {f}")
50
51vocab = merge_pair(vocab, ("lo", "w"))
52print("\nAfter merging ('lo', 'w'):")
53for w, f in vocab.items():
54 print(f" {w}: {f}")Output:
1Before merge:
2 ('l', 'o', 'w', '</w>'): 5
3 ('l', 'o', 'w', 'e', 'r', '</w>'): 2
4
5After merging ('l', 'o'):
6 ('lo', 'w', '</w>'): 5
7 ('lo', 'w', 'e', 'r', '</w>'): 2
8
9After merging ('lo', 'w'):
10 ('low', '</w>'): 5
11 ('low', 'e', 'r', '</w>'): 23.4 Complete BPE Training
Main Training Loop
1def train_bpe(
2 word_freqs: Dict[str, int],
3 num_merges: int,
4 verbose: bool = False
5) -> Tuple[List[Tuple[str, str]], Set[str]]:
6 """
7 Train BPE by learning merge operations.
8
9 Args:
10 word_freqs: Dictionary of word -> frequency
11 num_merges: Number of merge operations to learn
12 verbose: Print progress if True
13
14 Returns:
15 Tuple of (list of merge rules, final vocabulary set)
16 """
17 # Initialize vocabulary with character-level representations
18 vocab = build_vocab(word_freqs)
19
20 # Initialize token vocabulary with all characters
21 tokens = set()
22 for word_repr in vocab.keys():
23 tokens.update(word_repr)
24
25 # Store merge rules
26 merges = []
27
28 for i in range(num_merges):
29 # Count pair frequencies
30 pairs = get_pair_frequencies(vocab)
31
32 if not pairs:
33 print(f"No more pairs to merge at iteration {i}")
34 break
35
36 # Find most frequent pair
37 best_pair = pairs.most_common(1)[0][0]
38 best_freq = pairs.most_common(1)[0][1]
39
40 if verbose:
41 print(f"Merge {i+1}: {best_pair} -> '{best_pair[0]}{best_pair[1]}' (freq: {best_freq})")
42
43 # Merge the pair
44 vocab = merge_pair(vocab, best_pair)
45
46 # Add new token and merge rule
47 new_token = best_pair[0] + best_pair[1]
48 tokens.add(new_token)
49 merges.append(best_pair)
50
51 return merges, tokens
52
53
54# Test training
55word_freqs = {
56 "low": 5,
57 "lower": 2,
58 "newest": 6,
59 "widest": 3,
60 "new": 4
61}
62
63merges, tokens = train_bpe(word_freqs, num_merges=10, verbose=True)
64
65print(f"\nLearned {len(merges)} merge rules:")
66for i, merge in enumerate(merges):
67 print(f" {i+1}. {merge} -> '{merge[0]}{merge[1]}'")
68
69print(f"\nFinal vocabulary ({len(tokens)} tokens):")
70print(sorted(tokens, key=lambda x: (-len(x), x)))Output:
1Merge 1: ('e', 's') -> 'es' (freq: 9)
2Merge 2: ('es', 't') -> 'est' (freq: 9)
3Merge 3: ('est', '</w>') -> 'est</w>' (freq: 9)
4Merge 4: ('l', 'o') -> 'lo' (freq: 7)
5Merge 5: ('lo', 'w') -> 'low' (freq: 7)
6Merge 6: ('n', 'e') -> 'ne' (freq: 10)
7Merge 7: ('ne', 'w') -> 'new' (freq: 10)
8Merge 8: ('low', '</w>') -> 'low</w>' (freq: 5)
9Merge 9: ('new', '</w>') -> 'new</w>' (freq: 4)
10Merge 10: ('w', 'i') -> 'wi' (freq: 3)
11
12Learned 10 merge rules:
13 1. ('e', 's') -> 'es'
14 2. ('es', 't') -> 'est'
15 3. ('est', '</w>') -> 'est</w>'
16 4. ('l', 'o') -> 'lo'
17 5. ('lo', 'w') -> 'low'
18 6. ('n', 'e') -> 'ne'
19 7. ('ne', 'w') -> 'new'
20 8. ('low', '</w>') -> 'low</w>'
21 9. ('new', '</w>') -> 'new</w>'
22 10. ('w', 'i') -> 'wi'
23
24Final vocabulary (22 tokens):
25['est</w>', 'low</w>', 'new</w>', 'est', 'low', 'new', '</w>', 'es', 'lo', 'ne', 'wi', 'd', 'e', 'i', 'l', 'n', 'o', 'r', 's', 't', 'w']3.5 Encoding Text with BPE
Apply Merge Rules to New Text
1def encode_word(word: str, merges: List[Tuple[str, str]]) -> List[str]:
2 """
3 Encode a single word using learned BPE merges.
4
5 Args:
6 word: Word to encode
7 merges: List of merge rules (in order)
8
9 Returns:
10 List of BPE tokens
11 """
12 # Start with character-level representation
13 tokens = list(word) + ["</w>"]
14
15 # Apply each merge rule in order
16 for pair in merges:
17 i = 0
18 while i < len(tokens) - 1:
19 if tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
20 # Merge the pair
21 tokens[i] = pair[0] + pair[1]
22 del tokens[i + 1]
23 else:
24 i += 1
25
26 return tokens
27
28
29def encode_text(text: str, merges: List[Tuple[str, str]]) -> List[str]:
30 """
31 Encode text using learned BPE merges.
32
33 Args:
34 text: Text to encode
35 merges: List of merge rules
36
37 Returns:
38 List of BPE tokens
39 """
40 # Simple whitespace tokenization
41 words = text.split()
42
43 all_tokens = []
44 for word in words:
45 tokens = encode_word(word, merges)
46 all_tokens.extend(tokens)
47
48 return all_tokens
49
50
51# Test encoding
52merges, _ = train_bpe(word_freqs, num_merges=10, verbose=False)
53
54test_words = ["low", "lower", "lowest", "newest", "newer", "wide"]
55
56print("Encoding test words:")
57for word in test_words:
58 tokens = encode_word(word, merges)
59 print(f" '{word}' -> {tokens}")
60
61# Test full sentence
62sentence = "the newest low price is lower"
63tokens = encode_text(sentence, merges)
64print(f"\nSentence: '{sentence}'")
65print(f"Tokens: {tokens}")Output:
1Encoding test words:
2 'low' -> ['low</w>']
3 'lower' -> ['low', 'e', 'r', '</w>']
4 'lowest' -> ['low', 'est</w>']
5 'newest' -> ['new', 'est</w>']
6 'newer' -> ['new', 'e', 'r', '</w>']
7 'wide' -> ['wi', 'd', 'e', '</w>']
8
9Sentence: 'the newest low price is lower'
10Tokens: ['t', 'h', 'e', '</w>', 'new', 'est</w>', 'low</w>', 'p', 'r', 'i', 'c', 'e', '</w>', 'i', 's', '</w>', 'low', 'e', 'r', '</w>']3.6 Decoding BPE Tokens
Convert Tokens Back to Text
1def decode_tokens(tokens: List[str]) -> str:
2 """
3 Decode BPE tokens back to text.
4
5 Args:
6 tokens: List of BPE tokens
7
8 Returns:
9 Decoded text string
10 """
11 # Join tokens and handle end-of-word markers
12 text = ""
13 for token in tokens:
14 if token.endswith("</w>"):
15 # Remove </w> and add space
16 text += token[:-4] + " "
17 else:
18 text += token
19
20 return text.strip()
21
22
23# Test decoding
24encoded = encode_text("the newest low", merges)
25decoded = decode_tokens(encoded)
26
27print(f"Original: 'the newest low'")
28print(f"Encoded: {encoded}")
29print(f"Decoded: '{decoded}'")
30print(f"Match: {decoded == 'the newest low'}")Output:
1Original: 'the newest low'
2Encoded: ['t', 'h', 'e', '</w>', 'new', 'est</w>', 'low</w>']
3Decoded: 'the newest low'
4Match: True3.7 Complete BPE Tokenizer Class
Production-Ready Implementation
1import json
2from typing import Optional
3from pathlib import Path
4
5
6class BPETokenizer:
7 """
8 Complete BPE tokenizer with training, encoding, and decoding.
9
10 Example:
11 >>> tokenizer = BPETokenizer()
12 >>> tokenizer.train(corpus, vocab_size=1000)
13 >>> tokens = tokenizer.encode("Hello world")
14 >>> text = tokenizer.decode(tokens)
15 """
16
17 # Special tokens
18 PAD_TOKEN = "<pad>"
19 UNK_TOKEN = "<unk>"
20 BOS_TOKEN = "<bos>"
21 EOS_TOKEN = "<eos>"
22 END_OF_WORD = "</w>"
23
24 def __init__(self):
25 self.merges: List[Tuple[str, str]] = []
26 self.vocab: Dict[str, int] = {} # token -> id
27 self.id_to_token: Dict[int, str] = {} # id -> token
28
29 def train(
30 self,
31 texts: List[str],
32 vocab_size: int = 10000,
33 min_freq: int = 2,
34 verbose: bool = False
35 ) -> None:
36 """
37 Train BPE tokenizer on corpus.
38
39 Args:
40 texts: List of training texts
41 vocab_size: Target vocabulary size
42 min_freq: Minimum word frequency to include
43 verbose: Print progress
44 """
45 # Count word frequencies
46 word_freqs = Counter()
47 for text in texts:
48 words = text.lower().split()
49 word_freqs.update(words)
50
51 # Filter by minimum frequency
52 word_freqs = {w: f for w, f in word_freqs.items() if f >= min_freq}
53
54 if verbose:
55 print(f"Training on {len(word_freqs)} unique words")
56
57 # Get initial character vocabulary
58 char_vocab = set()
59 for word in word_freqs.keys():
60 char_vocab.update(word)
61 char_vocab.add(self.END_OF_WORD)
62
63 # Calculate number of merges needed
64 initial_size = len(char_vocab) + 4 # +4 for special tokens
65 num_merges = vocab_size - initial_size
66
67 if num_merges <= 0:
68 print(f"Warning: vocab_size {vocab_size} too small, using character-level")
69 num_merges = 0
70
71 # Train BPE
72 self.merges, _ = train_bpe(word_freqs, num_merges, verbose=verbose)
73
74 # Build vocabulary
75 self._build_vocab(char_vocab)
76
77 if verbose:
78 print(f"Final vocabulary size: {len(self.vocab)}")
79
80 def _build_vocab(self, char_vocab: Set[str]) -> None:
81 """Build token to id mappings."""
82 self.vocab = {}
83
84 # Add special tokens first
85 special_tokens = [self.PAD_TOKEN, self.UNK_TOKEN,
86 self.BOS_TOKEN, self.EOS_TOKEN]
87 for i, token in enumerate(special_tokens):
88 self.vocab[token] = i
89
90 # Add characters
91 idx = len(special_tokens)
92 for char in sorted(char_vocab):
93 if char not in self.vocab:
94 self.vocab[char] = idx
95 idx += 1
96
97 # Add merged tokens
98 for pair in self.merges:
99 merged = pair[0] + pair[1]
100 if merged not in self.vocab:
101 self.vocab[merged] = idx
102 idx += 1
103
104 # Create reverse mapping
105 self.id_to_token = {v: k for k, v in self.vocab.items()}
106
107 def encode(
108 self,
109 text: str,
110 add_special_tokens: bool = False
111 ) -> List[int]:
112 """
113 Encode text to token IDs.
114
115 Args:
116 text: Input text
117 add_special_tokens: Add BOS/EOS tokens
118
119 Returns:
120 List of token IDs
121 """
122 # Get tokens
123 tokens = encode_text(text.lower(), self.merges)
124
125 # Convert to IDs
126 ids = []
127 if add_special_tokens:
128 ids.append(self.vocab[self.BOS_TOKEN])
129
130 for token in tokens:
131 if token in self.vocab:
132 ids.append(self.vocab[token])
133 else:
134 ids.append(self.vocab[self.UNK_TOKEN])
135
136 if add_special_tokens:
137 ids.append(self.vocab[self.EOS_TOKEN])
138
139 return ids
140
141 def decode(self, ids: List[int], skip_special: bool = True) -> str:
142 """
143 Decode token IDs back to text.
144
145 Args:
146 ids: List of token IDs
147 skip_special: Skip special tokens in output
148
149 Returns:
150 Decoded text
151 """
152 special = {self.PAD_TOKEN, self.UNK_TOKEN,
153 self.BOS_TOKEN, self.EOS_TOKEN}
154
155 tokens = []
156 for id_ in ids:
157 token = self.id_to_token.get(id_, self.UNK_TOKEN)
158 if skip_special and token in special:
159 continue
160 tokens.append(token)
161
162 return decode_tokens(tokens)
163
164 def get_vocab_size(self) -> int:
165 """Return vocabulary size."""
166 return len(self.vocab)
167
168 @property
169 def pad_token_id(self) -> int:
170 return self.vocab[self.PAD_TOKEN]
171
172 @property
173 def unk_token_id(self) -> int:
174 return self.vocab[self.UNK_TOKEN]
175
176 @property
177 def bos_token_id(self) -> int:
178 return self.vocab[self.BOS_TOKEN]
179
180 @property
181 def eos_token_id(self) -> int:
182 return self.vocab[self.EOS_TOKEN]
183
184 def save(self, path: str) -> None:
185 """Save tokenizer to file."""
186 data = {
187 "merges": self.merges,
188 "vocab": self.vocab
189 }
190 with open(path, "w") as f:
191 json.dump(data, f)
192
193 @classmethod
194 def load(cls, path: str) -> "BPETokenizer":
195 """Load tokenizer from file."""
196 with open(path, "r") as f:
197 data = json.load(f)
198
199 tokenizer = cls()
200 tokenizer.merges = [tuple(m) for m in data["merges"]]
201 tokenizer.vocab = data["vocab"]
202 tokenizer.id_to_token = {v: k for k, v in tokenizer.vocab.items()}
203
204 return tokenizer
205
206 def __repr__(self) -> str:
207 return f"BPETokenizer(vocab_size={len(self.vocab)}, num_merges={len(self.merges)})"3.8 Performance Considerations
Efficiency of Our Implementation
Our educational implementation has:
- Time complexity: O(N × M × L) where N = corpus size, M = num merges, L = avg word length
- Space complexity: O(V) where V = vocabulary size
Production Optimizations
Real implementations use:
1# 1. Caching pair positions for faster updates
2class OptimizedBPE:
3 def __init__(self):
4 self.pair_positions = defaultdict(set) # pair -> {(word_idx, position)}
5
6 def get_pair_frequencies(self):
7 # O(1) lookup instead of scanning all words
8 return {pair: len(positions)
9 for pair, positions in self.pair_positions.items()}
10
11# 2. Priority queue for finding most frequent pair
12import heapq
13
14class PriorityQueueBPE:
15 def __init__(self):
16 self.heap = [] # (-freq, pair)
17
18 def get_most_frequent(self):
19 return heapq.heappop(self.heap)[1]
20
21# 3. Byte-level BPE (GPT-2 style)
22# Works on bytes instead of characters
23# Guarantees all UTF-8 text can be encodedWhen to Use Our Implementation
Use educational implementation for:
- Understanding the algorithm
- Small experiments
- Learning/teaching
Use production libraries for:
- Real training data (millions of sentences)
- Production models
- Multilingual tokenization
Summary
Components Built
| Component | Purpose |
|---|---|
| get_word_repr() | Convert word to character tuple |
| build_vocab() | Initialize vocabulary |
| get_pair_frequencies() | Count adjacent pairs |
| merge_pair() | Apply merge to vocabulary |
| train_bpe() | Learn merge rules |
| encode_word() | Encode single word |
| encode_text() | Encode full text |
| decode_tokens() | Convert tokens to text |
| BPETokenizer | Complete tokenizer class |
Key Takeaways
- BPE starts with characters and iteratively merges
- Merge order matters - apply in learned order
- End-of-word marker preserves word boundaries
- Our implementation is educational, not optimized
- Real models use libraries like SentencePiece
Exercises
Implementation Exercises
- Modify
BPETokenizerto support case-sensitive encoding. - Add a
tokenize()method that returns tokens instead of IDs. - Implement dropout BPE: randomly skip some merges during training for robustness.
Analysis Exercises
- Train BPE with different vocabulary sizes (100, 500, 1000, 5000) on the same corpus. Compare average tokens per word.
- Create a visualization showing how a word gets progressively merged during encoding.
- Compare encoding of "playing", "played", "plays" - are the root morphemes shared?
Extension Exercises
- Implement WordPiece: use likelihood ratio instead of frequency for merge selection.
- Add support for special tokens that should never be split (like URLs or emails).
Next Section Preview
In the next section, we'll transition from our educational implementation to SentencePiece, a production-grade library. We'll learn how to train tokenizers for our German-English translation project and integrate them with PyTorch datasets.