Chapter 6
30 min read
Section 7 of 17

Grouped-Query Attention (GQA)

Grouped-Query Attention (GQA)

Ainslie, Lee-Thorp, de Jong, Zemlyanskiy, Lebrón & Sanghai, "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints", Google Research, 2023


Learning Objectives

After completing this chapter, you will be able to:

  1. Explain why the KV-cache is the primary memory bottleneck during autoregressive inference and why MHA makes it H×H\times worse than necessary
  2. Describe how GQA partitions HH query heads into GG groups where each group shares a single set of keys and values
  3. Compute the group assignment function g(i)=i/(H/G)g(i) = \lfloor i \,/\, (H/G) \rfloor and understand how it generalizes MHA (G=HG = H) and MQA (G=1G = 1)
  4. Perform a full GQA forward pass by hand on the shared “The cat sat on mat” example, comparing G=2G = 2 and G=1G = 1
  5. Calculate KV-cache memory savings for production models like LLaMA 3 70B and Mistral 7B
  6. Understand the “uptrain” procedure that converts an existing MHA checkpoint to GQA with only 5% of original training compute
  7. Implement GQA from scratch in both NumPy and PyTorch

The Problem: The KV-Cache Dilemma

What Is the KV-Cache?

When a transformer generates text token by token (autoregressive decoding), it must recompute the attention scores for every previously generated token at every step. To avoid this O(N2)O(N^2) recomputation, modern systems cache the Key and Value tensors from all previous positions. This “KV-cache” trades memory for speed: once a token's KK and VV are computed, they are stored and reused for all subsequent tokens.

The problem is that the KV-cache grows with three multipliers: the number of attention heads HH, the sequence length NN, and the per-head dimension dkd_k. For each transformer layer, the cache stores 2×H×N×dk2 \times H \times N \times d_k floating-point values (the factor of 2 accounts for both K and V).

MHA's Memory Wall

In standard Multi-Head Attention (Chapter 2), every head has its own independent KK and VV projections. For LLaMA 3 70B with H=64H = 64 heads, dk=128d_k = 128, and 80 layers in FP16:

KV-cache=2×64×N×128×80×2  bytes2.5MB per token\text{KV-cache} = 2 \times 64 \times N \times 128 \times 80 \times 2 \;\text{bytes} \approx 2.5\,\text{MB per token}

At a 128K context window, that is 320 GB of KV-cache alone — far exceeding the memory of any single GPU. This memory wall is the reason that long-context inference requires multi-GPU setups or aggressive compression.

MQA's Quality Cliff

Chapter 5 introduced Multi-Query Attention (MQA, Shazeer 2019), which shares a single K,VK, V pair across all HH heads. This reduces the KV-cache by a factor of HH — from 320 GB to just 5 GB in our LLaMA example. The speed improvement is dramatic.

But the quality cost is real. When all heads are forced to attend through the same key-value lens, the model loses the representational diversity that makes multi-head attention powerful. Benchmarks show measurable degradation on reasoning-heavy tasks, summarization, and multi-hop question answering (Ainslie et al., 2023, Table 1).

The Core Tension: MHA gives maximum quality but is memory-hungry. MQA gives maximum speed but sacrifices quality. Is there a sweet spot between these extremes?

The Intuition Behind Grouped-Query Attention

The Library Analogy

Imagine a research institute with 64 researchers (query heads). Each researcher needs to look up information in reference books (keys and values) to answer questions.

  • MHA: Every researcher has their own private library. Maximum research diversity, but the institute needs 64 separate library buildings. Expensive.
  • MQA: Everyone shares a single library. Cheap to maintain, but when 64 researchers all need the same book simultaneously, the bottleneck is severe. Popular topics get congested; niche topics are underserved.
  • GQA: Researchers are organized into departments of 8. Each department has its own library. Eight libraries instead of 64 — an 8×8\times cost reduction — but each department can specialize its collection. The biology department keeps different books than the physics department.

This is exactly how GQA works. Instead of HH separate KV sets (MHA) or 1 shared KV set (MQA), GQA uses GG groups, where 1GH1 \le G \le H. Each group of H/GH/G query heads shares one set of keys and values.

The Paper: Ainslie et al. 2023

Grouped-Query Attention was introduced by Joshua Ainslie and colleagues at Google Research in 2023. Their paper, “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,” made two key contributions:

  1. The GQA mechanism itself: a clean generalization that places MHA and MQA as special cases on a single continuum controlled by GG
  2. The uptrain procedure: a practical method to convert existing MHA models to GQA without training from scratch, requiring only ~5% of the original training compute

The Uptrain Insight

The uptrain procedure works by mean-pooling the KV projection weights within each group. If an MHA model has 64 separate WKW^K matrices and you want G=8G = 8 groups, you average heads 0–7 into group 0's WKW^K, heads 8–15 into group 1's WKW^K, and so on. Then you fine-tune the model for a small fraction of the original training budget to recover any quality loss.

This is how LLaMA 2 70B was converted from MHA to GQA. The resulting model achieved near-MHA quality at near-MQA inference speed, and GQA has since become the default architecture for virtually all production large language models.


Mathematical Formulation

The Group Assignment Function

The heart of GQA is a simple function that maps each query head ii to its KV group:

g(i)=iH/G=iheads_per_groupg(i) = \left\lfloor \frac{i}{H/G} \right\rfloor = \left\lfloor \frac{i}{\text{heads\_per\_group}} \right\rfloor

where \lfloor \cdot \rfloor is the floor function (integer division). With HH query heads and GG KV groups, each group serves H/GH/G heads. Heads 0 through H/G1H/G - 1 map to group 0, heads H/GH/G through 2H/G12H/G - 1 map to group 1, and so on.

The GQA Equation

Using the group assignment, each head computes standard scaled dot-product attention but uses its group's K and V rather than its own:

headi=Attention ⁣(QWiQ,    KWg(i)K,    VWg(i)V)\text{head}_i = \text{Attention}\!\left(Q \cdot W_i^Q,\;\; K \cdot W_{g(i)}^K,\;\; V \cdot W_{g(i)}^V\right)
GQA(Q,K,V)=Concat ⁣(head0,  head1,  ,  headH1)WO\text{GQA}(Q, K, V) = \text{Concat}\!\left(\text{head}_0,\;\text{head}_1,\;\ldots,\;\text{head}_{H-1}\right) \cdot W^O

The subscript g(i)g(i) on WKW^K and WVW^V is what makes GQA different: multiple heads share the same projection matrices for keys and values.

Symbol-by-Symbol Breakdown

SymbolMeaningOur Example
HHTotal number of query heads2
GGNumber of KV groups (1 ≤ G ≤ H, must divide H)2 or 1
dkd_kPer-head dimension = d_model / H4 / 2 = 2
H/GH/GHeads per group1 (G=2) or 2 (G=1)
g(i)g(i)Group index for head ifloor(i / (H/G))
WiQW_i^QQuery projection for head i (unique per head)
Wg(i)KW_{g(i)}^KKey projection for group g(i) (shared within group)
Wg(i)VW_{g(i)}^VValue projection for group g(i) (shared within group)

The Spectrum: MQA ↔ GQA ↔ MHA

GQA is a single mechanism that continuously interpolates between MQA and MHA:

ConfigurationGroups (G)Heads/GroupKV-Cache SizeQuality
MQA1`H`2 × 1 × N × d_kLowest
GQA1 < G < HH / G2 × G × N × d_kTunable
MHAH12 × H × N × d_kHighest
Design rule of thumb: Production models typically use G=H/8G = H/8 or G=H/4G = H/4. This gives 75–87.5% KV-cache savings with negligible quality loss. LLaMA 3 70B uses H=64,G=8H = 64, G = 8 (87.5% saving). Mistral 7B uses H=32,G=8H = 32, G = 8 (75% saving).

Interactive: Group Visualizer

Use the controls below to explore how changing GG affects the head-to-group mapping. Try the model presets to see real configurations used in production:

Loading group visualizer...

KV-Cache Memory Analysis

The Memory Formula

The KV-cache size per layer in bytes is:

KV-cache=2×G×N×dk×bytes_per_param\text{KV-cache} = 2 \times G \times N \times d_k \times \text{bytes\_per\_param}

The total cache across all LL layers is LL times this value. The key insight: the cache scales with GG, not HH. By choosing G<HG < H, we directly reduce memory consumption.

ModelHGSaving vs MHACache at 4K (FP16)
Our Example22 (MHA)0%80 B
Our Example21 (MQA)50%40 B
LLaMA 3 8B32875%256 MB
LLaMA 3 70B64887.5%1.28 GB
Mistral 7B32875%256 MB

Interactive: KV-Cache Comparison

Select a model and sequence length to see how GQA reduces KV-cache memory compared to MHA and MQA:

Loading KV-cache comparison...

Step-by-Step Calculation

Setup: Shared Example

We use the same Q,K,VR5×4Q, K, V \in \mathbb{R}^{5 \times 4} matrices as every chapter, with dmodel=4d_{\text{model}} = 4 and H=2H = 2 query heads (dk=2d_k = 2 per head). We compare two configurations:

  • G = 2 (one group per head = MHA behavior): Head 0 uses K[:,0:2],V[:,0:2]K[:,0{:}2], V[:,0{:}2]; Head 1 uses K[:,2:4],V[:,2:4]K[:,2{:}4], V[:,2{:}4]
  • G = 1 (one group for all = MQA behavior): Both heads use K[:,0:2],V[:,0:2]K[:,0{:}2], V[:,0{:}2]

GQA G = 2 (MHA Mode)

With G=2G = 2 and H=2H = 2, heads_per_group=1\text{heads\_per\_group} = 1. Each head gets its own KV group. We trace the computation for “The” (row 0).

Head 0 → Group 0

g(0)=0/1=0g(0) = \lfloor 0/1 \rfloor = 0. Head 0 uses Kg0=K[:,0:2]K_{g0} = K[:,0{:}2], Vg0=V[:,0:2]V_{g0} = V[:,0{:}2].

Query: Qh0[0]=Q[0,0:2]=[1,0]Q_{h0}[0] = Q[0, 0{:}2] = [1, 0]

Scaled scores (Qh0[0]Kg0T/2Q_{h0}[0] \cdot K_{g0}^T / \sqrt{2}):

Key tokenDot productScaled
The: K=[0,1](1×0 + 0×1) = 00.0000
cat: K=[1,0](1×1 + 0×0) = 10.7071
sat: K=[1,1](1×1 + 0×1) = 10.7071
on: K=[0,0](1×0 + 0×0) = 00.0000
mat: K=[1,0](1×1 + 0×0) = 10.7071

Softmax: [0.1237,  0.2509,  0.2509,  0.1237,  0.2509][0.1237,\; 0.2509,\; 0.2509,\; 0.1237,\; 0.2509]

Output: Oh0[0]=wjVg0[j]=[0.2491,  0.3764]O_{h0}[0] = \sum w_j \cdot V_{g0}[j] = [0.2491,\; 0.3764]

Head 1 → Group 1

g(1)=1/1=1g(1) = \lfloor 1/1 \rfloor = 1. Head 1 uses Kg1=K[:,2:4]K_{g1} = K[:,2{:}4], Vg1=V[:,2:4]V_{g1} = V[:,2{:}4].

Query: Qh1[0]=Q[0,2:4]=[1,0]Q_{h1}[0] = Q[0, 2{:}4] = [1, 0]

Scaled scores (Qh1[0]Kg1T/2Q_{h1}[0] \cdot K_{g1}^T / \sqrt{2}):

