Chapter 3
12 min read
Section 15 of 75

Linear Projections for Q, K, V

Multi-Head Attention

Introduction

Before computing attention, we need to transform our input embeddings into Query, Key, and Value representations. In multi-head attention, this transformation happens through learnable linear projections.

This section explains what these projections are, why they're necessary, and how to implement them with proper shape tracking.


What Are Linear Projections?

The Basic Idea

A linear projection transforms input from one space to another:

extprojection=extinputimesW+bext{projection} = ext{input} imes W + b

Where:

  • input: Shape [extbatch,extseq_len,dextinput][ ext{batch}, ext{seq\_len}, d_{ ext{input}}]
  • WW: Learnable weight matrix [dextinput,dextoutput][d_{ ext{input}}, d_{ ext{output}}]
  • bb: Learnable bias vector [dextoutput][d_{ ext{output}}]
  • projection: Shape [extbatch,extseq_len,dextoutput][ ext{batch}, ext{seq\_len}, d_{ ext{output}}]

Why Project Q, K, V?

Without projections:

Q=K=V=extinputext(Samerepresentationforall!)Q = K = V = ext{input} \quad ext{(Same representation for all!)}

With projections:

egin{aligned} Q &= ext{input} imes W_Q quad & ext{("What am I looking for?")} \\ K &= ext{input} imes W_K quad & ext{("What do I contain?")} \\ V &= ext{input} imes W_V quad & ext{("What information can I give?")} end{aligned}

Each projection learns a different view of the input.


Projection Matrices in Multi-Head Attention

The Transformation Flow

The Transformation Flow
📝text
12 lines without explanation
1Input X: [batch, seq_len, d_model]
23    ┌──────┴──────┐
4    ↓      ↓      ↓
5   W_Q    W_K    W_V
6    ↓      ↓      ↓
7   Q      K      V
8[batch, seq_len, d_model]
910    Split into heads
1112   [batch, num_heads, seq_len, d_k]

Dimensions

For the original Transformer:

egin{aligned} d_{ ext{model}} &= 512 quad & ext{(embedding dimension)} \\ n_{ ext{heads}} &= 8 quad & ext{(number of attention heads)} \\ d_k = d_v &= rac{d_{ ext{model}}}{n_{ ext{heads}}} = rac{512}{8} = 64 quad & ext{(dimension per head)} end{aligned}

Projection matrices:

egin{aligned} W_Q &: [d_{ ext{model}}, d_{ ext{model}}] = [512, 512] \\ W_K &: [d_{ ext{model}}, d_{ ext{model}}] = [512, 512] \\ W_V &: [d_{ ext{model}}, d_{ ext{model}}] = [512, 512] \\ W_O &: [d_{ ext{model}}, d_{ ext{model}}] = [512, 512] quad ext{(output projection)} end{aligned}

Why Project to Same Dimension?

We could project to any dimension, but projecting to dextmodeld_{ ext{model}}:

  1. Keeps the overall embedding dimension consistent
  2. Allows splitting evenly across heads
  3. Maintains residual connection compatibility

Step-by-Step Shape Analysis

Example Configuration

egin{aligned} ext{batch\_size} &= 2 \\ ext{seq\_len} &= 10 \\ d_{ ext{model}} &= 512 \\ n_{ ext{heads}} &= 8 \\ d_k &= rac{d_{ ext{model}}}{n_{ ext{heads}}} = rac{512}{8} = 64 end{aligned}

Step 1: Input

extInputX:[2,10,512]=[extbatch,extseq_len,dextmodel]ext{Input } X: [2, 10, 512] = [ ext{batch}, ext{seq\_len}, d_{ ext{model}}]

Step 2: Linear Projections

egin{aligned} Q &= X imes W_Q quad &[2, 10, 512] imes [512, 512] ightarrow [2, 10, 512] \\ K &= X imes W_K quad &[2, 10, 512] imes [512, 512] ightarrow [2, 10, 512] \\ V &= X imes W_V quad &[2, 10, 512] imes [512, 512] ightarrow [2, 10, 512] end{aligned}

