Chapter 5
25 min read
Section 27 of 75

Implementing BPE from Scratch

Subword Tokenization for Translation

Introduction

Now we'll implement Byte-Pair Encoding from scratch in Python. This educational implementation helps you understand exactly how BPE works before using production libraries like SentencePiece.

We'll build each component step by step, with tests at each stage.


3.1 Data Structures

Word Representation

Each word is represented as a tuple of symbols (initially characters):

🐍python
1from typing import Dict, List, Tuple, Set
2from collections import Counter, defaultdict
3
4
5# Word: tuple of symbols
6# "low" → ("l", "o", "w", "</w>")
7# After merging: ("lo", "w", "</w>") → ("low", "</w>")
8
9WordRepr = Tuple[str, ...]  # Type alias
10
11
12def get_word_repr(word: str) -> WordRepr:
13    """Convert word to tuple of characters with end-of-word marker."""
14    return tuple(list(word) + ["</w>"])
15
16
17# Test
18print(get_word_repr("low"))      # ('l', 'o', 'w', '</w>')
19print(get_word_repr("hello"))    # ('h', 'e', 'l', 'l', 'o', '</w>')

Vocabulary Structure

🐍python
1# Vocabulary: maps word representation to frequency
2# {"low": 5, "lower": 2} becomes:
3# {("l", "o", "w", "</w>"): 5, ("l", "o", "w", "e", "r", "</w>"): 2}
4
5VocabDict = Dict[WordRepr, int]
6
7
8def build_vocab(word_freqs: Dict[str, int]) -> VocabDict:
9    """Convert word frequencies to vocabulary with character-level representations."""
10    vocab = {}
11    for word, freq in word_freqs.items():
12        word_repr = get_word_repr(word)
13        vocab[word_repr] = freq
14    return vocab
15
16
17# Test
18word_freqs = {"low": 5, "lower": 2, "newest": 6, "widest": 3}
19vocab = build_vocab(word_freqs)
20
21print("Vocabulary:")
22for word_repr, freq in vocab.items():
23    print(f"  {word_repr}: {freq}")

Output:

📝text
1Vocabulary:
2  ('l', 'o', 'w', '</w>'): 5
3  ('l', 'o', 'w', 'e', 'r', '</w>'): 2
4  ('n', 'e', 'w', 'e', 's', 't', '</w>'): 6
5  ('w', 'i', 'd', 'e', 's', 't', '</w>'): 3

3.2 Counting Pair Frequencies

The Core Function

🐍python
1def get_pair_frequencies(vocab: VocabDict) -> Counter:
2    """
3    Count frequencies of all adjacent symbol pairs in vocabulary.
4
5    Args:
6        vocab: Dictionary mapping word tuples to frequencies
7
8    Returns:
9        Counter of (symbol1, symbol2) pairs with their total frequencies
10    """
11    pairs = Counter()
12
13    for word_repr, freq in vocab.items():
14        # Get all adjacent pairs in this word
15        for i in range(len(word_repr) - 1):
16            pair = (word_repr[i], word_repr[i + 1])
17            pairs[pair] += freq
18
19    return pairs
20
21
22# Test
23vocab = build_vocab({"low": 5, "lower": 2, "newest": 6, "widest": 3})
24pairs = get_pair_frequencies(vocab)
25
26print("Pair frequencies:")
27for pair, freq in pairs.most_common(10):
28    print(f"  {pair}: {freq}")

Output:

📝text
1Pair frequencies:
2  ('e', 's'): 9
3  ('s', 't'): 9
4  ('t', '</w>'): 9
5  ('l', 'o'): 7
6  ('o', 'w'): 7
7  ('n', 'e'): 6
8  ('e', 'w'): 6
9  ('w', 'e'): 6
10  ('w', 'i'): 3
11  ('i', 'd'): 3

3.3 Merging Pairs

Update Vocabulary with Merged Pair

