Chapter 4
20 min read
Section 20 of 117

The KV Cache Bottleneck

Multi-Head Latent Attention (MLA)

The Real Problem: Generation Is Sequential

A transformer is famously parallel during training. Feed a sequence of length S, and every token attends to every other token in one big matrix multiply. The whole forward pass happens at once on the GPU.

At inference, the picture inverts. To answer a user, the model must produce one token, then read what it just produced, then produce the next, and so on. Generation is autoregressive: token t+1t+1 depends on tokens 1t1 \dots t. The model cannot leap forward, because it does not yet know what it is going to say.

Now ask the naive question: at step t+1t+1, what does attention actually need? For every past token iti \le t, attention multiplies the new token's query qt+1q_{t+1} against the past token's key kik_i, then weights the past token's value viv_i by that score. The keys and values k1,,ktk_1, \dots, k_t and v1,,vtv_1, \dots, v_t were already computed on earlier steps. They do not change. Yet a naive implementation, asked to process a fresh forward call for the new token, recomputes all of them from scratch.

The waste: a naive decoder doing S tokens of generation performs 1+2+3++S=O(S2)1 + 2 + 3 + \dots + S = O(S^2) key/value projections, when O(S)O(S) would have been enough. At S = 32k, that is the difference between billions of redundant matmuls and none.

The fix is brutally simple in one sentence: store the K and V vectors of every past token in a buffer, and reuse them at every future step. That buffer is the KV cache. It buys back O(S)O(S) generation. It costs memory. And at the scale of frontier models, that memory cost becomes the dominant constraint on inference systems.

Why this matters: the entire Multi-Head Latent Attention architecture you will study in this chapter exists to compress this buffer. Before MLA can make sense, the bottleneck has to feel real. That is the job of this section.

The Intuition: A Notebook the Model Refuses to Reread

Picture a person taking an exam where each question depends on every previous answer. Two strategies:

  1. The naive student. Before answering question 100, they re-read questions 1 through 99 and re-solve them in their head, just to recompute the context they need. Question 1000 takes ten times longer than question 100.
  2. The student with a notebook. They write down their key and value for every past question as they go. Before question 100 they glance at the notebook. Question 1000 still takes one glance per past answer to attend over them, but they never re-solve.

The KV cache is the notebook. Two things are obvious from the analogy:

  • The notebook saves enormous amounts of redundant work.
  • The notebook grows with every question. By question 10,000 it is a small book.

For a 70B model with 80 layers and 64 heads of dimension 128, the "notebook" for a single 32k-token conversation is about 20 GB. Eight users in a batch turn that into 160 GB — already more memory than a single GPU has, before the model weights even load. That is the bottleneck.


The Mathematics of the Cache

Recall single-head causal self-attention. With input embedding xtRdmodelx_t \in \mathbb{R}^{d_{\text{model}}} for token tt, we form qt=Wqxtq_t = W_q x_t, kt=Wkxtk_t = W_k x_t, vt=Wvxtv_t = W_v x_t, where each is a vector in Rdh\mathbb{R}^{d_h}. The output for the new token tt is the familiar

ot=i=1tsoftmaxi ⁣(qtkidh)vio_t = \sum_{i=1}^{t} \mathrm{softmax}_i\!\left( \frac{q_t \cdot k_i}{\sqrt{d_h}} \right) v_i.

Read it slowly. The sum is over past indices i=1,,ti = 1, \dots, t. The pairs (ki,vi)(k_i, v_i) for i<ti < t were produced on earlier steps and the model has no reason to recompute them. Caching them is not a hack — the math literally asks for them.

With multiple heads (h=1,,nhh = 1, \dots, n_h) and multiple layers (=1,,L\ell = 1, \dots, L), every token writes 2Lnh2 \cdot L \cdot n_h vectors of dimension dhd_h into the cache: K for each (layer, head) and V for each (layer, head). For a whole batch of BB sequences of length SS, the cache stores

Ncache=2LnhdhSBN_{\text{cache}} = 2 \cdot L \cdot n_h \cdot d_h \cdot S \cdot B scalars in total. Where:

  • 22 — one tensor for K, one for V.
  • LL — number of transformer layers. Each layer has its own attention, so each has its own cache.
  • nhn_h — number of attention heads per layer.
  • dhd_h — dimension of each head, typically dmodel/nhd_{\text{model}} / n_h.
  • SS — sequence length (prompt + generated so far).
  • BB — number of concurrent sequences in the batch.

The Memory Equation

Multiply by the byte size of one scalar (bb bytes — 4 for fp32, 2 for fp16/bf16, 1 for fp8) to get bytes:

Mcache=2LnhdhSBb    bytes.M_{\text{cache}} = 2 \cdot L \cdot n_h \cdot d_h \cdot S \cdot B \cdot b \;\;\text{bytes}.

