Chapter 4
20 min read
Section 23 of 75

Combined Embedding Layer

Positional Encoding and Embeddings

Introduction

Now we bring everything together: token embeddings, positional encoding, scaling, and dropout into a single, reusable TransformerEmbedding module. This is the complete input processing pipeline that transforms raw token indices into the representations fed to transformer layers.


5.1 The Complete Pipeline

Data Flow

πŸ“text
1Token IDs: [batch, seq_len]
2     ↓
3Token Embedding (lookup)
4     ↓
5[batch, seq_len, d_model]
6     ↓
7Scale by √d_model
8     ↓
9Add Positional Encoding
10     ↓
11Dropout
12     ↓
13Output: [batch, seq_len, d_model]

Why This Order?

1. Token embedding first: Must convert discrete to continuous

2. Scale before adding: Ensures proper magnitude balance

3. Add positional encoding: Injects position information

4. Dropout last: Regularizes the combined representation


5.2 Complete Implementation

🐍python
1import torch
2import torch.nn as nn
3import math
4from typing import Optional, Literal
5
6
7class TransformerEmbedding(nn.Module):
8    """
9    Complete embedding layer for Transformers.
10
11    Combines:
12    - Token embeddings (learnable lookup table)
13    - Positional encoding (sinusoidal or learned)
14    - Scaling by √d_model
15    - Dropout for regularization
16
17    Args:
18        vocab_size: Size of token vocabulary
19        d_model: Model dimension (embedding size)
20        max_seq_len: Maximum sequence length
21        dropout: Dropout probability
22        padding_idx: Index of padding token
23        pos_encoding: Type of positional encoding ('sinusoidal' or 'learned')
24        scale_embedding: Whether to scale token embeddings by √d_model
25
26    Example:
27        >>> emb = TransformerEmbedding(
28        ...     vocab_size=10000,
29        ...     d_model=512,
30        ...     max_seq_len=1000,
31        ...     dropout=0.1
32        ... )
33        >>> input_ids = torch.randint(0, 10000, (2, 50))
34        >>> output = emb(input_ids)  # [2, 50, 512]
35    """
36
37    def __init__(
38        self,
39        vocab_size: int,
40        d_model: int,
41        max_seq_len: int = 5000,
42        dropout: float = 0.1,
43        padding_idx: int = 0,
44        pos_encoding: Literal['sinusoidal', 'learned'] = 'sinusoidal',
45        scale_embedding: bool = True
46    ):
47        super().__init__()
48
49        self.vocab_size = vocab_size
50        self.d_model = d_model
51        self.max_seq_len = max_seq_len
52        self.scale_embedding = scale_embedding
53        self.scale_factor = math.sqrt(d_model) if scale_embedding else 1.0
54
55        # Token embedding
56        self.token_embedding = nn.Embedding(
57            num_embeddings=vocab_size,
58            embedding_dim=d_model,
59            padding_idx=padding_idx
60        )
61
62        # Positional encoding
63        if pos_encoding == 'sinusoidal':
64            self.positional_encoding = SinusoidalPositionalEncoding(
65                d_model=d_model,
66                max_len=max_seq_len,
67                dropout=0.0  # We apply dropout after combining
68            )
69        elif pos_encoding == 'learned':
70            self.positional_encoding = LearnedPositionalEncoding(
71                max_seq_len=max_seq_len,
72                d_model=d_model,
73                dropout=0.0
74            )
75        else:
76            raise ValueError(f"Unknown pos_encoding: {pos_encoding}")
77
78        # Dropout
79        self.dropout = nn.Dropout(p=dropout)
80
81        # Initialize
82        self._reset_parameters()
83
84    def _reset_parameters(self):
85        """Initialize embedding weights."""
86        nn.init.normal_(self.token_embedding.weight, mean=0.0, std=self.d_model ** -0.5)
87
88        # Keep padding as zero
89        if self.token_embedding.padding_idx is not None:
90            with torch.no_grad():
91                self.token_embedding.weight[self.token_embedding.padding_idx].fill_(0)
92
93    def forward(
94        self,
95        input_ids: torch.Tensor,
96        position_ids: Optional[torch.Tensor] = None
97    ) -> torch.Tensor:
98        """
99        Create embeddings from token indices.
100
101        Args:
102            input_ids: Token indices [batch, seq_len]
103            position_ids: Optional custom position indices [batch, seq_len]
104
105        Returns:
106            embeddings: [batch, seq_len, d_model]
107        """
108        # Get token embeddings
109        x = self.token_embedding(input_ids)  # [batch, seq_len, d_model]
110
111        # Scale embeddings
112        if self.scale_embedding:
113            x = x * self.scale_factor
114
115        # Add positional encoding
116        if isinstance(self.positional_encoding, LearnedPositionalEncoding):
117            x = self.positional_encoding(x, position_ids)
118        else:
119            x = self.positional_encoding(x)
120
121        # Apply dropout
122        x = self.dropout(x)
123
124        return x
125
126    def extra_repr(self) -> str:
127        pos_type = type(self.positional_encoding).__name__
128        return (f'vocab_size={self.vocab_size}, d_model={self.d_model}, '
129                f'max_seq_len={self.max_seq_len}, pos_encoding={pos_type}')
130
131
132# Supporting classes (from previous sections)
133
134class SinusoidalPositionalEncoding(nn.Module):
135    """Sinusoidal positional encoding."""
136
137    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.0):
138        super().__init__()
139        self.dropout = nn.Dropout(p=dropout)
140
141        pe = torch.zeros(max_len, d_model)
142        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
143        div_term = torch.exp(
144            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
145        )
146        pe[:, 0::2] = torch.sin(position * div_term)
147        pe[:, 1::2] = torch.cos(position * div_term)
148        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
149
150        self.register_buffer('pe', pe)
151
152    def forward(self, x: torch.Tensor) -> torch.Tensor:
153        x = x + self.pe[:, :x.size(1), :]
154        return self.dropout(x)
155
156
157class LearnedPositionalEncoding(nn.Module):
158    """Learned positional encoding."""
159
160    def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.0):
161        super().__init__()
162        self.position_embedding = nn.Embedding(max_seq_len, d_model)
163        self.dropout = nn.Dropout(p=dropout)
164        nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
165
166    def forward(
167        self,
168        x: torch.Tensor,
169        position_ids: Optional[torch.Tensor] = None
170    ) -> torch.Tensor:
171        if position_ids is None:
172            position_ids = torch.arange(x.size(1), device=x.device)
173        pos_emb = self.position_embedding(position_ids)
174        x = x + pos_emb
175        return self.dropout(x)

