Explain why encoder-decoder architectures need a mechanism to let the decoder selectively read from the encoder's output, and how cross-attention solves this.
Distinguish cross-attention from self-attention by identifying which sequence contributes the queries and which contributes the keys and values.
Derive the cross-attention formula softmax(dkQdecKenc⊤)Venc and explain why the attention matrix is (Ndec×Nenc), not necessarily square.
Compute cross-attention weights and outputs by hand for our shared example sentence “The cat sat on mat”.
Implement a complete, runnable Python class that supports different-length encoder and decoder sequences.
Connect cross-attention to its modern applications: machine translation, vision-language models, text-to-image generation, and retrieval-augmented generation.
Where this appears: Cross-attention is the bridge between any two information streams. It powers the decoder in the original Transformer (Vaswani et al., 2017), connects text encoders to image decoders in Stable Diffusion and DALL-E, links audio encoders to text decoders in Whisper, and bridges retrieved documents to generation in RAG systems. Whenever a model needs to “read from” a separate source of information, cross-attention is the mechanism that makes it possible.
The Real Problem
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). “Attention Is All You Need.” Advances in Neural Information Processing Systems (NeurIPS), 30.
In Chapter 1, we learned that self-attention lets every token in a sequence attend to every other token in the same sequence. But many of the most important problems in AI involve two separate sequences that need to communicate with each other:
In machine translation, the source sentence (English) and the target sentence (French) are separate sequences.
In image captioning, the image (a grid of patches) and the caption (a sequence of words) are separate.
In text-to-image generation, the text prompt and the image being generated are separate.
In speech recognition, the audio spectrogram and the transcribed text are separate.
In retrieval-augmented generation, the retrieved documents and the generated answer are separate.
Self-attention alone cannot solve these problems because it only lets tokens attend within a single sequence. What is needed is a mechanism that lets one sequence selectively read information from another.
The Encoder-Decoder Gap
The original Transformer has two halves: an encoder that processes the source sequence and a decoder that generates the target sequence one token at a time. The encoder produces a rich, contextualized representation of the source — but how does the decoder access it?
Before attention, the answer was simple and devastating: the encoder's entire output was compressed into a single fixed-size vector (Sutskever et al., 2014). This worked passably for short sentences but catastrophically failed for long ones. Translating a 50-word sentence required compressing all 50 words into, say, 512 numbers — and then somehow reconstructing the meaning in another language.
Bahdanau's Breakthrough
Bahdanau, Cho, & Bengio (2014) solved the compression problem by letting the decoder look back at every encoder hidden state at each generation step. This was the first cross-attention mechanism, though it used additive scoring (a small feedforward network) rather than dot products.
Their key insight: instead of asking “what did the source say overall?”, the decoder could ask “which source words are relevant to the word I am generating right now?” When generating the French word “chat”, the decoder could attend strongly to the English word “cat” and ignore the rest.
The Cross-Attention Insight
Vaswani et al. (2017) simplified this idea into the elegant cross-attention mechanism used in every modern encoder-decoder model. The insight is a clean separation of roles:
The decoder provides the Q (queries) — “what am I looking for right now?”
The encoder provides the K (keys) and V (values) — “here is what I know and here is the content to retrieve.”
This separation is what makes cross-attention fundamentally different from self-attention. In self-attention, Q, K, and V all come from the same sequence. In cross-attention, Q comes from one sequence and K, V come from another.
From Intuition to Mathematics
The Library Analogy
Imagine you are a student (the decoder) sitting in a library (the encoder). You have a question in mind (your query). Each book on the shelf has a title on its spine (its key) and content inside (its value). To answer your question, you:
Compare your question to every book title (compute similarity scores).
Rank the books by relevance (softmax to get weights).
Read from the most relevant books, blending their content proportionally (weighted sum of values).
In self-attention, you are a book comparing yourself to other books on the same shelf. In cross-attention, you are a student (a different entity) querying a library (a different entity). The student and the library are fundamentally separate — they may have different sizes (3 questions, 1000 books), different internal representations, and different roles.
Self-Attention vs Cross-Attention
Property
Self-Attention (Ch. 1)
Cross-Attention (Ch. 4)
Source of Q
Same sequence (e.g., encoder input)
Decoder sequence
Source of K, V
Same sequence as Q
Encoder output (different sequence)
Attention matrix shape
N×N (always square)
Ndec×Nenc (often rectangular)
What it captures
Relationships within a single sequence
Relationships between two different sequences
Typical use
Every encoder layer, every decoder layer
Decoder layers only (connecting to encoder)
The decoder uses both
In the original Transformer, each decoder layer contains three sub-layers: (1) causal self-attention (Chapter 3) over the decoder's own tokens, (2) cross-attention from the decoder to the encoder, and (3) a feedforward network. Self-attention and cross-attention are complementary, not competing.
Mathematically, this is identical to the scaled dot-product attention from Chapter 1. The only difference is where Q, K, and V come from. This is a profound design insight: the same mathematical operation can serve fundamentally different purposes depending on the source of its inputs.
Symbol-by-Symbol Breakdown
Symbol
Shape
Meaning
Qdec
(Ndec,dk)
Decoder queries: what is the decoder looking for at each position?
Kenc
(Nenc,dk)
Encoder keys: what does each encoder position advertise?
Venc
(Nenc,dv)
Encoder values: what content does each encoder position provide?
QdecKenc⊤
(Ndec,Nenc)
Cross-sequence similarity scores: how well does each decoder query match each encoder key?
dk
scalar
Scaling factor to prevent softmax saturation (same reason as Chapter 1)
softmax
(Ndec,Nenc)
Row-wise normalization: each decoder token distributes 100% attention across encoder tokens
Output
(Ndec,dv)
Each decoder position receives a weighted blend of encoder values
What the Formula Says in Plain English
For each decoder token, cross-attention answers the question: “Given what I am looking for (my query), which encoder tokens have relevant information (matching keys), and what content should I retrieve from them (their values)?”
The decoder query Qdec[i] is compared against every encoder key Kenc[j] via dot product. The resulting scores are scaled by 1/dk and passed through softmax to produce a probability distribution over encoder positions. The output is the weighted average of encoder values, where the weights indicate how much each encoder position is relevant to this decoder position.
The Key Insight: Different Sequence Lengths
Unlike self-attention, where the attention matrix is always N×N, cross-attention produces an Ndec×Nenc matrix. The two sequences can have completely different lengths:
Translating a 20-word English sentence into a 25-word French sentence: 25×20 attention matrix
Captioning a 196-patch image with a 15-word description: 15×196 attention matrix
Generating an answer from 3 retrieved documents with 500 tokens each: 100×1500 attention matrix
This is a fundamental architectural advantage
Self-attention forces Q and K to have the same length because they come from the same sequence. Cross-attention decouples them. This decoupling is what allows encoder-decoder models to handle input and output sequences of arbitrary and independent lengths.
Step-by-Step Calculation
We use the shared example from Chapter 1, but now the decoder has its own queries. In a real transformer, the decoder's WQ projection is a different learned matrix from the encoder's WQ. We simulate this with per-dimension scaling:
Qdec=Q⊙[1.2,0.8,1.1,0.9]
where ⊙ denotes element-wise multiplication (Hadamard product). The encoder keys K and values V remain unchanged — they come from the encoder's output, not the decoder.
Step 1: Decoder Queries
Apply the decoder projection to each token's query vector:
Token
Q (encoder)
× scale
Q_dec (decoder)
The
[1.0, 0.0, 1.0, 0.0]
× [1.2, 0.8, 1.1, 0.9]
[1.2, 0.0, 1.1, 0.0]
cat
[0.0, 2.0, 0.0, 1.0]
× [1.2, 0.8, 1.1, 0.9]
[0.0, 1.6, 0.0, 0.9]
sat
[1.0, 1.0, 1.0, 0.0]
× [1.2, 0.8, 1.1, 0.9]
[1.2, 0.8, 1.1, 0.0]
on
[0.0, 0.0, 1.0, 1.0]
× [1.2, 0.8, 1.1, 0.9]
[0.0, 0.0, 1.1, 0.9]
mat
[1.0, 0.0, 0.0, 1.0]
× [1.2, 0.8, 1.1, 0.9]
[1.2, 0.0, 0.0, 0.9]
Why different queries matter
The decoder query for “The” is now [1.2, 0.0, 1.1, 0.0] instead of [1.0, 0.0, 1.0, 0.0]. The amplified dim-0 (1.0 \u2192 1.2) and dim-2 (1.0 \u2192 1.1) mean the decoder version of “The” is slightly more interested in tokens that have high values in these dimensions. This is analogous to how a decoder's learned projection produces queries that are tuned for the generation task rather than the encoding task.
Step 2: Cross-Sequence Dot Products
Compute Qdec[The]⋅Kenc[j] for each encoder token:
Encoder token
K_enc[j]
Dot product
Result
The
[0.0, 1.0, 0.0, 1.0]
1.2×0+0×1+1.1×0+0×1
0.00
cat
[1.0, 0.0, 1.0, 0.0]
1.2×1+0×0+1.1×1+0×0
2.30
sat
[1.0, 1.0, 0.0, 0.0]
1.2×1+0×1+1.1×0+0×0
1.20
on
[0.0, 0.0, 1.0, 1.0]
1.2×0+0×0+1.1×1+0×1
1.10
mat
[1.0, 0.0, 0.5, 0.5]
1.2×1+0×0+1.1×0.5+0×0.5
1.75
The decoder's “The” is most attracted to encoder's “cat” (score 2.30), even more strongly than in self-attention (where the score was 2.00). The amplified dim-0 in the decoder query increases alignment with K[cat]=[1,0,1,0].
Step 3: Scaling
Divide by dk=4=2:
Scaled=[0.000,1.150,0.600,0.550,0.875]
Step 4: Softmax
Apply softmax to convert scores into a probability distribution:
Encoder token
Scaled score
Cross-attention weight
Self-attention weight (Ch. 1)
Δ
The
0.000
0.0989
0.1095
−0.0106
cat
1.150
0.3123
0.2976
+0.0147
sat
0.600
0.1802
0.1805
−0.0003
on
0.550
0.1714
0.1805
−0.0091
mat
0.875
0.2372
0.2318
+0.0054
Cross-attention increases the decoder's focus on “cat” (0.3123 vs 0.2976 in self-attention) and “mat” (0.2372 vs 0.2318), while decreasing attention to “The” and “on”. The decoder's distinct projection creates a different reading pattern of the same source material.
Step 5: Weighted Sum of Encoder Values
The output for decoder token “The” is the weighted sum of encoder values:
V_enc[j]
× weight
Contribution
The
[1.0, 0.0, 0.0, 0.0]
× 0.0989
[0.0989, 0.0000, 0.0000, 0.0000]
cat
[0.0, 1.0, 0.0, 0.0]
× 0.3123
[0.0000, 0.3123, 0.0000, 0.0000]
sat
[0.0, 0.0, 1.0, 0.0]
× 0.1802
[0.0000, 0.0000, 0.1802, 0.0000]
on
[0.0, 0.0, 0.0, 1.0]
× 0.1714
[0.0000, 0.0000, 0.0000, 0.1714]
mat
[0.5, 0.5, 0.5, 0.5]
× 0.2372
[0.1186, 0.1186, 0.1186, 0.1186]
Total
[0.2175, 0.4309, 0.2988, 0.2900]
Compare to self-attention output for “The”: [0.2254, 0.4135, 0.2964, 0.2964]. The cross-attention output has a higher dim-1 component (0.4309 vs 0.4135) because the decoder allocates more attention to “cat” (whose value is entirely in dim-1).
Interactive: Cross-Attention Heatmap
Toggle between cross-attention, self-attention, and the difference to see how the decoder's distinct queries shift the attention pattern. Hover over any cell to see the full computation.
Loading cross-attention heatmap...
Full Attention Weights and Output
Interpreting the Weights
Cross-Attention Weight Matrix (5×5, decoder Q, encoder K):
Q_dec \ K_enc
The
cat
sat
on
mat
The
0.0989
0.3123
0.1802
0.1714
0.2372
cat
0.3660
0.1049
0.2334
0.1645
0.1313
sat
0.1297
0.2746
0.2364
0.1507
0.2086
on
0.1809
0.1999
0.1154
0.3136
0.1902
mat
0.1731
0.2011
0.2011
0.1731
0.2518
Output Matrix (5×4):
dim-0
dim-1
dim-2
dim-3
The
0.2175
0.4309
0.2988
0.2900
cat
0.4317
0.1705
0.2990
0.2301
sat
0.2340
0.3789
0.3407
0.2550
on
0.2760
0.2950
0.2105
0.4087
mat
0.2989
0.3269
0.3269
0.2989
Key observations from the cross-attention weights:
Decoder “The” attends most to encoder “cat” (0.3123). The decoder's amplified dim-0 and dim-2 increase alignment with K[cat]=[1,0,1,0].
Decoder “cat” attends most to encoder “The” (0.3660). This pattern is similar to self-attention (0.4026) but slightly reduced because the decoder query for “cat” is [0.0, 1.6, 0.0, 0.9] vs the encoder's [0.0, 2.0, 0.0, 1.0].
Decoder “on” strongly attends to encoder “on” (0.3136), nearly identical to self-attention (0.3137). This makes sense: “on” is a function word whose cross-sequence relevance pattern barely changes.
Decoder “mat” shows the most uniform attention (range 0.1731–0.2518), spreading its attention relatively evenly across all encoder positions.
Interactive: Different-Length Sequences
The most distinctive feature of cross-attention is that the encoder and decoder can have different sequence lengths. Here, 3 French decoder tokens (“Le”, “chat”, “assis”) attend to 5 English encoder tokens, producing a 3×5 attention matrix.
Loading flow visualizer...
Non-square attention matrices
In self-attention, the attention matrix is always N×N. In cross-attention, it is Ndec×Nenc. A 100-token decoder attending to a 500-token encoder produces a 100×500 matrix — only 20% of the entries compared to a 500×500 self-attention matrix. This asymmetry is both architecturally necessary (the sequences are different lengths) and computationally beneficial.
Applications Across Domains
Machine Translation
Cross-attention's original use case. The encoder processes the source sentence (“The cat sat on mat”) and the decoder generates the target sentence (“Le chat assis sur tapis”). At each decoder step, cross-attention lets the model look back at the source to find the right word to translate. The attention pattern typically shows a roughly diagonal structure (word order is often similar between languages) with deviations that capture reordering.
In practice: Google Translate, Meta's NLLB (No Language Left Behind, 200+ languages), and Helsinki-NLP's MarianMT all use cross-attention in their decoder layers.
Vision-Language Models
In models like Flamingo (DeepMind, 2022) and LLaVA, a vision encoder (like ViT) processes an image into a sequence of patch embeddings. The language decoder then uses cross-attention to “look at” different parts of the image while generating text. When generating “A cat sitting on a red mat,” the cross-attention weights would focus on the cat patches when producing “cat” and on the mat patches when producing “mat.”
Text-to-Image Generation
Stable Diffusion and DALL-E 2/3 use cross-attention to connect text descriptions to image generation. The text encoder (typically CLIP or T5) produces embeddings for the prompt, and the image decoder (a U-Net or DiT) uses cross-attention to condition each spatial location on the text. This is how the model knows where to place “a sunset” vs “a mountain” in the generated image.
Speech Recognition
OpenAI's Whisper uses an encoder-decoder architecture where the encoder processes mel spectrogram features (audio) and the decoder generates text tokens. Cross-attention bridges the acoustic and linguistic domains, allowing the model to align speech segments with their textual transcriptions.
Retrieval-Augmented Generation
In RAG systems, retrieved documents serve as a “knowledge encoder” and the generation model serves as the decoder. Cross-attention (or its variants like Fusion-in-Decoder) lets the generator selectively read from multiple retrieved passages to produce grounded, factual answers.
Connection to Modern Systems
Multi-Head Cross-Attention
In practice, cross-attention always uses multiple heads (Chapter 2). Each head learns to attend to different aspects of the encoder output:
Head 1 might learn syntactic alignment (matching subject to subject)
Head 2 might learn semantic alignment (matching meaning-related tokens)
Head 3 might learn positional alignment (matching position roughly)
The multi-head formula becomes: MultiHead(Qdec,Kenc,Venc)=Concat(head1,…,headH)WO, where each headh=CrossAttn(QdecWhQ,KencWhK,VencWhV).
KV-Cache in Cross-Attention
Cross-attention has a special property that makes KV-caching particularly efficient: the encoder output never changes during decoding. Once the encoder has processed the source sequence, its K and V matrices are fixed for all decoder steps. This means the KV-cache for cross-attention needs to be computed only once and can be reused for every generated token.
Compare this to causal self-attention in the decoder, where the KV-cache grows by one entry per generated token. Cross-attention KV-cache is constant-size — a significant memory advantage during autoregressive generation.
Flash Attention for Cross-Attention
Flash Attention (Chapter 13) works for cross-attention without modification. The tiling strategy operates on blocks of the attention matrix regardless of whether the matrix is square (self-attention) or rectangular (cross-attention). The memory savings from avoiding materialization of the full Ndec×Nenc attention matrix are equally valuable.
Prefix Tuning and Prompt Engineering
Prefix tuning (Li & Liang, 2021) can be understood as a form of learned cross-attention: the prefix tokens serve as a “virtual encoder” whose K and V embeddings are learned parameters, while the decoder queries come from the actual input. This connection highlights how cross-attention is a general mechanism for conditioning one sequence on another.
Complexity Analysis
Operation
Time Complexity
Space Complexity
QdecKenc⊤
O(Ndec⋅Nenc⋅dk)
O(Ndec⋅Nenc)
Softmax
O(Ndec⋅Nenc)
O(Ndec⋅Nenc)
Weights×Venc
O(Ndec⋅Nenc⋅dv)
O(Ndec⋅dv)
Total
O(Ndec⋅Nenc⋅d)
O(Ndec⋅Nenc)
Note: when Ndec=Nenc=N, this reduces to the same O(N2d) complexity as self-attention. But in practice, Ndec is often much smaller than Nenc (e.g., during autoregressive generation where Ndec=1 for each step), making cross-attention significantly cheaper than encoder self-attention.
Python Implementation
The full cross-attention class. Compare with Chapter 1's ScaledDotProductAttention — the only structural difference is that forward() takes three separate arguments (Q_dec, K_enc, V_enc) instead of three arguments that happen to come from the same source.
Cross-Attention \u2014 NumPy Implementation
🐍cross_attention.py
Explanation(43)
Code(154)
1import numpy as np
NumPy provides vectorized matrix operations. Q_dec @ K_enc.T runs as optimized C code, not Python loops.
2import math
Python standard library. We use math.sqrt() to precompute the scaling factor.
4class CrossAttention
Cross-attention class. Unlike ScaledDotProductAttention (Chapter 1), this class expects Q from the decoder and K, V from the encoder — the three inputs come from two different sequences.
16def __init__(self, d_k: int)
Constructor. Takes one parameter d_k (dimension of query/key vectors). Stores it and precomputes the scaling factor. Identical to Chapter 1 — the scaling math doesn’t change for cross-attention.
EXECUTION STATE
⬇ input: d_k = 4
21self.d_k = d_k
Store d_k as an instance attribute so other methods can access it.
EXECUTION STATE
self.d_k = 4
22self.scale = math.sqrt(d_k)
Precompute √d_k once. This constant divides every dot product.
EXECUTION STATE
math.sqrt(4) = 2.0
self.scale = 2.0
24def _softmax(self, x) → np.ndarray
Numerically stable softmax. In cross-attention the input x may be non-square (N_dec × N_enc). Softmax is applied along the last axis (encoder dimension) so each decoder row sums to 1.0.
EXECUTION STATE
⬇ input: x = shape may be (N_dec, N_enc) — not necessarily square
⬆ returns = np.ndarray — same shape, each row sums to 1.0
26x_shifted = x - np.max(x, axis=-1, keepdims=True)
Subtract row-wise max to prevent exp() overflow. All 5 rows shown for our shared example (5×5 case):
EXECUTION STATE
axis=-1 = operate along the LAST axis (encoder tokens). For a 5×5 matrix, find the max of each row independently.
keepdims=True = keep the reduced axis as size-1 dimension. Returns shape (5,1) so broadcasting x(5×5) - max(5×1) works correctly.
Function takes decoder queries Q_dec (N_dec×d_k) and encoder keys K_enc (N_enc×d_k). Returns Q_dec @ K_enc.T — a (N_dec×N_enc) matrix of cross-sequence dot products. This is the core difference from self-attention: Q and K come from different sequences.
EXECUTION STATE
⬇ input: Q_dec (5×4) =
d0 d1 d2 d3
The 1.2 0.0 1.1 0.0
cat 0.0 1.6 0.0 0.9
sat 1.2 0.8 1.1 0.0
on 0.0 0.0 1.1 0.9
mat 1.2 0.0 0.0 0.9
⬇ input: K_enc (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
Matrix multiply Q_dec (5×4) with K_enc transposed (4×5). Entry (i,j) = dot product of decoder query i with encoder key j. Measures how much decoder token i wants to attend to encoder token j.
The cat sat on mat
The 0.00 2.30 1.20 1.10 1.75
cat 2.50 0.00 1.60 0.90 0.45
sat 0.80 2.30 2.00 1.10 1.75
on 0.90 1.10 0.00 2.00 1.00
mat 0.90 1.20 1.20 0.90 1.65
35def scale_scores(self, scores) → np.ndarray
Divides every element by self.scale=√d_k. Identical to self-attention — scaling is independent of where Q and K originate.
EXECUTION STATE
⬇ input: scores (5×5) =
The cat sat on mat
The 0.00 2.30 1.20 1.10 1.75
cat 2.50 0.00 1.60 0.90 0.45
sat 0.80 2.30 2.00 1.10 1.75
on 0.90 1.10 0.00 2.00 1.00
mat 0.90 1.20 1.20 0.90 1.65
self.scale = 2.0 (precomputed √4 in __init__)
⬆ returns = np.ndarray (5, 5) — each element ÷ 2.0
37return scores / self.scale
Every score divided by 2.0.
EXECUTION STATE
⬆ return: scores / 2.0 =
The cat sat on mat
The 0.000 1.150 0.600 0.550 0.875
cat 1.250 0.000 0.800 0.450 0.225
sat 0.400 1.150 1.000 0.550 0.875
on 0.450 0.550 0.000 1.000 0.500
mat 0.450 0.600 0.600 0.450 0.825
39def compute_weights(self, scaled) → np.ndarray
Applies softmax row-wise. Each row sums to 1.0 — the decoder token distributes 100% of its attention across all encoder tokens.
EXECUTION STATE
⬇ input: scaled (5×5) =
The cat sat on mat
The 0.000 1.150 0.600 0.550 0.875
cat 1.250 0.000 0.800 0.450 0.225
sat 0.400 1.150 1.000 0.550 0.875
on 0.450 0.550 0.000 1.000 0.500
mat 0.450 0.600 0.600 0.450 0.825
⬆ returns = np.ndarray (5, 5) — each row sums to 1.0
41return self._softmax(scaled)
Calls _softmax on the full matrix. Each row independently becomes a probability distribution over encoder tokens.
EXECUTION STATE
⬆ return: weights =
The cat sat on mat
The 0.0989 0.3123 0.1802 0.1714 0.2372
cat 0.3660 0.1049 0.2334 0.1645 0.1313
sat 0.1297 0.2746 0.2364 0.1507 0.2086
on 0.1809 0.1999 0.1154 0.3136 0.1902
mat 0.1731 0.2011 0.2011 0.1731 0.2518
Each output row is the weighted average of encoder value vectors. The decoder token’s output is a blend of encoder information, weighted by attention.
EXECUTION STATE
⬇ input: weights (5×5) =
The cat sat on mat
The 0.0989 0.3123 0.1802 0.1714 0.2372
cat 0.3660 0.1049 0.2334 0.1645 0.1313
sat 0.1297 0.2746 0.2364 0.1507 0.2086
on 0.1809 0.1999 0.1154 0.3136 0.1902
mat 0.1731 0.2011 0.2011 0.1731 0.2518
⬇ input: V_enc (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
⬆ returns = np.ndarray (N_dec, d_v) — weighted sum of values
45return weights @ V_enc
Matrix multiply weights (5×5) with V_enc (5×4). Each output row is a blend of encoder values, weighted by how much the decoder attends to each encoder position.
EXECUTION STATE
⬆ return: weights @ V_enc =
d0 d1 d2 d3
The 0.2175 0.4309 0.2988 0.2900
cat 0.4317 0.1705 0.2990 0.2301
sat 0.2340 0.3789 0.3407 0.2550
on 0.2760 0.2950 0.2105 0.4087
mat 0.2989 0.3269 0.3269 0.2989
47def forward(self, Q_dec, K_enc, V_enc)
Main entry point. Chains all four steps. Q comes from the decoder; K and V come from the encoder. This signature is the fundamental difference from self-attention’s forward(Q, K, V).
EXECUTION STATE
⬇ input: Q_dec (5×4) =
d0 d1 d2 d3
The 1.2 0.0 1.1 0.0
cat 0.0 1.6 0.0 0.9
sat 1.2 0.8 1.1 0.0
on 0.0 0.0 1.1 0.9
mat 1.2 0.0 0.0 0.9
⬇ input: K_enc (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
⬇ input: V_enc (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
⬆ returns = (weights, output) — shapes (5,5) and (5,4)
60raw_scores = self.compute_scores(Q_dec, K_enc)
Calls compute_scores() → returns Q_dec @ K_enc.T. The matrix of cross-sequence dot products.
EXECUTION STATE
raw_scores =
The cat sat on mat
The 0.00 2.30 1.20 1.10 1.75
cat 2.50 0.00 1.60 0.90 0.45
sat 0.80 2.30 2.00 1.10 1.75
on 0.90 1.10 0.00 2.00 1.00
mat 0.90 1.20 1.20 0.90 1.65
61scaled_scores = self.scale_scores(raw_scores)
Divides every element by 2.0.
EXECUTION STATE
scaled_scores =
The cat sat on mat
The 0.000 1.150 0.600 0.550 0.875
cat 1.250 0.000 0.800 0.450 0.225
sat 0.400 1.150 1.000 0.550 0.875
on 0.450 0.550 0.000 1.000 0.500
mat 0.450 0.600 0.600 0.450 0.825
62weights = self.compute_weights(scaled_scores)
Applies softmax row-wise.
EXECUTION STATE
weights =
The cat sat on mat
The 0.0989 0.3123 0.1802 0.1714 0.2372
cat 0.3660 0.1049 0.2334 0.1645 0.1313
sat 0.1297 0.2746 0.2364 0.1507 0.2086
on 0.1809 0.1999 0.1154 0.3136 0.1902
mat 0.1731 0.2011 0.2011 0.1731 0.2518
63output = self.compute_output(weights, V_enc)
Weighted sum of encoder values.
EXECUTION STATE
output =
d0 d1 d2 d3
The 0.2175 0.4309 0.2988 0.2900
cat 0.4317 0.1705 0.2990 0.2301
sat 0.2340 0.3789 0.3407 0.2550
on 0.2760 0.2950 0.2105 0.4087
mat 0.2989 0.3269 0.3269 0.2989
64return weights, output
Returns both matrices. The caller gets cross-attention weights (N_dec×N_enc) and the context-enriched output (N_dec×d_v).
EXECUTION STATE
⬆ return: weights = shape (5, 5)
⬆ return: output = shape (5, 4)
66def explain(self, Q_dec, K_enc, V_enc, ...)
Diagnostic function. Takes decoder Q, encoder K/V, token names for both sequences, and which decoder token to trace. Recomputes all intermediates and prints a step-by-step trace.
88for j, t in enumerate(enc_tokens): — scaled scores
Same loop, now printing scaled scores (raw ÷ 2.0).
LOOP TRACE · 5 iterations
j=0, t='The'
S[The,The] = 0.00 / 2.0 = 0.0000
j=1, t='cat'
S[The,cat] = 2.30 / 2.0 = 1.1500
j=2, t='sat'
S[The,sat] = 1.20 / 2.0 = 0.6000
j=3, t='on'
S[The,on] = 1.10 / 2.0 = 0.5500
j=4, t='mat'
S[The,mat] = 1.75 / 2.0 = 0.8750
93for j, t in enumerate(enc_tokens): — softmax weights
Same loop, now printing attention weights after softmax.
LOOP TRACE · 5 iterations
j=0, t='The'
A[The,The] = 0.0989 |###|
j=1, t='cat'
A[The,cat] = 0.3123 |############|
j=2, t='sat'
A[The,sat] = 0.1802 |#######|
j=3, t='on'
A[The,on] = 0.1714 |######|
j=4, t='mat'
A[The,mat] = 0.2372 |#########|
98print O[dec_tok] and sum of weights
Print the final output vector and verify weights sum to 1.0.
EXECUTION STATE
O[The] = [0.2175, 0.4309, 0.2988, 0.2900]
sum of weights[0] = 1.000000 ✓
103tokens = [...]
The 5 tokens used in every chapter.
EXECUTION STATE
tokens = ['The', 'cat', 'sat', 'on', 'mat']
105Q = np.array([...])
Encoder-side query matrix (same as Chapter 1). This serves as the base from which we derive decoder queries.
EXECUTION STATE
Q =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
113K = np.array([...])
Encoder key matrix — unchanged. In cross-attention, K always comes from the encoder.
EXECUTION STATE
K =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
121V = np.array([...])
Encoder value matrix — unchanged. In cross-attention, V always comes from the encoder.
EXECUTION STATE
V =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
131decoder_scale = np.array([1.2, 0.8, 1.1, 0.9])
Per-dimension scaling factors that simulate a decoder’s learned W^Q projection. In a real transformer, the decoder has its own W^Q matrix that transforms decoder embeddings into queries. Here we approximate this with element-wise scaling to keep the example tractable.
EXECUTION STATE
decoder_scale = [1.2, 0.8, 1.1, 0.9]
132Q_dec = Q * decoder_scale
Element-wise multiplication of Q (5×4) with decoder_scale (4,). NumPy broadcasts the 1D scale across all rows. Each token’s query is subtly shifted.
EXECUTION STATE
Q_dec =
d0 d1 d2 d3
The 1.2 0.0 1.1 0.0
cat 0.0 1.6 0.0 0.9
sat 1.2 0.8 1.1 0.0
on 0.0 0.0 1.1 0.9
mat 1.2 0.0 0.0 0.9
137attn = CrossAttention(d_k=4)
Instantiate the class. Sets self.d_k=4 and self.scale=2.0.
EXECUTION STATE
attn.d_k = 4
attn.scale = 2.0
138weights, output = attn.forward(Q_dec, K, V)
Runs the full cross-attention pipeline. Q_dec is the decoder query; K and V are the encoder’s key and value matrices.
EXECUTION STATE
weights =
The cat sat on mat
The 0.0989 0.3123 0.1802 0.1714 0.2372
cat 0.3660 0.1049 0.2334 0.1645 0.1313
sat 0.1297 0.2746 0.2364 0.1507 0.2086
on 0.1809 0.1999 0.1154 0.3136 0.1902
mat 0.1731 0.2011 0.2011 0.1731 0.2518
output =
d0 d1 d2 d3
The 0.2175 0.4309 0.2988 0.2900
cat 0.4317 0.1705 0.2990 0.2301
sat 0.2340 0.3789 0.3407 0.2550
on 0.2760 0.2950 0.2105 0.4087
mat 0.2989 0.3269 0.3269 0.2989
1import numpy as np
2import math
34classCrossAttention:5"""
6 Cross-Attention (Vaswani et al., 2017)
78 Computes: CrossAttn(Q_dec, K_enc, V_enc)
9 = softmax(Q_dec @ K_enc^T / sqrt(d_k)) @ V_enc
1011 Unlike self-attention where Q, K, V come from the same sequence,
12 cross-attention takes Q from the decoder and K, V from the encoder.
13 This lets the decoder 'read' from the encoder at every generation step.
14 """1516def__init__(self, d_k:int):17"""
18 Args:
19 d_k: Dimension of query/key vectors (used for scaling)
20 """21 self.d_k = d_k
22 self.scale = math.sqrt(d_k)2324def_softmax(self, x: np.ndarray)-> np.ndarray:25"""Numerically stable softmax along last axis."""26 x_shifted = x - np.max(x, axis=-1, keepdims=True)27 exp_x = np.exp(x_shifted)28return exp_x / np.sum(exp_x, axis=-1, keepdims=True)2930defcompute_scores(self, Q_dec: np.ndarray,31 K_enc: np.ndarray)-> np.ndarray:32"""Step 1: Cross-sequence dot products Q_dec @ K_enc^T."""33return Q_dec @ K_enc.T
3435defscale_scores(self, scores: np.ndarray)-> np.ndarray:36"""Step 2: Divide by sqrt(d_k) to control variance."""37return scores / self.scale
3839defcompute_weights(self, scaled: np.ndarray)-> np.ndarray:40"""Step 3: Apply softmax to get attention weights."""41return self._softmax(scaled)4243defcompute_output(self, weights: np.ndarray,44 V_enc: np.ndarray)-> np.ndarray:45"""Step 4: Weighted sum of encoder value vectors."""46return weights @ V_enc
4748defforward(self, Q_dec: np.ndarray,49 K_enc: np.ndarray, V_enc: np.ndarray):50"""
51 Full forward pass.
5253 Args:
54 Q_dec: Decoder query matrix (N_dec, d_k)
55 K_enc: Encoder key matrix (N_enc, d_k)
56 V_enc: Encoder value matrix (N_enc, d_v)
5758 Returns:
59 weights: Attention weight matrix (N_dec, N_enc)
60 output: Context-enriched output (N_dec, d_v)
61 """62 raw_scores = self.compute_scores(Q_dec, K_enc)63 scaled_scores = self.scale_scores(raw_scores)64 weights = self.compute_weights(scaled_scores)65 output = self.compute_output(weights, V_enc)66return weights, output
6768defexplain(self, Q_dec: np.ndarray, K_enc: np.ndarray,69 V_enc: np.ndarray, dec_tokens:list,70 enc_tokens:list, query_idx:int=0):71"""
72 Print a detailed trace of cross-attention
73 for a specific decoder query token.
74 """75 raw_scores = self.compute_scores(Q_dec, K_enc)76 scaled_scores = self.scale_scores(raw_scores)77 weights = self.compute_weights(scaled_scores)78 output = self.compute_output(weights, V_enc)7980 dec_tok = dec_tokens[query_idx]81print(f"\n=== Cross-Attention trace: "82f"'{dec_tok}' (decoder row {query_idx}) ===")83print(f"Q_dec[{query_idx}] = {Q_dec[query_idx]}")84print(f"\nStep 1 - Raw cross-dot-products:")85for j, t inenumerate(enc_tokens):86print(f" Q_dec[{dec_tok}] . K_enc[{t}]"87f" = {raw_scores[query_idx, j]:.4f}")8889print(f"\nStep 2 - Scaled (/ sqrt({self.d_k})"90f" = / {self.scale:.1f}):")91for j, t inenumerate(enc_tokens):92print(f" S[{dec_tok},{t}]"93f" = {scaled_scores[query_idx, j]:.4f}")9495print(f"\nStep 3 - Softmax weights:")96for j, t inenumerate(enc_tokens):97 bar ='#'*int(weights[query_idx, j]*40)98print(f" A[{dec_tok},{t}]"99f" = {weights[query_idx, j]:.4f} |{bar}|")100101print(f"\nStep 4 - Output (weighted sum of V_enc):")102print(f" O[{dec_tok}] = {output[query_idx]}")103print(f" Sum of weights"104f" = {weights[query_idx].sum():.6f}")105106107# ── Shared Example (used in every chapter) ──108tokens =["The","cat","sat","on","mat"]109110Q = np.array([111[1.0,0.0,1.0,0.0],# The112[0.0,2.0,0.0,1.0],# cat113[1.0,1.0,1.0,0.0],# sat114[0.0,0.0,1.0,1.0],# on115[1.0,0.0,0.0,1.0],# mat116])117118K = np.array([119[0.0,1.0,0.0,1.0],# The120[1.0,0.0,1.0,0.0],# cat121[1.0,1.0,0.0,0.0],# sat122[0.0,0.0,1.0,1.0],# on123[1.0,0.0,0.5,0.5],# mat124])125126V = np.array([127[1.0,0.0,0.0,0.0],# The128[0.0,1.0,0.0,0.0],# cat129[0.0,0.0,1.0,0.0],# sat130[0.0,0.0,0.0,1.0],# on131[0.5,0.5,0.5,0.5],# mat132])133134# ── Simulate decoder queries ──135# In a real transformer, Q_dec = X_dec @ W^Q_dec136# We simulate this with per-dimension scaling137decoder_scale = np.array([1.2,0.8,1.1,0.9])138Q_dec = Q * decoder_scale
139140print("Decoder Q (Q * [1.2, 0.8, 1.1, 0.9]):")141print(np.round(Q_dec,1))142143# ── Run ──144attn = CrossAttention(d_k=4)145weights, output = attn.forward(Q_dec, K, V)146147print("\nCross-Attention Weight Matrix (5x5):")148print(np.round(weights,4))149150print("\nOutput Matrix (5x4):")151print(np.round(output,4))152153# Detailed trace for "The" (decoder token 0)154attn.explain(Q_dec, K, V, tokens, tokens, query_idx=0)
PyTorch Implementation
The PyTorch version adds GPU support, automatic differentiation, batched inputs, and an optional mask. Note the different-length demo at the bottom showing a 3-token decoder attending to a 5-token encoder.
Cross-Attention \u2014 PyTorch Implementation
🐍cross_attention_torch.py
Explanation(14)
Code(121)
1Import PyTorch
torch is the core tensor library. torch.nn provides nn.Module for trainable components. torch.nn.functional provides stateless operations like softmax. math is used for sqrt.
6nn.Module subclass
By subclassing nn.Module, cross-attention gets GPU support (.cuda()), gradient tracking (autograd), and integration with PyTorch training loops. In a real transformer decoder layer, this module is instantiated alongside a self-attention module.
19def __init__(self, d_k: int)
Constructor. Same as self-attention — the scaling factor doesn’t change for cross-attention.
EXECUTION STATE
⬇ input: d_k = 4
self.scale = √4 = 2.0
23def forward(self, Q_dec, K_enc, V_enc, mask)
The key difference from Chapter 1: Q_dec comes from the decoder, K_enc and V_enc from the encoder. The attention matrix shape is (N_dec, N_enc), which is not necessarily square when the two sequences have different lengths.
torch.matmul(Q_dec, K_enc.transpose(-2, -1)) computes all pairwise scores between decoder queries and encoder keys. The .transpose(-2, -1) works for any number of batch dims.
EXECUTION STATE
K_enc.transpose(-2, -1) = swaps last two dims: (5,4) → (4,5). For 4D: only the inner (N,d_k) is transposed.
scores =
The cat sat on mat
The 0.00 2.30 1.20 1.10 1.75
cat 2.50 0.00 1.60 0.90 0.45
sat 0.80 2.30 2.00 1.10 1.75
on 0.90 1.10 0.00 2.00 1.00
mat 0.90 1.20 1.20 0.90 1.65
48Step 2: Scale
Divide by √d_k = 2.0. Same as self-attention.
EXECUTION STATE
scores (after scaling) =
The cat sat on mat
The 0.000 1.150 0.600 0.550 0.875
cat 1.250 0.000 0.800 0.450 0.225
sat 0.400 1.150 1.000 0.550 0.875
on 0.450 0.550 0.000 1.000 0.500
mat 0.450 0.600 0.600 0.450 0.825
51Step 3: Masking
In cross-attention, masking is used for padding (when encoder sequences in a batch have different lengths). Unlike causal self-attention, cross-attention typically does NOT use causal masks — the decoder should be able to attend to all encoder positions.
EXECUTION STATE
mask = None → skipped
typical use = padding mask: True where encoder tokens are padding
55Step 4: F.softmax(scores, dim=-1)
Softmax along dim=-1 (encoder dimension). Each decoder token distributes attention across all encoder tokens. The weights sum to 1.0 per decoder row.
EXECUTION STATE
dim=-1 = normalize along encoder positions. For (N_dec, N_enc), this normalizes across the N_enc columns.
weights =
The cat sat on mat
The 0.0989 0.3123 0.1802 0.1714 0.2372
cat 0.3660 0.1049 0.2334 0.1645 0.1313
sat 0.1297 0.2746 0.2364 0.1507 0.2086
on 0.1809 0.1999 0.1154 0.3136 0.1902
mat 0.1731 0.2011 0.2011 0.1731 0.2518
58Step 5: Weighted sum of V_enc
matmul(weights, V_enc): (N_dec, N_enc) @ (N_enc, d_v) → (N_dec, d_v). Each decoder output row is a blend of encoder value vectors.
EXECUTION STATE
output =
d0 d1 d2 d3
The 0.2175 0.4309 0.2988 0.2900
cat 0.4317 0.1705 0.2990 0.2301
sat 0.2340 0.3789 0.3407 0.2550
on 0.2760 0.2950 0.2105 0.4087
mat 0.2989 0.3269 0.3269 0.2989
65Shared example tensors
Same Q, K, V as the NumPy version, converted to torch.tensor. These are the encoder-side matrices used across all chapters.
EXECUTION STATE
Q.shape = torch.Size([5, 4])
K.shape = torch.Size([5, 4])
V.shape = torch.Size([5, 4])
89Q_dec = Q * decoder_scale
Element-wise multiplication with the decoder scaling factors. PyTorch broadcasts (4,) across all rows of (5,4).
EXECUTION STATE
decoder_scale = tensor([1.2, 0.8, 1.1, 0.9])
Q_dec =
d0 d1 d2 d3
The 1.2 0.0 1.1 0.0
cat 0.0 1.6 0.0 0.9
sat 1.2 0.8 1.1 0.0
on 0.0 0.0 1.1 0.9
mat 1.2 0.0 0.0 0.9
92attn(Q_dec, K, V)
Call the module. Note the asymmetry: Q_dec is the decoder query, K and V are encoder outputs. This is the signature difference from self-attention where all three come from the same source.
EXECUTION STATE
weights shape = torch.Size([5, 5])
output shape = torch.Size([5, 4])
99Different-length demo
The true power of cross-attention: the decoder and encoder can have different sequence lengths. Here 3 decoder tokens attend to 5 encoder tokens → a 3×5 attention matrix.
F.scaled_dot_product_attention works for cross-attention too — just pass different Q and K/V tensors. The built-in doesn’t require Q, K, V to come from the same source.
PyTorch's F.scaled_dot_product_attention supports cross-attention natively. Simply pass different Q and K/V tensors:
🐍builtin_cross_attention.py
1import torch
2import torch.nn.functional as F
34# Q from decoder, K and V from encoder5Q_dec = torch.randn(1,8,100,64)# (batch, heads, N_dec, d_k)6K_enc = torch.randn(1,8,500,64)# (batch, heads, N_enc, d_k)7V_enc = torch.randn(1,8,500,64)# (batch, heads, N_enc, d_v)89# Cross-attention: Q and K/V have different sequence lengths10output = F.scaled_dot_product_attention(Q_dec, K_enc, V_enc)11# output.shape: (1, 8, 100, 64) — 100 decoder tokens, each 64-dim1213# With padding mask (mask encoder padding tokens)14# mask shape: (1, 1, 100, 500) — broadcast over heads15padding_mask = torch.zeros(1,1,100,500, dtype=torch.bool)16padding_mask[:,:,:,450:]=True# last 50 encoder tokens are padding1718output = F.scaled_dot_product_attention(19 Q_dec, K_enc, V_enc, attn_mask=~padding_mask
20)# Note: attn_mask=True means ATTEND, opposite of mask convention
Key Takeaways
Cross-attention separates the source of Q from the source of K and V. Q comes from the decoder (the sequence being generated), while K and V come from the encoder (the source of information). This is the fundamental difference from self-attention.
The math is identical to scaled dot-product attention.softmax(QdecKenc⊤/dk)Venc uses the same operations as Chapter 1. The difference is purely in what the matrices represent, not how they are combined.
The attention matrix can be non-square. Unlike self-attention's N×N matrix, cross-attention produces an Ndec×Nenc matrix because the two sequences can have different lengths.
Encoder K and V are fixed during decoding. This means cross-attention's KV-cache is computed once and reused for every generated token — unlike decoder self-attention where the cache grows with each step.
Cross-attention is the universal bridge. It connects text to images (Stable Diffusion), audio to text (Whisper), source to target language (translation), and retrieved documents to generation (RAG). Anywhere two information streams need to communicate, cross-attention is the mechanism.
Exercises
Exercise 1: Identical Q
What happens if the decoder uses the exact same Q as the encoder (i.e., decoder_scale = [1, 1, 1, 1])? Run the code and compare the output to Chapter 1. Explain why the results are identical.
Exercise 2: Different Lengths
Modify the code to use only the first 3 rows of Q_dec as the decoder queries (simulating a 3-token decoder attending to a 5-token encoder). What is the shape of the attention weight matrix? Compute the output for the first decoder token by hand.
Exercise 3: Masked Cross-Attention
In encoder-decoder models with batched inputs, the encoder sequences may have different lengths. Shorter sequences are padded. Create a mask that blocks attention to padding positions (the last 2 encoder tokens) and recompute the attention weights. What happens to the weight distribution?
Exercise 4: Image Tokens
Imagine the encoder processes a 2×2 image grid (4 patches) and the decoder generates a 3-word caption. Create custom K and V matrices of shape (4, 4) representing image patch features, and Q_dec of shape (3, 4) representing caption tokens. Run cross-attention and interpret which patches each caption word attends to most.
References
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). “Attention Is All You Need.” Advances in Neural Information Processing Systems, 30.
Bahdanau, D., Cho, K., & Bengio, Y. (2014). “Neural Machine Translation by Jointly Learning to Align and Translate.” arXiv:1409.0473.
Sutskever, I., Vinyals, O., & Le, Q.V. (2014). “Sequence to Sequence Learning with Neural Networks.” Advances in Neural Information Processing Systems, 27.
Alayrac, J.B., Donahue, J., Luc, P., Miech, A., Barr, I., Hasson, Y., et al. (2022). “Flamingo: A Visual Language Model for Few-Shot Learning.” Advances in Neural Information Processing Systems, 35.
Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). “High-Resolution Image Synthesis with Latent Diffusion Models.” CVPR 2022.
Radford, A., Kim, J.W., Xu, T., Brockman, G., McLeavey, C., & Sutskever, I. (2023). “Robust Speech Recognition via Large-Scale Weak Supervision.” ICML 2023.
Li, X.L. & Liang, P. (2021). “Prefix-Tuning: Optimizing Continuous Prompts for Generation.” ACL 2021.
Izacard, G. & Grave, E. (2021). “Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering.” EACL 2021.