Each token's 512-dim embedding is transformed into:

  • 512512-dim query representation
  • 512512-dim key representation
  • 512512-dim value representation

Step 3: Reshape for Multiple Heads

Reshape: [extbatch,extseq_len,dextmodel]ightarrow[extbatch,extseq_len,nextheads,dk][ ext{batch}, ext{seq\_len}, d_{ ext{model}}] ightarrow [ ext{batch}, ext{seq\_len}, n_{ ext{heads}}, d_k]

Q=Q.extview(extbatch_size,extseq_len,nextheads,dk)Q = Q. ext{view}( ext{batch\_size}, ext{seq\_len}, n_{ ext{heads}}, d_k)
[2,10,512]ightarrow[2,10,8,64][2, 10, 512] ightarrow [2, 10, 8, 64]

Transpose: [extbatch,extseq_len,nextheads,dk]ightarrow[extbatch,nextheads,extseq_len,dk][ ext{batch}, ext{seq\_len}, n_{ ext{heads}}, d_k] ightarrow [ ext{batch}, n_{ ext{heads}}, ext{seq\_len}, d_k]

Q=Q.exttranspose(1,2)Q = Q. ext{transpose}(1, 2)
[2,10,8,64]ightarrow[2,8,10,64][2, 10, 8, 64] ightarrow [2, 8, 10, 64]

Now we have:

  • 22 batches
  • 88 heads
  • 1010 positions
  • 6464-dimensional queries per head

Implementation with nn.Linear

Basic Implementation

QKV Projection Module
🐍projection.py
65 lines without explanation
1import torch
2import torch.nn as nn
3
4class QKVProjection(nn.Module):
5    """
6    Linear projections for Query, Key, Value in multi-head attention.
7    """
8
9    def __init__(self, d_model: int, num_heads: int, bias: bool = True):
10        """
11        Args:
12            d_model: Model embedding dimension
13            num_heads: Number of attention heads
14            bias: Whether to include bias terms
15        """
16        super().__init__()
17
18        assert d_model % num_heads == 0, \
19            f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
20
21        self.d_model = d_model
22        self.num_heads = num_heads
23        self.d_k = d_model // num_heads
24
25        # Three separate projection layers
26        self.W_Q = nn.Linear(d_model, d_model, bias=bias)
27        self.W_K = nn.Linear(d_model, d_model, bias=bias)
28        self.W_V = nn.Linear(d_model, d_model, bias=bias)
29
30    def forward(self, query_input, key_input, value_input):
31        """
32        Project inputs to Q, K, V.
33
34        Args:
35            query_input: [batch, seq_len_q, d_model]
36            key_input: [batch, seq_len_k, d_model]
37            value_input: [batch, seq_len_k, d_model]
38
39        Returns:
40            Q: [batch, seq_len_q, d_model]
41            K: [batch, seq_len_k, d_model]
42            V: [batch, seq_len_k, d_model]
43        """
44        Q = self.W_Q(query_input)
45        K = self.W_K(key_input)
46        V = self.W_V(value_input)
47
48        return Q, K, V
49
50
51# Example usage
52d_model = 512
53num_heads = 8
54batch_size = 2
55seq_len = 10
56
57projection = QKVProjection(d_model, num_heads)
58x = torch.randn(batch_size, seq_len, d_model)
59
60Q, K, V = projection(x, x, x)  # Self-attention: same input for Q, K, V
61
62print(f"Input shape: {x.shape}")
63print(f"Q shape: {Q.shape}")
64print(f"K shape: {K.shape}")
65print(f"V shape: {V.shape}")

Output:

QKV Projection Output
📝text
4 lines without explanation
1Input shape: torch.Size([2, 10, 512])
2Q shape: torch.Size([2, 10, 512])
3K shape: torch.Size([2, 10, 512])
4V shape: torch.Size([2, 10, 512])

Alternative: Combined Projection