5.3 Testing the Combined Layer

🐍python
1def test_transformer_embedding():
2    """Comprehensive tests for TransformerEmbedding."""
3
4    print("Testing TransformerEmbedding...")
5    print("-" * 50)
6
7    vocab_size = 10000
8    d_model = 512
9    max_seq_len = 1000
10    batch_size = 2
11    seq_len = 50
12
13    # Test with sinusoidal PE
14    print("\n1. Sinusoidal Positional Encoding:")
15    emb_sin = TransformerEmbedding(
16        vocab_size=vocab_size,
17        d_model=d_model,
18        max_seq_len=max_seq_len,
19        dropout=0.0,
20        pos_encoding='sinusoidal'
21    )
22
23    input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
24    output = emb_sin(input_ids)
25
26    print(f"   Input shape: {input_ids.shape}")
27    print(f"   Output shape: {output.shape}")
28    assert output.shape == (batch_size, seq_len, d_model)
29    print("   βœ“ Shape test passed")
30
31    # Test with learned PE
32    print("\n2. Learned Positional Encoding:")
33    emb_learned = TransformerEmbedding(
34        vocab_size=vocab_size,
35        d_model=d_model,
36        max_seq_len=max_seq_len,
37        dropout=0.0,
38        pos_encoding='learned'
39    )
40
41    output_learned = emb_learned(input_ids)
42    print(f"   Output shape: {output_learned.shape}")
43    assert output_learned.shape == (batch_size, seq_len, d_model)
44    print("   βœ“ Shape test passed")
45
46    # Test padding handling
47    print("\n3. Padding Handling:")
48    input_with_pad = torch.randint(1, vocab_size, (1, 10))
49    input_with_pad[0, -3:] = 0  # Last 3 tokens are padding
50
51    output_pad = emb_sin(input_with_pad)
52    padding_emb_norm = output_pad[0, -1].norm().item()
53
54    print(f"   Padding embedding norm: {padding_emb_norm:.6f}")
55    # Note: With PE added, padding won't be exactly zero
56    # But token embedding contribution is zero
57    print("   βœ“ Padding test passed")
58
59    # Test gradient flow
60    print("\n4. Gradient Flow:")
61    emb_sin_grad = TransformerEmbedding(
62        vocab_size=vocab_size,
63        d_model=d_model,
64        max_seq_len=max_seq_len,
65        dropout=0.1,
66        pos_encoding='sinusoidal'
67    )
68    emb_sin_grad.train()
69
70    input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
71    output = emb_sin_grad(input_ids)
72    loss = output.sum()
73    loss.backward()
74
75    grad_norm = emb_sin_grad.token_embedding.weight.grad.norm().item()
76    print(f"   Gradient norm: {grad_norm:.4f}")
77    assert grad_norm > 0, "Gradients should be non-zero"
78    print("   βœ“ Gradient test passed")
79
80    # Test different sequence lengths
81    print("\n5. Variable Sequence Lengths:")
82    for length in [10, 100, 500]:
83        input_ids = torch.randint(1, vocab_size, (1, length))
84        output = emb_sin(input_ids)
85        assert output.shape == (1, length, d_model)
86        print(f"   Length {length}: βœ“")
87
88    # Parameter count
89    print("\n6. Parameter Count:")
90    total_params_sin = sum(p.numel() for p in emb_sin.parameters())
91    total_params_learned = sum(p.numel() for p in emb_learned.parameters())
92    print(f"   Sinusoidal PE: {total_params_sin:,} parameters")
93    print(f"   Learned PE: {total_params_learned:,} parameters")
94    print(f"   Difference: {total_params_learned - total_params_sin:,} (from learned positions)")
95
96    print("\n" + "-" * 50)
97    print("All tests passed! βœ“")
98
99
100test_transformer_embedding()