🐍python
1def merge_pair(
2    vocab: VocabDict,
3    pair: Tuple[str, str]
4) -> VocabDict:
5    """
6    Merge all occurrences of a pair in the vocabulary.
7
8    Args:
9        vocab: Current vocabulary
10        pair: (symbol1, symbol2) to merge
11
12    Returns:
13        New vocabulary with merged pair
14    """
15    new_vocab = {}
16
17    # Create the merged symbol
18    merged = pair[0] + pair[1]
19
20    for word_repr, freq in vocab.items():
21        # Convert tuple to list for modification
22        new_word = list(word_repr)
23
24        # Find and merge all occurrences of the pair
25        i = 0
26        while i < len(new_word) - 1:
27            if new_word[i] == pair[0] and new_word[i + 1] == pair[1]:
28                # Replace pair with merged symbol
29                new_word[i] = merged
30                del new_word[i + 1]
31            else:
32                i += 1
33
34        # Convert back to tuple
35        new_vocab[tuple(new_word)] = freq
36
37    return new_vocab
38
39
40# Test
41vocab = build_vocab({"low": 5, "lower": 2})
42print("Before merge:")
43for w, f in vocab.items():
44    print(f"  {w}: {f}")
45
46vocab = merge_pair(vocab, ("l", "o"))
47print("\nAfter merging ('l', 'o'):")
48for w, f in vocab.items():
49    print(f"  {w}: {f}")
50
51vocab = merge_pair(vocab, ("lo", "w"))
52print("\nAfter merging ('lo', 'w'):")
53for w, f in vocab.items():
54    print(f"  {w}: {f}")

Output:

📝text
1Before merge:
2  ('l', 'o', 'w', '</w>'): 5
3  ('l', 'o', 'w', 'e', 'r', '</w>'): 2
4
5After merging ('l', 'o'):
6  ('lo', 'w', '</w>'): 5
7  ('lo', 'w', 'e', 'r', '</w>'): 2
8
9After merging ('lo', 'w'):
10  ('low', '</w>'): 5
11  ('low', 'e', 'r', '</w>'): 2

3.4 Complete BPE Training

Main Training Loop

🐍python
1def train_bpe(
2    word_freqs: Dict[str, int],
3    num_merges: int,
4    verbose: bool = False
5) -> Tuple[List[Tuple[str, str]], Set[str]]:
6    """
7    Train BPE by learning merge operations.
8
9    Args:
10        word_freqs: Dictionary of word -> frequency
11        num_merges: Number of merge operations to learn
12        verbose: Print progress if True
13
14    Returns:
15        Tuple of (list of merge rules, final vocabulary set)
16    """
17    # Initialize vocabulary with character-level representations
18    vocab = build_vocab(word_freqs)
19
20    # Initialize token vocabulary with all characters
21    tokens = set()
22    for word_repr in vocab.keys():
23        tokens.update(word_repr)
24
25    # Store merge rules
26    merges = []
27
28    for i in range(num_merges):
29        # Count pair frequencies
30        pairs = get_pair_frequencies(vocab)
31
32        if not pairs:
33            print(f"No more pairs to merge at iteration {i}")
34            break
35
36        # Find most frequent pair
37        best_pair = pairs.most_common(1)[0][0]
38        best_freq = pairs.most_common(1)[0][1]
39
40        if verbose:
41            print(f"Merge {i+1}: {best_pair} -> '{best_pair[0]}{best_pair[1]}' (freq: {best_freq})")
42
43        # Merge the pair
44        vocab = merge_pair(vocab, best_pair)
45
46        # Add new token and merge rule
47        new_token = best_pair[0] + best_pair[1]
48        tokens.add(new_token)
49        merges.append(best_pair)
50
51    return merges, tokens
52
53
54# Test training
55word_freqs = {
56    "low": 5,
57    "lower": 2,
58    "newest": 6,
59    "widest": 3,
60    "new": 4
61}
62
63merges, tokens = train_bpe(word_freqs, num_merges=10, verbose=True)
64
65print(f"\nLearned {len(merges)} merge rules:")
66for i, merge in enumerate(merges):
67    print(f"  {i+1}. {merge} -> '{merge[0]}{merge[1]}'")
68
69print(f"\nFinal vocabulary ({len(tokens)} tokens):")
70print(sorted(tokens, key=lambda x: (-len(x), x)))

Output:

📝text
1Merge 1: ('e', 's') -> 'es' (freq: 9)
2Merge 2: ('es', 't') -> 'est' (freq: 9)
3Merge 3: ('est', '</w>') -> 'est</w>' (freq: 9)
4Merge 4: ('l', 'o') -> 'lo' (freq: 7)
5Merge 5: ('lo', 'w') -> 'low' (freq: 7)
6Merge 6: ('n', 'e') -> 'ne' (freq: 10)
7Merge 7: ('ne', 'w') -> 'new' (freq: 10)
8Merge 8: ('low', '</w>') -> 'low</w>' (freq: 5)
9Merge 9: ('new', '</w>') -> 'new</w>' (freq: 4)
10Merge 10: ('w', 'i') -> 'wi' (freq: 3)
11
12Learned 10 merge rules:
13  1. ('e', 's') -> 'es'
14  2. ('es', 't') -> 'est'
15  3. ('est', '</w>') -> 'est</w>'
16  4. ('l', 'o') -> 'lo'
17  5. ('lo', 'w') -> 'low'
18  6. ('n', 'e') -> 'ne'
19  7. ('ne', 'w') -> 'new'
20  8. ('low', '</w>') -> 'low</w>'
21  9. ('new', '</w>') -> 'new</w>'
22  10. ('w', 'i') -> 'wi'
23
24Final vocabulary (22 tokens):
25['est</w>', 'low</w>', 'new</w>', 'est', 'low', 'new', '</w>', 'es', 'lo', 'ne', 'wi', 'd', 'e', 'i', 'l', 'n', 'o', 'r', 's', 't', 'w']

