Ainslie, Lee-Thorp, de Jong, Zemlyanskiy, Lebrón & Sanghai, "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints", Google Research, 2023
Learning Objectives
After completing this chapter, you will be able to:
- Explain why the KV-cache is the primary memory bottleneck during autoregressive inference and why MHA makes it worse than necessary
- Describe how GQA partitions query heads into groups where each group shares a single set of keys and values
- Compute the group assignment function and understand how it generalizes MHA () and MQA ()
- Perform a full GQA forward pass by hand on the shared “The cat sat on mat” example, comparing and
- Calculate KV-cache memory savings for production models like LLaMA 3 70B and Mistral 7B
- Understand the “uptrain” procedure that converts an existing MHA checkpoint to GQA with only 5% of original training compute
- Implement GQA from scratch in both NumPy and PyTorch
The Problem: The KV-Cache Dilemma
What Is the KV-Cache?
When a transformer generates text token by token (autoregressive decoding), it must recompute the attention scores for every previously generated token at every step. To avoid this recomputation, modern systems cache the Key and Value tensors from all previous positions. This “KV-cache” trades memory for speed: once a token's and are computed, they are stored and reused for all subsequent tokens.
The problem is that the KV-cache grows with three multipliers: the number of attention heads , the sequence length , and the per-head dimension . For each transformer layer, the cache stores floating-point values (the factor of 2 accounts for both K and V).
MHA's Memory Wall
In standard Multi-Head Attention (Chapter 2), every head has its own independent and projections. For LLaMA 3 70B with heads, , and 80 layers in FP16:
At a 128K context window, that is 320 GB of KV-cache alone — far exceeding the memory of any single GPU. This memory wall is the reason that long-context inference requires multi-GPU setups or aggressive compression.
MQA's Quality Cliff
Chapter 5 introduced Multi-Query Attention (MQA, Shazeer 2019), which shares a single pair across all heads. This reduces the KV-cache by a factor of — from 320 GB to just 5 GB in our LLaMA example. The speed improvement is dramatic.
But the quality cost is real. When all heads are forced to attend through the same key-value lens, the model loses the representational diversity that makes multi-head attention powerful. Benchmarks show measurable degradation on reasoning-heavy tasks, summarization, and multi-hop question answering (Ainslie et al., 2023, Table 1).
The Core Tension: MHA gives maximum quality but is memory-hungry. MQA gives maximum speed but sacrifices quality. Is there a sweet spot between these extremes?
The Intuition Behind Grouped-Query Attention
The Library Analogy
Imagine a research institute with 64 researchers (query heads). Each researcher needs to look up information in reference books (keys and values) to answer questions.
- MHA: Every researcher has their own private library. Maximum research diversity, but the institute needs 64 separate library buildings. Expensive.
- MQA: Everyone shares a single library. Cheap to maintain, but when 64 researchers all need the same book simultaneously, the bottleneck is severe. Popular topics get congested; niche topics are underserved.
- GQA: Researchers are organized into departments of 8. Each department has its own library. Eight libraries instead of 64 — an cost reduction — but each department can specialize its collection. The biology department keeps different books than the physics department.
This is exactly how GQA works. Instead of separate KV sets (MHA) or 1 shared KV set (MQA), GQA uses groups, where . Each group of query heads shares one set of keys and values.
The Paper: Ainslie et al. 2023
Grouped-Query Attention was introduced by Joshua Ainslie and colleagues at Google Research in 2023. Their paper, “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,” made two key contributions:
- The GQA mechanism itself: a clean generalization that places MHA and MQA as special cases on a single continuum controlled by
- The uptrain procedure: a practical method to convert existing MHA models to GQA without training from scratch, requiring only ~5% of the original training compute
The Uptrain Insight
The uptrain procedure works by mean-pooling the KV projection weights within each group. If an MHA model has 64 separate matrices and you want groups, you average heads 0–7 into group 0's , heads 8–15 into group 1's , and so on. Then you fine-tune the model for a small fraction of the original training budget to recover any quality loss.
This is how LLaMA 2 70B was converted from MHA to GQA. The resulting model achieved near-MHA quality at near-MQA inference speed, and GQA has since become the default architecture for virtually all production large language models.
Mathematical Formulation
The Group Assignment Function
The heart of GQA is a simple function that maps each query head to its KV group:
where is the floor function (integer division). With query heads and KV groups, each group serves heads. Heads 0 through map to group 0, heads through map to group 1, and so on.
The GQA Equation
Using the group assignment, each head computes standard scaled dot-product attention but uses its group's K and V rather than its own:
The subscript on and is what makes GQA different: multiple heads share the same projection matrices for keys and values.
Symbol-by-Symbol Breakdown
| Symbol | Meaning | Our Example |
|---|---|---|
| Total number of query heads | 2 | |
| Number of KV groups (1 ≤ G ≤ H, must divide H) | 2 or 1 | |
| Per-head dimension = d_model / H | 4 / 2 = 2 | |
| Heads per group | 1 (G=2) or 2 (G=1) | |
| Group index for head i | floor(i / (H/G)) | |
| Query projection for head i (unique per head) | — | |
| Key projection for group g(i) (shared within group) | — | |
| Value projection for group g(i) (shared within group) | — |
The Spectrum: MQA ↔ GQA ↔ MHA
GQA is a single mechanism that continuously interpolates between MQA and MHA:
| Configuration | Groups (G) | Heads/Group | KV-Cache Size | Quality |
|---|---|---|---|---|
| MQA | 1 | `H` | 2 × 1 × N × d_k | Lowest |
| GQA | 1 < G < H | H / G | 2 × G × N × d_k | Tunable |
| MHA | H | 1 | 2 × H × N × d_k | Highest |
Interactive: Group Visualizer
Use the controls below to explore how changing affects the head-to-group mapping. Try the model presets to see real configurations used in production:
KV-Cache Memory Analysis
The Memory Formula
The KV-cache size per layer in bytes is:
The total cache across all layers is times this value. The key insight: the cache scales with , not . By choosing , we directly reduce memory consumption.
| Model | H | G | Saving vs MHA | Cache at 4K (FP16) |
|---|---|---|---|---|
| Our Example | 2 | 2 (MHA) | 0% | 80 B |
| Our Example | 2 | 1 (MQA) | 50% | 40 B |
| LLaMA 3 8B | 32 | 8 | 75% | 256 MB |
| LLaMA 3 70B | 64 | 8 | 87.5% | 1.28 GB |
| Mistral 7B | 32 | 8 | 75% | 256 MB |
Interactive: KV-Cache Comparison
Select a model and sequence length to see how GQA reduces KV-cache memory compared to MHA and MQA:
Step-by-Step Calculation
Setup: Shared Example
We use the same matrices as every chapter, with and query heads ( per head). We compare two configurations:
- G = 2 (one group per head = MHA behavior): Head 0 uses ; Head 1 uses
- G = 1 (one group for all = MQA behavior): Both heads use
GQA G = 2 (MHA Mode)
With and , . Each head gets its own KV group. We trace the computation for “The” (row 0).
Head 0 → Group 0
. Head 0 uses , .
Query:
Scaled scores ():
| Key token | Dot product | Scaled |
|---|---|---|
| The: K=[0,1] | (1×0 + 0×1) = 0 | 0.0000 |
| cat: K=[1,0] | (1×1 + 0×0) = 1 | 0.7071 |
| sat: K=[1,1] | (1×1 + 0×1) = 1 | 0.7071 |
| on: K=[0,0] | (1×0 + 0×0) = 0 | 0.0000 |
| mat: K=[1,0] | (1×1 + 0×0) = 1 | 0.7071 |
Softmax:
Output:
Head 1 → Group 1
. Head 1 uses , .
Query:
Scaled scores ():
| Key token | Dot product | Scaled |
|---|---|---|
| The: K=[0,1] | (1×0 + 0×1) = 0 | 0.0000 |
| cat: K=[1,0] | (1×1 + 0×0) = 1 | 0.7071 |
| sat: K=[0,0] | (1×0 + 0×0) = 0 | 0.0000 |
| on: K=[1,1] | (1×1 + 0×1) = 1 | 0.7071 |
| mat: K=[0.5,0.5] | (1×0.5 + 0×0.5) = 0.5 | 0.3536 |
Softmax:
Output:
Concatenation
The final output for “The” concatenates both head outputs:
GQA G = 1 (MQA Mode)
Now set . With , both heads share the same KV group. Head 0 is unchanged (it already used group 0). But Head 1 now also uses group 0: and instead of and .
Head 1 Changes KV Source
Query: (unchanged — queries are always head-specific)
Scaled scores (now against instead of ):
| Key token | G=2 score (K[:,2:4]) | G=1 score (K[:,0:2]) | Changed? |
|---|---|---|---|
| The | 0.0000 | 0.0000 | — |
| cat | 0.7071 | 0.7071 | — |
| sat | 0.0000 | 0.7071 | ✅ changed |
| on | 0.7071 | 0.0000 | ✅ changed |
| mat | 0.3536 | 0.7071 | ✅ changed |
New softmax:
New output:
Critical observation: Head 1's output for “The” is now — exactly the same as Head 0's output! When both heads share the same K, V, and their queries happen to produce similar attention patterns, the heads become redundant. This is the quality cost of MQA that GQA exists to mitigate.
Output Comparison: G = 2 vs G = 1
| Token | G=2 output (MHA) | G=1 output (MQA) | Max Δ |
|---|---|---|---|
| The | [0.2491, 0.3764, 0.2289, 0.3663] | [0.2491, 0.3764, 0.2491, 0.3764] | 0.020 |
| cat | [0.4110, 0.1337, 0.2289, 0.3663] | [0.4110, 0.1337, 0.3583, 0.2126] | 0.154 |
| sat | [0.2718, 0.2718, 0.2289, 0.3663] | [0.2718, 0.2718, 0.2491, 0.3764] | 0.020 |
| on | [0.3000, 0.3000, 0.1799, 0.4579] | [0.3000, 0.3000, 0.2718, 0.2718] | 0.186 |
| mat | [0.2491, 0.3764, 0.2289, 0.3663] | [0.2491, 0.3764, 0.3583, 0.2126] | 0.154 |
Dims 0–1 (Head 0) are identical in both configurations. Dims 2–3 (Head 1) differ because Head 1 sees different keys and values. The maximum change is 0.186 for “on” in dimension 3. In our toy example with , there is no middle ground — can only be 1 or 2. The real power of GQA emerges with where can be 2, 4, 8, 16, or 32.
Full Attention Weight Matrices
Averaged Weights — GQA G = 2 (MHA)
| The | cat | sat | on | mat | |
|---|---|---|---|---|---|
| The | 0.1287 | 0.2610 | 0.1923 | 0.1974 | 0.2206 |
| cat | 0.3188 | 0.1114 | 0.2501 | 0.1801 | 0.1397 |
| sat | 0.1575 | 0.2262 | 0.2505 | 0.1802 | 0.1858 |
| on | 0.1906 | 0.1906 | 0.1447 | 0.2837 | 0.1906 |
| mat | 0.1974 | 0.1923 | 0.1923 | 0.1974 | 0.2206 |
Full Output Matrices
Output — GQA G = 2 (MHA)
| dim-0 | dim-1 | dim-2 | dim-3 | |
|---|---|---|---|---|
| The | 0.2491 | 0.3764 | 0.2289 | 0.3663 |
| cat | 0.4110 | 0.1337 | 0.2289 | 0.3663 |
| sat | 0.2718 | 0.2718 | 0.2289 | 0.3663 |
| on | 0.3000 | 0.3000 | 0.1799 | 0.4579 |
| mat | 0.2491 | 0.3764 | 0.2289 | 0.3663 |
Output — GQA G = 1 (MQA)
| dim-0 | dim-1 | dim-2 | dim-3 | |
|---|---|---|---|---|
| The | 0.2491 | 0.3764 | 0.2491 | 0.3764 |
| cat | 0.4110 | 0.1337 | 0.3583 | 0.2126 |
| sat | 0.2718 | 0.2718 | 0.2491 | 0.3764 |
| on | 0.3000 | 0.3000 | 0.2718 | 0.2718 |
| mat | 0.2491 | 0.3764 | 0.3583 | 0.2126 |
Interactive: Attention Heatmap
Toggle between G = 2 (MHA) and G = 1 (MQA) to see how sharing KV groups changes the attention patterns. Cells highlighted in amber indicate weights that changed from the MHA baseline:
Applications Across Domains
| Domain | Model | Config | Why GQA? |
|---|---|---|---|
| NLP | LLaMA 3 70B | H=64, G=8 | 128K context with manageable KV-cache |
| NLP | Mistral 7B | H=32, G=8 | Fast inference for 7B-class model |
| NLP | Gemma 2 27B | H=32, G=16 | High quality with moderate compression |
| Code | StarCoder 2 | H=24, G=8 | Long code context (16K+) with fast completion |
| Vision | PaliGemma | H=16, G=4 | High-resolution image understanding |
| Multi-modal | LLaVA-Next | GQA backbone | Efficient vision-language inference |
GQA has become the de facto standard for all new large models. The few exceptions (very small models under 1B parameters) use MHA because the KV-cache savings are negligible at small scale.
Connection to Modern Systems
- Flash Attention + GQA: Flash Attention (Chapter 13) is an IO-aware implementation trick that does not change the mathematical result. It composes perfectly with GQA — the reduced number of KV heads means fewer memory reads during tiled computation, amplifying Flash Attention's speed benefits. All major inference engines (vLLM, TGI, TensorRT-LLM) implement this combination.
- Paged Attention + GQA: vLLM's PagedAttention manages KV-cache as virtual memory pages. With GQA, the page tables are smaller, enabling higher batch sizes and better GPU utilization for serving.
- KV-Cache Quantization: GQA reduces the number of KV heads that need to be stored, and each KV head can be further compressed via INT8 or INT4 quantization. The combination gives to total compression versus MHA FP16.
- Speculative Decoding: Draft models in speculative decoding pipelines often use aggressive GQA ( or ) because inference speed matters more than quality for the draft — the verifier catches errors.
- Positional Encodings (Chapters 7–9): RoPE and ALiBi are applied to Q and K vectors. In GQA, the rotation/bias is applied to each query head independently but to each KV group only once. This is a minor efficiency gain but matters at scale.
Complexity Analysis
| Metric | MHA | GQA | MQA |
|---|---|---|---|
| Compute (FLOPs) | |||
| KV-Cache Memory | |||
| KV Projection Params |
Note that the compute cost is identical across all three — GQA does not reduce FLOPs. The savings come entirely from memory bandwidth: less data needs to be loaded from GPU memory during each attention operation, which translates to higher throughput during inference.
Python Implementation
The full GQA class with the shared example. The _group_index method is the only addition beyond standard multi-head attention — one line of integer division controls the entire KV sharing behavior.
PyTorch Implementation
The PyTorch version is structurally identical to the NumPy version. Key differences: torch.matmul replaces @ for clarity, F.softmax replaces our manual implementation, and an optional mask parameter supports causal attention (Chapter 3).
Key Takeaways
- GQA is a generalization. It places MHA () and MQA () as endpoints of a continuous spectrum controlled by a single parameter .
- The math is one line. The entire mechanism is controlled by — a floor division that maps each query head to its KV group.
- KV-cache scales with G, not H. Choosing gives 87.5% memory savings with negligible quality loss in practice.
- Uptrain, don't retrain. Existing MHA models can be converted to GQA by mean-pooling KV heads within groups and fine-tuning with ~5% of original compute.
- It's the industry standard. LLaMA 3, Mistral, Gemma 2, StarCoder 2, and virtually all new production models use GQA.
- Composes with everything. GQA works seamlessly with Flash Attention, RoPE, ALiBi, PagedAttention, and KV-cache quantization.
Exercises
Exercise 1: Compute for “cat”
Using the shared example with , compute the full GQA output for “cat” (row 1). Then repeat with and compare the two outputs. Which dimensions change and why?
Exercise 2: KV-Cache Budget
You have a 128K context model with heads, , 40 layers, FP16. Your GPU has 24 GB of memory and 8 GB is available for KV-cache. What is the maximum you can use? What is the minimum?
Exercise 3: Prove the Spectrum
Prove formally that GQA with is mathematically identical to MHA, and that GQA with is identical to MQA. Start from the group assignment function and show that the KV slicing reduces to the expected behavior in each case.
Exercise 4: Uptrain Design
You have a trained MHA model with and want to convert it to GQA with . (a) How many heads are in each group? (b) Write pseudocode for mean-pooling the matrices. (c) If the original training took 1000 GPU-hours, how many GPU-hours would the uptrain take (approximately)?
Exercise 5: Flash Attention Synergy
Explain why GQA improves Flash Attention's performance beyond just reducing KV-cache size. Hint: think about memory bandwidth and the ratio of computation to memory access in the tiled attention kernel.
References
- Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. arXiv:2305.13245
- Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150
- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). Attention Is All You Need. NeurIPS 2017. arXiv:1706.03762
- Touvron, H., et al. (2023). LLaMA 2: Open Foundation and Fine-Tuned Chat Models. arXiv:2307.09288
- Meta AI (2024). The LLaMA 3 Herd of Models. arXiv:2407.21783
- Jiang, A. Q., et al. (2023). Mistral 7B. arXiv:2310.06825
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691
- Kwon, W., et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023. arXiv:2309.06180