5.4 Source and Target Embeddings

For encoder-decoder models like our translation system, we need embeddings for both languages:

🐍python
1class TranslationEmbeddings(nn.Module):
2    """
3    Embedding layers for encoder-decoder translation models.
4
5    Creates separate embeddings for source and target languages,
6    with optional sharing.
7    """
8
9    def __init__(
10        self,
11        src_vocab_size: int,
12        tgt_vocab_size: int,
13        d_model: int,
14        max_seq_len: int = 5000,
15        dropout: float = 0.1,
16        src_padding_idx: int = 0,
17        tgt_padding_idx: int = 0,
18        share_embeddings: bool = False,
19        pos_encoding: str = 'sinusoidal'
20    ):
21        super().__init__()
22
23        self.share_embeddings = share_embeddings
24
25        # Source embeddings
26        self.src_embedding = TransformerEmbedding(
27            vocab_size=src_vocab_size,
28            d_model=d_model,
29            max_seq_len=max_seq_len,
30            dropout=dropout,
31            padding_idx=src_padding_idx,
32            pos_encoding=pos_encoding
33        )
34
35        # Target embeddings (shared or separate)
36        if share_embeddings:
37            assert src_vocab_size == tgt_vocab_size, \
38                "Vocabularies must match for shared embeddings"
39            self.tgt_embedding = self.src_embedding
40        else:
41            self.tgt_embedding = TransformerEmbedding(
42                vocab_size=tgt_vocab_size,
43                d_model=d_model,
44                max_seq_len=max_seq_len,
45                dropout=dropout,
46                padding_idx=tgt_padding_idx,
47                pos_encoding=pos_encoding
48            )
49
50    def encode_source(self, src_ids: torch.Tensor) -> torch.Tensor:
51        """Embed source sequence."""
52        return self.src_embedding(src_ids)
53
54    def encode_target(self, tgt_ids: torch.Tensor) -> torch.Tensor:
55        """Embed target sequence."""
56        return self.tgt_embedding(tgt_ids)
57
58
59# Example usage
60def demo_translation_embeddings():
61    emb = TranslationEmbeddings(
62        src_vocab_size=8000,   # German vocabulary
63        tgt_vocab_size=10000,  # English vocabulary
64        d_model=512,
65        max_seq_len=128,
66        dropout=0.1,
67        share_embeddings=False
68    )
69
70    # Source (German)
71    src_ids = torch.randint(1, 8000, (2, 20))
72    src_emb = emb.encode_source(src_ids)
73    print(f"Source embedding: {src_emb.shape}")  # [2, 20, 512]
74
75    # Target (English)
76    tgt_ids = torch.randint(1, 10000, (2, 25))
77    tgt_emb = emb.encode_target(tgt_ids)
78    print(f"Target embedding: {tgt_emb.shape}")  # [2, 25, 512]
79
80
81demo_translation_embeddings()

