Chapter 7
15 min read
Section 37 of 75

Complete Encoder Forward Pass Walkthrough

Transformer Encoder

Introduction

Understanding exact tensor shapes at every stage is crucial for debugging and optimizing transformers. This section traces through a complete encoder forward pass with concrete dimensions.


Configuration

Our Example Setup

πŸ“text
1Batch size:      B = 32
2Sequence length: S = 50
3Model dimension: D = 512
4Heads:           H = 8
5d_k (per head):  d = 64 (512 / 8)
6FFN dimension:   F = 2048
7Number of layers: N = 6

Input Data

πŸ“text
1German sentences (tokenized):
2[
3  "Der schnelle braune Fuchs springt ...",  # 45 tokens
4  "Die Katze sitzt auf der Matte",          # 38 tokens
5  ...                                        # 32 sentences total
6]
7
8After padding to length 50:
9[
10  [tok1, tok2, ..., tok45, PAD, PAD, PAD, PAD, PAD],
11  [tok1, tok2, ..., tok38, PAD, PAD, ..., PAD],
12  ...
13]

Before the Encoder

Embedding + Positional Encoding

πŸ“text
1β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
2β”‚  Token IDs: [32, 50]                                        β”‚
3β”‚       ↓                                                     β”‚
4β”‚  Token Embedding: [32, 50, 512]                             β”‚
5β”‚       +                                                     β”‚
6β”‚  Positional Encoding: [1, 50, 512] (broadcasted)           β”‚
7β”‚       ↓                                                     β”‚
8β”‚  Embedding Output: [32, 50, 512]                            β”‚
9β”‚       ↓                                                     β”‚
10β”‚  Dropout: [32, 50, 512]                                     β”‚
11β”‚       ↓                                                     β”‚
12β”‚  Encoder Input: x = [32, 50, 512]                           β”‚
13β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Padding Mask Creation

🐍python
1# Original token counts per sentence
2token_lengths = [45, 38, 50, 42, ...]  # 32 values
3
4# Create mask: True where we should attend, False for padding
5# Shape: [32, 50]
6mask = torch.zeros(32, 50, dtype=torch.bool)
7for i, length in enumerate(token_lengths):
8    mask[i, length:] = True  # Padding positions = True (ignore)
9
10# For MultiheadAttention, reshape to [32, 50]
11# (key_padding_mask format)

Encoder Layer 1 Walkthrough

Step 1: Pre-LayerNorm (Attention)

πŸ“text
1Input x:           [32, 50, 512]
2                        ↓
3     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4     β”‚ LayerNorm                            β”‚
5     β”‚   Normalize across dim=-1 (512)      β”‚
6     β”‚   Learnable Ξ³, Ξ²: [512] each         β”‚
7     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
8                        ↓
9Normalized:        [32, 50, 512]

Step 2: Q, K, V Projections

πŸ“text
1Normalized x:      [32, 50, 512]
2                        ↓
3     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4     β”‚ Linear Projections                   β”‚
5     β”‚   W_q: [512, 512]                    β”‚
6     β”‚   W_k: [512, 512]                    β”‚
7     β”‚   W_v: [512, 512]                    β”‚
8     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
9                        ↓
10Q = x @ W_q:       [32, 50, 512]
11K = x @ W_k:       [32, 50, 512]
12V = x @ W_v:       [32, 50, 512]

Step 3: Reshape for Multi-Head

πŸ“text
1Q before reshape:  [32, 50, 512]
2                        ↓
3     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4     β”‚ view(32, 50, 8, 64)                  β”‚
5     β”‚ transpose(1, 2)                      β”‚
6     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
7                        ↓
8Q after reshape:   [32, 8, 50, 64]  (batch, heads, seq, d_k)
9K after reshape:   [32, 8, 50, 64]
10V after reshape:   [32, 8, 50, 64]

Step 4: Attention Scores

πŸ“text
1Q:                 [32, 8, 50, 64]
2K^T:               [32, 8, 64, 50]
3                        ↓
4     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
5     β”‚ matmul(Q, K^T) / sqrt(64)            β”‚
6     β”‚ scores = QK^T / 8.0                  β”‚
7     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
8                        ↓
9scores:            [32, 8, 50, 50]
10
11# Memory for scores: 32 Γ— 8 Γ— 50 Γ— 50 Γ— 4 bytes = 2.56 MB

Step 5: Apply Mask and Softmax

πŸ“text
1scores:            [32, 8, 50, 50]
2mask:              [32, 1, 1, 50] (broadcasted)
3                        ↓
4     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
5     β”‚ scores.masked_fill(mask, -inf)       β”‚
6     β”‚ softmax(scores, dim=-1)              β”‚
7     β”‚ dropout(weights)                     β”‚
8     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
9                        ↓
10attention_weights: [32, 8, 50, 50]

Step 6: Apply Attention to Values

