During autoregressive generation, the model recomputes attention for all previous tokens at every step—wasteful since earlier tokens don't change. KV-caching stores key and value projections, dramatically reducing inference time from O(n²) to O(n).
The Problem: Redundant Computation
Without Caching
📝text
1Step 1: Generate token 1
2 Input: [<bos>]
3 Compute K, V for: [<bos>] ← 1 token
4
5Step 2: Generate token 2
6 Input: [<bos>, The]
7 Compute K, V for: [<bos>, The] ← 2 tokens (recomputed <bos>!)
8
9Step 3: Generate token 3
10 Input: [<bos>, The, dog]
11 Compute K, V for: [<bos>, The, dog] ← 3 tokens (recomputed again!)
12
13Step N: Generate token N
14 Input: [<bos>, The, dog, ..., token_N-1]
15 Compute K, V for all N tokens ← N tokens (all recomputed!)
16
17Total K,V computations: 1 + 2 + 3 + ... + N = N(N+1)/2 = O(N²)With Caching
📝text
1Step 1: Generate token 1
2 Input: [<bos>]
3 Compute K, V for: [<bos>] ← 1 token
4 Cache: K₁, V₁
5
6Step 2: Generate token 2
7 Input: [<bos>, The]
8 Compute K, V for: [The] only ← 1 new token
9 Retrieve from cache: K₁, V₁
10 Concatenate: K = [K₁, K₂], V = [V₁, V₂]
11
12Step 3: Generate token 3
13 Input: [<bos>, The, dog]
14 Compute K, V for: [dog] only ← 1 new token
15 Retrieve from cache: K₁, K₂, V₁, V₂
16 Concatenate: K = [K₁, K₂, K₃], V = [V₁, V₂, V₃]
17
18Total K,V computations: 1 + 1 + 1 + ... = N = O(N)Speedup
📝text
1Without caching: O(N²) computations
2With caching: O(N) computations
3
4For N=100 tokens: ~50x speedup!
5For N=1000 tokens: ~500x speedup!Implementation
Multi-Head Attention with KV Cache
🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5from typing import Optional, Tuple, Dict
6
7
8class MultiHeadAttentionWithCache(nn.Module):
9 """
10 Multi-Head Attention with KV caching for efficient inference.
11
12 During inference, stores key and value projections to avoid
13 recomputing them for previous positions.
14
15 Args:
16 d_model: Model dimension
17 num_heads: Number of attention heads
18 dropout: Dropout probability
19
20 Example:
21 >>> mha = MultiHeadAttentionWithCache(512, 8)
22 >>> # First token
23 >>> out, cache = mha(x[:, :1, :], use_cache=True)
24 >>> # Second token (uses cache)
25 >>> out, cache = mha(x[:, 1:2, :], past_kv=cache, use_cache=True)
26 """
27
28 def __init__(
29 self,
30 d_model: int,
31 num_heads: int,
32 dropout: float = 0.1
33 ):
34 super().__init__()
35
36 assert d_model % num_heads == 0
37
38 self.d_model = d_model
39 self.num_heads = num_heads
40 self.d_k = d_model // num_heads
41
42 self.W_q = nn.Linear(d_model, d_model)
43 self.W_k = nn.Linear(d_model, d_model)
44 self.W_v = nn.Linear(d_model, d_model)
45 self.W_o = nn.Linear(d_model, d_model)
46
47 self.dropout = nn.Dropout(dropout)
48
49 def forward(
50 self,
51 query: torch.Tensor,
52 key: torch.Tensor,
53 value: torch.Tensor,
54 mask: Optional[torch.Tensor] = None,
55 past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
56 use_cache: bool = False
57 ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
58 """
59 Forward pass with optional KV caching.
60
61 Args:
62 query: [batch, query_len, d_model]
63 key: [batch, key_len, d_model]
64 value: [batch, key_len, d_model]
65 mask: Attention mask
66 past_kv: Cached (K, V) from previous steps
67 use_cache: Whether to return updated cache
68
69 Returns:
70 output: [batch, query_len, d_model]
71 present_kv: Updated (K, V) cache if use_cache=True
72 """
73 batch_size = query.size(0)
74 query_len = query.size(1)
75
76 # Project query, key, value
77 Q = self.W_q(query) # [batch, query_len, d_model]
78 K = self.W_k(key) # [batch, key_len, d_model]
79 V = self.W_v(value) # [batch, key_len, d_model]
80
81 # Reshape for multi-head attention
82 Q = Q.view(batch_size, query_len, self.num_heads, self.d_k).transpose(1, 2)
83
84 key_len = key.size(1)
85 K = K.view(batch_size, key_len, self.num_heads, self.d_k).transpose(1, 2)
86 V = V.view(batch_size, key_len, self.num_heads, self.d_k).transpose(1, 2)
87
88 # Concatenate with past KV if provided
89 if past_kv is not None:
90 past_K, past_V = past_kv
91 K = torch.cat([past_K, K], dim=2)
92 V = torch.cat([past_V, V], dim=2)
93
94 # Cache current K, V
95 present_kv = (K, V) if use_cache else None
96
97 # Compute attention
98 total_len = K.size(2)
99 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
100
101 if mask is not None:
102 scores = scores.masked_fill(mask == 0, float('-inf'))
103
104 attn_weights = F.softmax(scores, dim=-1)
105 attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
106 attn_weights = self.dropout(attn_weights)
107
108 context = torch.matmul(attn_weights, V)
109
110 # Combine heads
111 context = context.transpose(1, 2).contiguous()
112 context = context.view(batch_size, query_len, self.d_model)
113
114 output = self.W_o(context)
115
116 return output, present_kvTransformer Decoder Layer with Cache
🐍python
1class TransformerDecoderLayerWithCache(nn.Module):
2 """
3 Transformer Decoder Layer with KV caching.
4 """
5
6 def __init__(
7 self,
8 d_model: int = 512,
9 num_heads: int = 8,
10 d_ff: int = 2048,
11 dropout: float = 0.1
12 ):
13 super().__init__()
14
15 self.self_attention = MultiHeadAttentionWithCache(d_model, num_heads, dropout)
16 self.norm1 = nn.LayerNorm(d_model)
17
18 self.cross_attention = MultiHeadAttentionWithCache(d_model, num_heads, dropout)
19 self.norm2 = nn.LayerNorm(d_model)
20
21 self.ffn = nn.Sequential(
22 nn.Linear(d_model, d_ff),
23 nn.GELU(),
24 nn.Dropout(dropout),
25 nn.Linear(d_ff, d_model)
26 )
27 self.norm3 = nn.LayerNorm(d_model)
28 self.dropout = nn.Dropout(dropout)
29
30 def forward(
31 self,
32 x: torch.Tensor,
33 memory: torch.Tensor,
34 tgt_mask: Optional[torch.Tensor] = None,
35 memory_mask: Optional[torch.Tensor] = None,
36 past_kv: Optional[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = None,
37 use_cache: bool = False
38 ) -> Tuple[torch.Tensor, Optional[Dict[str, Tuple]]]:
39 """
40 Forward with caching.
41
42 Args:
43 x: Decoder input [batch, tgt_len, d_model]
44 memory: Encoder output [batch, src_len, d_model]
45 past_kv: Dict with 'self' and 'cross' cached KVs
46 use_cache: Whether to cache KVs
47
48 Returns:
49 output: [batch, tgt_len, d_model]
50 present_kv: Updated cache dict
51 """
52 present_kv = {} if use_cache else None
53
54 # Get past KV caches
55 past_self_kv = past_kv.get('self') if past_kv else None
56 past_cross_kv = past_kv.get('cross') if past_kv else None
57
58 # Self-attention (Pre-LN)
59 residual = x
60 x = self.norm1(x)
61 self_attn_out, self_kv = self.self_attention(
62 x, x, x,
63 mask=tgt_mask,
64 past_kv=past_self_kv,
65 use_cache=use_cache
66 )
67 x = residual + self.dropout(self_attn_out)
68
69 if use_cache:
70 present_kv['self'] = self_kv
71
72 # Cross-attention (Pre-LN)
73 residual = x
74 x = self.norm2(x)
75
76 # For cross-attention, we can also cache encoder projections
77 cross_attn_out, cross_kv = self.cross_attention(
78 x, memory, memory,
79 mask=memory_mask,
80 past_kv=past_cross_kv,
81 use_cache=use_cache
82 )
83 x = residual + self.dropout(cross_attn_out)
84
85 if use_cache:
86 present_kv['cross'] = cross_kv
87
88 # FFN (Pre-LN)
89 residual = x
90 x = self.norm3(x)
91 x = residual + self.dropout(self.ffn(x))
92
93 return x, present_kvComplete Cached Decoder
Full Implementation
🐍python
1class TransformerDecoderWithCache(nn.Module):
2 """
3 Full Transformer Decoder with KV caching for efficient inference.
4 """
5
6 def __init__(
7 self,
8 vocab_size: int,
9 d_model: int = 512,
10 num_heads: int = 8,
11 num_layers: int = 6,
12 d_ff: int = 2048,
13 max_len: int = 5000,
14 dropout: float = 0.1,
15 pad_id: int = 0
16 ):
17 super().__init__()
18
19 self.d_model = d_model
20 self.pad_id = pad_id
21
22 self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
23
24 # Positional encoding
25 pe = torch.zeros(max_len, d_model)
26 position = torch.arange(0, max_len).unsqueeze(1).float()
27 div_term = torch.exp(
28 torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
29 )
30 pe[:, 0::2] = torch.sin(position * div_term)
31 pe[:, 1::2] = torch.cos(position * div_term)
32 self.register_buffer('pe', pe.unsqueeze(0))
33
34 self.layers = nn.ModuleList([
35 TransformerDecoderLayerWithCache(d_model, num_heads, d_ff, dropout)
36 for _ in range(num_layers)
37 ])
38
39 self.norm = nn.LayerNorm(d_model)
40 self.output_projection = nn.Linear(d_model, vocab_size)
41
42 self.dropout = nn.Dropout(dropout)
43
44 def forward(
45 self,
46 tgt_ids: torch.Tensor,
47 memory: torch.Tensor,
48 memory_mask: Optional[torch.Tensor] = None,
49 past_key_values: Optional[list] = None,
50 use_cache: bool = False,
51 position_offset: int = 0
52 ) -> Tuple[torch.Tensor, Optional[list]]:
53 """
54 Forward pass with optional caching.
55
56 Args:
57 tgt_ids: Target token IDs [batch, tgt_len]
58 memory: Encoder output [batch, src_len, d_model]
59 memory_mask: Cross-attention mask
60 past_key_values: List of cached KVs per layer
61 use_cache: Whether to use/update cache
62 position_offset: Position offset for positional encoding
63
64 Returns:
65 logits: [batch, tgt_len, vocab_size]
66 present_key_values: Updated cache
67 """
68 batch_size, tgt_len = tgt_ids.shape
69
70 # Embedding + positional encoding
71 x = self.embedding(tgt_ids) * math.sqrt(self.d_model)
72 x = x + self.pe[:, position_offset:position_offset + tgt_len, :]
73 x = self.dropout(x)
74
75 # Create causal mask for new tokens
76 # Only needed for the new positions
77 if tgt_len > 1:
78 causal_mask = torch.tril(
79 torch.ones(tgt_len, tgt_len, device=tgt_ids.device)
80 ).unsqueeze(0).unsqueeze(0)
81 else:
82 causal_mask = None
83
84 # Process through layers
85 present_key_values = [] if use_cache else None
86
87 for i, layer in enumerate(self.layers):
88 past_kv = past_key_values[i] if past_key_values else None
89
90 x, present_kv = layer(
91 x, memory,
92 tgt_mask=causal_mask,
93 memory_mask=memory_mask,
94 past_kv=past_kv,
95 use_cache=use_cache
96 )
97
98 if use_cache:
99 present_key_values.append(present_kv)
100
101 # Final norm and projection
102 x = self.norm(x)
103 logits = self.output_projection(x)
104
105 return logits, present_key_values
106
107
108class CachedGenerator:
109 """
110 Generator using KV caching for efficient autoregressive generation.
111 """
112
113 def __init__(
114 self,
115 encoder: nn.Module,
116 decoder: TransformerDecoderWithCache,
117 bos_token_id: int = 2,
118 eos_token_id: int = 3,
119 pad_token_id: int = 0
120 ):
121 self.encoder = encoder
122 self.decoder = decoder
123 self.bos_token_id = bos_token_id
124 self.eos_token_id = eos_token_id
125 self.pad_token_id = pad_token_id
126
127 @torch.no_grad()
128 def generate(
129 self,
130 src_ids: torch.Tensor,
131 max_length: int = 100,
132 temperature: float = 1.0,
133 do_sample: bool = False
134 ) -> torch.Tensor:
135 """
136 Generate with KV caching.
137
138 Much faster than without caching for long sequences.
139 """
140 self.encoder.eval()
141 self.decoder.eval()
142
143 batch_size = src_ids.size(0)
144 device = src_ids.device
145
146 # Encode source (once)
147 memory = self.encoder(src_ids)
148
149 # Initialize
150 generated = torch.full(
151 (batch_size, 1),
152 self.bos_token_id,
153 dtype=torch.long,
154 device=device
155 )
156
157 done = torch.zeros(batch_size, dtype=torch.bool, device=device)
158 past_key_values = None
159 position = 0
160
161 for step in range(max_length - 1):
162 # Only process new token (using cache for previous)
163 if past_key_values is None:
164 # First step: process BOS
165 input_ids = generated
166 else:
167 # Subsequent steps: only new token
168 input_ids = generated[:, -1:]
169
170 # Forward with cache
171 logits, past_key_values = self.decoder(
172 input_ids,
173 memory,
174 past_key_values=past_key_values,
175 use_cache=True,
176 position_offset=position
177 )
178
179 position += input_ids.size(1)
180
181 # Get next token logits
182 next_logits = logits[:, -1, :] / temperature
183
184 # Sample or greedy
185 if do_sample:
186 probs = F.softmax(next_logits, dim=-1)
187 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
188 else:
189 next_tokens = next_logits.argmax(dim=-1)
190
191 # Handle finished sequences
192 next_tokens = next_tokens.masked_fill(done, self.pad_token_id)
193
194 # Append
195 generated = torch.cat([
196 generated,
197 next_tokens.unsqueeze(1)
198 ], dim=1)
199
200 # Update done status
201 done = done | (next_tokens == self.eos_token_id)
202
203 if done.all():
204 break
205
206 return generatedMemory Considerations
Cache Memory Usage
📝text
1Memory per token cached (in KB):
2
3Small (d=256, layers=6): 12.00 KB/token
4Base (d=512, layers=6): 24.00 KB/token
5Large (d=1024, layers=12): 96.00 KB/token
6
7Total cache memory (in MB) for different sequence lengths:
8
9Config L=100 L=500 L=1000 L=2000
10Small 1.2 6.0 12.0 24.0
11Base 2.4 12.0 24.0 48.0
12Large 9.6 48.0 96.0 192.0
13
14Key insights:
15- Cache grows linearly with sequence length
16- Cache grows linearly with number of layers
17- For long sequences, cache can exceed model size!
18
19Mitigation strategies:
20- Sliding window attention (only cache last N tokens)
21- Sparse attention patterns
22- Quantized caching (FP16 or INT8)
23- Multi-query attention (share K,V across heads)Cross-Attention Caching
Caching Encoder Projections
🐍python
1class OptimizedCrossAttention(nn.Module):
2 """
3 Cross-attention with encoder K,V caching.
4
5 Encoder output doesn't change during generation,
6 so we can cache K and V projections of the encoder output.
7 """
8
9 def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
10 super().__init__()
11
12 self.d_model = d_model
13 self.num_heads = num_heads
14 self.d_k = d_model // num_heads
15
16 self.W_q = nn.Linear(d_model, d_model)
17 self.W_k = nn.Linear(d_model, d_model)
18 self.W_v = nn.Linear(d_model, d_model)
19 self.W_o = nn.Linear(d_model, d_model)
20
21 self.dropout = nn.Dropout(dropout)
22
23 # Cached encoder projections
24 self._cached_k = None
25 self._cached_v = None
26
27 def cache_encoder(self, encoder_output: torch.Tensor):
28 """
29 Pre-compute and cache K, V for encoder output.
30 Call once before generation.
31 """
32 batch_size, src_len, _ = encoder_output.shape
33
34 K = self.W_k(encoder_output)
35 V = self.W_v(encoder_output)
36
37 self._cached_k = K.view(batch_size, src_len, self.num_heads, self.d_k).transpose(1, 2)
38 self._cached_v = V.view(batch_size, src_len, self.num_heads, self.d_k).transpose(1, 2)
39
40 def clear_cache(self):
41 """Clear cached encoder projections."""
42 self._cached_k = None
43 self._cached_v = NoneBenefits of Cross-Attention Caching
📝text
1Without caching encoder K,V:
2─────────────────────────────
3Step 1: K_enc = W_k(encoder_output) ← Computed
4 V_enc = W_v(encoder_output) ← Computed
5Step 2: K_enc = W_k(encoder_output) ← Recomputed!
6 V_enc = W_v(encoder_output) ← Recomputed!
7...
8Step N: K_enc = W_k(encoder_output) ← Recomputed N times!
9 V_enc = W_v(encoder_output) ← Recomputed N times!
10
11With caching encoder K,V:
12─────────────────────────────
13Setup: K_enc = W_k(encoder_output) ← Computed once
14 V_enc = W_v(encoder_output) ← Computed once
15 cache(K_enc, V_enc)
16
17Step 1: Use cached K_enc, V_enc
18Step 2: Use cached K_enc, V_enc
19...
20Step N: Use cached K_enc, V_enc
21
22Savings: (N-1) × encoder_projection_cost
23For N=100, src_len=200: ~99% reduction in cross-attn compute!Summary
| Aspect | Without Cache | With Cache |
|---|---|---|
| Compute per step | O(seq_len × d) | O(d) |
| Total compute | O(seq_len² × d) | O(seq_len × d) |
| Memory | O(seq_len × d) | O(seq_len × layers × d) |
| Complexity | Simple | Moderate |
Best Practices
- Always use caching for inference (except for very short sequences)
- Cache encoder K,V once before generation starts
- Clear cache between different inputs
- Consider memory for very long sequences
Implementation Checklist
- Self-attention KV cache (grows with each token)
- Cross-attention encoder cache (computed once)
- Position offset tracking
- Proper cache clearing between batches
Chapter Summary
In this chapter, we covered autoregressive generation:
- Fundamentals: Token-by-token generation process
- Greedy Decoding: Simple, fast, deterministic
- Beam Search: Higher quality, maintains multiple hypotheses
- Sampling: Temperature, top-k, top-p for diversity
- KV Caching: Critical optimization for inference speed
The transformer is now ready for efficient translation inference!
In Chapter 10: Training Pipeline, we'll implement data loading and batching, loss computation with label smoothing, learning rate scheduling, training loop with validation, and checkpointing and logging.