Chapter 4
16 min read
Section 21 of 75

Learned Positional Embeddings

Positional Encoding and Embeddings

Introduction

While sinusoidal positional encoding is mathematically elegant, many modern models (BERT, GPT, etc.) use learned positional embeddings instead. This approach treats positions like vocabulary tokensβ€”each position gets its own learnable vector.

This section covers implementation, trade-offs, and when to choose learned over sinusoidal positions.


3.1 The Concept

Positions as Vocabulary

Just as we look up word embeddings by token ID:

🐍python
1token_embedding = embedding_table[token_id]

We can look up position embeddings by position:

🐍python
1position_embedding = position_table[position]

Visual Comparison

πŸ“text
1Sinusoidal:
2Position 0 β†’ sin/cos formula β†’ [0.00, 1.00, 0.00, 1.00, ...]
3Position 1 β†’ sin/cos formula β†’ [0.84, 0.54, 0.01, 1.00, ...]
4Position 2 β†’ sin/cos formula β†’ [0.91, -0.42, 0.02, 1.00, ...]
5
6Learned:
7Position 0 β†’ lookup table β†’ [p0_0, p0_1, p0_2, p0_3, ...]  (learned)
8Position 1 β†’ lookup table β†’ [p1_0, p1_1, p1_2, p1_3, ...]  (learned)
9Position 2 β†’ lookup table β†’ [p2_0, p2_1, p2_2, p2_3, ...]  (learned)

3.2 Implementation

Basic Implementation

🐍python
1import torch
2import torch.nn as nn
3from typing import Optional
4
5
6class LearnedPositionalEmbedding(nn.Module):
7    """
8    Learned positional embeddings.
9
10    Each position has a learnable embedding vector that is
11    added to the token embeddings.
12
13    Args:
14        max_seq_len: Maximum sequence length
15        d_model: Embedding dimension
16        dropout: Dropout probability
17
18    Example:
19        >>> pe = LearnedPositionalEmbedding(max_seq_len=512, d_model=768)
20        >>> x = torch.randn(2, 100, 768)
21        >>> output = pe(x)  # [2, 100, 768]
22    """
23
24    def __init__(
25        self,
26        max_seq_len: int,
27        d_model: int,
28        dropout: float = 0.1
29    ):
30        super().__init__()
31
32        self.max_seq_len = max_seq_len
33        self.d_model = d_model
34
35        # Learnable position embedding table
36        self.position_embedding = nn.Embedding(max_seq_len, d_model)
37
38        # Dropout
39        self.dropout = nn.Dropout(p=dropout)
40
41        # Initialize embeddings
42        self._reset_parameters()
43
44    def _reset_parameters(self):
45        """Initialize position embeddings."""
46        # Normal initialization (common choice)
47        nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
48
49    def forward(
50        self,
51        x: torch.Tensor,
52        position_ids: Optional[torch.Tensor] = None
53    ) -> torch.Tensor:
54        """
55        Add positional embeddings to input.
56
57        Args:
58            x: Input tensor [batch, seq_len, d_model]
59            position_ids: Optional position indices [batch, seq_len]
60                         If None, uses [0, 1, 2, ..., seq_len-1]
61
62        Returns:
63            Output tensor [batch, seq_len, d_model]
64        """
65        batch_size, seq_len, _ = x.shape
66
67        if seq_len > self.max_seq_len:
68            raise ValueError(
69                f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}"
70            )
71
72        # Create position indices if not provided
73        if position_ids is None:
74            position_ids = torch.arange(seq_len, device=x.device)
75            position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
76
77        # Look up position embeddings
78        pos_emb = self.position_embedding(position_ids)  # [batch, seq_len, d_model]
79
80        # Add to input
81        output = x + pos_emb
82
83        return self.dropout(output)
84
85    def extra_repr(self) -> str:
86        return f'max_seq_len={self.max_seq_len}, d_model={self.d_model}'
87
88
89# Test
90def test_learned_pe():
91    max_seq_len = 512
92    d_model = 768
93    batch_size = 2
94    seq_len = 100
95
96    pe = LearnedPositionalEmbedding(max_seq_len, d_model, dropout=0.0)
97
98    x = torch.randn(batch_size, seq_len, d_model)
99    output = pe(x)
100
101    print(f"Input shape: {x.shape}")
102    print(f"Output shape: {output.shape}")
103    print(f"Position embedding table shape: {pe.position_embedding.weight.shape}")
104    print(f"Number of parameters: {sum(p.numel() for p in pe.parameters()):,}")
105
106    # Test with custom position_ids
107    custom_pos = torch.tensor([[0, 2, 4, 6, 8]])  # Non-contiguous positions
108    x_small = torch.randn(1, 5, d_model)
109    output_custom = pe(x_small, position_ids=custom_pos)
110    print(f"\nCustom positions output shape: {output_custom.shape}")
111
112    print("\nβœ“ Learned PE test passed!")
113
114
115test_learned_pe()