3.5 Encoding Text with BPE

Apply Merge Rules to New Text

🐍python
1def encode_word(word: str, merges: List[Tuple[str, str]]) -> List[str]:
2    """
3    Encode a single word using learned BPE merges.
4
5    Args:
6        word: Word to encode
7        merges: List of merge rules (in order)
8
9    Returns:
10        List of BPE tokens
11    """
12    # Start with character-level representation
13    tokens = list(word) + ["</w>"]
14
15    # Apply each merge rule in order
16    for pair in merges:
17        i = 0
18        while i < len(tokens) - 1:
19            if tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
20                # Merge the pair
21                tokens[i] = pair[0] + pair[1]
22                del tokens[i + 1]
23            else:
24                i += 1
25
26    return tokens
27
28
29def encode_text(text: str, merges: List[Tuple[str, str]]) -> List[str]:
30    """
31    Encode text using learned BPE merges.
32
33    Args:
34        text: Text to encode
35        merges: List of merge rules
36
37    Returns:
38        List of BPE tokens
39    """
40    # Simple whitespace tokenization
41    words = text.split()
42
43    all_tokens = []
44    for word in words:
45        tokens = encode_word(word, merges)
46        all_tokens.extend(tokens)
47
48    return all_tokens
49
50
51# Test encoding
52merges, _ = train_bpe(word_freqs, num_merges=10, verbose=False)
53
54test_words = ["low", "lower", "lowest", "newest", "newer", "wide"]
55
56print("Encoding test words:")
57for word in test_words:
58    tokens = encode_word(word, merges)
59    print(f"  '{word}' -> {tokens}")
60
61# Test full sentence
62sentence = "the newest low price is lower"
63tokens = encode_text(sentence, merges)
64print(f"\nSentence: '{sentence}'")
65print(f"Tokens: {tokens}")

Output:

📝text
1Encoding test words:
2  'low' -> ['low</w>']
3  'lower' -> ['low', 'e', 'r', '</w>']
4  'lowest' -> ['low', 'est</w>']
5  'newest' -> ['new', 'est</w>']
6  'newer' -> ['new', 'e', 'r', '</w>']
7  'wide' -> ['wi', 'd', 'e', '</w>']
8
9Sentence: 'the newest low price is lower'
10Tokens: ['t', 'h', 'e', '</w>', 'new', 'est</w>', 'low</w>', 'p', 'r', 'i', 'c', 'e', '</w>', 'i', 's', '</w>', 'low', 'e', 'r', '</w>']

3.6 Decoding BPE Tokens

Convert Tokens Back to Text

🐍python
1def decode_tokens(tokens: List[str]) -> str:
2    """
3    Decode BPE tokens back to text.
4
5    Args:
6        tokens: List of BPE tokens
7
8    Returns:
9        Decoded text string
10    """
11    # Join tokens and handle end-of-word markers
12    text = ""
13    for token in tokens:
14        if token.endswith("</w>"):
15            # Remove </w> and add space
16            text += token[:-4] + " "
17        else:
18            text += token
19
20    return text.strip()
21
22
23# Test decoding
24encoded = encode_text("the newest low", merges)
25decoded = decode_tokens(encoded)
26
27print(f"Original: 'the newest low'")
28print(f"Encoded: {encoded}")
29print(f"Decoded: '{decoded}'")
30print(f"Match: {decoded == 'the newest low'}")

Output:

📝text
1Original: 'the newest low'
2Encoded: ['t', 'h', 'e', '</w>', 'new', 'est</w>', 'low</w>']
3Decoded: 'the newest low'
4Match: True

3.7 Complete BPE Tokenizer Class

Production-Ready Implementation