Key tokenDot productScaled
The: K=[0,1](1×0 + 0×1) = 00.0000
cat: K=[1,0](1×1 + 0×0) = 10.7071
sat: K=[0,0](1×0 + 0×0) = 00.0000
on: K=[1,1](1×1 + 0×1) = 10.7071
mat: K=[0.5,0.5](1×0.5 + 0×0.5) = 0.50.3536

Softmax: [0.1337,  0.2711,  0.1337,  0.2711,  0.1904][0.1337,\; 0.2711,\; 0.1337,\; 0.2711,\; 0.1904]

Output: Oh1[0]=[0.2289,  0.3663]O_{h1}[0] = [0.2289,\; 0.3663]

Head 0 attends most strongly to “cat”, “sat”, and “mat” (keys with high dim-0). Head 1 attends most strongly to “cat” and “on” (keys with high dim-2). The two heads capture different relationships because they use different K, V subspaces.

Concatenation

The final output for “The” concatenates both head outputs:

O[0]=[Oh0[0],  Oh1[0]]=[0.2491,  0.3764,  0.2289,  0.3663]O[0] = [O_{h0}[0],\; O_{h1}[0]] = [0.2491,\; 0.3764,\; 0.2289,\; 0.3663]

GQA G = 1 (MQA Mode)

Now set G=1G = 1. With heads_per_group=2\text{heads\_per\_group} = 2, both heads share the same KV group. Head 0 is unchanged (it already used group 0). But Head 1 now also uses group 0: K[:,0:2]K[:,0{:}2] and V[:,0:2]V[:,0{:}2] instead of K[:,2:4]K[:,2{:}4] and V[:,2:4]V[:,2{:}4].

Head 1 Changes KV Source

Query: Qh1[0]=[1,0]Q_{h1}[0] = [1, 0] (unchanged — queries are always head-specific)

Scaled scores (now against Kg0K_{g0} instead of Kg1K_{g1}):

Key tokenG=2 score (K[:,2:4])G=1 score (K[:,0:2])Changed?
The0.00000.0000
cat0.70710.7071
sat0.00000.7071✅ changed
on0.70710.0000✅ changed
mat0.35360.7071✅ changed

New softmax: [0.1237,  0.2509,  0.2509,  0.1237,  0.2509][0.1237,\; 0.2509,\; 0.2509,\; 0.1237,\; 0.2509]

New output: Oh1MQA[0]=[0.2491,  0.3764]O_{h1}^{\text{MQA}}[0] = [0.2491,\; 0.3764]

Critical observation: Head 1's output for “The” is now [0.2491,0.3764][0.2491, 0.3764] — exactly the same as Head 0's output! When both heads share the same K, V, and their queries happen to produce similar attention patterns, the heads become redundant. This is the quality cost of MQA that GQA exists to mitigate.

Output Comparison: G = 2 vs G = 1

TokenG=2 output (MHA)G=1 output (MQA)Max Δ
The[0.2491, 0.3764, 0.2289, 0.3663][0.2491, 0.3764, 0.2491, 0.3764]0.020
cat[0.4110, 0.1337, 0.2289, 0.3663][0.4110, 0.1337, 0.3583, 0.2126]0.154
sat[0.2718, 0.2718, 0.2289, 0.3663][0.2718, 0.2718, 0.2491, 0.3764]0.020
on[0.3000, 0.3000, 0.1799, 0.4579][0.3000, 0.3000, 0.2718, 0.2718]0.186
mat[0.2491, 0.3764, 0.2289, 0.3663][0.2491, 0.3764, 0.3583, 0.2126]0.154

Dims 0–1 (Head 0) are identical in both configurations. Dims 2–3 (Head 1) differ because Head 1 sees different keys and values. The maximum change is 0.186 for “on” in dimension 3. In our toy example with H=2H = 2, there is no middle ground — GG can only be 1 or 2. The real power of GQA emerges with H=64H = 64 where GG can be 2, 4, 8, 16, or 32.


Full Attention Weight Matrices

Averaged Weights — GQA G = 2 (MHA)

Thecatsatonmat
The0.12870.26100.19230.19740.2206
cat0.31880.11140.25010.18010.1397
sat0.15750.22620.25050.18020.1858
on0.19060.19060.14470.28370.1906
mat0.19740.19230.19230.19740.2206

Full Output Matrices

Output — GQA G = 2 (MHA)

dim-0dim-1dim-2dim-3
The0.24910.37640.22890.3663
cat0.41100.13370.22890.3663
sat0.27180.27180.22890.3663
on0.30000.30000.17990.4579
mat0.24910.37640.22890.3663

Output — GQA G = 1 (MQA)

dim-0dim-1dim-2dim-3
The0.24910.37640.24910.3764
cat0.41100.13370.35830.2126
sat0.27180.27180.24910.3764
on0.30000.30000.27180.2718
mat0.24910.37640.35830.2126

Interactive: Attention Heatmap

Toggle between G = 2 (MHA) and G = 1 (MQA) to see how sharing KV groups changes the attention patterns. Cells highlighted in amber indicate weights that changed from the MHA baseline:

Loading attention heatmap...

Applications Across Domains

DomainModelConfigWhy GQA?
NLPLLaMA 3 70BH=64, G=8128K context with manageable KV-cache
NLPMistral 7BH=32, G=8Fast inference for 7B-class model
NLPGemma 2 27BH=32, G=16High quality with moderate compression
CodeStarCoder 2H=24, G=8Long code context (16K+) with fast completion
VisionPaliGemmaH=16, G=4High-resolution image understanding
Multi-modalLLaVA-NextGQA backboneEfficient vision-language inference

GQA has become the de facto standard for all new large models. The few exceptions (very small models under 1B parameters) use MHA because the KV-cache savings are negligible at small scale.