3.3 Initialization Strategies

Normal Initialization (Common)

🐍python
1nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)

Used by: GPT-2, BERT

Xavier Initialization

🐍python
1nn.init.xavier_uniform_(self.position_embedding.weight)

Alternative for different variance scaling.

Zero Initialization

🐍python
1nn.init.zeros_(self.position_embedding.weight)

Start neutral, let model learn from scratch.

Sinusoidal Initialization

Initialize learned embeddings with sinusoidal values, then fine-tune:

🐍python
1def init_with_sinusoidal(self):
2    """Initialize with sinusoidal values, then make learnable."""
3    position = torch.arange(self.max_seq_len).unsqueeze(1).float()
4    div_term = torch.exp(
5        torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model)
6    )
7
8    with torch.no_grad():
9        self.position_embedding.weight[:, 0::2] = torch.sin(position * div_term)
10        self.position_embedding.weight[:, 1::2] = torch.cos(position * div_term)

This gives the benefits of sinusoidal encoding as a starting point, with ability to adapt.


3.4 Handling Sequence Length

The Max Length Constraint

Learned embeddings have a fixed vocabulary:

🐍python
1position_embedding = nn.Embedding(max_seq_len, d_model)
2# Can only handle positions 0 to max_seq_len-1

Problem: Can't handle sequences longer than max_seq_len!

Solutions

1. Choose Large max_seq_len

🐍python
1# BERT uses 512
2# GPT-2 uses 1024
3# GPT-3 uses 2048
4# Modern models: 4096, 8192, or more

2. Position Interpolation

Scale position indices to fit within max_seq_len:

🐍python
1def interpolate_positions(self, seq_len):
2    """Interpolate for longer sequences."""
3    if seq_len <= self.max_seq_len:
4        return torch.arange(seq_len)
5
6    # Scale positions to [0, max_seq_len-1]
7    scale = self.max_seq_len / seq_len
8    positions = torch.arange(seq_len) * scale
9    return positions.long().clamp(0, self.max_seq_len - 1)

3. Rotary Position Embedding (RoPE)

Modern alternative that naturally extends to longer sequences.


3.5 BERT-Style Position Embedding

BERT uses learned positions with specific features:

🐍python
1class BERTPositionEmbedding(nn.Module):
2    """
3    BERT-style position embedding with segment embeddings.
4    """
5
6    def __init__(
7        self,
8        vocab_size: int,
9        max_seq_len: int,
10        d_model: int,
11        num_segments: int = 2,
12        dropout: float = 0.1,
13        padding_idx: int = 0
14    ):
15        super().__init__()
16
17        # Token embeddings
18        self.token_embedding = nn.Embedding(
19            vocab_size, d_model, padding_idx=padding_idx
20        )
21
22        # Position embeddings
23        self.position_embedding = nn.Embedding(max_seq_len, d_model)
24
25        # Segment embeddings (sentence A vs B)
26        self.segment_embedding = nn.Embedding(num_segments, d_model)
27
28        # Layer normalization
29        self.layer_norm = nn.LayerNorm(d_model, eps=1e-12)
30
31        # Dropout
32        self.dropout = nn.Dropout(dropout)
33
34    def forward(self, input_ids, segment_ids=None, position_ids=None):
35        """
36        Args:
37            input_ids: [batch, seq_len]
38            segment_ids: [batch, seq_len], 0 for sentence A, 1 for sentence B
39            position_ids: [batch, seq_len], optional
40        """
41        batch_size, seq_len = input_ids.shape
42        device = input_ids.device
43
44        # Default positions
45        if position_ids is None:
46            position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
47
48        # Default segments (all 0)
49        if segment_ids is None:
50            segment_ids = torch.zeros_like(input_ids)
51
52        # Get embeddings
53        token_emb = self.token_embedding(input_ids)
54        position_emb = self.position_embedding(position_ids)
55        segment_emb = self.segment_embedding(segment_ids)
56
57        # Combine
58        embeddings = token_emb + position_emb + segment_emb
59
60        # Normalize and dropout
61        embeddings = self.layer_norm(embeddings)
62        embeddings = self.dropout(embeddings)
63
64        return embeddings