For efficiency, we can project Q, K, V with a single larger matrix:

Combined QKV Projection
🐍projection.py
35 lines without explanation
1class CombinedQKVProjection(nn.Module):
2    """
3    Combined projection for Q, K, V using a single linear layer.
4    More efficient as it can be parallelized.
5    """
6
7    def __init__(self, d_model: int, bias: bool = True):
8        super().__init__()
9        self.d_model = d_model
10        # Single projection that outputs Q, K, V concatenated
11        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=bias)
12
13    def forward(self, x):
14        """
15        Args:
16            x: [batch, seq_len, d_model]
17
18        Returns:
19            Q, K, V: each [batch, seq_len, d_model]
20        """
21        # [batch, seq_len, 3 * d_model]
22        qkv = self.qkv_proj(x)
23
24        # Split into three parts
25        Q, K, V = qkv.chunk(3, dim=-1)
26
27        return Q, K, V
28
29
30# Example
31combined_proj = CombinedQKVProjection(d_model)
32x = torch.randn(batch_size, seq_len, d_model)
33
34Q, K, V = combined_proj(x)
35print(f"Q shape: {Q.shape}")  # [2, 10, 512]

This is more efficient because:

  • One matrix multiplication instead of three
  • Better memory access patterns
  • Common in production implementations

The Output Projection (W_O)

After Attention

After computing multi-head attention, we concatenate head outputs:

egin{aligned} ext{Head 1 output} &: [ ext{batch}, ext{seq\_len}, d_k] = [2, 10, 64] \\ ext{Head 2 output} &: [ ext{batch}, ext{seq\_len}, d_k] = [2, 10, 64] \\ &\vdots \\ ext{Head 8 output} &: [ ext{batch}, ext{seq\_len}, d_k] = [2, 10, 64] \\[0.5em] \hline \\[-0.5em] ext{Concatenated} &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] = [2, 10, 512] end{aligned}

The Output Projection

The concatenated output goes through W_O:

Output Projection Module
🐍projection.py
16 lines without explanation
1class OutputProjection(nn.Module):
2    """Output projection after multi-head attention."""
3
4    def __init__(self, d_model: int, bias: bool = True):
5        super().__init__()
6        self.W_O = nn.Linear(d_model, d_model, bias=bias)
7
8    def forward(self, x):
9        """
10        Args:
11            x: [batch, seq_len, d_model] (concatenated heads)
12
13        Returns:
14            output: [batch, seq_len, d_model]
15        """
16        return self.W_O(x)

Why W_O?

The output projection serves several purposes:

  1. Mixing head outputs: Combines information from different heads
  2. Learning to weight heads: Some heads may be more important
  3. Dimensional consistency: Ensures output matches input dimension

Without WOW_O:

extOutput=extConcat(exthead1,,exthead8)ext(Justconcatenation)ext{Output} = ext{Concat}( ext{head}_1, \ldots, ext{head}_8) \quad ext{(Just concatenation)}

With WOW_O:

extOutput=extConcat(exthead1,,exthead8)imesWOext(Learnedcombination)ext{Output} = ext{Concat}( ext{head}_1, \ldots, ext{head}_8) imes W_O \quad ext{(Learned combination)}

Parameter Count Analysis

Counting Parameters

For dextmodel=512d_{ ext{model}} = 512, nextheads=8n_{ ext{heads}} = 8:

Without bias:

egin{aligned} W_Q &: 512 imes 512 = 262{,}144 \\ W_K &: 512 imes 512 = 262{,}144 \\ W_V &: 512 imes 512 = 262{,}144 \\ W_O &: 512 imes 512 = 262{,}144 \\ \hline extbf{Total} &: 1{,}048{,}576 ext{ parameters} quad (4 imes d_{ ext{model}}^2) end{aligned}

With bias:

egin{aligned} W_Q &: 512 imes 512 + 512 = 262{,}656 \\ W_K &: 512 imes 512 + 512 = 262{,}656 \\ W_V &: 512 imes 512 + 512 = 262{,}656 \\ W_O &: 512 imes 512 + 512 = 262{,}656 \\ \hline extbf{Total} &: 1{,}050{,}624 ext{ parameters} end{aligned}

Comparison to Single-Head

Single-head (dk=dextmodeld_k = d_{ ext{model}}):

egin{aligned} W_Q &: 512 imes 512 = 262{,}144 \\ W_K &: 512 imes 512 = 262{,}144 \\ W_V &: 512 imes 512 = 262{,}144 \\ & ext{(No } W_O ext{ needed if no combining)} \\ \hline extbf{Total} &: 786{,}432 ext{ parameters} end{aligned}

Multi-head adds WOW_O → ~33% more parameters for significant expressiveness gain.


Initialization Strategies

Why Initialization Matters

Poor initialization can cause:

  • Vanishing/exploding attention scores
  • Heads learning identical patterns
  • Slow or failed training

Xavier/Glorot Initialization (Default in PyTorch)

PyTorch nn.Linear uses this by default. Good for tanh activations:

\sigma = \sqrt{ rac{2}{ ext{fan\_in} + ext{fan\_out}}}

Kaiming Initialization

Better for ReLU networks, but attention doesn't use ReLU:

\sigma = \sqrt{ rac{2}{ ext{fan\_in}}}

Scaled Initialization (GPT-2 style)

For very deep models, scale down initialization:

Scaled Initialization (GPT-2 Style)
🐍initialization.py
10 lines without explanation
1def init_weights(module, n_layer, d_model):
2    """GPT-2 style initialization."""
3    if isinstance(module, nn.Linear):
4        std = 0.02  # Base std
5        if hasattr(module, 'NANOGPT_SCALE_INIT'):
6            # Scale by layer depth for residual connections
7            std *= (2 * n_layer) ** -0.5
8        nn.init.normal_(module.weight, mean=0.0, std=std)
9        if module.bias is not None:
10            nn.init.zeros_(module.bias)

Recommendation

For most cases, PyTorch defaults work well. For very deep models (>12 layers), consider scaled initialization.


Complete Projection Module