Connection to Modern Systems

  • Flash Attention + GQA: Flash Attention (Chapter 13) is an IO-aware implementation trick that does not change the mathematical result. It composes perfectly with GQA — the reduced number of KV heads means fewer memory reads during tiled computation, amplifying Flash Attention's speed benefits. All major inference engines (vLLM, TGI, TensorRT-LLM) implement this combination.
  • Paged Attention + GQA: vLLM's PagedAttention manages KV-cache as virtual memory pages. With GQA, the page tables are H/G×H/G \times smaller, enabling higher batch sizes and better GPU utilization for serving.
  • KV-Cache Quantization: GQA reduces the number of KV heads that need to be stored, and each KV head can be further compressed via INT8 or INT4 quantization. The combination gives (H/G)×2(H/G) \times 2 to (H/G)×4(H/G) \times 4 total compression versus MHA FP16.
  • Speculative Decoding: Draft models in speculative decoding pipelines often use aggressive GQA (G=1G = 1 or G=2G = 2) because inference speed matters more than quality for the draft — the verifier catches errors.
  • Positional Encodings (Chapters 7–9): RoPE and ALiBi are applied to Q and K vectors. In GQA, the rotation/bias is applied to each query head independently but to each KV group only once. This is a minor efficiency gain but matters at scale.

Complexity Analysis

MetricMHAGQAMQA
Compute (FLOPs)O(HNdk+HN2)O(HNd_k + HN^2)O(HNdk+HN2)O(HNd_k + HN^2)O(HNdk+HN2)O(HNd_k + HN^2)
KV-Cache MemoryO(HNdk)O(HNd_k)O(GNdk)O(GNd_k)O(Ndk)O(Nd_k)
KV Projection Params2Hdk22Hd_k^22Gdk22Gd_k^22dk22d_k^2

Note that the compute cost is identical across all three — GQA does not reduce FLOPs. The savings come entirely from memory bandwidth: less data needs to be loaded from GPU memory during each attention operation, which translates to higher throughput during inference.


Python Implementation

The full GQA class with the shared example. The _group_index method is the only addition beyond standard multi-head attention — one line of integer division controls the entire KV sharing behavior.

Grouped-Query Attention — NumPy
🐍grouped_query_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. Q_h @ K_g.T runs as optimized C code. Same library used in every chapter.

2import math

Python standard library. We use math.sqrt() to precompute the scaling factor once.

4class GroupedQueryAttention

Wraps the GQA mechanism in a reusable class. The key difference from Chapters 1-5: this class takes G (number of KV groups) as a parameter. When G=H it behaves like MHA; when G=1 it behaves like MQA.

14def __init__(self, d_model, H, G)

Constructor takes three parameters. d_model is the total embedding dimension, H is the number of query heads, and G is the number of KV groups. G must divide H evenly. The key innovation: G controls where on the MQA-MHA spectrum the model sits.

EXECUTION STATE
⬇ input: d_model = 4 (total dimension)
⬇ input: H = 2 (query heads)
⬇ input: G = 2 (KV groups) — or 1 for MQA mode
26self.d_k = d_model // H

Per-head dimension. Each query head operates on d_k dimensions of the full d_model vector. With d_model=4, H=2, each head sees 2 dimensions.

EXECUTION STATE
d_model // H = 4 // 2 = 2
self.d_k = 2
27self.heads_per_group = H // G

How many query heads share each KV group. This is THE core GQA parameter. With H=2, G=2: 1 head per group (MHA). With H=2, G=1: 2 heads per group (MQA). With H=64, G=8: 8 heads per group (LLaMA 3 70B).

EXECUTION STATE
G=2: H // G = 2 // 2 = 1 (each head has its own KV → MHA)
G=1: H // G = 2 // 1 = 2 (both heads share one KV → MQA)
28self.scale = math.sqrt(self.d_k)

Scaling factor for dot-product attention. Same as Chapter 1: divide scores by √d_k to prevent softmax saturation.

EXECUTION STATE
math.sqrt(2) = 1.4142
self.scale = 1.4142
30def _softmax(self, x) → np.ndarray

Numerically stable softmax. Subtracts row-wise max before exp() to prevent overflow. Identical to Chapter 1 implementation.

EXECUTION STATE
⬇ input: x = shape (5, 5) — scaled score matrix for one head
⬆ returns = np.ndarray (5, 5) — each row sums to 1.0
36def _group_index(self, head_idx) → int

THE core GQA function. Maps a query head index to its KV group index using integer division. This single line determines the entire KV sharing pattern.

EXECUTION STATE
⬇ input: head_idx = 0 or 1 (for H=2)
⬆ returns = int — the KV group index for this head
38return head_idx // self.heads_per_group

Integer division maps head to group. With G=2 (heads_per_group=1): head 0→0, head 1→1 (separate groups = MHA). With G=1 (heads_per_group=2): head 0→0, head 1→0 (same group = MQA). For LLaMA 3 70B (H=64, G=8, heads_per_group=8): heads 0-7→group 0, heads 8-15→group 1, etc.

EXECUTION STATE
── G=2 (MHA) ── =
head 0: 0 // 1 = = 0 → KV Group 0
head 1: 1 // 1 = = 1 → KV Group 1
── G=1 (MQA) ── =
head 0: 0 // 2 = = 0 → KV Group 0
head 1: 1 // 2 = = 0 → KV Group 0 (SAME!)
40def forward(self, Q, K, V)

Main computation. Loops over H heads, maps each to its KV group, computes scaled dot-product attention per head, and concatenates all head outputs.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = (all_weights, output) — list of (5,5) + one (5,4)
54for h in range(self.H):

Loop over all H=2 query heads. For each head: (1) find its KV group via _group_index, (2) slice Q, K, V, (3) compute attention, (4) collect output.

