Introduction
Understanding exact tensor shapes at every stage is crucial for debugging and optimizing transformers. This section traces through a complete encoder forward pass with concrete dimensions.
Configuration
Our Example Setup
πtext
1Batch size: B = 32
2Sequence length: S = 50
3Model dimension: D = 512
4Heads: H = 8
5d_k (per head): d = 64 (512 / 8)
6FFN dimension: F = 2048
7Number of layers: N = 6Input Data
πtext
1German sentences (tokenized):
2[
3 "Der schnelle braune Fuchs springt ...", # 45 tokens
4 "Die Katze sitzt auf der Matte", # 38 tokens
5 ... # 32 sentences total
6]
7
8After padding to length 50:
9[
10 [tok1, tok2, ..., tok45, PAD, PAD, PAD, PAD, PAD],
11 [tok1, tok2, ..., tok38, PAD, PAD, ..., PAD],
12 ...
13]Before the Encoder
Embedding + Positional Encoding
πtext
1βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2β Token IDs: [32, 50] β
3β β β
4β Token Embedding: [32, 50, 512] β
5β + β
6β Positional Encoding: [1, 50, 512] (broadcasted) β
7β β β
8β Embedding Output: [32, 50, 512] β
9β β β
10β Dropout: [32, 50, 512] β
11β β β
12β Encoder Input: x = [32, 50, 512] β
13βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββPadding Mask Creation
πpython
1# Original token counts per sentence
2token_lengths = [45, 38, 50, 42, ...] # 32 values
3
4# Create mask: True where we should attend, False for padding
5# Shape: [32, 50]
6mask = torch.zeros(32, 50, dtype=torch.bool)
7for i, length in enumerate(token_lengths):
8 mask[i, length:] = True # Padding positions = True (ignore)
9
10# For MultiheadAttention, reshape to [32, 50]
11# (key_padding_mask format)Encoder Layer 1 Walkthrough
Step 1: Pre-LayerNorm (Attention)
πtext
1Input x: [32, 50, 512]
2 β
3 ββββββββββββββββββββββββββββββββββββββββ
4 β LayerNorm β
5 β Normalize across dim=-1 (512) β
6 β Learnable Ξ³, Ξ²: [512] each β
7 ββββββββββββββββββββββββββββββββββββββββ
8 β
9Normalized: [32, 50, 512]Step 2: Q, K, V Projections
πtext
1Normalized x: [32, 50, 512]
2 β
3 ββββββββββββββββββββββββββββββββββββββββ
4 β Linear Projections β
5 β W_q: [512, 512] β
6 β W_k: [512, 512] β
7 β W_v: [512, 512] β
8 ββββββββββββββββββββββββββββββββββββββββ
9 β
10Q = x @ W_q: [32, 50, 512]
11K = x @ W_k: [32, 50, 512]
12V = x @ W_v: [32, 50, 512]Step 3: Reshape for Multi-Head
πtext
1Q before reshape: [32, 50, 512]
2 β
3 ββββββββββββββββββββββββββββββββββββββββ
4 β view(32, 50, 8, 64) β
5 β transpose(1, 2) β
6 ββββββββββββββββββββββββββββββββββββββββ
7 β
8Q after reshape: [32, 8, 50, 64] (batch, heads, seq, d_k)
9K after reshape: [32, 8, 50, 64]
10V after reshape: [32, 8, 50, 64]Step 4: Attention Scores
πtext
1Q: [32, 8, 50, 64]
2K^T: [32, 8, 64, 50]
3 β
4 ββββββββββββββββββββββββββββββββββββββββ
5 β matmul(Q, K^T) / sqrt(64) β
6 β scores = QK^T / 8.0 β
7 ββββββββββββββββββββββββββββββββββββββββ
8 β
9scores: [32, 8, 50, 50]
10
11# Memory for scores: 32 Γ 8 Γ 50 Γ 50 Γ 4 bytes = 2.56 MBStep 5: Apply Mask and Softmax
πtext
1scores: [32, 8, 50, 50]
2mask: [32, 1, 1, 50] (broadcasted)
3 β
4 ββββββββββββββββββββββββββββββββββββββββ
5 β scores.masked_fill(mask, -inf) β
6 β softmax(scores, dim=-1) β
7 β dropout(weights) β
8 ββββββββββββββββββββββββββββββββββββββββ
9 β
10attention_weights: [32, 8, 50, 50]Step 6: Apply Attention to Values
πtext
1attention_weights: [32, 8, 50, 50]
2V: [32, 8, 50, 64]
3 β
4 ββββββββββββββββββββββββββββββββββββββββ
5 β matmul(attention_weights, V) β
6 ββββββββββββββββββββββββββββββββββββββββ
7 β
8context: [32, 8, 50, 64]Step 7: Combine Heads
πtext
1context: [32, 8, 50, 64]
2 β
3 ββββββββββββββββββββββββββββββββββββββββ
4 β transpose(1, 2) β
5 β contiguous() β
6 β view(32, 50, 512) β
7 ββββββββββββββββββββββββββββββββββββββββ
8 β
9combined: [32, 50, 512]Step 8: Output Projection
πtext
1combined: [32, 50, 512]
2 β
3 ββββββββββββββββββββββββββββββββββββββββ
4 β W_o: [512, 512] β
5 β combined @ W_o β
6 ββββββββββββββββββββββββββββββββββββββββ
7 β
8attention_output: [32, 50, 512]Step 9: First Residual
πtext
1Original x: [32, 50, 512]
2attention_output: [32, 50, 512]
3 β
4 ββββββββββββββββββββββββββββββββββββββββ
5 β dropout(attention_output) β
6 β x + dropout_output β
7 ββββββββββββββββββββββββββββββββββββββββ
8 β
9x (updated): [32, 50, 512]Step 10: Feed-Forward Network
πtext
1x: [32, 50, 512]
2 β
3 ββββββββββββββββββββββββββββββββββββββββ
4 β LayerNorm: [32, 50, 512] β
5 β β
6 β Linear1: 512 β 2048 β
7 β β [32, 50, 2048] β
8 β β
9 β GELU activation β
10 β β [32, 50, 2048] β
11 β β
12 β Dropout β
13 β β [32, 50, 2048] β
14 β β
15 β Linear2: 2048 β 512 β
16 β β [32, 50, 512] β
17 ββββββββββββββββββββββββββββββββββββββββ
18 β
19ffn_output: [32, 50, 512]Step 11: Second Residual
πtext
1x (before FFN): [32, 50, 512]
2ffn_output: [32, 50, 512]
3 β
4 ββββββββββββββββββββββββββββββββββββββββ
5 β dropout(ffn_output) β
6 β x + dropout_output β
7 ββββββββββββββββββββββββββββββββββββββββ
8 β
9Layer 1 Output: [32, 50, 512]Layers 2-6
Each Layer Follows Same Pattern
πtext
1Layer 1 Output: [32, 50, 512]
2 β
3Layer 2: Self-Attention β Add&Norm β FFN β Add&Norm
4 β
5Layer 2 Output: [32, 50, 512]
6 β
7Layer 3: ...
8 β
9Layer 3 Output: [32, 50, 512]
10 β
11...
12 β
13Layer 6 Output: [32, 50, 512]Shape Consistency
πtext
1All layers maintain: [batch=32, seq_len=50, d_model=512]
2
3Layer 1: [32, 50, 512] β [32, 50, 512]
4Layer 2: [32, 50, 512] β [32, 50, 512]
5Layer 3: [32, 50, 512] β [32, 50, 512]
6Layer 4: [32, 50, 512] β [32, 50, 512]
7Layer 5: [32, 50, 512] β [32, 50, 512]
8Layer 6: [32, 50, 512] β [32, 50, 512]Final LayerNorm
For Pre-LN Architecture
πtext
1Layer 6 Output: [32, 50, 512]
2 β
3 ββββββββββββββββββββββββββββββββββββββββ
4 β Final LayerNorm β
5 β Normalize across dim=-1 β
6 ββββββββββββββββββββββββββββββββββββββββ
7 β
8Encoder Output: [32, 50, 512]Complete Memory Analysis
Per-Layer Memory
πpython
1def calculate_encoder_memory(batch, seq, d_model, heads, d_ff):
2 """Calculate memory usage for encoder forward pass."""
3
4 d_k = d_model // heads
5 bytes_per_float = 4
6
7 memory = {}
8
9 # Activations saved for backward pass
10 memory['input'] = batch * seq * d_model * bytes_per_float
11
12 # Q, K, V (before and after reshape)
13 memory['qkv'] = 3 * batch * seq * d_model * bytes_per_float
14
15 # Attention scores [batch, heads, seq, seq]
16 memory['attn_scores'] = batch * heads * seq * seq * bytes_per_float
17
18 # Attention output
19 memory['attn_out'] = batch * seq * d_model * bytes_per_float
20
21 # FFN hidden [batch, seq, d_ff]
22 memory['ffn_hidden'] = batch * seq * d_ff * bytes_per_float
23
24 return memory
25
26
27# Calculate for our example
28mem = calculate_encoder_memory(
29 batch=32, seq=50, d_model=512, heads=8, d_ff=2048
30)
31
32print("Memory per encoder layer:")
33total = 0
34for name, size in mem.items():
35 mb = size / 1024**2
36 total += mb
37 print(f" {name:15s}: {mb:.2f} MB")
38print(f" {'TOTAL':15s}: {total:.2f} MB")
39print(f"\n6 layers total: {total * 6:.2f} MB")Output:
πtext
1Memory per encoder layer:
2 input : 3.12 MB
3 qkv : 9.38 MB
4 attn_scores : 2.44 MB
5 attn_out : 3.12 MB
6 ffn_hidden : 12.50 MB
7 TOTAL : 30.56 MB
8
96 layers total: 183.38 MBScaling with Sequence Length
πtext
1Attention scores scale as O(seq_lenΒ²):
2
3seq_len=50: 2.44 MB
4seq_len=100: 9.77 MB (4Γ more)
5seq_len=512: 256 MB (104Γ more!)
6seq_len=1024: 1024 MB (419Γ more!)
7
8Long sequences require significant memory!Code Verification
Trace Script
πpython
1import torch
2import torch.nn as nn
3
4
5def trace_encoder_shapes():
6 """Trace shapes through entire encoder."""
7
8 # Configuration
9 batch_size = 32
10 seq_len = 50
11 d_model = 512
12 num_heads = 8
13 d_ff = 2048
14 num_layers = 6
15
16 print("Encoder Shape Trace")
17 print("=" * 60)
18 print(f"Config: B={batch_size}, S={seq_len}, D={d_model}, H={num_heads}")
19 print("=" * 60)
20
21 # Create encoder
22 encoder = TransformerEncoder(
23 d_model=d_model,
24 num_heads=num_heads,
25 num_layers=num_layers,
26 d_ff=d_ff,
27 dropout=0.0 # Disable for cleaner trace
28 )
29 encoder.eval()
30
31 # Input
32 x = torch.randn(batch_size, seq_len, d_model)
33 print(f"\nInput: {x.shape}")
34
35 # Hook to capture intermediate shapes
36 shapes = {}
37
38 def make_hook(name):
39 def hook(module, input, output):
40 if isinstance(output, tuple):
41 shapes[name] = output[0].shape
42 else:
43 shapes[name] = output.shape
44 return hook
45
46 # Register hooks on key modules
47 for i, layer in enumerate(encoder.layers):
48 layer.self_attention.register_forward_hook(make_hook(f'layer_{i}_attn'))
49 layer.norm2.register_forward_hook(make_hook(f'layer_{i}_norm2'))
50
51 encoder.norm.register_forward_hook(make_hook('final_norm'))
52
53 # Forward pass
54 output = encoder(x)
55
56 # Print captured shapes
57 print("\nIntermediate shapes:")
58 for name, shape in shapes.items():
59 print(f" {name:20s}: {shape}")
60
61 print(f"\nFinal output: {output.shape}")
62
63 # Verify
64 assert output.shape == x.shape, "Shape mismatch!"
65 print("\nβ All shapes verified correct!")
66
67
68trace_encoder_shapes()Summary Table
Complete Shape Summary
| Stage | Shape | Notes |
|---|---|---|
| Input tokens | [32, 50] | Token IDs |
| After embedding | [32, 50, 512] | + positional encoding |
| Encoder Layer | ||
| Pre-LN 1 | [32, 50, 512] | Normalize |
| Q, K, V | [32, 50, 512] each | Linear projection |
| Q, K, V reshaped | [32, 8, 50, 64] | Multi-head format |
| Attention scores | [32, 8, 50, 50] | QK^T / βd_k |
| Attention weights | [32, 8, 50, 50] | After softmax + mask |
| Context | [32, 8, 50, 64] | weights @ V |
| Combined heads | [32, 50, 512] | Concatenate |
| Attention output | [32, 50, 512] | Output projection |
| After residual 1 | [32, 50, 512] | x + attn_out |
| Pre-LN 2 | [32, 50, 512] | Normalize |
| FFN hidden | [32, 50, 2048] | Expansion |
| FFN output | [32, 50, 512] | Contraction |
| After residual 2 | [32, 50, 512] | x + ffn_out |
| Γ6 layers | ||
| Final LN | [32, 50, 512] | Pre-LN requires this |
| Output | [32, 50, 512] | Encoder memory |
Exercises
Shape Tracing
1. Trace shapes for batch_size=1, seq_len=100, d_model=768, heads=12.
2. Calculate memory usage for GPT-2 medium (24 layers, 1024 dim, 16 heads).
Analysis
3. Why does attention memory scale quadratically with sequence length?
4. What is the maximum sequence length you can use with 8GB GPU memory?
In the next section, we'll focus on how the encoder is specifically used for translation: processing German source sentences and preparing the encoder output (memory) for the decoder's cross-attention.