Every factor matters. Make any of them bigger and the cache grows linearly with it. Make all of them bigger — which is exactly what the trend in frontier models does — and the cache explodes.


Manual Numerical Walkthrough

Pick numbers up before you trust the formula. Use a Llama-2 7B style configuration and one long-context conversation, then scale to a real serving batch.


Interactive: KV Cache Memory Explorer

Drive the equation yourself. Move the sliders and watch where each axis bites: layers and heads are the multiplicative constants, sequence length is the linear ramp, batch size and precision are the brutal multipliers that determine whether you fit on one GPU.

Three observations worth pausing on:

  1. fp16 → fp8 halves the cache. This is why quantising the KV cache is one of the highest-leverage inference optimisations. Cache memory is the bottleneck; weights have already been compressed.
  2. Doubling batch size doubles the cache. Throughput-vs-latency tradeoffs in production serving are largely cache-memory tradeoffs.
  3. Long context is the dominant axis. 4k → 32k means 8× the cache. 32k → 128k means another 4×. The march toward million-token contexts is a direct collision course with this equation.

Plain Python Implementation

Before PyTorch, write the cache in raw numpy so the mechanism is naked. The naive version is deliberately ugly — that ugliness is the cost the cached version eliminates.

Naive vs cached self-attention (numpy)
🐍kv_cache_naive.py
3Naive baseline: no cache

This is the version that wastes work. At every step it recomputes K and V for every token in the prefix, not just the new one.

5Shape of the input

We assume the toy example has 4 tokens and d_model = 8. So tokens is a (4, 8) array. seq_len = 4.

7Step loop

We generate one token at a time, as a real decoder does at inference. t runs from 0 to seq_len - 1.

8The whole prefix

We slice the entire history up to and including token t. At t = 3 this is 4 rows: a (4, 8) matrix. We will multiply this whole thing by W_k and W_v even though rows 0..t-1 have not changed.

9Query projection

We compute Q for every token in the prefix. Only the last row (the current token) is actually needed to compute the new output, but the naive version pays for all of them.

10K recomputed (the waste)

Here is the disaster. We recompute every K vector we already computed on previous steps. The cost grows linearly per step → total work is quadratic in sequence length.

EXAMPLE
At t = 3 we redo K for tokens 0, 1, 2 — work we did on the last 3 steps.
11V recomputed (the waste)

Same story for V. Two big matmuls per step, both growing.

12Attention scores

Q[-1:] is the new token's query row, shape (1, d_h). K.T is (d_h, t+1). The dot product gives one row of scores: how much the new token attends to each past token.

13Softmax

Turn raw scores into a probability distribution over the prefix. One row of length t+1.

14Weighted average of V

The output of attention for token t. Shape (1, d_h). Then we move to step t+1.

18Cached version: the fix

Same math, but we never recompute K or V for past tokens. Each step appends exactly one row.

22Empty caches

K_cache and V_cache start with zero rows. They grow by one row per step. Each is (current_len, d_h).

26Pick only the new token

x_t is a (1, d_model) slice — just the current token. Compare to the naive version that grabbed the whole prefix.

27One Q for the current step

q_t is (1, d_h). One row. This is all we need to produce the current step's output.

28One K row

We compute the K vector for the new token only. Constant work per step.

29One V row

Same: one V vector for the new token only.

30Append to cache

K_cache becomes (t+1, d_h). The first t rows are the same objects we wrote on earlier steps — never recomputed.

32Attend over the whole cache

We still attend to every past token — the cache lets us do so without redoing the projections. q_t @ K_cache.T is (1, t+1).

34Weighted average

Same output shape as the naive version, identical numerical result. The only thing that changed is how much we computed to get here.

21 lines without explanation
1import numpy as np
2
3def attention_naive(tokens, W_q, W_k, W_v):
4    """Recompute K, V for the entire prefix at every step."""
5    seq_len, d_model = tokens.shape
6    outputs = []
7    for t in range(seq_len):
8        prefix = tokens[: t + 1]            # (t+1, d_model)
9        Q = prefix @ W_q                    # (t+1, d_h)
10        K = prefix @ W_k                    # (t+1, d_h)  <- recomputed!
11        V = prefix @ W_v                    # (t+1, d_h)  <- recomputed!
12        scores = Q[-1:] @ K.T               # (1, t+1)
13        weights = softmax(scores)
14        outputs.append(weights @ V)         # (1, d_h)
15    return np.vstack(outputs)
16
17
18def attention_with_kv_cache(tokens, W_q, W_k, W_v):
19    """Keep K, V from previous steps; only compute the new token's row."""
20    seq_len, d_model = tokens.shape
21    d_h = W_k.shape[1]
22    K_cache = np.zeros((0, d_h))
23    V_cache = np.zeros((0, d_h))
24    outputs = []
25    for t in range(seq_len):
26        x_t = tokens[t : t + 1]             # (1, d_model)
27        q_t = x_t @ W_q                     # (1, d_h)
28        k_t = x_t @ W_k                     # (1, d_h)
29        v_t = x_t @ W_v                     # (1, d_h)
30        K_cache = np.vstack([K_cache, k_t]) # grow cache by one row
31        V_cache = np.vstack([V_cache, v_t])
32        scores = q_t @ K_cache.T            # (1, t+1)
33        weights = softmax(scores)
34        outputs.append(weights @ V_cache)   # (1, d_h)
35    return np.vstack(outputs)
36
37
38def softmax(x):
39    e = np.exp(x - x.max(axis=-1, keepdims=True))
40    return e / e.sum(axis=-1, keepdims=True)