LOOP TRACE · 2 iterations
h=0
g = _group_index(0) = 0 — Head 0 uses KV Group 0
Q_h0 = Q[:, 0:2] = The[1,0] cat[0,2] sat[1,1] on[0,0] mat[1,0]
K_g0 = K[:, 0:2] = The[0,1] cat[1,0] sat[1,1] on[0,0] mat[1,0]
V_g0 = V[:, 0:2] = The[1,0] cat[0,1] sat[0,0] on[0,0] mat[0.5,0.5]
h=1
g = _group_index(1) = 1 — Head 1 uses KV Group 1 (G=2) or Group 0 (G=1)
Q_h1 = Q[:, 2:4] = The[1,0] cat[0,1] sat[1,0] on[1,1] mat[0,1]
G=2: K_g1 = K[:, 2:4] = The[0,1] cat[1,0] sat[0,0] on[1,1] mat[0.5,0.5]
G=1: K_g0 = K[:, 0:2] = The[0,1] cat[1,0] sat[1,1] on[0,0] mat[1,0] ← DIFFERENT!
55g = self._group_index(h)

Call _group_index to determine which KV group this head belongs to. This is where GQA differs from MHA: multiple heads can map to the same group.

EXECUTION STATE
h=0: g = 0 (always group 0 for head 0)
h=1, G=2: g = 1 (own group — MHA behavior)
h=1, G=1: g = 0 (shares with head 0 — MQA behavior)
57Q_h = Q[:, h*self.d_k : (h+1)*self.d_k]

Slice query matrix for this head. Head 0 gets dims [0:2], head 1 gets dims [2:4]. Query slicing is the same regardless of G — each head always has its own query.

EXECUTION STATE
h=0: Q[:, 0:2] =
      d0   d1
The  1.0  0.0
cat  0.0  2.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
h=1: Q[:, 2:4] =
      d2   d3
The  1.0  0.0
cat  0.0  1.0
sat  1.0  0.0
on   1.0  1.0
mat  0.0  1.0
58K_g = K[:, g*self.d_k : (g+1)*self.d_k]

Slice key matrix for this GROUP (not head). When G=2, each head has its own group so K slices differ. When G=1, both heads use g=0, so both see K[:, 0:2]. This is the KV sharing mechanism.

EXECUTION STATE
g=0: K[:, 0:2] =
      d0   d1
The  0.0  1.0
cat  1.0  0.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
g=1: K[:, 2:4] (only used when G=2) =
      d2   d3
The  0.0  1.0
cat  1.0  0.0
sat  0.0  0.0
on   1.0  1.0
mat  0.5  0.5
61scores = Q_h @ K_g.T / self.scale

Standard scaled dot-product: compute similarity between each query and all keys in the group, then divide by √d_k=1.414. Produces a 5×5 score matrix per head.

EXECUTION STATE
h=0, g=0: Q_h0 @ K_g0.T / 1.414 =
      The     cat     sat      on     mat
The  0.000   0.707   0.707   0.000   0.707
cat  1.414   0.000   1.414   0.000   0.000
sat  0.707   0.707   1.414   0.000   0.707
on   0.000   0.000   0.000   0.000   0.000
mat  0.000   0.707   0.707   0.000   0.707
62weights = self._softmax(scores)

Apply softmax row-wise to convert scores into attention probabilities. Each row sums to 1.0.

EXECUTION STATE
h=0: softmax(scores) =
       The      cat      sat       on      mat
The  0.1237   0.2509   0.2509   0.1237   0.2509
cat  0.3664   0.0891   0.3664   0.0891   0.0891
sat  0.1812   0.1812   0.3673   0.0893   0.1812
on   0.2000   0.2000   0.2000   0.2000   0.2000
mat  0.1237   0.2509   0.2509   0.1237   0.2509
63output = weights @ V_g

Weighted sum of value vectors. Each output row is a blend of all 5 value vectors from this group, weighted by attention. Produces a 5×2 matrix per head.

EXECUTION STATE
h=0: weights @ V_g0 =
        d0       d1
The  0.2491   0.3764
cat  0.4110   0.1337
sat  0.2718   0.2718
on   0.3000   0.3000
mat  0.2491   0.3764
68return all_weights, np.hstack(head_outputs)

Concatenate all head outputs along the feature dimension. With H=2 heads each producing (5,2), the final output is (5,4). Returns both weights (for visualization) and the concatenated output.

EXECUTION STATE
⬆ return: np.hstack([O_h0, O_h1]) =
       d0       d1       d2       d3
The  0.2491   0.3764   0.2289   0.3663
cat  0.4110   0.1337   0.2289   0.3663
sat  0.2718   0.2718   0.2289   0.3663
on   0.3000   0.3000   0.1799   0.4579
mat  0.2491   0.3764   0.2289   0.3663
91tokens = [...]

The 5 tokens used in every chapter. 5 tokens gives clean 5×5 attention matrices.

EXECUTION STATE
tokens = ['The', 'cat', 'sat', 'on', 'mat']
93Q = np.array([...])

Query matrix (5×4). Each row is what that token is looking for. Shared across all 15 chapters.

EXECUTION STATE
Q =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
101K = np.array([...])

Key matrix (5×4). Each row is what that token advertises. In GQA, different groups see different slices of this matrix.

EXECUTION STATE
K =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
109V = np.array([...])

Value matrix (5×4). The content retrieved when a token is attended to. GQA groups share the same V slice, so multiple heads produce outputs from the same value pool.

EXECUTION STATE
V =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
118gqa_mha = GroupedQueryAttention(d_model=4, H=2, G=2)

Instantiate GQA with G=2 groups (one per head = MHA behavior). d_k=2, heads_per_group=1, scale=1.414.

EXECUTION STATE
d_k = 2
heads_per_group = 1 (each head has its own KV → MHA)
scale = 1.4142
119weights_g2, output_g2 = gqa_mha.forward(Q, K, V)