🐍python
1import json
2from typing import Optional
3from pathlib import Path
4
5
6class BPETokenizer:
7    """
8    Complete BPE tokenizer with training, encoding, and decoding.
9
10    Example:
11        >>> tokenizer = BPETokenizer()
12        >>> tokenizer.train(corpus, vocab_size=1000)
13        >>> tokens = tokenizer.encode("Hello world")
14        >>> text = tokenizer.decode(tokens)
15    """
16
17    # Special tokens
18    PAD_TOKEN = "<pad>"
19    UNK_TOKEN = "<unk>"
20    BOS_TOKEN = "<bos>"
21    EOS_TOKEN = "<eos>"
22    END_OF_WORD = "</w>"
23
24    def __init__(self):
25        self.merges: List[Tuple[str, str]] = []
26        self.vocab: Dict[str, int] = {}  # token -> id
27        self.id_to_token: Dict[int, str] = {}  # id -> token
28
29    def train(
30        self,
31        texts: List[str],
32        vocab_size: int = 10000,
33        min_freq: int = 2,
34        verbose: bool = False
35    ) -> None:
36        """
37        Train BPE tokenizer on corpus.
38
39        Args:
40            texts: List of training texts
41            vocab_size: Target vocabulary size
42            min_freq: Minimum word frequency to include
43            verbose: Print progress
44        """
45        # Count word frequencies
46        word_freqs = Counter()
47        for text in texts:
48            words = text.lower().split()
49            word_freqs.update(words)
50
51        # Filter by minimum frequency
52        word_freqs = {w: f for w, f in word_freqs.items() if f >= min_freq}
53
54        if verbose:
55            print(f"Training on {len(word_freqs)} unique words")
56
57        # Get initial character vocabulary
58        char_vocab = set()
59        for word in word_freqs.keys():
60            char_vocab.update(word)
61        char_vocab.add(self.END_OF_WORD)
62
63        # Calculate number of merges needed
64        initial_size = len(char_vocab) + 4  # +4 for special tokens
65        num_merges = vocab_size - initial_size
66
67        if num_merges <= 0:
68            print(f"Warning: vocab_size {vocab_size} too small, using character-level")
69            num_merges = 0
70
71        # Train BPE
72        self.merges, _ = train_bpe(word_freqs, num_merges, verbose=verbose)
73
74        # Build vocabulary
75        self._build_vocab(char_vocab)
76
77        if verbose:
78            print(f"Final vocabulary size: {len(self.vocab)}")
79
80    def _build_vocab(self, char_vocab: Set[str]) -> None:
81        """Build token to id mappings."""
82        self.vocab = {}
83
84        # Add special tokens first
85        special_tokens = [self.PAD_TOKEN, self.UNK_TOKEN,
86                         self.BOS_TOKEN, self.EOS_TOKEN]
87        for i, token in enumerate(special_tokens):
88            self.vocab[token] = i
89
90        # Add characters
91        idx = len(special_tokens)
92        for char in sorted(char_vocab):
93            if char not in self.vocab:
94                self.vocab[char] = idx
95                idx += 1
96
97        # Add merged tokens
98        for pair in self.merges:
99            merged = pair[0] + pair[1]
100            if merged not in self.vocab:
101                self.vocab[merged] = idx
102                idx += 1
103
104        # Create reverse mapping
105        self.id_to_token = {v: k for k, v in self.vocab.items()}
106
107    def encode(
108        self,
109        text: str,
110        add_special_tokens: bool = False
111    ) -> List[int]:
112        """
113        Encode text to token IDs.
114
115        Args:
116            text: Input text
117            add_special_tokens: Add BOS/EOS tokens
118
119        Returns:
120            List of token IDs
121        """
122        # Get tokens
123        tokens = encode_text(text.lower(), self.merges)
124
125        # Convert to IDs
126        ids = []
127        if add_special_tokens:
128            ids.append(self.vocab[self.BOS_TOKEN])
129
130        for token in tokens:
131            if token in self.vocab:
132                ids.append(self.vocab[token])
133            else:
134                ids.append(self.vocab[self.UNK_TOKEN])
135
136        if add_special_tokens:
137            ids.append(self.vocab[self.EOS_TOKEN])
138
139        return ids
140
141    def decode(self, ids: List[int], skip_special: bool = True) -> str:
142        """
143        Decode token IDs back to text.
144
145        Args:
146            ids: List of token IDs
147            skip_special: Skip special tokens in output
148
149        Returns:
150            Decoded text
151        """
152        special = {self.PAD_TOKEN, self.UNK_TOKEN,
153                   self.BOS_TOKEN, self.EOS_TOKEN}
154
155        tokens = []
156        for id_ in ids:
157            token = self.id_to_token.get(id_, self.UNK_TOKEN)
158            if skip_special and token in special:
159                continue
160            tokens.append(token)
161
162        return decode_tokens(tokens)
163
164    def get_vocab_size(self) -> int:
165        """Return vocabulary size."""
166        return len(self.vocab)
167
168    @property
169    def pad_token_id(self) -> int:
170        return self.vocab[self.PAD_TOKEN]
171
172    @property
173    def unk_token_id(self) -> int:
174        return self.vocab[self.UNK_TOKEN]
175
176    @property
177    def bos_token_id(self) -> int:
178        return self.vocab[self.BOS_TOKEN]
179
180    @property
181    def eos_token_id(self) -> int:
182        return self.vocab[self.EOS_TOKEN]
183
184    def save(self, path: str) -> None:
185        """Save tokenizer to file."""
186        data = {
187            "merges": self.merges,
188            "vocab": self.vocab
189        }
190        with open(path, "w") as f:
191            json.dump(data, f)
192
193    @classmethod
194    def load(cls, path: str) -> "BPETokenizer":
195        """Load tokenizer from file."""
196        with open(path, "r") as f:
197            data = json.load(f)
198
199        tokenizer = cls()
200        tokenizer.merges = [tuple(m) for m in data["merges"]]
201        tokenizer.vocab = data["vocab"]
202        tokenizer.id_to_token = {v: k for k, v in tokenizer.vocab.items()}
203
204        return tokenizer
205
206    def __repr__(self) -> str:
207        return f"BPETokenizer(vocab_size={len(self.vocab)}, num_merges={len(self.merges)})"

