Introduction
Now we bring everything together: token embeddings, positional encoding, scaling, and dropout into a single, reusable TransformerEmbedding module. This is the complete input processing pipeline that transforms raw token indices into the representations fed to transformer layers.
5.1 The Complete Pipeline
Data Flow
1Token IDs: [batch, seq_len]
2 β
3Token Embedding (lookup)
4 β
5[batch, seq_len, d_model]
6 β
7Scale by βd_model
8 β
9Add Positional Encoding
10 β
11Dropout
12 β
13Output: [batch, seq_len, d_model]Why This Order?
1. Token embedding first: Must convert discrete to continuous
2. Scale before adding: Ensures proper magnitude balance
3. Add positional encoding: Injects position information
4. Dropout last: Regularizes the combined representation
5.2 Complete Implementation
1import torch
2import torch.nn as nn
3import math
4from typing import Optional, Literal
5
6
7class TransformerEmbedding(nn.Module):
8 """
9 Complete embedding layer for Transformers.
10
11 Combines:
12 - Token embeddings (learnable lookup table)
13 - Positional encoding (sinusoidal or learned)
14 - Scaling by βd_model
15 - Dropout for regularization
16
17 Args:
18 vocab_size: Size of token vocabulary
19 d_model: Model dimension (embedding size)
20 max_seq_len: Maximum sequence length
21 dropout: Dropout probability
22 padding_idx: Index of padding token
23 pos_encoding: Type of positional encoding ('sinusoidal' or 'learned')
24 scale_embedding: Whether to scale token embeddings by βd_model
25
26 Example:
27 >>> emb = TransformerEmbedding(
28 ... vocab_size=10000,
29 ... d_model=512,
30 ... max_seq_len=1000,
31 ... dropout=0.1
32 ... )
33 >>> input_ids = torch.randint(0, 10000, (2, 50))
34 >>> output = emb(input_ids) # [2, 50, 512]
35 """
36
37 def __init__(
38 self,
39 vocab_size: int,
40 d_model: int,
41 max_seq_len: int = 5000,
42 dropout: float = 0.1,
43 padding_idx: int = 0,
44 pos_encoding: Literal['sinusoidal', 'learned'] = 'sinusoidal',
45 scale_embedding: bool = True
46 ):
47 super().__init__()
48
49 self.vocab_size = vocab_size
50 self.d_model = d_model
51 self.max_seq_len = max_seq_len
52 self.scale_embedding = scale_embedding
53 self.scale_factor = math.sqrt(d_model) if scale_embedding else 1.0
54
55 # Token embedding
56 self.token_embedding = nn.Embedding(
57 num_embeddings=vocab_size,
58 embedding_dim=d_model,
59 padding_idx=padding_idx
60 )
61
62 # Positional encoding
63 if pos_encoding == 'sinusoidal':
64 self.positional_encoding = SinusoidalPositionalEncoding(
65 d_model=d_model,
66 max_len=max_seq_len,
67 dropout=0.0 # We apply dropout after combining
68 )
69 elif pos_encoding == 'learned':
70 self.positional_encoding = LearnedPositionalEncoding(
71 max_seq_len=max_seq_len,
72 d_model=d_model,
73 dropout=0.0
74 )
75 else:
76 raise ValueError(f"Unknown pos_encoding: {pos_encoding}")
77
78 # Dropout
79 self.dropout = nn.Dropout(p=dropout)
80
81 # Initialize
82 self._reset_parameters()
83
84 def _reset_parameters(self):
85 """Initialize embedding weights."""
86 nn.init.normal_(self.token_embedding.weight, mean=0.0, std=self.d_model ** -0.5)
87
88 # Keep padding as zero
89 if self.token_embedding.padding_idx is not None:
90 with torch.no_grad():
91 self.token_embedding.weight[self.token_embedding.padding_idx].fill_(0)
92
93 def forward(
94 self,
95 input_ids: torch.Tensor,
96 position_ids: Optional[torch.Tensor] = None
97 ) -> torch.Tensor:
98 """
99 Create embeddings from token indices.
100
101 Args:
102 input_ids: Token indices [batch, seq_len]
103 position_ids: Optional custom position indices [batch, seq_len]
104
105 Returns:
106 embeddings: [batch, seq_len, d_model]
107 """
108 # Get token embeddings
109 x = self.token_embedding(input_ids) # [batch, seq_len, d_model]
110
111 # Scale embeddings
112 if self.scale_embedding:
113 x = x * self.scale_factor
114
115 # Add positional encoding
116 if isinstance(self.positional_encoding, LearnedPositionalEncoding):
117 x = self.positional_encoding(x, position_ids)
118 else:
119 x = self.positional_encoding(x)
120
121 # Apply dropout
122 x = self.dropout(x)
123
124 return x
125
126 def extra_repr(self) -> str:
127 pos_type = type(self.positional_encoding).__name__
128 return (f'vocab_size={self.vocab_size}, d_model={self.d_model}, '
129 f'max_seq_len={self.max_seq_len}, pos_encoding={pos_type}')
130
131
132# Supporting classes (from previous sections)
133
134class SinusoidalPositionalEncoding(nn.Module):
135 """Sinusoidal positional encoding."""
136
137 def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.0):
138 super().__init__()
139 self.dropout = nn.Dropout(p=dropout)
140
141 pe = torch.zeros(max_len, d_model)
142 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
143 div_term = torch.exp(
144 torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
145 )
146 pe[:, 0::2] = torch.sin(position * div_term)
147 pe[:, 1::2] = torch.cos(position * div_term)
148 pe = pe.unsqueeze(0) # [1, max_len, d_model]
149
150 self.register_buffer('pe', pe)
151
152 def forward(self, x: torch.Tensor) -> torch.Tensor:
153 x = x + self.pe[:, :x.size(1), :]
154 return self.dropout(x)
155
156
157class LearnedPositionalEncoding(nn.Module):
158 """Learned positional encoding."""
159
160 def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.0):
161 super().__init__()
162 self.position_embedding = nn.Embedding(max_seq_len, d_model)
163 self.dropout = nn.Dropout(p=dropout)
164 nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
165
166 def forward(
167 self,
168 x: torch.Tensor,
169 position_ids: Optional[torch.Tensor] = None
170 ) -> torch.Tensor:
171 if position_ids is None:
172 position_ids = torch.arange(x.size(1), device=x.device)
173 pos_emb = self.position_embedding(position_ids)
174 x = x + pos_emb
175 return self.dropout(x)5.3 Testing the Combined Layer
1def test_transformer_embedding():
2 """Comprehensive tests for TransformerEmbedding."""
3
4 print("Testing TransformerEmbedding...")
5 print("-" * 50)
6
7 vocab_size = 10000
8 d_model = 512
9 max_seq_len = 1000
10 batch_size = 2
11 seq_len = 50
12
13 # Test with sinusoidal PE
14 print("\n1. Sinusoidal Positional Encoding:")
15 emb_sin = TransformerEmbedding(
16 vocab_size=vocab_size,
17 d_model=d_model,
18 max_seq_len=max_seq_len,
19 dropout=0.0,
20 pos_encoding='sinusoidal'
21 )
22
23 input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
24 output = emb_sin(input_ids)
25
26 print(f" Input shape: {input_ids.shape}")
27 print(f" Output shape: {output.shape}")
28 assert output.shape == (batch_size, seq_len, d_model)
29 print(" β Shape test passed")
30
31 # Test with learned PE
32 print("\n2. Learned Positional Encoding:")
33 emb_learned = TransformerEmbedding(
34 vocab_size=vocab_size,
35 d_model=d_model,
36 max_seq_len=max_seq_len,
37 dropout=0.0,
38 pos_encoding='learned'
39 )
40
41 output_learned = emb_learned(input_ids)
42 print(f" Output shape: {output_learned.shape}")
43 assert output_learned.shape == (batch_size, seq_len, d_model)
44 print(" β Shape test passed")
45
46 # Test padding handling
47 print("\n3. Padding Handling:")
48 input_with_pad = torch.randint(1, vocab_size, (1, 10))
49 input_with_pad[0, -3:] = 0 # Last 3 tokens are padding
50
51 output_pad = emb_sin(input_with_pad)
52 padding_emb_norm = output_pad[0, -1].norm().item()
53
54 print(f" Padding embedding norm: {padding_emb_norm:.6f}")
55 # Note: With PE added, padding won't be exactly zero
56 # But token embedding contribution is zero
57 print(" β Padding test passed")
58
59 # Test gradient flow
60 print("\n4. Gradient Flow:")
61 emb_sin_grad = TransformerEmbedding(
62 vocab_size=vocab_size,
63 d_model=d_model,
64 max_seq_len=max_seq_len,
65 dropout=0.1,
66 pos_encoding='sinusoidal'
67 )
68 emb_sin_grad.train()
69
70 input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
71 output = emb_sin_grad(input_ids)
72 loss = output.sum()
73 loss.backward()
74
75 grad_norm = emb_sin_grad.token_embedding.weight.grad.norm().item()
76 print(f" Gradient norm: {grad_norm:.4f}")
77 assert grad_norm > 0, "Gradients should be non-zero"
78 print(" β Gradient test passed")
79
80 # Test different sequence lengths
81 print("\n5. Variable Sequence Lengths:")
82 for length in [10, 100, 500]:
83 input_ids = torch.randint(1, vocab_size, (1, length))
84 output = emb_sin(input_ids)
85 assert output.shape == (1, length, d_model)
86 print(f" Length {length}: β")
87
88 # Parameter count
89 print("\n6. Parameter Count:")
90 total_params_sin = sum(p.numel() for p in emb_sin.parameters())
91 total_params_learned = sum(p.numel() for p in emb_learned.parameters())
92 print(f" Sinusoidal PE: {total_params_sin:,} parameters")
93 print(f" Learned PE: {total_params_learned:,} parameters")
94 print(f" Difference: {total_params_learned - total_params_sin:,} (from learned positions)")
95
96 print("\n" + "-" * 50)
97 print("All tests passed! β")
98
99
100test_transformer_embedding()5.4 Source and Target Embeddings
For encoder-decoder models like our translation system, we need embeddings for both languages:
1class TranslationEmbeddings(nn.Module):
2 """
3 Embedding layers for encoder-decoder translation models.
4
5 Creates separate embeddings for source and target languages,
6 with optional sharing.
7 """
8
9 def __init__(
10 self,
11 src_vocab_size: int,
12 tgt_vocab_size: int,
13 d_model: int,
14 max_seq_len: int = 5000,
15 dropout: float = 0.1,
16 src_padding_idx: int = 0,
17 tgt_padding_idx: int = 0,
18 share_embeddings: bool = False,
19 pos_encoding: str = 'sinusoidal'
20 ):
21 super().__init__()
22
23 self.share_embeddings = share_embeddings
24
25 # Source embeddings
26 self.src_embedding = TransformerEmbedding(
27 vocab_size=src_vocab_size,
28 d_model=d_model,
29 max_seq_len=max_seq_len,
30 dropout=dropout,
31 padding_idx=src_padding_idx,
32 pos_encoding=pos_encoding
33 )
34
35 # Target embeddings (shared or separate)
36 if share_embeddings:
37 assert src_vocab_size == tgt_vocab_size, \
38 "Vocabularies must match for shared embeddings"
39 self.tgt_embedding = self.src_embedding
40 else:
41 self.tgt_embedding = TransformerEmbedding(
42 vocab_size=tgt_vocab_size,
43 d_model=d_model,
44 max_seq_len=max_seq_len,
45 dropout=dropout,
46 padding_idx=tgt_padding_idx,
47 pos_encoding=pos_encoding
48 )
49
50 def encode_source(self, src_ids: torch.Tensor) -> torch.Tensor:
51 """Embed source sequence."""
52 return self.src_embedding(src_ids)
53
54 def encode_target(self, tgt_ids: torch.Tensor) -> torch.Tensor:
55 """Embed target sequence."""
56 return self.tgt_embedding(tgt_ids)
57
58
59# Example usage
60def demo_translation_embeddings():
61 emb = TranslationEmbeddings(
62 src_vocab_size=8000, # German vocabulary
63 tgt_vocab_size=10000, # English vocabulary
64 d_model=512,
65 max_seq_len=128,
66 dropout=0.1,
67 share_embeddings=False
68 )
69
70 # Source (German)
71 src_ids = torch.randint(1, 8000, (2, 20))
72 src_emb = emb.encode_source(src_ids)
73 print(f"Source embedding: {src_emb.shape}") # [2, 20, 512]
74
75 # Target (English)
76 tgt_ids = torch.randint(1, 10000, (2, 25))
77 tgt_emb = emb.encode_target(tgt_ids)
78 print(f"Target embedding: {tgt_emb.shape}") # [2, 25, 512]
79
80
81demo_translation_embeddings()5.5 Embedding with Tied Output Weights
For language models and decoder output:
1class EmbeddingWithTiedOutput(nn.Module):
2 """
3 Embedding layer with output projection that shares weights.
4
5 Used for:
6 - Language models (GPT-style)
7 - Decoder output in seq2seq models
8 """
9
10 def __init__(
11 self,
12 vocab_size: int,
13 d_model: int,
14 max_seq_len: int = 5000,
15 dropout: float = 0.1,
16 padding_idx: int = 0,
17 pos_encoding: str = 'sinusoidal'
18 ):
19 super().__init__()
20
21 self.d_model = d_model
22 self.scale_factor = math.sqrt(d_model)
23
24 # Token embedding
25 self.token_embedding = nn.Embedding(
26 vocab_size, d_model, padding_idx=padding_idx
27 )
28
29 # Positional encoding
30 if pos_encoding == 'sinusoidal':
31 pe = torch.zeros(max_seq_len, d_model)
32 position = torch.arange(0, max_seq_len).unsqueeze(1).float()
33 div_term = torch.exp(
34 torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
35 )
36 pe[:, 0::2] = torch.sin(position * div_term)
37 pe[:, 1::2] = torch.cos(position * div_term)
38 self.register_buffer('pe', pe.unsqueeze(0))
39 else:
40 self.position_embedding = nn.Embedding(max_seq_len, d_model)
41
42 self.pos_encoding = pos_encoding
43 self.dropout = nn.Dropout(dropout)
44
45 # Output projection (shares weights with token embedding)
46 self.output_projection = nn.Linear(d_model, vocab_size, bias=False)
47 self.output_projection.weight = self.token_embedding.weight # Tie weights!
48
49 def embed(self, input_ids: torch.Tensor) -> torch.Tensor:
50 """Convert token IDs to embeddings."""
51 x = self.token_embedding(input_ids) * self.scale_factor
52
53 if self.pos_encoding == 'sinusoidal':
54 x = x + self.pe[:, :x.size(1), :]
55 else:
56 positions = torch.arange(x.size(1), device=x.device)
57 x = x + self.position_embedding(positions)
58
59 return self.dropout(x)
60
61 def project(self, hidden_states: torch.Tensor) -> torch.Tensor:
62 """Project hidden states to vocabulary logits."""
63 return self.output_projection(hidden_states)
64
65
66# Test weight tying
67def test_weight_tying():
68 model = EmbeddingWithTiedOutput(vocab_size=1000, d_model=256)
69
70 # Verify weights are the same object
71 assert model.token_embedding.weight is model.output_projection.weight
72 print("β Weights are tied (same object)")
73
74 # Test forward pass
75 input_ids = torch.randint(0, 1000, (2, 10))
76 embeddings = model.embed(input_ids)
77 logits = model.project(embeddings)
78
79 print(f"Embeddings shape: {embeddings.shape}") # [2, 10, 256]
80 print(f"Logits shape: {logits.shape}") # [2, 10, 1000]
81
82
83test_weight_tying()5.6 Complete Module for Course Project
Here's the final embedding module we'll use in our translation model:
1class Seq2SeqEmbedding(nn.Module):
2 """
3 Complete embedding module for sequence-to-sequence models.
4
5 Features:
6 - Separate source and target embeddings
7 - Sinusoidal positional encoding
8 - βd_model scaling
9 - Dropout regularization
10 - Optional weight tying for output
11 """
12
13 def __init__(
14 self,
15 src_vocab_size: int,
16 tgt_vocab_size: int,
17 d_model: int,
18 max_seq_len: int = 5000,
19 dropout: float = 0.1,
20 src_padding_idx: int = 0,
21 tgt_padding_idx: int = 0
22 ):
23 super().__init__()
24
25 self.d_model = d_model
26 self.scale = math.sqrt(d_model)
27
28 # Token embeddings
29 self.src_token_emb = nn.Embedding(
30 src_vocab_size, d_model, padding_idx=src_padding_idx
31 )
32 self.tgt_token_emb = nn.Embedding(
33 tgt_vocab_size, d_model, padding_idx=tgt_padding_idx
34 )
35
36 # Shared positional encoding
37 pe = self._create_positional_encoding(max_seq_len, d_model)
38 self.register_buffer('pe', pe)
39
40 # Dropout
41 self.dropout = nn.Dropout(dropout)
42
43 # Initialize
44 self._init_weights()
45
46 def _create_positional_encoding(self, max_len: int, d_model: int) -> torch.Tensor:
47 """Create sinusoidal positional encoding."""
48 pe = torch.zeros(max_len, d_model)
49 position = torch.arange(0, max_len).unsqueeze(1).float()
50 div_term = torch.exp(
51 torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
52 )
53 pe[:, 0::2] = torch.sin(position * div_term)
54 pe[:, 1::2] = torch.cos(position * div_term)
55 return pe.unsqueeze(0) # [1, max_len, d_model]
56
57 def _init_weights(self):
58 """Initialize embedding weights."""
59 nn.init.normal_(self.src_token_emb.weight, std=self.d_model ** -0.5)
60 nn.init.normal_(self.tgt_token_emb.weight, std=self.d_model ** -0.5)
61
62 # Zero out padding
63 with torch.no_grad():
64 if self.src_token_emb.padding_idx is not None:
65 self.src_token_emb.weight[self.src_token_emb.padding_idx].fill_(0)
66 if self.tgt_token_emb.padding_idx is not None:
67 self.tgt_token_emb.weight[self.tgt_token_emb.padding_idx].fill_(0)
68
69 def encode_source(self, src: torch.Tensor) -> torch.Tensor:
70 """
71 Embed source sequence.
72
73 Args:
74 src: Source token IDs [batch, src_len]
75
76 Returns:
77 Source embeddings [batch, src_len, d_model]
78 """
79 x = self.src_token_emb(src) * self.scale
80 x = x + self.pe[:, :x.size(1), :]
81 return self.dropout(x)
82
83 def encode_target(self, tgt: torch.Tensor) -> torch.Tensor:
84 """
85 Embed target sequence.
86
87 Args:
88 tgt: Target token IDs [batch, tgt_len]
89
90 Returns:
91 Target embeddings [batch, tgt_len, d_model]
92 """
93 x = self.tgt_token_emb(tgt) * self.scale
94 x = x + self.pe[:, :x.size(1), :]
95 return self.dropout(x)
96
97
98# Final verification
99def final_test():
100 print("Final Integration Test")
101 print("=" * 50)
102
103 emb = Seq2SeqEmbedding(
104 src_vocab_size=8000, # German
105 tgt_vocab_size=10000, # English
106 d_model=512,
107 max_seq_len=128,
108 dropout=0.1
109 )
110
111 # Simulate translation batch
112 src = torch.randint(1, 8000, (4, 30)) # 4 sentences, 30 tokens
113 tgt = torch.randint(1, 10000, (4, 25)) # 4 sentences, 25 tokens
114
115 src_emb = emb.encode_source(src)
116 tgt_emb = emb.encode_target(tgt)
117
118 print(f"Source: {src.shape} β {src_emb.shape}")
119 print(f"Target: {tgt.shape} β {tgt_emb.shape}")
120 print(f"\nTotal parameters: {sum(p.numel() for p in emb.parameters()):,}")
121
122 print("\nβ Ready for transformer layers!")
123
124
125final_test()Summary
Complete Input Pipeline
1Token IDs [batch, seq_len]
2 β
3nn.Embedding (token β vector)
4 β
5Γ βd_model (scaling)
6 β
7+ Positional Encoding
8 β
9Dropout
10 β
11Embeddings [batch, seq_len, d_model]Key Implementation Points
1. Separate embeddings for source and target in translation
2. Scale by βd_model before adding positional encoding
3. Sinusoidal PE for our course project (generalizes better)
4. Zero padding embeddings via padding_idx
5. Dropout applied after combining token + position
Chapter Summary
In this chapter, you learned:
1. The Position Problem: Attention is permutation invariant, needs explicit position info
2. Sinusoidal Encoding: Mathematical elegance, relative position property, no learned parameters
3. Learned Embeddings: Task-specific, but limited to max_seq_len
4. Token Embeddings: Vocabulary β dense vectors via nn.Embedding
5. Combined Layer: Complete pipeline ready for transformer
You now have all the components to create input representations for transformers!
Next Chapter Preview
In Chapter 5: Subword Tokenization for Translation, we'll learn how to prepare text for our embeddings using Byte-Pair Encoding (BPE)βthe tokenization method that handles multiple languages and rare words gracefully.