Run the full GQA pipeline with G=2. Head 0 uses K[:,0:2],V[:,0:2]; Head 1 uses K[:,2:4],V[:,2:4]. Each head sees different keys and values.

EXECUTION STATE
⬆ output_g2 (5×4) =
       d0       d1       d2       d3
The  0.2491   0.3764   0.2289   0.3663
cat  0.4110   0.1337   0.2289   0.3663
sat  0.2718   0.2718   0.2289   0.3663
on   0.3000   0.3000   0.1799   0.4579
mat  0.2491   0.3764   0.2289   0.3663
125gqa_mqa = GroupedQueryAttention(d_model=4, H=2, G=1)

Instantiate GQA with G=1 group (MQA behavior). Now heads_per_group=2, so BOTH heads share the same K,V slice (group 0).

EXECUTION STATE
heads_per_group = 2 (both heads share one KV → MQA)
Head 0 → group 0 = K[:, 0:2], V[:, 0:2]
Head 1 → group 0 = K[:, 0:2], V[:, 0:2] ← SAME as Head 0!
126weights_g1, output_g1 = gqa_mqa.forward(Q, K, V)

Run GQA with G=1. Head 1 now sees K[:,0:2] instead of K[:,2:4]. Dims 0-1 are identical (Head 0 unchanged), but dims 2-3 differ because Head 1 uses different K,V.

EXECUTION STATE
⬆ output_g1 (5×4) =
       d0       d1       d2       d3
The  0.2491   0.3764   0.2491   0.3764
cat  0.4110   0.1337   0.3583   0.2126
sat  0.2718   0.2718   0.2491   0.3764
on   0.3000   0.3000   0.2718   0.2718
mat  0.2491   0.3764   0.3583   0.2126
⚠ Notice = Row 0 (The): dims 2-3 = [0.2491, 0.3764] = SAME as dims 0-1! Head diversity lost.
132diff = np.abs(output_g2 - output_g1)

Element-wise absolute difference between G=2 and G=1 outputs. Dims 0-1 are always zero (Head 0 unchanged). Dims 2-3 show the quality impact of KV sharing.

EXECUTION STATE
diff (5×4) =
       d0     d1     d2     d3
The  0.000  0.000  0.020  0.010
cat  0.000  0.000  0.129  0.154
sat  0.000  0.000  0.020  0.010
on   0.000  0.000  0.092  0.186
mat  0.000  0.000  0.129  0.154
Max diff = 0.186 (token 'on', dim 3)
106 lines without explanation
1import numpy as np
2import math
3
4class GroupedQueryAttention:
5    """
6    Grouped-Query Attention (Ainslie et al., 2023)
7
8    Generalizes MHA (G=H) and MQA (G=1) via G KV groups.
9    Each group of H/G query heads shares one set of K, V.
10    Computes: head_i = Attention(Q_h, K_{g(h)}, V_{g(h)})
11    where g(h) = floor(h / (H/G))
12    """
13
14    def __init__(self, d_model: int, H: int, G: int):
15        """
16        Args:
17            d_model: Total model dimension
18            H: Number of query heads
19            G: Number of KV groups (1 <= G <= H)
20        """
21        assert d_model % H == 0
22        assert H % G == 0
23        self.d_model = d_model
24        self.H = H
25        self.G = G
26        self.d_k = d_model // H
27        self.heads_per_group = H // G
28        self.scale = math.sqrt(self.d_k)
29
30    def _softmax(self, x: np.ndarray) -> np.ndarray:
31        """Numerically stable softmax along last axis."""
32        x_shifted = x - np.max(x, axis=-1, keepdims=True)
33        exp_x = np.exp(x_shifted)
34        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
35
36    def _group_index(self, head_idx: int) -> int:
37        """Maps query head index to its KV group."""
38        return head_idx // self.heads_per_group
39
40    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
41        """
42        Args:
43            Q: Query matrix  (N, d_model)
44            K: Key matrix    (N, d_model)
45            V: Value matrix  (N, d_model)
46        Returns:
47            all_weights: List of H weight matrices, each (N, N)
48            output: Concatenated output (N, d_model)
49        """
50        N = Q.shape[0]
51        head_outputs = []
52        all_weights = []
53
54        for h in range(self.H):
55            g = self._group_index(h)
56
57            Q_h = Q[:, h * self.d_k : (h + 1) * self.d_k]
58            K_g = K[:, g * self.d_k : (g + 1) * self.d_k]
59            V_g = V[:, g * self.d_k : (g + 1) * self.d_k]
60
61            scores = Q_h @ K_g.T / self.scale
62            weights = self._softmax(scores)
63            output = weights @ V_g
64
65            all_weights.append(weights)
66            head_outputs.append(output)
67
68        return all_weights, np.hstack(head_outputs)
69
70    def explain(self, Q, K, V, tokens, query_idx=0):
71        """Print step-by-step trace for one token."""
72        all_weights, full_output = self.forward(Q, K, V)
73        token = tokens[query_idx]
74        mech = "MHA" if self.G == self.H else "MQA" if self.G == 1 else "GQA"
75        print(f"\n=== GQA trace for '{token}' (H={self.H}, G={self.G}) ===")
76        print(f"Mechanism: {mech}  |  Heads/group: {self.heads_per_group}")
77        print(f"KV saving: {(1 - self.G / self.H) * 100:.0f}%\n")
78
79        for h in range(self.H):
80            g = self._group_index(h)
81            print(f"--- Head {h} -> KV Group {g} ---")
82            w = all_weights[h]
83            for j, t in enumerate(tokens):
84                bar = "#" * int(w[query_idx, j] * 40)
85                print(f"  A[{token},{t}] = {w[query_idx, j]:.4f} |{bar}|")
86
87        print(f"\nOutput[{token}] = {np.round(full_output[query_idx], 4)}")
88
89
90# ── Shared Example (used in every chapter) ──
91tokens = ["The", "cat", "sat", "on", "mat"]
92
93Q = np.array([
94    [1.0, 0.0, 1.0, 0.0],   # The
95    [0.0, 2.0, 0.0, 1.0],   # cat
96    [1.0, 1.0, 1.0, 0.0],   # sat
97    [0.0, 0.0, 1.0, 1.0],   # on
98    [1.0, 0.0, 0.0, 1.0],   # mat
99])
100
101K = np.array([
102    [0.0, 1.0, 0.0, 1.0],   # The
103    [1.0, 0.0, 1.0, 0.0],   # cat
104    [1.0, 1.0, 0.0, 0.0],   # sat
105    [0.0, 0.0, 1.0, 1.0],   # on
106    [1.0, 0.0, 0.5, 0.5],   # mat
107])
108
109V = np.array([
110    [1.0, 0.0, 0.0, 0.0],   # The
111    [0.0, 1.0, 0.0, 0.0],   # cat
112    [0.0, 0.0, 1.0, 0.0],   # sat
113    [0.0, 0.0, 0.0, 1.0],   # on
114    [0.5, 0.5, 0.5, 0.5],   # mat
115])
116
117# ── GQA with G=2 (same as MHA for H=2) ──
118gqa_mha = GroupedQueryAttention(d_model=4, H=2, G=2)
119weights_g2, output_g2 = gqa_mha.forward(Q, K, V)
120print("=== GQA G=2 (MHA mode) ===")
121print("Output:\n", np.round(output_g2, 4))
122gqa_mha.explain(Q, K, V, tokens, query_idx=0)
123
124# ── GQA with G=1 (same as MQA for H=2) ──
125gqa_mqa = GroupedQueryAttention(d_model=4, H=2, G=1)
126weights_g1, output_g1 = gqa_mqa.forward(Q, K, V)
127print("\n=== GQA G=1 (MQA mode) ===")
128print("Output:\n", np.round(output_g1, 4))
129gqa_mqa.explain(Q, K, V, tokens, query_idx=0)
130
131# ── Compare ──
132diff = np.abs(output_g2 - output_g1)
133print("\nDifference:\n", np.round(diff, 4))
134print("Max diff:", round(diff.max(), 4))