3.8 Performance Considerations

Efficiency of Our Implementation

Our educational implementation has:

  • Time complexity: O(N × M × L) where N = corpus size, M = num merges, L = avg word length
  • Space complexity: O(V) where V = vocabulary size

Production Optimizations

Real implementations use:

🐍python
1# 1. Caching pair positions for faster updates
2class OptimizedBPE:
3    def __init__(self):
4        self.pair_positions = defaultdict(set)  # pair -> {(word_idx, position)}
5
6    def get_pair_frequencies(self):
7        # O(1) lookup instead of scanning all words
8        return {pair: len(positions)
9                for pair, positions in self.pair_positions.items()}
10
11# 2. Priority queue for finding most frequent pair
12import heapq
13
14class PriorityQueueBPE:
15    def __init__(self):
16        self.heap = []  # (-freq, pair)
17
18    def get_most_frequent(self):
19        return heapq.heappop(self.heap)[1]
20
21# 3. Byte-level BPE (GPT-2 style)
22# Works on bytes instead of characters
23# Guarantees all UTF-8 text can be encoded

When to Use Our Implementation

Use educational implementation for:

  • Understanding the algorithm
  • Small experiments
  • Learning/teaching

Use production libraries for:

  • Real training data (millions of sentences)
  • Production models
  • Multilingual tokenization

Summary

Components Built

ComponentPurpose
get_word_repr()Convert word to character tuple
build_vocab()Initialize vocabulary
get_pair_frequencies()Count adjacent pairs
merge_pair()Apply merge to vocabulary
train_bpe()Learn merge rules
encode_word()Encode single word
encode_text()Encode full text
decode_tokens()Convert tokens to text
BPETokenizerComplete tokenizer class

Key Takeaways

  1. BPE starts with characters and iteratively merges
  2. Merge order matters - apply in learned order
  3. End-of-word marker preserves word boundaries
  4. Our implementation is educational, not optimized
  5. Real models use libraries like SentencePiece

Exercises

Implementation Exercises

  1. Modify BPETokenizer to support case-sensitive encoding.
  2. Add a tokenize() method that returns tokens instead of IDs.
  3. Implement dropout BPE: randomly skip some merges during training for robustness.

Analysis Exercises

  1. Train BPE with different vocabulary sizes (100, 500, 1000, 5000) on the same corpus. Compare average tokens per word.
  2. Create a visualization showing how a word gets progressively merged during encoding.
  3. Compare encoding of "playing", "played", "plays" - are the root morphemes shared?

Extension Exercises

  1. Implement WordPiece: use likelihood ratio instead of frequency for merge selection.
  2. Add support for special tokens that should never be split (like URLs or emails).

Next Section Preview

In the next section, we'll transition from our educational implementation to SentencePiece, a production-grade library. We'll learn how to train tokenizers for our German-English translation project and integrate them with PyTorch datasets.