πŸ“text
1attention_weights: [32, 8, 50, 50]
2V:                 [32, 8, 50, 64]
3                        ↓
4     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
5     β”‚ matmul(attention_weights, V)         β”‚
6     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
7                        ↓
8context:           [32, 8, 50, 64]

Step 7: Combine Heads

πŸ“text
1context:           [32, 8, 50, 64]
2                        ↓
3     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4     β”‚ transpose(1, 2)                      β”‚
5     β”‚ contiguous()                         β”‚
6     β”‚ view(32, 50, 512)                    β”‚
7     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
8                        ↓
9combined:          [32, 50, 512]

Step 8: Output Projection

πŸ“text
1combined:          [32, 50, 512]
2                        ↓
3     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4     β”‚ W_o: [512, 512]                      β”‚
5     β”‚ combined @ W_o                       β”‚
6     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
7                        ↓
8attention_output:  [32, 50, 512]

Step 9: First Residual

πŸ“text
1Original x:        [32, 50, 512]
2attention_output:  [32, 50, 512]
3                        ↓
4     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
5     β”‚ dropout(attention_output)            β”‚
6     β”‚ x + dropout_output                   β”‚
7     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
8                        ↓
9x (updated):       [32, 50, 512]

Step 10: Feed-Forward Network

πŸ“text
1x:                 [32, 50, 512]
2                        ↓
3     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4     β”‚ LayerNorm: [32, 50, 512]             β”‚
5     β”‚                                      β”‚
6     β”‚ Linear1: 512 β†’ 2048                  β”‚
7     β”‚   β†’ [32, 50, 2048]                   β”‚
8     β”‚                                      β”‚
9     β”‚ GELU activation                      β”‚
10     β”‚   β†’ [32, 50, 2048]                   β”‚
11     β”‚                                      β”‚
12     β”‚ Dropout                              β”‚
13     β”‚   β†’ [32, 50, 2048]                   β”‚
14     β”‚                                      β”‚
15     β”‚ Linear2: 2048 β†’ 512                  β”‚
16     β”‚   β†’ [32, 50, 512]                    β”‚
17     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
18                        ↓
19ffn_output:        [32, 50, 512]

Step 11: Second Residual

πŸ“text
1x (before FFN):    [32, 50, 512]
2ffn_output:        [32, 50, 512]
3                        ↓
4     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
5     β”‚ dropout(ffn_output)                  β”‚
6     β”‚ x + dropout_output                   β”‚
7     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
8                        ↓
9Layer 1 Output:    [32, 50, 512]

Layers 2-6

Each Layer Follows Same Pattern

πŸ“text
1Layer 1 Output: [32, 50, 512]
2       ↓
3Layer 2: Self-Attention β†’ Add&Norm β†’ FFN β†’ Add&Norm
4       ↓
5Layer 2 Output: [32, 50, 512]
6       ↓
7Layer 3: ...
8       ↓
9Layer 3 Output: [32, 50, 512]
10       ↓
11...
12       ↓
13Layer 6 Output: [32, 50, 512]

Shape Consistency

πŸ“text
1All layers maintain: [batch=32, seq_len=50, d_model=512]
2
3Layer 1: [32, 50, 512] β†’ [32, 50, 512]
4Layer 2: [32, 50, 512] β†’ [32, 50, 512]
5Layer 3: [32, 50, 512] β†’ [32, 50, 512]
6Layer 4: [32, 50, 512] β†’ [32, 50, 512]
7Layer 5: [32, 50, 512] β†’ [32, 50, 512]
8Layer 6: [32, 50, 512] β†’ [32, 50, 512]

Final LayerNorm

For Pre-LN Architecture

πŸ“text
1Layer 6 Output:    [32, 50, 512]
2                        ↓
3     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4     β”‚ Final LayerNorm                      β”‚
5     β”‚   Normalize across dim=-1            β”‚
6     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
7                        ↓
8Encoder Output:    [32, 50, 512]

Complete Memory Analysis

Per-Layer Memory

🐍python
1def calculate_encoder_memory(batch, seq, d_model, heads, d_ff):
2    """Calculate memory usage for encoder forward pass."""
3
4    d_k = d_model // heads
5    bytes_per_float = 4
6
7    memory = {}
8
9    # Activations saved for backward pass
10    memory['input'] = batch * seq * d_model * bytes_per_float
11
12    # Q, K, V (before and after reshape)
13    memory['qkv'] = 3 * batch * seq * d_model * bytes_per_float
14
15    # Attention scores [batch, heads, seq, seq]
16    memory['attn_scores'] = batch * heads * seq * seq * bytes_per_float
17
18    # Attention output
19    memory['attn_out'] = batch * seq * d_model * bytes_per_float
20
21    # FFN hidden [batch, seq, d_ff]
22    memory['ffn_hidden'] = batch * seq * d_ff * bytes_per_float
23
24    return memory
25
26
27# Calculate for our example
28mem = calculate_encoder_memory(
29    batch=32, seq=50, d_model=512, heads=8, d_ff=2048
30)
31
32print("Memory per encoder layer:")
33total = 0
34for name, size in mem.items():
35    mb = size / 1024**2
36    total += mb
37    print(f"  {name:15s}: {mb:.2f} MB")
38print(f"  {'TOTAL':15s}: {total:.2f} MB")
39print(f"\n6 layers total: {total * 6:.2f} MB")