PyTorch Implementation

The PyTorch version is structurally identical to the NumPy version. Key differences: torch.matmul replaces @ for clarity, F.softmax replaces our manual implementation, and an optional mask parameter supports causal attention (Chapter 3).

Grouped-Query Attention — PyTorch
🐍grouped_query_attention_torch.py
1import torch

PyTorch tensor library. Provides GPU acceleration, automatic differentiation, and CUDA support for production deployment.

2import torch.nn as nn

Neural network module. nn.Module is the base class for all PyTorch models. Handles parameter registration and device management.

3import torch.nn.functional as F

Functional API. We use F.softmax() instead of implementing our own — it handles edge cases and is optimized for GPU.

7class GroupedQueryAttention(nn.Module)

PyTorch module version of GQA. Inherits from nn.Module for compatibility with PyTorch training loops, .to(device), .eval(), state_dict(), etc.

15def __init__(self, d_model, H, G)

Same parameters as NumPy version. super().__init__() registers this as a PyTorch module.

EXECUTION STATE
⬇ d_model = 4
⬇ H = 2 query heads
⬇ G = 2 or 1 KV groups
43g = h // self.heads_per_group

Group index computed inline (no separate method). Same integer division as NumPy version. With G=2: head 0→0, head 1→1. With G=1: both→0.

EXECUTION STATE
G=2: h=0→g=0, h=1→g=1 = separate KV groups (MHA)
G=1: h=0→g=0, h=1→g=0 = shared KV group (MQA)
48scores = torch.matmul(Q_h, K_g.T) / self.scale

torch.matmul replaces NumPy’s @ operator. K_g.T is the same as K_g.transpose(-2,-1) for 2D tensors. Functionally identical to NumPy version.

EXECUTION STATE
torch.matmul vs @ = Equivalent for 2D. torch.matmul supports batched inputs (..., N, d_k).
50scores = scores.masked_fill(mask, float("-inf"))

If a mask is provided, set masked positions to -∞ before softmax. After softmax, these become 0.0 — the token cannot attend to masked positions. Used for causal attention (Chapter 3).

52weights = F.softmax(scores, dim=-1)

PyTorch’s built-in softmax. dim=-1 applies along the last axis (each row independently). Numerically stable internally.

EXECUTION STATE
dim=-1 = Softmax along last dimension — each query token’s attention weights sum to 1.0
57return all_weights, torch.cat(head_outputs, dim=-1)

torch.cat concatenates along the feature dimension (dim=-1), equivalent to np.hstack. Result is (5, 4) — same as NumPy version.