Run both on the same input and the outputs match to machine precision. The difference is entirely in the FLOPs spent to get there: the naive version does O(S2)O(S^2) projections, the cached version does O(S)O(S). For S = 1024 that is roughly a 500× speedup on the projections, before even counting the attention matmul itself.


PyTorch Implementation

Real systems do this inside an nn.Module\text{nn.Module} with shape (B,nh,T,dh)(B, n_h, T, d_h) tensors. The cache lives across forward calls and gets concatenated along the sequence axis. This is exactly how Hugging Face Transformers, vLLM, and TensorRT-LLM all structure it (give or take a memory layout trick).

CachedSelfAttention in PyTorch
🐍cached_attention.py
6Module skeleton

We subclass nn.Module so PyTorch can collect parameters and run autograd. d_model is the hidden size, n_heads the number of attention heads.

9Per-head dimension

Each head works on a slice of size d_h = d_model / n_heads. For Llama-2 7B that's 4096 / 32 = 128.

10Q, K, V projections

Three linear layers, no bias, mapping (B, T, d_model) → (B, T, d_model). These are the parameters that turn token embeddings into queries, keys, and values.

14Output projection

After attention we mix the heads back through W_o. Standard transformer block.

16Forward signature

The kv_cache argument is the heart of this lesson. It's either None (first step or prefill) or a (K_prev, V_prev) tuple from the previous call.

22Input shape

B is batch, T_new is how many new tokens we're processing this call. During generation T_new = 1; during the prompt prefill it equals the prompt length.

24Query for new tokens

Project x to Q, reshape to (B, T_new, n_heads, d_h), then transpose so heads come before sequence: (B, n_heads, T_new, d_h). This is the standard multi-head layout.

25Key for new tokens

k_new contains K vectors only for the new tokens — never for the prefix. Same shape conventions as Q. This is the row we are about to add to the cache.

26Value for new tokens

Same as K but for V. Together (k_new, v_new) is exactly what we will append to the cache.

28First step — no cache yet

On the very first call there is no past, so the cache is just the new K, V. The next call will be given this tuple back.

31Subsequent step — concatenate

k_prev shape (B, n_heads, T_prev, d_h). After torch.cat with k_new the new shape is (B, n_heads, T_prev + T_new, d_h). dim=2 is the sequence axis. No recomputation.

35Scaled dot-product

Standard attention. q shape (B, n_heads, T_new, d_h), k.transpose(-2, -1) shape (B, n_heads, d_h, T_total). Result (B, n_heads, T_new, T_total). During generation T_new = 1 and T_total = previous + 1.

36Softmax over keys

dim = -1 normalises across the T_total key positions. We get attention weights from each new query to every past + current token.

37Weighted sum of V

Multiply attention weights by V to mix the per-token value vectors. Then reshape back to (B, T_new, d_model).

39Return cache

We return both the output and the updated (k, v) tuple. The caller passes this tuple back on the next forward, so the cache lives across function calls.

24 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5
6class CachedSelfAttention(nn.Module):
7    def __init__(self, d_model: int, n_heads: int):
8        super().__init__()
9        self.n_heads = n_heads
10        self.d_h = d_model // n_heads
11        self.W_q = nn.Linear(d_model, d_model, bias=False)
12        self.W_k = nn.Linear(d_model, d_model, bias=False)
13        self.W_v = nn.Linear(d_model, d_model, bias=False)
14        self.W_o = nn.Linear(d_model, d_model, bias=False)
15
16    def forward(
17        self,
18        x: torch.Tensor,
19        kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None,
20    ):
21        # x: (B, T_new, d_model)  — usually T_new == 1 during generation
22        B, T_new, _ = x.shape
23
24        q = self.W_q(x).view(B, T_new, self.n_heads, self.d_h).transpose(1, 2)
25        k_new = self.W_k(x).view(B, T_new, self.n_heads, self.d_h).transpose(1, 2)
26        v_new = self.W_v(x).view(B, T_new, self.n_heads, self.d_h).transpose(1, 2)
27
28        if kv_cache is None:
29            k, v = k_new, v_new
30        else:
31            k_prev, v_prev = kv_cache                        # each (B, n_heads, T_prev, d_h)
32            k = torch.cat([k_prev, k_new], dim=2)            # grow along sequence axis
33            v = torch.cat([v_prev, v_new], dim=2)
34
35        scores = (q @ k.transpose(-2, -1)) / self.d_h ** 0.5  # (B, n_heads, T_new, T_total)
36        attn = F.softmax(scores, dim=-1)
37        out = (attn @ v).transpose(1, 2).contiguous().view(B, T_new, -1)
38
39        return self.W_o(out), (k, v)

