Chapter 9
20 min read
Section 49 of 75

KV Caching for Efficient Inference

Autoregressive Generation

During autoregressive generation, the model recomputes attention for all previous tokens at every step—wasteful since earlier tokens don't change. KV-caching stores key and value projections, dramatically reducing inference time from O(n²) to O(n).


The Problem: Redundant Computation

Without Caching

📝text
1Step 1: Generate token 1
2  Input: [<bos>]
3  Compute K, V for: [<bos>]                    ← 1 token
4
5Step 2: Generate token 2
6  Input: [<bos>, The]
7  Compute K, V for: [<bos>, The]               ← 2 tokens (recomputed <bos>!)
8
9Step 3: Generate token 3
10  Input: [<bos>, The, dog]
11  Compute K, V for: [<bos>, The, dog]          ← 3 tokens (recomputed again!)
12
13Step N: Generate token N
14  Input: [<bos>, The, dog, ..., token_N-1]
15  Compute K, V for all N tokens                ← N tokens (all recomputed!)
16
17Total K,V computations: 1 + 2 + 3 + ... + N = N(N+1)/2 = O(N²)

With Caching

📝text
1Step 1: Generate token 1
2  Input: [<bos>]
3  Compute K, V for: [<bos>]                    ← 1 token
4  Cache: K₁, V₁
5
6Step 2: Generate token 2
7  Input: [<bos>, The]
8  Compute K, V for: [The] only                 ← 1 new token
9  Retrieve from cache: K₁, V₁
10  Concatenate: K = [K₁, K₂], V = [V₁, V₂]
11
12Step 3: Generate token 3
13  Input: [<bos>, The, dog]
14  Compute K, V for: [dog] only                 ← 1 new token
15  Retrieve from cache: K₁, K₂, V₁, V₂
16  Concatenate: K = [K₁, K₂, K₃], V = [V₁, V₂, V₃]
17
18Total K,V computations: 1 + 1 + 1 + ... = N = O(N)

Speedup

📝text
1Without caching: O(N²) computations
2With caching:    O(N) computations
3
4For N=100 tokens: ~50x speedup!
5For N=1000 tokens: ~500x speedup!

Implementation