5.5 Embedding with Tied Output Weights

For language models and decoder output:

🐍python
1class EmbeddingWithTiedOutput(nn.Module):
2    """
3    Embedding layer with output projection that shares weights.
4
5    Used for:
6    - Language models (GPT-style)
7    - Decoder output in seq2seq models
8    """
9
10    def __init__(
11        self,
12        vocab_size: int,
13        d_model: int,
14        max_seq_len: int = 5000,
15        dropout: float = 0.1,
16        padding_idx: int = 0,
17        pos_encoding: str = 'sinusoidal'
18    ):
19        super().__init__()
20
21        self.d_model = d_model
22        self.scale_factor = math.sqrt(d_model)
23
24        # Token embedding
25        self.token_embedding = nn.Embedding(
26            vocab_size, d_model, padding_idx=padding_idx
27        )
28
29        # Positional encoding
30        if pos_encoding == 'sinusoidal':
31            pe = torch.zeros(max_seq_len, d_model)
32            position = torch.arange(0, max_seq_len).unsqueeze(1).float()
33            div_term = torch.exp(
34                torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
35            )
36            pe[:, 0::2] = torch.sin(position * div_term)
37            pe[:, 1::2] = torch.cos(position * div_term)
38            self.register_buffer('pe', pe.unsqueeze(0))
39        else:
40            self.position_embedding = nn.Embedding(max_seq_len, d_model)
41
42        self.pos_encoding = pos_encoding
43        self.dropout = nn.Dropout(dropout)
44
45        # Output projection (shares weights with token embedding)
46        self.output_projection = nn.Linear(d_model, vocab_size, bias=False)
47        self.output_projection.weight = self.token_embedding.weight  # Tie weights!
48
49    def embed(self, input_ids: torch.Tensor) -> torch.Tensor:
50        """Convert token IDs to embeddings."""
51        x = self.token_embedding(input_ids) * self.scale_factor
52
53        if self.pos_encoding == 'sinusoidal':
54            x = x + self.pe[:, :x.size(1), :]
55        else:
56            positions = torch.arange(x.size(1), device=x.device)
57            x = x + self.position_embedding(positions)
58
59        return self.dropout(x)
60
61    def project(self, hidden_states: torch.Tensor) -> torch.Tensor:
62        """Project hidden states to vocabulary logits."""
63        return self.output_projection(hidden_states)
64
65
66# Test weight tying
67def test_weight_tying():
68    model = EmbeddingWithTiedOutput(vocab_size=1000, d_model=256)
69
70    # Verify weights are the same object
71    assert model.token_embedding.weight is model.output_projection.weight
72    print("βœ“ Weights are tied (same object)")
73
74    # Test forward pass
75    input_ids = torch.randint(0, 1000, (2, 10))
76    embeddings = model.embed(input_ids)
77    logits = model.project(embeddings)
78
79    print(f"Embeddings shape: {embeddings.shape}")  # [2, 10, 256]
80    print(f"Logits shape: {logits.shape}")          # [2, 10, 1000]
81
82
83test_weight_tying()

5.6 Complete Module for Course Project

Here's the final embedding module we'll use in our translation model:

🐍python
1class Seq2SeqEmbedding(nn.Module):
2    """
3    Complete embedding module for sequence-to-sequence models.
4
5    Features:
6    - Separate source and target embeddings
7    - Sinusoidal positional encoding
8    - √d_model scaling
9    - Dropout regularization
10    - Optional weight tying for output
11    """
12
13    def __init__(
14        self,
15        src_vocab_size: int,
16        tgt_vocab_size: int,
17        d_model: int,
18        max_seq_len: int = 5000,
19        dropout: float = 0.1,
20        src_padding_idx: int = 0,
21        tgt_padding_idx: int = 0
22    ):
23        super().__init__()
24
25        self.d_model = d_model
26        self.scale = math.sqrt(d_model)
27
28        # Token embeddings
29        self.src_token_emb = nn.Embedding(
30            src_vocab_size, d_model, padding_idx=src_padding_idx
31        )
32        self.tgt_token_emb = nn.Embedding(
33            tgt_vocab_size, d_model, padding_idx=tgt_padding_idx
34        )
35
36        # Shared positional encoding
37        pe = self._create_positional_encoding(max_seq_len, d_model)
38        self.register_buffer('pe', pe)
39
40        # Dropout
41        self.dropout = nn.Dropout(dropout)
42
43        # Initialize
44        self._init_weights()
45
46    def _create_positional_encoding(self, max_len: int, d_model: int) -> torch.Tensor:
47        """Create sinusoidal positional encoding."""
48        pe = torch.zeros(max_len, d_model)
49        position = torch.arange(0, max_len).unsqueeze(1).float()
50        div_term = torch.exp(
51            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
52        )
53        pe[:, 0::2] = torch.sin(position * div_term)
54        pe[:, 1::2] = torch.cos(position * div_term)
55        return pe.unsqueeze(0)  # [1, max_len, d_model]
56
57    def _init_weights(self):
58        """Initialize embedding weights."""
59        nn.init.normal_(self.src_token_emb.weight, std=self.d_model ** -0.5)
60        nn.init.normal_(self.tgt_token_emb.weight, std=self.d_model ** -0.5)
61
62        # Zero out padding
63        with torch.no_grad():
64            if self.src_token_emb.padding_idx is not None:
65                self.src_token_emb.weight[self.src_token_emb.padding_idx].fill_(0)
66            if self.tgt_token_emb.padding_idx is not None:
67                self.tgt_token_emb.weight[self.tgt_token_emb.padding_idx].fill_(0)
68
69    def encode_source(self, src: torch.Tensor) -> torch.Tensor:
70        """
71        Embed source sequence.
72
73        Args:
74            src: Source token IDs [batch, src_len]
75
76        Returns:
77            Source embeddings [batch, src_len, d_model]
78        """
79        x = self.src_token_emb(src) * self.scale
80        x = x + self.pe[:, :x.size(1), :]
81        return self.dropout(x)
82
83    def encode_target(self, tgt: torch.Tensor) -> torch.Tensor:
84        """
85        Embed target sequence.
86
87        Args:
88            tgt: Target token IDs [batch, tgt_len]
89
90        Returns:
91            Target embeddings [batch, tgt_len, d_model]
92        """
93        x = self.tgt_token_emb(tgt) * self.scale
94        x = x + self.pe[:, :x.size(1), :]
95        return self.dropout(x)
96
97
98# Final verification
99def final_test():
100    print("Final Integration Test")
101    print("=" * 50)
102
103    emb = Seq2SeqEmbedding(
104        src_vocab_size=8000,   # German
105        tgt_vocab_size=10000,  # English
106        d_model=512,
107        max_seq_len=128,
108        dropout=0.1
109    )
110
111    # Simulate translation batch
112    src = torch.randint(1, 8000, (4, 30))   # 4 sentences, 30 tokens
113    tgt = torch.randint(1, 10000, (4, 25))  # 4 sentences, 25 tokens
114
115    src_emb = emb.encode_source(src)
116    tgt_emb = emb.encode_target(tgt)
117
118    print(f"Source: {src.shape} β†’ {src_emb.shape}")
119    print(f"Target: {tgt.shape} β†’ {tgt_emb.shape}")
120    print(f"\nTotal parameters: {sum(p.numel() for p in emb.parameters()):,}")
121
122    print("\nβœ“ Ready for transformer layers!")
123
124
125final_test()

Summary

Complete Input Pipeline

πŸ“text
1Token IDs [batch, seq_len]
2     ↓
3nn.Embedding (token β†’ vector)
4     ↓
5Γ— √d_model (scaling)
6     ↓
7+ Positional Encoding
8     ↓
9Dropout
10     ↓
11Embeddings [batch, seq_len, d_model]

Key Implementation Points

1. Separate embeddings for source and target in translation

2. Scale by √d_model before adding positional encoding

3. Sinusoidal PE for our course project (generalizes better)

4. Zero padding embeddings via padding_idx

5. Dropout applied after combining token + position


Chapter Summary

In this chapter, you learned:

1. The Position Problem: Attention is permutation invariant, needs explicit position info

2. Sinusoidal Encoding: Mathematical elegance, relative position property, no learned parameters

3. Learned Embeddings: Task-specific, but limited to max_seq_len

4. Token Embeddings: Vocabulary β†’ dense vectors via nn.Embedding

5. Combined Layer: Complete pipeline ready for transformer

You now have all the components to create input representations for transformers!


Next Chapter Preview

In Chapter 5: Subword Tokenization for Translation, we'll learn how to prepare text for our embeddings using Byte-Pair Encoding (BPE)β€”the tokenization method that handles multiple languages and rare words gracefully.