A single-step generation loop then looks like this in spirit:

cacheNone;for each new token: (out,cache)attn(xt,cache)\text{cache} \leftarrow \text{None}; \quad \text{for each new token: } (\text{out}, \text{cache}) \leftarrow \text{attn}(x_t, \text{cache}).

The cache object grows by one slice along axis 2 each step. Across all layers, that growth is exactly the memory we have been computing.

Practical detail: production systems pre-allocate the cache to its maximum length and write into a slot, rather than concatenating. vLLM's PagedAttention goes further: it stores the cache in fixed-size pages, the same way an operating system pages virtual memory, so that GPU memory is not fragmented as sequences finish at different times. The algebra is identical; only the allocator changes.

Why It Becomes Catastrophic at Scale

Five things go wrong together as models and contexts grow.

PressureDirectionConsequence
More layers Ldeeper modelsLinear blow-up of cache; 80 layers ≈ 2.5× the cache of a 32-layer model at the same width.
More heads n_hmore parallel attention pathsLinear blow-up; each head writes its own K, V.
Larger context S32k → 128k → 1MCache is linear in S, but S is also the axis users push hardest. The headline driver.
Higher batch Bthroughput / costCache is linear in B. Doubling concurrent users doubles cache. Throughput stalls when cache hits the wall.
Lower precision bfp16 → fp8 → int4The only factor that goes down. Quantising the cache is the high-leverage lever — but quality must not regress.

The model parameters themselves do not grow with sequence length. The activations during a single forward pass also do not, beyond an O(S) factor. The KV cache is the one quantity that is multiplied by every architectural choice you have made and by the user's usage pattern at the same time. That is why it dominates.

Reality check from production. When a serving cluster falls over, it is almost always because the KV cache exceeded the GPU's HBM and the scheduler had to start evicting sessions. The model weights are static; activations are transient; the cache is the one resource that grows in real time as users converse.

Engineering Reality: What Production Systems Do

There are four families of solutions, listed in roughly the order they were invented. The rest of this chapter is the story of the fourth — DeepSeek's MLA — but every team is using some combination of these:

  1. Shrink B or S (give up). Smaller batches, shorter contexts. This works but throws away product capability. Used as a last resort.
  2. Share K and V across heads — GQA / MQA. The next section. Group several query heads to share a single K, V pair. Cuts the cache by the grouping factor, but caps how aggressively you can compress.
  3. Quantise the cache. fp16 → fp8 → int4. Halves or quarters the cache for essentially free, as long as the model tolerates the precision loss. Now standard.
  4. Compress what is stored — MLA. Project K and V into a small shared latent space, cache only that latent, and reconstruct K and V on the fly. The remaining sections of this chapter derive this from scratch and show why DeepSeek V3 can hold a 128k context with a cache that fits where a 70B's 8k cache used to live.

Adjacent to the cache itself, two systems-level techniques squeeze more out of whatever cache you have:

  • PagedAttention (vLLM). Manages the cache as pages, eliminating internal fragmentation so you can pack more concurrent sequences into the same memory.
  • Prefix sharing. When many users share the same system prompt, store one copy of its KV cache and reuse it for all of them. Saves billions of bytes in chatbot workloads.

Summary and Bridge to MLA

Three things to carry into the rest of the chapter:

  1. The KV cache is a direct mathematical consequence of causal attention. It is not a hack — the attention sum literally calls for the past tokens' K and V vectors, and storing them once is the only sane option.
  2. Its size is 2LnhdhSBb2 L n_h d_h S B b. Every factor is something the architecture or the workload pushes upward. There is no axis along which the cache shrinks on its own.
  3. At frontier scale (70B+ models, 32k+ context, real batches) the cache is the dominant memory cost of inference — bigger than the weights, bigger than the activations, bigger than anything else on the GPU.

The next section examines the first serious architectural fix — Grouped-Query Attention — which shares K and V across groups of heads and was adopted by Llama 2, Mistral, and most open models. Then we will see why GQA, while a big win, is not enough; and finally we will derive Multi-Head Latent Attention, the technique that lets DeepSeek V3 carry massive contexts at batch sizes its competitors cannot match.

Loading comments...