Multi-Head Attention with KV Cache

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5from typing import Optional, Tuple, Dict
6
7
8class MultiHeadAttentionWithCache(nn.Module):
9    """
10    Multi-Head Attention with KV caching for efficient inference.
11
12    During inference, stores key and value projections to avoid
13    recomputing them for previous positions.
14
15    Args:
16        d_model: Model dimension
17        num_heads: Number of attention heads
18        dropout: Dropout probability
19
20    Example:
21        >>> mha = MultiHeadAttentionWithCache(512, 8)
22        >>> # First token
23        >>> out, cache = mha(x[:, :1, :], use_cache=True)
24        >>> # Second token (uses cache)
25        >>> out, cache = mha(x[:, 1:2, :], past_kv=cache, use_cache=True)
26    """
27
28    def __init__(
29        self,
30        d_model: int,
31        num_heads: int,
32        dropout: float = 0.1
33    ):
34        super().__init__()
35
36        assert d_model % num_heads == 0
37
38        self.d_model = d_model
39        self.num_heads = num_heads
40        self.d_k = d_model // num_heads
41
42        self.W_q = nn.Linear(d_model, d_model)
43        self.W_k = nn.Linear(d_model, d_model)
44        self.W_v = nn.Linear(d_model, d_model)
45        self.W_o = nn.Linear(d_model, d_model)
46
47        self.dropout = nn.Dropout(dropout)
48
49    def forward(
50        self,
51        query: torch.Tensor,
52        key: torch.Tensor,
53        value: torch.Tensor,
54        mask: Optional[torch.Tensor] = None,
55        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
56        use_cache: bool = False
57    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
58        """
59        Forward pass with optional KV caching.
60
61        Args:
62            query: [batch, query_len, d_model]
63            key: [batch, key_len, d_model]
64            value: [batch, key_len, d_model]
65            mask: Attention mask
66            past_kv: Cached (K, V) from previous steps
67            use_cache: Whether to return updated cache
68
69        Returns:
70            output: [batch, query_len, d_model]
71            present_kv: Updated (K, V) cache if use_cache=True
72        """
73        batch_size = query.size(0)
74        query_len = query.size(1)
75
76        # Project query, key, value
77        Q = self.W_q(query)  # [batch, query_len, d_model]
78        K = self.W_k(key)    # [batch, key_len, d_model]
79        V = self.W_v(value)  # [batch, key_len, d_model]
80
81        # Reshape for multi-head attention
82        Q = Q.view(batch_size, query_len, self.num_heads, self.d_k).transpose(1, 2)
83
84        key_len = key.size(1)
85        K = K.view(batch_size, key_len, self.num_heads, self.d_k).transpose(1, 2)
86        V = V.view(batch_size, key_len, self.num_heads, self.d_k).transpose(1, 2)
87
88        # Concatenate with past KV if provided
89        if past_kv is not None:
90            past_K, past_V = past_kv
91            K = torch.cat([past_K, K], dim=2)
92            V = torch.cat([past_V, V], dim=2)
93
94        # Cache current K, V
95        present_kv = (K, V) if use_cache else None
96
97        # Compute attention
98        total_len = K.size(2)
99        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
100
101        if mask is not None:
102            scores = scores.masked_fill(mask == 0, float('-inf'))
103
104        attn_weights = F.softmax(scores, dim=-1)
105        attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
106        attn_weights = self.dropout(attn_weights)
107
108        context = torch.matmul(attn_weights, V)
109
110        # Combine heads
111        context = context.transpose(1, 2).contiguous()
112        context = context.view(batch_size, query_len, self.d_model)
113
114        output = self.W_o(context)
115
116        return output, present_kv

Transformer Decoder Layer with Cache

🐍python
1class TransformerDecoderLayerWithCache(nn.Module):
2    """
3    Transformer Decoder Layer with KV caching.
4    """
5
6    def __init__(
7        self,
8        d_model: int = 512,
9        num_heads: int = 8,
10        d_ff: int = 2048,
11        dropout: float = 0.1
12    ):
13        super().__init__()
14
15        self.self_attention = MultiHeadAttentionWithCache(d_model, num_heads, dropout)
16        self.norm1 = nn.LayerNorm(d_model)
17
18        self.cross_attention = MultiHeadAttentionWithCache(d_model, num_heads, dropout)
19        self.norm2 = nn.LayerNorm(d_model)
20
21        self.ffn = nn.Sequential(
22            nn.Linear(d_model, d_ff),
23            nn.GELU(),
24            nn.Dropout(dropout),
25            nn.Linear(d_ff, d_model)
26        )
27        self.norm3 = nn.LayerNorm(d_model)
28        self.dropout = nn.Dropout(dropout)
29
30    def forward(
31        self,
32        x: torch.Tensor,
33        memory: torch.Tensor,
34        tgt_mask: Optional[torch.Tensor] = None,
35        memory_mask: Optional[torch.Tensor] = None,
36        past_kv: Optional[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = None,
37        use_cache: bool = False
38    ) -> Tuple[torch.Tensor, Optional[Dict[str, Tuple]]]:
39        """
40        Forward with caching.
41
42        Args:
43            x: Decoder input [batch, tgt_len, d_model]
44            memory: Encoder output [batch, src_len, d_model]
45            past_kv: Dict with 'self' and 'cross' cached KVs
46            use_cache: Whether to cache KVs
47
48        Returns:
49            output: [batch, tgt_len, d_model]
50            present_kv: Updated cache dict
51        """
52        present_kv = {} if use_cache else None
53
54        # Get past KV caches
55        past_self_kv = past_kv.get('self') if past_kv else None
56        past_cross_kv = past_kv.get('cross') if past_kv else None
57
58        # Self-attention (Pre-LN)
59        residual = x
60        x = self.norm1(x)
61        self_attn_out, self_kv = self.self_attention(
62            x, x, x,
63            mask=tgt_mask,
64            past_kv=past_self_kv,
65            use_cache=use_cache
66        )
67        x = residual + self.dropout(self_attn_out)
68
69        if use_cache:
70            present_kv['self'] = self_kv
71
72        # Cross-attention (Pre-LN)
73        residual = x
74        x = self.norm2(x)
75
76        # For cross-attention, we can also cache encoder projections
77        cross_attn_out, cross_kv = self.cross_attention(
78            x, memory, memory,
79            mask=memory_mask,
80            past_kv=past_cross_kv,
81            use_cache=use_cache
82        )
83        x = residual + self.dropout(cross_attn_out)
84
85        if use_cache:
86            present_kv['cross'] = cross_kv
87
88        # FFN (Pre-LN)
89        residual = x
90        x = self.norm3(x)
91        x = residual + self.dropout(self.ffn(x))
92
93        return x, present_kv

Complete Cached Decoder

Full Implementation

🐍python
1class TransformerDecoderWithCache(nn.Module):
2    """
3    Full Transformer Decoder with KV caching for efficient inference.
4    """
5
6    def __init__(
7        self,
8        vocab_size: int,
9        d_model: int = 512,
10        num_heads: int = 8,
11        num_layers: int = 6,
12        d_ff: int = 2048,
13        max_len: int = 5000,
14        dropout: float = 0.1,
15        pad_id: int = 0
16    ):
17        super().__init__()
18
19        self.d_model = d_model
20        self.pad_id = pad_id
21
22        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
23
24        # Positional encoding
25        pe = torch.zeros(max_len, d_model)
26        position = torch.arange(0, max_len).unsqueeze(1).float()
27        div_term = torch.exp(
28            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
29        )
30        pe[:, 0::2] = torch.sin(position * div_term)
31        pe[:, 1::2] = torch.cos(position * div_term)
32        self.register_buffer('pe', pe.unsqueeze(0))
33
34        self.layers = nn.ModuleList([
35            TransformerDecoderLayerWithCache(d_model, num_heads, d_ff, dropout)
36            for _ in range(num_layers)
37        ])
38
39        self.norm = nn.LayerNorm(d_model)
40        self.output_projection = nn.Linear(d_model, vocab_size)
41
42        self.dropout = nn.Dropout(dropout)
43
44    def forward(
45        self,
46        tgt_ids: torch.Tensor,
47        memory: torch.Tensor,
48        memory_mask: Optional[torch.Tensor] = None,
49        past_key_values: Optional[list] = None,
50        use_cache: bool = False,
51        position_offset: int = 0
52    ) -> Tuple[torch.Tensor, Optional[list]]:
53        """
54        Forward pass with optional caching.
55
56        Args:
57            tgt_ids: Target token IDs [batch, tgt_len]
58            memory: Encoder output [batch, src_len, d_model]
59            memory_mask: Cross-attention mask
60            past_key_values: List of cached KVs per layer
61            use_cache: Whether to use/update cache
62            position_offset: Position offset for positional encoding
63
64        Returns:
65            logits: [batch, tgt_len, vocab_size]
66            present_key_values: Updated cache
67        """
68        batch_size, tgt_len = tgt_ids.shape
69
70        # Embedding + positional encoding
71        x = self.embedding(tgt_ids) * math.sqrt(self.d_model)
72        x = x + self.pe[:, position_offset:position_offset + tgt_len, :]
73        x = self.dropout(x)
74
75        # Create causal mask for new tokens
76        # Only needed for the new positions
77        if tgt_len > 1:
78            causal_mask = torch.tril(
79                torch.ones(tgt_len, tgt_len, device=tgt_ids.device)
80            ).unsqueeze(0).unsqueeze(0)
81        else:
82            causal_mask = None
83
84        # Process through layers
85        present_key_values = [] if use_cache else None
86
87        for i, layer in enumerate(self.layers):
88            past_kv = past_key_values[i] if past_key_values else None
89
90            x, present_kv = layer(
91                x, memory,
92                tgt_mask=causal_mask,
93                memory_mask=memory_mask,
94                past_kv=past_kv,
95                use_cache=use_cache
96            )
97
98            if use_cache:
99                present_key_values.append(present_kv)
100
101        # Final norm and projection
102        x = self.norm(x)
103        logits = self.output_projection(x)
104
105        return logits, present_key_values
106
107
108class CachedGenerator:
109    """
110    Generator using KV caching for efficient autoregressive generation.
111    """
112
113    def __init__(
114        self,
115        encoder: nn.Module,
116        decoder: TransformerDecoderWithCache,
117        bos_token_id: int = 2,
118        eos_token_id: int = 3,
119        pad_token_id: int = 0
120    ):
121        self.encoder = encoder
122        self.decoder = decoder
123        self.bos_token_id = bos_token_id
124        self.eos_token_id = eos_token_id
125        self.pad_token_id = pad_token_id
126
127    @torch.no_grad()
128    def generate(
129        self,
130        src_ids: torch.Tensor,
131        max_length: int = 100,
132        temperature: float = 1.0,
133        do_sample: bool = False
134    ) -> torch.Tensor:
135        """
136        Generate with KV caching.
137
138        Much faster than without caching for long sequences.
139        """
140        self.encoder.eval()
141        self.decoder.eval()
142
143        batch_size = src_ids.size(0)
144        device = src_ids.device
145
146        # Encode source (once)
147        memory = self.encoder(src_ids)
148
149        # Initialize
150        generated = torch.full(
151            (batch_size, 1),
152            self.bos_token_id,
153            dtype=torch.long,
154            device=device
155        )
156
157        done = torch.zeros(batch_size, dtype=torch.bool, device=device)
158        past_key_values = None
159        position = 0
160
161        for step in range(max_length - 1):
162            # Only process new token (using cache for previous)
163            if past_key_values is None:
164                # First step: process BOS
165                input_ids = generated
166            else:
167                # Subsequent steps: only new token
168                input_ids = generated[:, -1:]
169
170            # Forward with cache
171            logits, past_key_values = self.decoder(
172                input_ids,
173                memory,
174                past_key_values=past_key_values,
175                use_cache=True,
176                position_offset=position
177            )
178
179            position += input_ids.size(1)
180
181            # Get next token logits
182            next_logits = logits[:, -1, :] / temperature
183
184            # Sample or greedy
185            if do_sample:
186                probs = F.softmax(next_logits, dim=-1)
187                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
188            else:
189                next_tokens = next_logits.argmax(dim=-1)
190
191            # Handle finished sequences
192            next_tokens = next_tokens.masked_fill(done, self.pad_token_id)
193
194            # Append
195            generated = torch.cat([
196                generated,
197                next_tokens.unsqueeze(1)
198            ], dim=1)
199
200            # Update done status
201            done = done | (next_tokens == self.eos_token_id)
202
203            if done.all():
204                break
205
206        return generated

Memory Considerations

Cache Memory Usage

📝text
1Memory per token cached (in KB):
2
3Small (d=256, layers=6):   12.00 KB/token
4Base (d=512, layers=6):    24.00 KB/token
5Large (d=1024, layers=12): 96.00 KB/token
6
7Total cache memory (in MB) for different sequence lengths:
8
9Config    L=100       L=500       L=1000      L=2000
10Small     1.2         6.0         12.0        24.0
11Base      2.4         12.0        24.0        48.0
12Large     9.6         48.0        96.0        192.0
13
14Key insights:
15- Cache grows linearly with sequence length
16- Cache grows linearly with number of layers
17- For long sequences, cache can exceed model size!
18
19Mitigation strategies:
20- Sliding window attention (only cache last N tokens)
21- Sparse attention patterns
22- Quantized caching (FP16 or INT8)
23- Multi-query attention (share K,V across heads)

Cross-Attention Caching

Caching Encoder Projections

🐍python
1class OptimizedCrossAttention(nn.Module):
2    """
3    Cross-attention with encoder K,V caching.
4
5    Encoder output doesn't change during generation,
6    so we can cache K and V projections of the encoder output.
7    """
8
9    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
10        super().__init__()
11
12        self.d_model = d_model
13        self.num_heads = num_heads
14        self.d_k = d_model // num_heads
15
16        self.W_q = nn.Linear(d_model, d_model)
17        self.W_k = nn.Linear(d_model, d_model)
18        self.W_v = nn.Linear(d_model, d_model)
19        self.W_o = nn.Linear(d_model, d_model)
20
21        self.dropout = nn.Dropout(dropout)
22
23        # Cached encoder projections
24        self._cached_k = None
25        self._cached_v = None
26
27    def cache_encoder(self, encoder_output: torch.Tensor):
28        """
29        Pre-compute and cache K, V for encoder output.
30        Call once before generation.
31        """
32        batch_size, src_len, _ = encoder_output.shape
33
34        K = self.W_k(encoder_output)
35        V = self.W_v(encoder_output)
36
37        self._cached_k = K.view(batch_size, src_len, self.num_heads, self.d_k).transpose(1, 2)
38        self._cached_v = V.view(batch_size, src_len, self.num_heads, self.d_k).transpose(1, 2)
39
40    def clear_cache(self):
41        """Clear cached encoder projections."""
42        self._cached_k = None
43        self._cached_v = None

Benefits of Cross-Attention Caching

📝text
1Without caching encoder K,V:
2─────────────────────────────
3Step 1: K_enc = W_k(encoder_output)  ← Computed
4        V_enc = W_v(encoder_output)  ← Computed
5Step 2: K_enc = W_k(encoder_output)  ← Recomputed!
6        V_enc = W_v(encoder_output)  ← Recomputed!
7...
8Step N: K_enc = W_k(encoder_output)  ← Recomputed N times!
9        V_enc = W_v(encoder_output)  ← Recomputed N times!
10
11With caching encoder K,V:
12─────────────────────────────
13Setup:  K_enc = W_k(encoder_output)  ← Computed once
14        V_enc = W_v(encoder_output)  ← Computed once
15        cache(K_enc, V_enc)
16
17Step 1: Use cached K_enc, V_enc
18Step 2: Use cached K_enc, V_enc
19...
20Step N: Use cached K_enc, V_enc
21
22Savings: (N-1) × encoder_projection_cost
23For N=100, src_len=200: ~99% reduction in cross-attn compute!

Summary

AspectWithout CacheWith Cache
Compute per stepO(seq_len × d)O(d)
Total computeO(seq_len² × d)O(seq_len × d)
MemoryO(seq_len × d)O(seq_len × layers × d)
ComplexitySimpleModerate

Best Practices

  1. Always use caching for inference (except for very short sequences)
  2. Cache encoder K,V once before generation starts
  3. Clear cache between different inputs
  4. Consider memory for very long sequences

Implementation Checklist

  • Self-attention KV cache (grows with each token)
  • Cross-attention encoder cache (computed once)
  • Position offset tracking
  • Proper cache clearing between batches

Chapter Summary

In this chapter, we covered autoregressive generation:

  1. Fundamentals: Token-by-token generation process
  2. Greedy Decoding: Simple, fast, deterministic
  3. Beam Search: Higher quality, maintains multiple hypotheses
  4. Sampling: Temperature, top-k, top-p for diversity
  5. KV Caching: Critical optimization for inference speed

The transformer is now ready for efficient translation inference!


In Chapter 10: Training Pipeline, we'll implement data loading and batching, loss computation with label smoothing, learning rate scheduling, training loop with validation, and checkpointing and logging.
Loading comments...