3.6 GPT-Style Position Embedding

GPT models use simpler position embeddings:

🐍python
1class GPTPositionEmbedding(nn.Module):
2    """
3    GPT-style embeddings (tokens + positions only).
4    """
5
6    def __init__(
7        self,
8        vocab_size: int,
9        max_seq_len: int,
10        d_model: int,
11        dropout: float = 0.1
12    ):
13        super().__init__()
14
15        self.token_embedding = nn.Embedding(vocab_size, d_model)
16        self.position_embedding = nn.Embedding(max_seq_len, d_model)
17        self.dropout = nn.Dropout(dropout)
18
19        # Initialize
20        self._init_weights()
21
22    def _init_weights(self):
23        nn.init.normal_(self.token_embedding.weight, std=0.02)
24        nn.init.normal_(self.position_embedding.weight, std=0.02)
25
26    def forward(self, input_ids, position_ids=None):
27        batch_size, seq_len = input_ids.shape
28        device = input_ids.device
29
30        if position_ids is None:
31            position_ids = torch.arange(seq_len, device=device)
32
33        token_emb = self.token_embedding(input_ids)
34        position_emb = self.position_embedding(position_ids)
35
36        # GPT adds embeddings (no layer norm here)
37        embeddings = token_emb + position_emb
38        embeddings = self.dropout(embeddings)
39
40        return embeddings

3.7 Comparison: Sinusoidal vs Learned

Parameter Count

ApproachParameters
Sinusoidal0 (computed)
Learned (512 pos, 768 dim)393,216
Learned (2048 pos, 1024 dim)2,097,152

Performance Comparison

Research findings (varies by task):

TaskSinusoidalLearnedWinner
Machine TranslationGoodGoodTie
Language ModelingGoodSlightly betterLearned
ClassificationGoodGoodTie
Long sequencesGoodStrugglesSinusoidal

When to Use Each

Use Sinusoidal When:

- Need to handle variable/long sequences

- Want fewer parameters

- Mathematical properties matter (relative position)

- Extrapolation to unseen lengths needed

Use Learned When:

- Fixed maximum sequence length is acceptable

- Task-specific position patterns might help

- Following established architecture (BERT, GPT)

- Have enough data to learn positions


3.8 Visualization

Visualize Learned Embeddings

🐍python
1import matplotlib.pyplot as plt
2
3
4def visualize_learned_embeddings(pe_module, num_positions=50):
5    """Visualize learned position embeddings."""
6
7    weights = pe_module.position_embedding.weight.detach()[:num_positions]
8
9    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
10
11    # Heatmap
12    ax1 = axes[0]
13    im = ax1.imshow(weights.numpy(), aspect='auto', cmap='RdBu')
14    ax1.set_xlabel('Dimension')
15    ax1.set_ylabel('Position')
16    ax1.set_title('Learned Position Embeddings')
17    plt.colorbar(im, ax=ax1)
18
19    # Position similarity
20    ax2 = axes[1]
21    # Cosine similarity
22    weights_norm = weights / weights.norm(dim=1, keepdim=True)
23    similarity = weights_norm @ weights_norm.T
24    im2 = ax2.imshow(similarity.numpy(), cmap='viridis')
25    ax2.set_xlabel('Position')
26    ax2.set_ylabel('Position')
27    ax2.set_title('Position Similarity (Cosine)')
28    plt.colorbar(im2, ax=ax2)
29
30    plt.tight_layout()
31    plt.savefig('learned_pe_visualization.png', dpi=150)
32    plt.close()
33    print("Saved learned_pe_visualization.png")
34
35
36# Create and visualize
37pe = LearnedPositionalEmbedding(max_seq_len=100, d_model=128, dropout=0.0)
38visualize_learned_embeddings(pe)

Summary

Implementation

🐍python
1class LearnedPositionalEmbedding(nn.Module):
2    def __init__(self, max_seq_len, d_model):
3        self.position_embedding = nn.Embedding(max_seq_len, d_model)
4
5    def forward(self, x):
6        positions = torch.arange(x.size(1), device=x.device)
7        return x + self.position_embedding(positions)

Key Differences from Sinusoidal

AspectSinusoidalLearned
ParametersNonemax_len Γ— d_model
ExtrapolationNaturalLimited
AdaptationFixedTask-specific
InitializationDeterministicRandom/chosen

Next Section Preview

In the next section, we'll implement token embeddingsβ€”the layer that converts vocabulary indices to dense vectors. We'll cover vocabulary construction, special tokens, and initialization strategies.