EXECUTION STATE
⬆ return: output shape = (5, 4) — 2 heads × 2 dims each
90 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class GroupedQueryAttention(nn.Module):
7    """
8    Grouped-Query Attention (Ainslie et al., 2023) - PyTorch
9
10    Raw version: takes pre-computed Q, K, V (no learned projections).
11    Supports GPU, autograd, optional causal mask.
12    """
13
14    def __init__(self, d_model: int, H: int, G: int):
15        super().__init__()
16        assert d_model % H == 0 and H % G == 0
17        self.d_model = d_model
18        self.H = H
19        self.G = G
20        self.d_k = d_model // H
21        self.heads_per_group = H // G
22        self.scale = math.sqrt(self.d_k)
23
24    def forward(
25        self,
26        Q: torch.Tensor,
27        K: torch.Tensor,
28        V: torch.Tensor,
29        mask: torch.Tensor | None = None,
30    ) -> tuple[list[torch.Tensor], torch.Tensor]:
31        """
32        Args:
33            Q: (N, d_model), K: (N, d_model), V: (N, d_model)
34            mask: Optional (N, N) boolean, True = masked
35        Returns:
36            all_weights: list of H tensors, each (N, N)
37            output: (N, d_model) concatenated
38        """
39        head_outputs = []
40        all_weights = []
41
42        for h in range(self.H):
43            g = h // self.heads_per_group
44
45            Q_h = Q[:, h * self.d_k : (h + 1) * self.d_k]
46            K_g = K[:, g * self.d_k : (g + 1) * self.d_k]
47            V_g = V[:, g * self.d_k : (g + 1) * self.d_k]
48
49            scores = torch.matmul(Q_h, K_g.T) / self.scale
50
51            if mask is not None:
52                scores = scores.masked_fill(mask, float("-inf"))
53
54            weights = F.softmax(scores, dim=-1)
55            output = torch.matmul(weights, V_g)
56
57            all_weights.append(weights)
58            head_outputs.append(output)
59
60        return all_weights, torch.cat(head_outputs, dim=-1)
61
62
63# ── Shared Example ──
64tokens = ["The", "cat", "sat", "on", "mat"]
65
66Q = torch.tensor([
67    [1.0, 0.0, 1.0, 0.0],
68    [0.0, 2.0, 0.0, 1.0],
69    [1.0, 1.0, 1.0, 0.0],
70    [0.0, 0.0, 1.0, 1.0],
71    [1.0, 0.0, 0.0, 1.0],
72])
73K = torch.tensor([
74    [0.0, 1.0, 0.0, 1.0],
75    [1.0, 0.0, 1.0, 0.0],
76    [1.0, 1.0, 0.0, 0.0],
77    [0.0, 0.0, 1.0, 1.0],
78    [1.0, 0.0, 0.5, 0.5],
79])
80V = torch.tensor([
81    [1.0, 0.0, 0.0, 0.0],
82    [0.0, 1.0, 0.0, 0.0],
83    [0.0, 0.0, 1.0, 0.0],
84    [0.0, 0.0, 0.0, 1.0],
85    [0.5, 0.5, 0.5, 0.5],
86])
87
88# GQA G=2
89gqa = GroupedQueryAttention(d_model=4, H=2, G=2)
90with torch.no_grad():
91    weights, output = gqa(Q, K, V)
92print("GQA G=2 output:")
93print(output.numpy().round(4))
94
95# GQA G=1
96gqa1 = GroupedQueryAttention(d_model=4, H=2, G=1)
97with torch.no_grad():
98    weights1, output1 = gqa1(Q, K, V)
99print("\nGQA G=1 output:")
100print(output1.numpy().round(4))
Production note: Real GQA implementations (e.g., in Hugging Face Transformers) use learned projection matrices (WiQ,Wg(i)K,Wg(i)VW_i^Q, W_{g(i)}^K, W_{g(i)}^V) and batch operations with repeat_interleave to expand the G KV heads to match H query heads. The loop-based version above is pedagogically clear but not optimal for GPU throughput.

Key Takeaways

  1. GQA is a generalization. It places MHA (G=HG = H) and MQA (G=1G = 1) as endpoints of a continuous spectrum controlled by a single parameter GG.
  2. The math is one line. The entire mechanism is controlled by g(i)=i/(H/G)g(i) = \lfloor i / (H/G) \rfloor — a floor division that maps each query head to its KV group.
  3. KV-cache scales with G, not H. Choosing G=H/8G = H/8 gives 87.5% memory savings with negligible quality loss in practice.
  4. Uptrain, don't retrain. Existing MHA models can be converted to GQA by mean-pooling KV heads within groups and fine-tuning with ~5% of original compute.
  5. It's the industry standard. LLaMA 3, Mistral, Gemma 2, StarCoder 2, and virtually all new production models use GQA.
  6. Composes with everything. GQA works seamlessly with Flash Attention, RoPE, ALiBi, PagedAttention, and KV-cache quantization.

Exercises

Exercise 1: Compute for “cat”

Using the shared example with H=2,G=2H = 2, G = 2, compute the full GQA output for “cat” (row 1). Then repeat with G=1G = 1 and compare the two outputs. Which dimensions change and why?

Exercise 2: KV-Cache Budget

You have a 128K context model with H=48H = 48 heads, dk=128d_k = 128, 40 layers, FP16. Your GPU has 24 GB of memory and 8 GB is available for KV-cache. What is the maximum GG you can use? What is the minimum?

Exercise 3: Prove the Spectrum

Prove formally that GQA with G=HG = H is mathematically identical to MHA, and that GQA with G=1G = 1 is identical to MQA. Start from the group assignment function and show that the KV slicing reduces to the expected behavior in each case.

Exercise 4: Uptrain Design

You have a trained MHA model with H=32H = 32 and want to convert it to GQA with G=4G = 4. (a) How many heads are in each group? (b) Write pseudocode for mean-pooling the WKW^K matrices. (c) If the original training took 1000 GPU-hours, how many GPU-hours would the uptrain take (approximately)?

Exercise 5: Flash Attention Synergy

Explain why GQA improves Flash Attention's performance beyond just reducing KV-cache size. Hint: think about memory bandwidth and the ratio of computation to memory access in the tiled attention kernel.


References

  1. Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. arXiv:2305.13245
  2. Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150
  3. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). Attention Is All You Need. NeurIPS 2017. arXiv:1706.03762
  4. Touvron, H., et al. (2023). LLaMA 2: Open Foundation and Fine-Tuned Chat Models. arXiv:2307.09288
  5. Meta AI (2024). The LLaMA 3 Herd of Models. arXiv:2407.21783
  6. Jiang, A. Q., et al. (2023). Mistral 7B. arXiv:2310.06825
  7. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691
  8. Kwon, W., et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023. arXiv:2309.06180
Loading comments...