Output:

πŸ“text
1Memory per encoder layer:
2  input          : 3.12 MB
3  qkv            : 9.38 MB
4  attn_scores    : 2.44 MB
5  attn_out       : 3.12 MB
6  ffn_hidden     : 12.50 MB
7  TOTAL          : 30.56 MB
8
96 layers total: 183.38 MB

Scaling with Sequence Length

πŸ“text
1Attention scores scale as O(seq_lenΒ²):
2
3seq_len=50:   2.44 MB
4seq_len=100:  9.77 MB   (4Γ— more)
5seq_len=512:  256 MB    (104Γ— more!)
6seq_len=1024: 1024 MB   (419Γ— more!)
7
8Long sequences require significant memory!

Code Verification

Trace Script

🐍python
1import torch
2import torch.nn as nn
3
4
5def trace_encoder_shapes():
6    """Trace shapes through entire encoder."""
7
8    # Configuration
9    batch_size = 32
10    seq_len = 50
11    d_model = 512
12    num_heads = 8
13    d_ff = 2048
14    num_layers = 6
15
16    print("Encoder Shape Trace")
17    print("=" * 60)
18    print(f"Config: B={batch_size}, S={seq_len}, D={d_model}, H={num_heads}")
19    print("=" * 60)
20
21    # Create encoder
22    encoder = TransformerEncoder(
23        d_model=d_model,
24        num_heads=num_heads,
25        num_layers=num_layers,
26        d_ff=d_ff,
27        dropout=0.0  # Disable for cleaner trace
28    )
29    encoder.eval()
30
31    # Input
32    x = torch.randn(batch_size, seq_len, d_model)
33    print(f"\nInput: {x.shape}")
34
35    # Hook to capture intermediate shapes
36    shapes = {}
37
38    def make_hook(name):
39        def hook(module, input, output):
40            if isinstance(output, tuple):
41                shapes[name] = output[0].shape
42            else:
43                shapes[name] = output.shape
44        return hook
45
46    # Register hooks on key modules
47    for i, layer in enumerate(encoder.layers):
48        layer.self_attention.register_forward_hook(make_hook(f'layer_{i}_attn'))
49        layer.norm2.register_forward_hook(make_hook(f'layer_{i}_norm2'))
50
51    encoder.norm.register_forward_hook(make_hook('final_norm'))
52
53    # Forward pass
54    output = encoder(x)
55
56    # Print captured shapes
57    print("\nIntermediate shapes:")
58    for name, shape in shapes.items():
59        print(f"  {name:20s}: {shape}")
60
61    print(f"\nFinal output: {output.shape}")
62
63    # Verify
64    assert output.shape == x.shape, "Shape mismatch!"
65    print("\nβœ“ All shapes verified correct!")
66
67
68trace_encoder_shapes()

Summary Table

Complete Shape Summary

StageShapeNotes
Input tokens[32, 50]Token IDs
After embedding[32, 50, 512]+ positional encoding
Encoder Layer
Pre-LN 1[32, 50, 512]Normalize
Q, K, V[32, 50, 512] eachLinear projection
Q, K, V reshaped[32, 8, 50, 64]Multi-head format
Attention scores[32, 8, 50, 50]QK^T / √d_k
Attention weights[32, 8, 50, 50]After softmax + mask
Context[32, 8, 50, 64]weights @ V
Combined heads[32, 50, 512]Concatenate
Attention output[32, 50, 512]Output projection
After residual 1[32, 50, 512]x + attn_out
Pre-LN 2[32, 50, 512]Normalize
FFN hidden[32, 50, 2048]Expansion
FFN output[32, 50, 512]Contraction
After residual 2[32, 50, 512]x + ffn_out
Γ—6 layers
Final LN[32, 50, 512]Pre-LN requires this
Output[32, 50, 512]Encoder memory

Exercises

Shape Tracing

1. Trace shapes for batch_size=1, seq_len=100, d_model=768, heads=12.

2. Calculate memory usage for GPT-2 medium (24 layers, 1024 dim, 16 heads).

Analysis

3. Why does attention memory scale quadratically with sequence length?

4. What is the maximum sequence length you can use with 8GB GPU memory?


In the next section, we'll focus on how the encoder is specifically used for translation: processing German source sentences and preparing the encoder output (memory) for the decoder's cross-attention.