Introduction Before computing attention, we need to transform our input embeddings into Query, Key, and Value representations. In multi-head attention, this transformation happens through learnable linear projections .
This section explains what these projections are, why they're necessary, and how to implement them with proper shape tracking.
What Are Linear Projections? The Basic Idea A linear projection transforms input from one space to another:
e x t p r o j e c t i o n = e x t i n p u t i m e s W + b ext{projection} = ext{input} imes W + b e x t p ro j ec t i o n = e x t in p u t im es W + b Where:
input : Shape [ e x t b a t c h , e x t s e q _ l e n , d e x t i n p u t ] [ ext{batch}, ext{seq\_len}, d_{ ext{input}}] [ e x t ba t c h , e x t se q _ l e n , d e x t in p u t ] W W W : Learnable weight matrix [ d e x t i n p u t , d e x t o u t p u t ] [d_{ ext{input}}, d_{ ext{output}}] [ d e x t in p u t , d e x t o u tp u t ] b b b : Learnable bias vector [ d e x t o u t p u t ] [d_{ ext{output}}] [ d e x t o u tp u t ] projection : Shape [ e x t b a t c h , e x t s e q _ l e n , d e x t o u t p u t ] [ ext{batch}, ext{seq\_len}, d_{ ext{output}}] [ e x t ba t c h , e x t se q _ l e n , d e x t o u tp u t ] Why Project Q, K, V? Without projections:
Q = K = V = e x t i n p u t e x t ( S a m e r e p r e s e n t a t i o n f o r a l l ! ) Q = K = V = ext{input} \quad ext{(Same representation for all!)} Q = K = V = e x t in p u t e x t ( S am ere p rese n t a t i o n f or a ll !) With projections:
egin{aligned}
Q &= ext{input} imes W_Q quad & ext{("What am I looking for?")} \\
K &= ext{input} imes W_K quad & ext{("What do I contain?")} \\
V &= ext{input} imes W_V quad & ext{("What information can I give?")}
end{aligned}
Each projection learns a different view of the input.
Projection Matrices in Multi-Head Attention The Transformation Flow Dimensions For the original Transformer:
egin{aligned}
d_{ ext{model}} &= 512 quad & ext{(embedding dimension)} \\
n_{ ext{heads}} &= 8 quad & ext{(number of attention heads)} \\
d_k = d_v &= rac{d_{ ext{model}}}{n_{ ext{heads}}} = rac{512}{8} = 64 quad & ext{(dimension per head)}
end{aligned}
Projection matrices:
egin{aligned}
W_Q &: [d_{ ext{model}}, d_{ ext{model}}] = [512, 512] \\
W_K &: [d_{ ext{model}}, d_{ ext{model}}] = [512, 512] \\
W_V &: [d_{ ext{model}}, d_{ ext{model}}] = [512, 512] \\
W_O &: [d_{ ext{model}}, d_{ ext{model}}] = [512, 512] quad ext{(output projection)}
end{aligned}
Why Project to Same Dimension? We could project to any dimension, but projecting to d e x t m o d e l d_{ ext{model}} d e x t m o d e l :
Keeps the overall embedding dimension consistent Allows splitting evenly across heads Maintains residual connection compatibility Step-by-Step Shape Analysis Example Configuration egin{aligned}
ext{batch\_size} &= 2 \\
ext{seq\_len} &= 10 \\
d_{ ext{model}} &= 512 \\
n_{ ext{heads}} &= 8 \\
d_k &= rac{d_{ ext{model}}}{n_{ ext{heads}}} = rac{512}{8} = 64
end{aligned}
Step 1: Input e x t I n p u t X : [ 2 , 10 , 512 ] = [ e x t b a t c h , e x t s e q _ l e n , d e x t m o d e l ] ext{Input } X: [2, 10, 512] = [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] e x t I n p u t X : [ 2 , 10 , 512 ] = [ e x t ba t c h , e x t se q _ l e n , d e x t m o d e l ] Step 2: Linear Projections egin{aligned}
Q &= X imes W_Q quad &[2, 10, 512] imes [512, 512]
ightarrow [2, 10, 512] \\
K &= X imes W_K quad &[2, 10, 512] imes [512, 512]
ightarrow [2, 10, 512] \\
V &= X imes W_V quad &[2, 10, 512] imes [512, 512]
ightarrow [2, 10, 512]
end{aligned}
Each token's 512-dim embedding is transformed into:
512 512 512 -dim query representation512 512 512 -dim key representation512 512 512 -dim value representationStep 3: Reshape for Multiple Heads Reshape: [ e x t b a t c h , e x t s e q _ l e n , d e x t m o d e l ] i g h t a r r o w [ e x t b a t c h , e x t s e q _ l e n , n e x t h e a d s , d k ] [ ext{batch}, ext{seq\_len}, d_{ ext{model}}]
ightarrow [ ext{batch}, ext{seq\_len}, n_{ ext{heads}}, d_k] [ e x t ba t c h , e x t se q _ l e n , d e x t m o d e l ] i g h t a rro w [ e x t ba t c h , e x t se q _ l e n , n e x t h e a d s , d k ]
Q = Q . e x t v i e w ( e x t b a t c h _ s i z e , e x t s e q _ l e n , n e x t h e a d s , d k ) Q = Q. ext{view}( ext{batch\_size}, ext{seq\_len}, n_{ ext{heads}}, d_k) Q = Q . e x t v i e w ( e x t ba t c h _ s i ze , e x t se q _ l e n , n e x t h e a d s , d k ) [ 2 , 10 , 512 ] i g h t a r r o w [ 2 , 10 , 8 , 64 ] [2, 10, 512]
ightarrow [2, 10, 8, 64] [ 2 , 10 , 512 ] i g h t a rro w [ 2 , 10 , 8 , 64 ] Transpose: [ e x t b a t c h , e x t s e q _ l e n , n e x t h e a d s , d k ] i g h t a r r o w [ e x t b a t c h , n e x t h e a d s , e x t s e q _ l e n , d k ] [ ext{batch}, ext{seq\_len}, n_{ ext{heads}}, d_k]
ightarrow [ ext{batch}, n_{ ext{heads}}, ext{seq\_len}, d_k] [ e x t ba t c h , e x t se q _ l e n , n e x t h e a d s , d k ] i g h t a rro w [ e x t ba t c h , n e x t h e a d s , e x t se q _ l e n , d k ]
Q = Q . e x t t r a n s p o s e ( 1 , 2 ) Q = Q. ext{transpose}(1, 2) Q = Q . e x t t r an s p ose ( 1 , 2 ) [ 2 , 10 , 8 , 64 ] i g h t a r r o w [ 2 , 8 , 10 , 64 ] [2, 10, 8, 64]
ightarrow [2, 8, 10, 64] [ 2 , 10 , 8 , 64 ] i g h t a rro w [ 2 , 8 , 10 , 64 ] Now we have:
2 2 2 batches8 8 8 heads10 10 10 positions64 64 64 -dimensional queries per headImplementation with nn.Linear Basic Implementation 1 import torch
2 import torch . nn as nn
3
4 class QKVProjection ( nn . Module ) :
5 """
6 Linear projections for Query, Key, Value in multi-head attention.
7 """
8
9 def __init__ ( self , d_model : int , num_heads : int , bias : bool = True ) :
10 """
11 Args:
12 d_model: Model embedding dimension
13 num_heads: Number of attention heads
14 bias: Whether to include bias terms
15 """
16 super ( ) . __init__ ( )
17
18 assert d_model % num_heads == 0 , \
19 f"d_model ( { d_model } ) must be divisible by num_heads ( { num_heads } )"
20
21 self . d_model = d_model
22 self . num_heads = num_heads
23 self . d_k = d_model // num_heads
24
25 # Three separate projection layers
26 self . W_Q = nn . Linear ( d_model , d_model , bias = bias )
27 self . W_K = nn . Linear ( d_model , d_model , bias = bias )
28 self . W_V = nn . Linear ( d_model , d_model , bias = bias )
29
30 def forward ( self , query_input , key_input , value_input ) :
31 """
32 Project inputs to Q, K, V.
33
34 Args:
35 query_input: [batch, seq_len_q, d_model]
36 key_input: [batch, seq_len_k, d_model]
37 value_input: [batch, seq_len_k, d_model]
38
39 Returns:
40 Q: [batch, seq_len_q, d_model]
41 K: [batch, seq_len_k, d_model]
42 V: [batch, seq_len_k, d_model]
43 """
44 Q = self . W_Q ( query_input )
45 K = self . W_K ( key_input )
46 V = self . W_V ( value_input )
47
48 return Q , K , V
49
50
51 # Example usage
52 d_model = 512
53 num_heads = 8
54 batch_size = 2
55 seq_len = 10
56
57 projection = QKVProjection ( d_model , num_heads )
58 x = torch . randn ( batch_size , seq_len , d_model )
59
60 Q , K , V = projection ( x , x , x ) # Self-attention: same input for Q, K, V
61
62 print ( f"Input shape: { x . shape } " )
63 print ( f"Q shape: { Q . shape } " )
64 print ( f"K shape: { K . shape } " )
65 print ( f"V shape: { V . shape } " )
Output:
Alternative: Combined Projection For efficiency, we can project Q, K, V with a single larger matrix:
This is more efficient because:
One matrix multiplication instead of three Better memory access patterns Common in production implementations The Output Projection (W_O) After Attention After computing multi-head attention, we concatenate head outputs:
egin{aligned}
ext{Head 1 output} &: [ ext{batch}, ext{seq\_len}, d_k] = [2, 10, 64] \\
ext{Head 2 output} &: [ ext{batch}, ext{seq\_len}, d_k] = [2, 10, 64] \\
&\vdots \\
ext{Head 8 output} &: [ ext{batch}, ext{seq\_len}, d_k] = [2, 10, 64] \\[0.5em]
\hline \\[-0.5em]
ext{Concatenated} &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] = [2, 10, 512]
end{aligned}
The Output Projection The concatenated output goes through W_O:
Why W_O? The output projection serves several purposes:
Mixing head outputs : Combines information from different headsLearning to weight heads : Some heads may be more importantDimensional consistency : Ensures output matches input dimensionWithout W O W_O W O :
e x t O u t p u t = e x t C o n c a t ( e x t h e a d 1 , … , e x t h e a d 8 ) e x t ( J u s t c o n c a t e n a t i o n ) ext{Output} = ext{Concat}( ext{head}_1, \ldots, ext{head}_8) \quad ext{(Just concatenation)} e x t O u tp u t = e x t C o n c a t ( e x t h e a d 1 , … , e x t h e a d 8 ) e x t ( J u s t co n c a t e na t i o n ) With W O W_O W O :
e x t O u t p u t = e x t C o n c a t ( e x t h e a d 1 , … , e x t h e a d 8 ) i m e s W O e x t ( L e a r n e d c o m b i n a t i o n ) ext{Output} = ext{Concat}( ext{head}_1, \ldots, ext{head}_8) imes W_O \quad ext{(Learned combination)} e x t O u tp u t = e x t C o n c a t ( e x t h e a d 1 , … , e x t h e a d 8 ) im es W O e x t ( L e a r n e d co mbina t i o n ) Parameter Count Analysis Counting Parameters For d e x t m o d e l = 512 d_{ ext{model}} = 512 d e x t m o d e l = 512 , n e x t h e a d s = 8 n_{ ext{heads}} = 8 n e x t h e a d s = 8 :
Without bias:
egin{aligned}
W_Q &: 512 imes 512 = 262{,}144 \\
W_K &: 512 imes 512 = 262{,}144 \\
W_V &: 512 imes 512 = 262{,}144 \\
W_O &: 512 imes 512 = 262{,}144 \\
\hline
extbf{Total} &: 1{,}048{,}576 ext{ parameters} quad (4 imes d_{ ext{model}}^2)
end{aligned}
With bias:
egin{aligned}
W_Q &: 512 imes 512 + 512 = 262{,}656 \\
W_K &: 512 imes 512 + 512 = 262{,}656 \\
W_V &: 512 imes 512 + 512 = 262{,}656 \\
W_O &: 512 imes 512 + 512 = 262{,}656 \\
\hline
extbf{Total} &: 1{,}050{,}624 ext{ parameters}
end{aligned}
Comparison to Single-Head Single-head (d k = d e x t m o d e l d_k = d_{ ext{model}} d k = d e x t m o d e l ):
egin{aligned}
W_Q &: 512 imes 512 = 262{,}144 \\
W_K &: 512 imes 512 = 262{,}144 \\
W_V &: 512 imes 512 = 262{,}144 \\
& ext{(No } W_O ext{ needed if no combining)} \\
\hline
extbf{Total} &: 786{,}432 ext{ parameters}
end{aligned}
Multi-head adds W O W_O W O → ~33% more parameters for significant expressiveness gain.
Initialization Strategies Why Initialization Matters Poor initialization can cause:
Vanishing/exploding attention scores Heads learning identical patterns Slow or failed training Xavier/Glorot Initialization (Default in PyTorch) PyTorch nn.Linear uses this by default. Good for tanh activations:
\sigma = \sqrt{rac{2}{ ext{fan\_in} + ext{fan\_out}}}
Kaiming Initialization Better for ReLU networks, but attention doesn't use ReLU:
\sigma = \sqrt{rac{2}{ ext{fan\_in}}}
Scaled Initialization (GPT-2 style) For very deep models, scale down initialization:
Scaled Initialization (GPT-2 Style)
Recommendation For most cases, PyTorch defaults work well. For very deep models (>12 layers), consider scaled initialization.
Complete Projection Module Complete MultiHeadProjection Module
🐍 multi_head_projection.py
1 import torch
2 import torch . nn as nn
3 import torch . nn . functional as F
4 from typing import Tuple
5
6
7 class MultiHeadProjection ( nn . Module ) :
8 """
9 Complete projection module for multi-head attention.
10
11 Includes W_Q, W_K, W_V for projections and W_O for output.
12 """
13
14 def __init__ (
15 self ,
16 d_model : int ,
17 num_heads : int ,
18 dropout : float = 0.0 ,
19 bias : bool = True
20 ) :
21 """
22 Args:
23 d_model: Model embedding dimension
24 num_heads: Number of attention heads
25 dropout: Dropout rate for projections
26 bias: Whether to include bias in linear layers
27 """
28 super ( ) . __init__ ( )
29
30 assert d_model % num_heads == 0 , \
31 f"d_model ( { d_model } ) must be divisible by num_heads ( { num_heads } )"
32
33 self . d_model = d_model
34 self . num_heads = num_heads
35 self . d_k = d_model // num_heads
36
37 # Projection layers
38 self . W_Q = nn . Linear ( d_model , d_model , bias = bias )
39 self . W_K = nn . Linear ( d_model , d_model , bias = bias )
40 self . W_V = nn . Linear ( d_model , d_model , bias = bias )
41 self . W_O = nn . Linear ( d_model , d_model , bias = bias )
42
43 # Dropout
44 self . dropout = nn . Dropout ( dropout ) if dropout > 0 else None
45
46 # Initialize weights
47 self . _reset_parameters ( )
48
49 def _reset_parameters ( self ) :
50 """Initialize parameters using Xavier uniform."""
51 for module in [ self . W_Q , self . W_K , self . W_V , self . W_O ] :
52 nn . init . xavier_uniform_ ( module . weight )
53 if module . bias is not None :
54 nn . init . zeros_ ( module . bias )
55
56 def project_qkv (
57 self ,
58 query_input : torch . Tensor ,
59 key_input : torch . Tensor ,
60 value_input : torch . Tensor
61 ) - > Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
62 """
63 Project inputs to Q, K, V.
64
65 Args:
66 query_input: [batch, seq_len_q, d_model]
67 key_input: [batch, seq_len_k, d_model]
68 value_input: [batch, seq_len_k, d_model]
69
70 Returns:
71 Q: [batch, seq_len_q, d_model]
72 K: [batch, seq_len_k, d_model]
73 V: [batch, seq_len_k, d_model]
74 """
75 Q = self . W_Q ( query_input )
76 K = self . W_K ( key_input )
77 V = self . W_V ( value_input )
78
79 if self . dropout :
80 Q = self . dropout ( Q )
81 K = self . dropout ( K )
82 V = self . dropout ( V )
83
84 return Q , K , V
85
86 def project_output ( self , attention_output : torch . Tensor ) - > torch . Tensor :
87 """
88 Project concatenated head outputs.
89
90 Args:
91 attention_output: [batch, seq_len, d_model]
92
93 Returns:
94 output: [batch, seq_len, d_model]
95 """
96 output = self . W_O ( attention_output )
97
98 if self . dropout :
99 output = self . dropout ( output )
100
101 return output
102
103
104 # Test the module
105 def test_projections ( ) :
106 d_model = 512
107 num_heads = 8
108 batch_size = 2
109 seq_len = 10
110
111 proj = MultiHeadProjection ( d_model , num_heads )
112
113 x = torch . randn ( batch_size , seq_len , d_model )
114
115 # Test QKV projection
116 Q , K , V = proj . project_qkv ( x , x , x )
117 assert Q . shape == ( batch_size , seq_len , d_model )
118 assert K . shape == ( batch_size , seq_len , d_model )
119 assert V . shape == ( batch_size , seq_len , d_model )
120
121 # Test output projection
122 output = proj . project_output ( Q ) # Using Q as dummy input
123 assert output . shape == ( batch_size , seq_len , d_model )
124
125 print ( "✓ All projection tests passed!" )
126
127 # Count parameters
128 total_params = sum ( p . numel ( ) for p in proj . parameters ( ) )
129 print ( f"Total parameters: { total_params : , } " )
130
131 test_projections ( )
Summary Key Concepts Concept Description Linear Projection Transform input to Q, K, V spaces W Q , W K , W V W_Q, W_K, W_V W Q , W K , W V [ d e x t m o d e l , d e x t m o d e l ] [d_{ ext{model}}, d_{ ext{model}}] [ d e x t m o d e l , d e x t m o d e l ] learnable matricesW O W_O W O Output projection after attention d k d_k d k = rac{d_{ ext{model}}}{n_{ ext{heads}}} (per-head dimension)
Shape Flow egin{aligned}
ext{Input} &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] \\
&\downarrow quad W_Q, W_K, W_V \\
Q, K, V &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] \\
&\downarrow quad ext{reshape + transpose} \\
ext{Per-head} &: [ ext{batch}, n_{ ext{heads}}, ext{seq\_len}, d_k] \\
&\downarrow quad ext{attention} \\
ext{Head outputs} &: [ ext{batch}, n_{ ext{heads}}, ext{seq\_len}, d_k] \\
&\downarrow quad ext{transpose + reshape} \\
ext{Concatenated} &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] \\
&\downarrow quad W_O \\
ext{Output} &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}]
end{aligned}
Implementation Notes Use nn.Linear for projections (handles batching automatically) Initialize carefully for deep models Consider combined QKV projection for efficiency Always verify shapes match expectations Exercises Implementation Exercises Implement a projection module that shares W K W_K W K and W V W_V W V (sometimes used for efficiency). Add a scaling factor to the projection: Q = rac{Q}{\sqrt{d_k}} applied after projection. Implement per-head projection matrices instead of full d e x t m o d e l d_{ ext{model}} d e x t m o d e l projections. Analysis Exercises Calculate the memory usage for projections with d e x t m o d e l = 1024 d_{ ext{model}} = 1024 d e x t m o d e l = 1024 , n e x t h e a d s = 16 n_{ ext{heads}} = 16 n e x t h e a d s = 16 . Compare the FLOPs (floating point operations) for separate vs combined QKV projection. Experiment with different initialization schemes and measure their effect on attention weights.