Complete MultiHeadProjection Module
🐍multi_head_projection.py
131 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Tuple
5
6
7class MultiHeadProjection(nn.Module):
8    """
9    Complete projection module for multi-head attention.
10
11    Includes W_Q, W_K, W_V for projections and W_O for output.
12    """
13
14    def __init__(
15        self,
16        d_model: int,
17        num_heads: int,
18        dropout: float = 0.0,
19        bias: bool = True
20    ):
21        """
22        Args:
23            d_model: Model embedding dimension
24            num_heads: Number of attention heads
25            dropout: Dropout rate for projections
26            bias: Whether to include bias in linear layers
27        """
28        super().__init__()
29
30        assert d_model % num_heads == 0, \
31            f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
32
33        self.d_model = d_model
34        self.num_heads = num_heads
35        self.d_k = d_model // num_heads
36
37        # Projection layers
38        self.W_Q = nn.Linear(d_model, d_model, bias=bias)
39        self.W_K = nn.Linear(d_model, d_model, bias=bias)
40        self.W_V = nn.Linear(d_model, d_model, bias=bias)
41        self.W_O = nn.Linear(d_model, d_model, bias=bias)
42
43        # Dropout
44        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
45
46        # Initialize weights
47        self._reset_parameters()
48
49    def _reset_parameters(self):
50        """Initialize parameters using Xavier uniform."""
51        for module in [self.W_Q, self.W_K, self.W_V, self.W_O]:
52            nn.init.xavier_uniform_(module.weight)
53            if module.bias is not None:
54                nn.init.zeros_(module.bias)
55
56    def project_qkv(
57        self,
58        query_input: torch.Tensor,
59        key_input: torch.Tensor,
60        value_input: torch.Tensor
61    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
62        """
63        Project inputs to Q, K, V.
64
65        Args:
66            query_input: [batch, seq_len_q, d_model]
67            key_input: [batch, seq_len_k, d_model]
68            value_input: [batch, seq_len_k, d_model]
69
70        Returns:
71            Q: [batch, seq_len_q, d_model]
72            K: [batch, seq_len_k, d_model]
73            V: [batch, seq_len_k, d_model]
74        """
75        Q = self.W_Q(query_input)
76        K = self.W_K(key_input)
77        V = self.W_V(value_input)
78
79        if self.dropout:
80            Q = self.dropout(Q)
81            K = self.dropout(K)
82            V = self.dropout(V)
83
84        return Q, K, V
85
86    def project_output(self, attention_output: torch.Tensor) -> torch.Tensor:
87        """
88        Project concatenated head outputs.
89
90        Args:
91            attention_output: [batch, seq_len, d_model]
92
93        Returns:
94            output: [batch, seq_len, d_model]
95        """
96        output = self.W_O(attention_output)
97
98        if self.dropout:
99            output = self.dropout(output)
100
101        return output
102
103
104# Test the module
105def test_projections():
106    d_model = 512
107    num_heads = 8
108    batch_size = 2
109    seq_len = 10
110
111    proj = MultiHeadProjection(d_model, num_heads)
112
113    x = torch.randn(batch_size, seq_len, d_model)
114
115    # Test QKV projection
116    Q, K, V = proj.project_qkv(x, x, x)
117    assert Q.shape == (batch_size, seq_len, d_model)
118    assert K.shape == (batch_size, seq_len, d_model)
119    assert V.shape == (batch_size, seq_len, d_model)
120
121    # Test output projection
122    output = proj.project_output(Q)  # Using Q as dummy input
123    assert output.shape == (batch_size, seq_len, d_model)
124
125    print("✓ All projection tests passed!")
126
127    # Count parameters
128    total_params = sum(p.numel() for p in proj.parameters())
129    print(f"Total parameters: {total_params:,}")
130
131test_projections()

Summary

Key Concepts

ConceptDescription
Linear ProjectionTransform input to Q, K, V spaces
WQ,WK,WVW_Q, W_K, W_V[dextmodel,dextmodel][d_{ ext{model}}, d_{ ext{model}}] learnable matrices
WOW_OOutput projection after attention
dkd_k= rac{d_{ ext{model}}}{n_{ ext{heads}}} (per-head dimension)

Shape Flow

egin{aligned} ext{Input} &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] \\ &\downarrow quad W_Q, W_K, W_V \\ Q, K, V &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] \\ &\downarrow quad ext{reshape + transpose} \\ ext{Per-head} &: [ ext{batch}, n_{ ext{heads}}, ext{seq\_len}, d_k] \\ &\downarrow quad ext{attention} \\ ext{Head outputs} &: [ ext{batch}, n_{ ext{heads}}, ext{seq\_len}, d_k] \\ &\downarrow quad ext{transpose + reshape} \\ ext{Concatenated} &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] \\ &\downarrow quad W_O \\ ext{Output} &: [ ext{batch}, ext{seq\_len}, d_{ ext{model}}] end{aligned}

Implementation Notes

  1. Use nn.Linear for projections (handles batching automatically)
  2. Initialize carefully for deep models
  3. Consider combined QKV projection for efficiency
  4. Always verify shapes match expectations

Exercises

Implementation Exercises

  1. Implement a projection module that shares WKW_K and WVW_V (sometimes used for efficiency).
  2. Add a scaling factor to the projection: Q = rac{Q}{\sqrt{d_k}} applied after projection.
  3. Implement per-head projection matrices instead of full dextmodeld_{ ext{model}} projections.

Analysis Exercises

  1. Calculate the memory usage for projections with dextmodel=1024d_{ ext{model}} = 1024, nextheads=16n_{ ext{heads}} = 16.
  2. Compare the FLOPs (floating point operations) for separate vs combined QKV projection.
  3. Experiment with different initialization schemes and measure their effect on attention weights.