Chapter 10
30 min read
Section 11 of 17

Linear Attention

Linear Attention

Learning Objectives

By the end of this chapter, you will:

  1. Understand why the quadratic bottleneck of standard attention is the fundamental scalability limit of transformers, and what real problems it causes at sequence lengths beyond 4K tokens
  2. Master the kernel trick that replaces softmax with a feature map ϕ(x)\phi(x), enabling the associativity reordering that eliminates the N×NN \times N matrix
  3. Derive and compute the full linear attention formula step by step using our shared example "The cat sat on the mat", tracing every matrix multiplication and normalization
  4. Understand why linear attention produces flatter weight distributions than softmax attention, and what this means for model quality in practice
  5. Implement both the non-causal and causal (recurrent) forms in Python/NumPy and PyTorch, understanding how the recurrent form enables O(1)O(1) per-token inference
  6. Connect linear attention to modern systems including state-space models (Mamba), RetNet, RWKV, and the linear-attention renaissance in long-context architectures

The Real Problem

The Quadratic Wall

Every attention mechanism we have discussed so far — scaled dot-product, multi-head, causal, cross-attention, MQA, GQA, and all positional encoding variants — shares a fundamental bottleneck: they must compute and store the full N×NN \times N attention score matrix. The cost is O(N2d)O(N^2 \cdot d) in time and O(N2)O(N^2) in memory, where NN is the sequence length and dd is the head dimension.

To grasp why this matters, consider the concrete numbers:

Sequence Length (N)Attention Matrix SizeMemory (FP32)FLOPs (d=64)
512262K entries1 MB16.8M
2,0484.2M entries16 MB268M
4,09616.8M entries64 MB1.1B
32,7681.07B entries4 GB68.7B
100,00010B entries40 GB640B
1,000,0001T entries4 TB64T

At N=4,096N = 4{,}096 (a modest context for modern LLMs), the attention matrix alone consumes 64 MB per head per layer. A model with 32 heads and 40 layers needs 64 MB ×\times 32 ×\times 40 = 80 GB just for attention scores — exceeding the memory of an A100 GPU. At N=100,000N = 100{,}000 (a single novel or legal document), the matrix has 10 billion entries. Processing a million-token document is simply impossible with quadratic attention.

The Quadratic Wall: Standard attention works beautifully for short sequences (N < 2K) but hits a hard memory and compute wall that prevents scaling to long documents, code repositories, genomic sequences, and multi-modal data. This wall motivated an entire research agenda to find sub-quadratic alternatives.

Who Solved It

In 2020, Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret published "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" at ICML 2020. Their key insight was deceptively simple: if you replace softmax with a kernel feature map ϕ(x)\phi(x), you can exploit the associativity of matrix multiplication to reorder the computation and eliminate the N×NN \times N matrix entirely.

The paper title reveals their deeper insight: with this reordering, a transformer layer becomes mathematically equivalent to an RNN, processing one token at a time with a fixed-size hidden state. This means linear attention can perform autoregressive inference in O(1)O(1) time per token (constant, regardless of context length), compared to O(N)O(N) for standard attention with KV-cache.


The Kernel Trick Intuition

Matrix Multiplication is Associative

The entire idea behind linear attention rests on a property you learned in linear algebra: matrix multiplication is associative. For matrices AA, BB, CC with compatible dimensions:

(AB)C=A(BC)(A \cdot B) \cdot C = A \cdot (B \cdot C)

In standard attention, the computation order is fixed by softmax:

output=softmax ⁣(QKd)must compute N×N firstV\text{output} = \underbrace{\text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d}}\right)}_{\text{must compute } N \times N \text{ first}} \cdot V

The softmax operates on the N×NN \times N score matrix, so you must compute QKQ K^\top first, then apply softmax, then multiply by VV. You cannot rearrange the parentheses because softmax is a nonlinear, row-wise operation that breaks associativity.

But what if we replace softmax with something that preserves the linear structure? Suppose we have a feature map ϕ\phi such that the attention weight between query ii and key jj is simply:

aij=ϕ(qi)ϕ(kj)a_{ij} = \phi(q_i)^\top \phi(k_j)

Then the output for query ii (before normalization) is:

oi=jϕ(qi)ϕ(kj)vj=ϕ(qi)jϕ(kj)vjo_i = \sum_j \phi(q_i)^\top \phi(k_j) \cdot v_j = \phi(q_i)^\top \sum_j \phi(k_j) v_j^\top

Since ϕ(qi)\phi(q_i) does not depend on jj, we can pull it out of the sum. The inner sum jϕ(kj)vj\sum_j \phi(k_j) v_j^\top is a d×dd \times d matrix that can be precomputed once for all queries. This is the associativity trick:

ϕ(Q)(ϕ(K)V)Linear: O(Nd2)vs(ϕ(Q)ϕ(K))VQuadratic: O(N2d)\underbrace{\phi(Q) \cdot \left(\phi(K)^\top \cdot V\right)}_{\text{Linear: } O(N d^2)} \quad \text{vs} \quad \underbrace{\left(\phi(Q) \cdot \phi(K)^\top\right) \cdot V}_{\text{Quadratic: } O(N^2 d)}

The key insight: By computing ϕ(K)V\phi(K)^\top V first (a d×dd \times d matrix), then multiplying by ϕ(Q)\phi(Q), we never form the N×NN \times N matrix. The cost drops from O(N2d)O(N^2 d) to O(Nd2)O(N d^2). When NdN \gg d (which is almost always true in practice), this is a massive speedup.

The Feature Map

The feature map ϕ\phi must satisfy one critical constraint: it must produce non-negative values. This is because ϕ(qi)ϕ(kj)\phi(q_i)^\top \phi(k_j) represents an unnormalized attention weight, and negative weights would mean "anti-attending" to certain tokens, which breaks the probabilistic interpretation.

Katharopoulos et al. chose ϕ(x)=ELU(x)+1\phi(x) = \text{ELU}(x) + 1:

ϕ(x)=ELU(x)+1={x+1if x0exif x<0\phi(x) = \text{ELU}(x) + 1 = \begin{cases} x + 1 & \text{if } x \geq 0 \\ e^x & \text{if } x < 0 \end{cases}

This maps every real number to a strictly positive value: for x0x \geq 0, the output is x+11x + 1 \geq 1; for x<0x < 0, the output is ex(0,1)e^x \in (0, 1). It is smooth, differentiable, and cheap to compute.

Why not just use ReLU? ReLU maps negative values to exactly zero, which would make many entries in ϕ(K)V\phi(K)^\top V vanish, destroying information. ELU+1 keeps everything strictly positive, preserving a non-zero contribution from every key.

The Mathematical Definition

Symbol-by-Symbol Breakdown

The complete linear attention formula is:

Oi=ϕ(Qi)(j=1Nϕ(Kj)Vj)ϕ(Qi)j=1Nϕ(Kj)O_i = \frac{\phi(Q_i) \cdot \left(\sum_{j=1}^{N} \phi(K_j)^\top V_j\right)}{\phi(Q_i) \cdot \sum_{j=1}^{N} \phi(K_j)}

In matrix form:

O=ϕ(Q)(ϕ(K)V)ϕ(Q)1Nϕ(K)O = \frac{\phi(Q) \cdot \left(\phi(K)^\top \cdot V\right)}{\phi(Q) \cdot \mathbf{1}_{N}^\top \phi(K)}

Let us break this down symbol by symbol:

SymbolShapeMeaning
Q, KN × d_kQuery and key matrices (our familiar 5×4 matrices)
VN × d_vValue matrix (5×4 in our example)
φ(·)element-wiseFeature map ELU(x)+1 applied to every element
φ(Q)N × d_kTransformed queries (5×4, all values > 0)
φ(K)N × d_kTransformed keys (5×4, all values > 0)
φ(K)ᵀ · Vd_k × d_vTHE KEY: a small d×d summary matrix (4×4!)
φ(Q) · (φ(K)ᵀ V)N × d_vNumerator: unnormalized output (5×4)
Σ φ(K_j)d_kSum of all transformed key vectors (4-dim vector)
φ(Q) · Σφ(K)NDenominator: one normalization scalar per query (5 values)
ON × d_vFinal output (5×4) = numerator / denominator

Why Normalize?

Without normalization, the magnitude of the output would depend on the number of tokens and the magnitude of the feature-mapped keys. The denominator Zi=ϕ(Qi)jϕ(Kj)Z_i = \phi(Q_i) \cdot \sum_j \phi(K_j) ensures that the implicit attention weights for each query sum to 1, just like softmax. This keeps the output values in a reasonable range regardless of sequence length.

Think of it this way: the numerator computes a weighted sum of values where the weights are ϕ(qi)ϕ(kj)\phi(q_i)^\top \phi(k_j). The denominator divides by the total weight, converting raw scores into a proper weighted average. This is analogous to how softmax normalizes esije^{s_{ij}} by jesij\sum_j e^{s_{ij}}.


Step-by-Step Calculation

Let us compute linear attention for our shared example "The cat sat on the mat" using the same Q, K, V matrices as every other chapter.

Step 1: Apply Feature Map ϕ\phi to Q and K

Since all entries in our Q and K are 0\geq 0, the feature map simply adds 1 to every element:

TokenQ[i]φ(Q[i]) = Q[i] + 1
The[1.0, 0.0, 1.0, 0.0][2.0, 1.0, 2.0, 1.0]
cat[0.0, 2.0, 0.0, 1.0][1.0, 3.0, 1.0, 2.0]
sat[1.0, 1.0, 1.0, 0.0][2.0, 2.0, 2.0, 1.0]
on[0.0, 0.0, 1.0, 1.0][1.0, 1.0, 2.0, 2.0]
mat[1.0, 0.0, 0.0, 1.0][2.0, 1.0, 1.0, 2.0]
TokenK[i]φ(K[i]) = K[i] + 1
The[0.0, 1.0, 0.0, 1.0][1.0, 2.0, 1.0, 2.0]
cat[1.0, 0.0, 1.0, 0.0][2.0, 1.0, 2.0, 1.0]
sat[1.0, 1.0, 0.0, 0.0][2.0, 2.0, 1.0, 1.0]
on[0.0, 0.0, 1.0, 1.0][1.0, 1.0, 2.0, 2.0]
mat[1.0, 0.0, 0.5, 0.5][2.0, 1.0, 1.5, 1.5]

Step 2: Compute KV=ϕ(K)VKV = \phi(K)^\top \cdot V (shape 4×44 \times 4, NOT 5×55 \times 5!)

This is the critical computation. Instead of building a 5×55 \times 5 attention matrix, we compute a 4×44 \times 4 summary matrix:

KV[a,b]=i=04ϕ(K[i,a])V[i,b]KV[a, b] = \sum_{i=0}^{4} \phi(K[i, a]) \cdot V[i, b]

For row 0 of KV:

KV[0,:]=1×[1,0,0,0]+2×[0,1,0,0]+2×[0,0,1,0]+1×[0,0,0,1]+2×[0.5,0.5,0.5,0.5]KV[0,:] = 1{\times}[1,0,0,0] + 2{\times}[0,1,0,0] + 2{\times}[0,0,1,0] + 1{\times}[0,0,0,1] + 2{\times}[0.5,0.5,0.5,0.5]

=[1,0,0,0]+[0,2,0,0]+[0,0,2,0]+[0,0,0,1]+[1,1,1,1]=[2.0,3.0,3.0,2.0]= [1,0,0,0] + [0,2,0,0] + [0,0,2,0] + [0,0,0,1] + [1,1,1,1] = [2.0, 3.0, 3.0, 2.0]

v₀v₁v₂v₃
KV[0,:]2.003.003.002.00
KV[1,:]2.501.502.501.50
KV[2,:]1.752.751.752.75
KV[3,:]2.751.751.752.75
This 4×44 \times 4 matrix encodes all the information about how keys relate to values. At N=100,000N = 100{,}000 with d=64d = 64, this is 4,096 entries instead of 10 billion. That is the power of the associativity trick.

Step 3: Compute Numerator = ϕ(Q)KV\phi(Q) \cdot KV

Each query row is multiplied by the KV matrix. For "The" (row 0):

num[The]=[2,1,2,1]KV\text{num}[\text{The}] = [2, 1, 2, 1] \cdot KV

Computing each dimension:

  • dim 0: 2×2.00+1×2.50+2×1.75+1×2.75=12.752{\times}2.00 + 1{\times}2.50 + 2{\times}1.75 + 1{\times}2.75 = 12.75
  • dim 1: 2×3.00+1×1.50+2×2.75+1×1.75=14.752{\times}3.00 + 1{\times}1.50 + 2{\times}2.75 + 1{\times}1.75 = 14.75
  • dim 2: 2×3.00+1×2.50+2×1.75+1×1.75=13.752{\times}3.00 + 1{\times}2.50 + 2{\times}1.75 + 1{\times}1.75 = 13.75
  • dim 3: 2×2.00+1×1.50+2×2.75+1×2.75=13.752{\times}2.00 + 1{\times}1.50 + 2{\times}2.75 + 1{\times}2.75 = 13.75
TokenNumerator
The[12.7500, 14.7500, 13.7500, 13.7500]
cat[16.7500, 13.7500, 15.7500, 14.7500]
sat[15.2500, 16.2500, 16.2500, 15.2500]
on[13.5000, 13.5000, 12.5000, 14.5000]
mat[13.7500, 13.7500, 13.7500, 13.7500]

Step 4: Compute Denominator = ϕ(Q)ϕ(K)\phi(Q) \cdot \sum \phi(K)

First, sum all ϕ(K)\phi(K) rows:

ϕ(K)=[1+2+2+1+2,  2+1+2+1+1,  1+2+1+2+1.5,  2+1+1+2+1.5]=[8.0,7.0,7.5,7.5]\sum \phi(K) = [1{+}2{+}2{+}1{+}2, \; 2{+}1{+}2{+}1{+}1, \; 1{+}2{+}1{+}2{+}1.5, \; 2{+}1{+}1{+}2{+}1.5] = [8.0, 7.0, 7.5, 7.5]

Then dot each ϕ(Q)\phi(Q) row with this sum:

Tokenφ(Q[i])Dot with [8, 7, 7.5, 7.5]Denominator
The[2, 1, 2, 1]2×8 + 1×7 + 2×7.5 + 1×7.545.5
cat[1, 3, 1, 2]1×8 + 3×7 + 1×7.5 + 2×7.551.5
sat[2, 2, 2, 1]2×8 + 2×7 + 2×7.5 + 1×7.552.5
on[1, 1, 2, 2]1×8 + 1×7 + 2×7.5 + 2×7.545.0
mat[2, 1, 1, 2]2×8 + 1×7 + 1×7.5 + 2×7.545.5

Step 5: Final Output = Numerator / Denominator

Divide each row of the numerator by its denominator:

TokenNumerator / DenomOutput
The[12.75, 14.75, 13.75, 13.75] / 45.5[0.2802, 0.3242, 0.3022, 0.3022]
cat[16.75, 13.75, 15.75, 14.75] / 51.5[0.3252, 0.2670, 0.3058, 0.2864]
sat[15.25, 16.25, 16.25, 15.25] / 52.5[0.2905, 0.3095, 0.3095, 0.2905]
on[13.50, 13.50, 12.50, 14.50] / 45.0[0.3000, 0.3000, 0.2778, 0.3222]
mat[13.75, 13.75, 13.75, 13.75] / 45.5[0.3022, 0.3022, 0.3022, 0.3022]

Interactive: The Kernel Trick Step-by-Step

Click through each step to see exactly how linear attention avoids the N×NN \times N matrix. Watch the shapes: the intermediate matrices are always d×dd \times d (4×4), never N×NN \times N (5×5).

Loading kernel trick visualizer...

Full Attention Weights and Output

Although linear attention never explicitly builds the N×NN \times N weight matrix, we can compute implicit attention weights to understand what the mechanism is doing. The implicit weight from query ii to key jj is:

wij=ϕ(qi)ϕ(kj)lϕ(qi)ϕ(kl)w_{ij} = \frac{\phi(q_i)^\top \phi(k_j)}{\sum_l \phi(q_i)^\top \phi(k_l)}

Interpreting the Weights

Query \ KeyThecatsatonmat
The0.17580.21980.19780.19780.2088
cat0.23300.17480.21360.19420.1845
sat0.19050.20950.20950.19050.2000
on0.20000.20000.17780.22220.2000
mat0.19780.19780.19780.19780.2088

Compare this with the softmax attention weights from Chapter 1:

Query \ KeyThecatsatonmat
The0.10950.29760.18050.18050.2318
cat0.40260.08980.24420.14810.1153
sat0.15190.25050.25050.15190.1951
on0.19030.19030.11540.31370.1903
mat0.18920.18920.18920.18920.2430

The difference is striking:

  • Softmax attention is peaked: "cat" gives 40.3% of its attention to "The" and only 9.0% to itself. It has strong preferences.
  • Linear attention is flat: "cat" gives 23.3% to "The" and 17.5% to itself. All weights are close to the uniform value of 20%.
  • Standard deviation: Softmax row std ranges from 0.02 to 0.11. Linear row std ranges from 0.004 to 0.02. Linear attention is 3-5x more uniform.
The Quality-Speed Trade-off: Linear attention is faster because it avoids the exponential nonlinearity of softmax. But that same exponential is what creates the sharp, peaked distributions that allow standard attention to selectively focus on the most relevant tokens. Linear attention trades discrimination power for computational efficiency.

Output Matrix — Linear Attention

Tokendim-0dim-1dim-2dim-3
The0.28020.32420.30220.3022
cat0.32520.26700.30580.2864
sat0.29050.30950.30950.2905
on0.30000.30000.27780.3222
mat0.30220.30220.30220.3022

Note how the linear attention outputs are closer to the mean (0.30) compared to softmax outputs, which show more variance. This reflects the flatter attention distributions.


Interactive: Linear vs Softmax Heatmap

Toggle between linear and softmax attention weights to see the difference in selectivity. Click a row to see the side-by-side bar chart comparison.

Loading heatmap comparison...

The Recurrent Form (Causal Linear Attention)

The most remarkable property of linear attention is that it can be reformulated as an RNN. This is why Katharopoulos et al. titled their paper "Transformers are RNNs." The recurrent form processes tokens one at a time using a fixed-size hidden state, enabling O(1)O(1) per-token inference — constant time regardless of how many tokens have been processed.

The Recurrent Equations

Define two running states that are updated as each new token arrives:

Si=Si1+ϕ(ki)viRdk×dvS_i = S_{i-1} + \phi(k_i) \cdot v_i^\top \quad \in \mathbb{R}^{d_k \times d_v}

zi=zi1+ϕ(ki)Rdkz_i = z_{i-1} + \phi(k_i) \quad \in \mathbb{R}^{d_k}

The output for token ii is then:

oi=ϕ(qi)Siϕ(qi)zio_i = \frac{\phi(q_i)^\top S_i}{\phi(q_i)^\top z_i}

SiS_i is a dk×dvd_k \times d_v matrix (4×4 in our example) that accumulates the outer products of all keys and values seen so far. ziz_i is a dkd_k-dimensional vector that accumulates the sum of all transformed keys. Both grow by a fixed amount per token — the memory is O(dkdv)O(d_k \cdot d_v) regardless of sequence length.

PropertyStandard AttentionLinear Attention (Recurrent)
Per-token computeO(N · d) — recompute all attentionO(d²) — update S and z
Memory per layerO(N · d) — full KV cacheO(d²) — just S and z
Total for N tokensO(N² · d)O(N · d²)
Context length limitBounded by KV cache memoryUnlimited (fixed state)

Interactive: Recurrent Processing

Click each token to see how the running state SS and zz accumulate information. Notice how the state size stays fixed at 4×4 regardless of how many tokens have been processed.

Loading recurrent form visualizer...

Complexity Analysis

The complexity comparison between standard and linear attention depends on the relationship between NN (sequence length) and dd (head dimension):

MetricStandard AttentionLinear AttentionWhen Linear Wins
TimeO(N² · d)O(N · d²)N > d (almost always)
MemoryO(N²)O(N · d + d²)N > d
Per-token (causal)O(N · d)O(d²)Always (N cancels)
Speedup factorN / de.g., 4096/64 = 64x

The crossover point is N=dN = d. When N<dN < d (extremely rare in practice), standard attention is actually more efficient. But for typical values (d=64d = 64 or 128128, N>512N > 512), linear attention wins by a factor of N/dN/d.

Interactive: Complexity Explorer

Drag the slider to change the head dimension dd and see how the speedup scales with sequence length. The bars use a logarithmic scale.

Loading complexity chart...

Applications Across Domains

Linear attention's O(N)O(N) scaling opens doors that quadratic attention cannot enter:

DomainWhy Linear Attention MattersExample Systems
Long-document NLPProcess entire books (100K+ tokens) in a single pass without chunking or hierarchical attentionLongT5, FNet
GenomicsDNA sequences can be millions of bases long. O(N²) is completely infeasibleEnformer variants, genomic foundation models
Time-series forecastingFinancial data, sensor streams, and weather models have extremely long sequencesAutoformer, linear-attention forecasters
Image generationHigh-resolution images (1024×1024) have 1M+ pixel patches. Quadratic attention requires terabytesLinear Transformer GAN, efficient ViT variants
Edge/mobile deploymentThe recurrent form uses O(d²) memory per layer — feasible on devices with <1GB RAMOn-device language models
Streaming/real-timeProcess infinite streams without growing memory. Each token is O(d²) = O(1) w.r.t. NReal-time translation, live captioning

Connection to Modern Systems

Linear attention was one of the earliest sub-quadratic attention mechanisms (2020), and it directly inspired a family of modern architectures:

  • State-Space Models (Mamba, S4): Mamba (Gu & Dao, 2023) is essentially a selective linear attention with a data-dependent gating mechanism. The core idea — maintaining a fixed-size hidden state updated per token — comes directly from linear attention's recurrent form. Mamba adds input-dependent selection of which information to keep or forget.
  • RetNet (Sun et al., 2023): Retention networks explicitly decompose attention into a recurrent form with exponential decay, directly building on the "Transformers are RNNs" insight. RetNet achieves the training parallelism of transformers with the inference efficiency of RNNs.
  • RWKV (Peng et al., 2023): Another linear-attention-inspired architecture that achieves RNN-like inference with transformer-like training. Uses a different feature map (WKV mechanism) but the same core principle of avoiding the N×N matrix.
  • Flash Attention (Dao et al., 2022): Takes the opposite approach — instead of changing the math (removing softmax), Flash Attention keeps exact softmax but optimizes the memory access pattern. Flash Attention is an IO-aware algorithm; linear attention is an algebraic reformulation. They solve the same problem from different angles.
  • Flash Linear Attention (Yang et al., 2024): Combines both ideas — applies Flash Attention's IO-aware tiling to the linear attention computation. Shows that the two approaches are complementary, not competing.
  • GLA (Gated Linear Attention): Adds a learnable gate to the recurrent state update, allowing the model to selectively forget old information. This addresses linear attention's tendency to accumulate noise in the state over very long sequences.
The linear attention renaissance: After being overshadowed by Flash Attention (which made quadratic attention fast enough for most uses), linear attention is experiencing a resurgence in 2024-2025 through models like Mamba-2, RWKV-6, and GLA. The key insight: for very long contexts (100K+), no amount of IO optimization can overcome the fundamental O(N2)O(N^2) scaling.

Python Implementation

The full Python class implementation with both non-causal and causal (recurrent) forms. Click any line to see the exact values flowing through the computation.

Linear Attention \u2014 NumPy Implementation
🐍linear_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. phi(K).T @ V runs as optimized C code, not Python loops.

2import math

Python standard library. Not used directly in this implementation but available for scaling computations.

4class LinearAttention

Wraps the linear attention mechanism in a reusable class. The key innovation: never builds the N×N attention matrix. Instead, uses the associativity of matrix multiplication to compute the result in O(N·d²) time.

17def __init__(self, d_k: int)

Constructor. Takes one parameter d_k (dimension of query/key vectors). Unlike standard attention, no scaling factor is needed because there is no softmax to saturate.

EXECUTION STATE
⬇ input: d_k = 4
22self.d_k = d_k

Store d_k for use in the causal form where we initialize the running state matrix S as d_k × d_v.

EXECUTION STATE
self.d_k = 4
24def _phi(self, x) → np.ndarray

The feature map function. Takes any matrix x and applies ELU(x) + 1 element-wise. This replaces softmax in standard attention. The +1 ensures all outputs are strictly positive (required for a valid kernel).

EXECUTION STATE
⬇ input: x = any matrix (e.g., Q with shape 5×4 or K with shape 5×4)
⬆ returns = np.ndarray same shape as x, all values > 0
26return np.where(x >= 0, x + 1, np.exp(x))

Element-wise: if x ≥ 0, return x+1. If x < 0, return e^x. Both branches are always positive. For our Q matrix with values [0, 1, 2], this gives [1, 2, 3]. For negative values, e^x is between 0 and 1.

EXECUTION STATE
np.where = element-wise conditional: np.where(condition, if_true, if_false)
x >= 0 branch = x + 1 → always ≥ 1 (e.g., 0→1, 1→2, 2→3)
x < 0 branch = np.exp(x) → always in (0, 1) (e.g., -1→0.368, -2→0.135)
── Example: phi(Q) ── =
Q[The] = [1,0,1,0] = φ: [1+1, 0+1, 1+1, 0+1] = [2, 1, 2, 1]
Q[cat] = [0,2,0,1] = φ: [0+1, 2+1, 0+1, 1+1] = [1, 3, 1, 2]
Q[sat] = [1,1,1,0] = φ: [1+1, 1+1, 1+1, 0+1] = [2, 2, 2, 1]
Q[on] = [0,0,1,1] = φ: [0+1, 0+1, 1+1, 1+1] = [1, 1, 2, 2]
Q[mat] = [1,0,0,1] = φ: [1+1, 0+1, 0+1, 1+1] = [2, 1, 1, 2]
28def compute_kv(self, K, V) → np.ndarray

THE KEY COMPUTATION. Computes phi(K)^T @ V — a small d_k × d_v matrix (4×4 in our example) that summarizes all key-value interactions. This is what makes linear attention linear: we never build the N×N matrix.

EXECUTION STATE
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      v0   v1   v2   v3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = np.ndarray (4, 4) — compact KV summary
30Kf = self._phi(K)

Apply the feature map to K. Every element is transformed to be strictly positive.

EXECUTION STATE
φ(K) (5×4) =
      d0   d1   d2   d3
The  1.0  2.0  1.0  2.0
cat  2.0  1.0  2.0  1.0
sat  2.0  2.0  1.0  1.0
on   1.0  1.0  2.0  2.0
mat  2.0  1.0  1.5  1.5
31return Kf.T @ V

Matrix multiply: φ(K)ᵀ (4×5) × V (5×4) = KV (4×4). This 4×4 matrix is the compact summary of ALL key-value relationships. At N=100K with d=64, this is 4,096 entries instead of 10 billion!

EXECUTION STATE
.T = NumPy transpose — φ(K)(5×4) becomes φ(K)ᵀ(4×5)
φ(K)ᵀ (4×5) =
     The   cat   sat    on   mat
d0  1.0   2.0   2.0   1.0   2.0
d1  2.0   1.0   2.0   1.0   1.0
d2  1.0   2.0   1.0   2.0   1.5
d3  2.0   1.0   1.0   2.0   1.5
⬆ return: KV (4×4) =
     v0     v1     v2     v3
d0  2.00   3.00   3.00   2.00
d1  2.50   1.50   2.50   1.50
d2  1.75   2.75   1.75   2.75
d3  2.75   1.75   1.75   2.75
33def compute_key_sum(self, K) → np.ndarray

Sum all φ(K) rows into a single d-dimensional vector. This will be the denominator’s target for the dot product with each φ(Q) row.

EXECUTION STATE
⬇ input: K (5×4) = the raw K matrix
⬆ returns = np.ndarray (4,) — column-wise sum of φ(K)
35Kf = self._phi(K)

Apply feature map to K (same as in compute_kv). In a real implementation, you would compute φ(K) once and reuse it.

EXECUTION STATE
Kf = same φ(K) as above
36return Kf.sum(axis=0)

Sum along axis=0 (sum each column across all 5 rows) to produce a single 4-element vector.

EXECUTION STATE
axis=0 = sum along the FIRST axis (rows). For a 5×4 matrix, this sums all 5 rows into one row of length 4.
⬆ return: K_sum = [8.0, 7.0, 7.5, 7.5]
d0: 1+2+2+1+2 = = 8.0
d1: 2+1+2+1+1 = = 7.0
d2: 1+2+1+2+1.5 = = 7.5
d3: 2+1+1+2+1.5 = = 7.5
38def compute_numerator(self, Q, KV) → np.ndarray

Multiply φ(Q) (5×4) by KV (4×4) to get the unnormalized output (5×4). Each row is the weighted sum of value information for that query.

EXECUTION STATE
⬇ input: Q (5×4) = raw Q matrix
⬇ input: KV (4×4) = precomputed φ(K)ᵀ @ V
⬆ returns = np.ndarray (5, 4) — unnormalized output
40Qf = self._phi(Q)

Apply feature map to Q.

EXECUTION STATE
φ(Q) (5×4) =
      d0   d1   d2   d3
The  2.0  1.0  2.0  1.0
cat  1.0  3.0  1.0  2.0
sat  2.0  2.0  2.0  1.0
on   1.0  1.0  2.0  2.0
mat  2.0  1.0  1.0  2.0
41return Qf @ KV

Matrix multiply φ(Q) (5×4) × KV (4×4) = numerator (5×4). Each row i is the dot product of φ(Q[i]) with each column of KV.

EXECUTION STATE
@ = Python matrix multiplication operator
── Row 0 (The): φ(Q[The]) = [2,1,2,1] ── =
dim0: 2×2.00 + 1×2.50 + 2×1.75 + 1×2.75 = = 12.7500
dim1: 2×3.00 + 1×1.50 + 2×2.75 + 1×1.75 = = 14.7500
dim2: 2×3.00 + 1×2.50 + 2×1.75 + 1×1.75 = = 13.7500
dim3: 2×2.00 + 1×1.50 + 2×2.75 + 1×2.75 = = 13.7500
⬆ return: numerator (5×4) =
        v0       v1       v2       v3
The  12.7500  14.7500  13.7500  13.7500
cat  16.7500  13.7500  15.7500  14.7500
sat  15.2500  16.2500  16.2500  15.2500
on   13.5000  13.5000  12.5000  14.5000
mat  13.7500  13.7500  13.7500  13.7500
43def compute_denominator(self, Q, K_sum) → np.ndarray

Compute the normalization factor for each query. Dot product of each φ(Q[i]) with the sum of all φ(K) vectors.

EXECUTION STATE
⬇ input: Q (5×4) = raw Q matrix
⬇ input: K_sum (4,) = [8.0, 7.0, 7.5, 7.5]
⬆ returns = np.ndarray (5,) — one scalar per token
45Qf = self._phi(Q)

Apply feature map to Q (same as in compute_numerator).

EXECUTION STATE
Qf = same φ(Q) as above
46return Qf @ K_sum

Matrix-vector multiply: φ(Q) (5×4) × K_sum (4,) = denominator (5,). Each entry is the dot product of a query row with the summed keys.

EXECUTION STATE
── The: [2,1,2,1] · [8,7,7.5,7.5] ── = 2×8 + 1×7 + 2×7.5 + 1×7.5 = 45.5
── cat: [1,3,1,2] · [8,7,7.5,7.5] ── = 1×8 + 3×7 + 1×7.5 + 2×7.5 = 51.5
── sat: [2,2,2,1] · [8,7,7.5,7.5] ── = 2×8 + 2×7 + 2×7.5 + 1×7.5 = 52.5
── on: [1,1,2,2] · [8,7,7.5,7.5] ── = 1×8 + 1×7 + 2×7.5 + 2×7.5 = 45.0
── mat: [2,1,1,2] · [8,7,7.5,7.5] ── = 2×8 + 1×7 + 1×7.5 + 2×7.5 = 45.5
⬆ return: denom = [45.5, 51.5, 52.5, 45.0, 45.5]
48def forward(self, Q, K, V)

Full forward pass. Computes both the output and the implicit attention weights. In practice, you would NOT compute the implicit weights (that’s O(N²)); they are shown here only for educational comparison with softmax attention.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      v0   v1   v2   v3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = (weights: 5×5, output: 5×4)
61Qf = self._phi(Q)

Apply feature map to Q. All values become ≥ 1.

EXECUTION STATE
Qf (5×4) =
      d0   d1   d2   d3
The  2.0  1.0  2.0  1.0
cat  1.0  3.0  1.0  2.0
sat  2.0  2.0  2.0  1.0
on   1.0  1.0  2.0  2.0
mat  2.0  1.0  1.0  2.0
62Kf = self._phi(K)

Apply feature map to K. All values become ≥ 1.

EXECUTION STATE
Kf (5×4) =
      d0   d1   d2   d3
The  1.0  2.0  1.0  2.0
cat  2.0  1.0  2.0  1.0
sat  2.0  2.0  1.0  1.0
on   1.0  1.0  2.0  2.0
mat  2.0  1.0  1.5  1.5
64KV = Kf.T @ V

The critical step: compute the 4×4 summary matrix instead of a 5×5 attention matrix.

EXECUTION STATE
KV (4×4) =
     v0     v1     v2     v3
d0  2.00   3.00   3.00   2.00
d1  2.50   1.50   2.50   1.50
d2  1.75   2.75   1.75   2.75
d3  2.75   1.75   1.75   2.75
65K_sum = Kf.sum(axis=0)

Column-wise sum of φ(K) — one scalar per dimension.

EXECUTION STATE
axis=0 = sum along first axis (rows)
K_sum = [8.0, 7.0, 7.5, 7.5]
67num = Qf @ KV

Numerator: φ(Q) (5×4) × KV (4×4) = (5×4).

EXECUTION STATE
num (5×4) =
        v0       v1       v2       v3
The  12.7500  14.7500  13.7500  13.7500
cat  16.7500  13.7500  15.7500  14.7500
sat  15.2500  16.2500  16.2500  15.2500
on   13.5000  13.5000  12.5000  14.5000
mat  13.7500  13.7500  13.7500  13.7500
68denom = Qf @ K_sum

Denominator: dot product of each φ(Q) row with the key sum vector.

EXECUTION STATE
denom (5,) = [45.5, 51.5, 52.5, 45.0, 45.5]
69denom = np.maximum(denom, 1e-6)

Clamp denominator to avoid division by zero. With ELU+1 feature map and positive inputs, this rarely activates, but is a safety net.

EXECUTION STATE
np.maximum = element-wise max — replaces any value < 1e-6 with 1e-6
denom (unchanged) = [45.5, 51.5, 52.5, 45.0, 45.5] (all >> 1e-6)
71output = num / denom[:, None]

Divide each row of the numerator by its corresponding denominator. The [:, None] reshapes denom from (5,) to (5,1) for broadcasting.

EXECUTION STATE
[:, None] = reshape (5,) → (5,1) so division broadcasts: num(5×4) / denom(5×1) divides each row by its scalar
── Row 0 (The) ── =
num = [12.75, 14.75, 13.75, 13.75]
denom = 45.5
output = [0.2802, 0.3242, 0.3022, 0.3022]
── Row 1 (cat) ── =
num = [16.75, 13.75, 15.75, 14.75]
denom = 51.5
output = [0.3252, 0.2670, 0.3058, 0.2864]
── Row 2 (sat) ── =
num = [15.25, 16.25, 16.25, 15.25]
denom = 52.5
output = [0.2905, 0.3095, 0.3095, 0.2905]
── Row 3 (on) ── =
num = [13.50, 13.50, 12.50, 14.50]
denom = 45.0
output = [0.3000, 0.3000, 0.2778, 0.3222]
── Row 4 (mat) ── =
num = [13.75, 13.75, 13.75, 13.75]
denom = 45.5
output = [0.3022, 0.3022, 0.3022, 0.3022]
74raw_weights = Qf @ Kf.T

Compute implicit attention weights (for analysis ONLY). This is the N×N matrix we deliberately avoid in the actual computation. Each entry is φ(Q[i]) · φ(K[j]).

EXECUTION STATE
raw_weights (5×5) =
        The     cat     sat      on     mat
The    8.00  10.00   9.00   9.00   9.50
cat   12.00   9.00  11.00  10.00   9.50
sat   10.00  11.00  11.00  10.00  10.50
on     9.00   9.00   8.00  10.00   9.00
mat    9.00   9.00   9.00   9.00   9.50
75weights = raw_weights / raw_weights.sum(axis=-1, keepdims=True)

Normalize each row to sum to 1.0. This is the linear attention equivalent of softmax: a simple division instead of exp+normalize.

EXECUTION STATE
axis=-1 = sum along last axis (each row independently)
keepdims=True = keep shape as (5,1) for broadcasting
weights (5×5) =
        The     cat     sat      on     mat
The  0.1758  0.2198  0.1978  0.1978  0.2088
cat  0.2330  0.1748  0.2136  0.1942  0.1845
sat  0.1905  0.2095  0.2095  0.1905  0.2000
on   0.2000  0.2000  0.1778  0.2222  0.2000
mat  0.1978  0.1978  0.1978  0.1978  0.2088
154 lines without explanation
1import numpy as np
2import math
3
4class LinearAttention:
5    """
6    Linear Attention (Katharopoulos et al., 2020)
7
8    Replaces softmax with a kernel feature map phi(x) = ELU(x) + 1,
9    enabling the associativity trick:
10        Standard: (phi(Q) @ phi(K)^T) @ V   -> O(N^2 d)
11        Linear:   phi(Q) @ (phi(K)^T @ V)   -> O(N d^2)
12
13    The key insight: compute phi(K)^T @ V first (a d x d matrix),
14    then multiply by phi(Q). Never builds the N x N attention matrix.
15    """
16
17    def __init__(self, d_k: int):
18        """
19        Args:
20            d_k: Dimension of query/key vectors
21        """
22        self.d_k = d_k
23
24    def _phi(self, x: np.ndarray) -> np.ndarray:
25        """Feature map: ELU(x) + 1, always strictly positive."""
26        return np.where(x >= 0, x + 1, np.exp(x))
27
28    def compute_kv(self, K: np.ndarray, V: np.ndarray) -> np.ndarray:
29        """Step 1: Precompute KV = phi(K)^T @ V. Shape: (d_k, d_v)."""
30        Kf = self._phi(K)
31        return Kf.T @ V
32
33    def compute_key_sum(self, K: np.ndarray) -> np.ndarray:
34        """Step 2: Sum of all phi(K) rows. Shape: (d_k,)."""
35        Kf = self._phi(K)
36        return Kf.sum(axis=0)
37
38    def compute_numerator(self, Q: np.ndarray, KV: np.ndarray) -> np.ndarray:
39        """Step 3: Numerator = phi(Q) @ KV. Shape: (N, d_v)."""
40        Qf = self._phi(Q)
41        return Qf @ KV
42
43    def compute_denominator(self, Q: np.ndarray, K_sum: np.ndarray) -> np.ndarray:
44        """Step 4: Denominator = phi(Q) @ sum(phi(K)). Shape: (N,)."""
45        Qf = self._phi(Q)
46        return Qf @ K_sum
47
48    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
49        """
50        Full forward pass.
51
52        Args:
53            Q: Query matrix  (N, d_k)
54            K: Key matrix    (N, d_k)
55            V: Value matrix  (N, d_v)
56
57        Returns:
58            weights: Implicit attention weights  (N, N)
59            output:  Context-enriched output     (N, d_v)
60        """
61        Qf = self._phi(Q)
62        Kf = self._phi(K)
63
64        KV = Kf.T @ V
65        K_sum = Kf.sum(axis=0)
66
67        num = Qf @ KV
68        denom = Qf @ K_sum
69        denom = np.maximum(denom, 1e-6)
70
71        output = num / denom[:, None]
72
73        # Implicit weights (for analysis only — never built in practice)
74        raw_weights = Qf @ Kf.T
75        weights = raw_weights / raw_weights.sum(axis=-1, keepdims=True)
76
77        return weights, output
78
79    def forward_causal(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
80        """
81        Causal (autoregressive) form — processes one token at a time.
82        Uses running state S and z, so each new token costs O(d^2).
83
84        Args:
85            Q, K, V: Same as forward()
86
87        Returns:
88            outputs: List of output vectors, one per token
89        """
90        N = Q.shape[0]
91        d_v = V.shape[1]
92        Qf = self._phi(Q)
93        Kf = self._phi(K)
94
95        S = np.zeros((self.d_k, d_v))
96        z = np.zeros(self.d_k)
97        outputs = []
98
99        for i in range(N):
100            S += np.outer(Kf[i], V[i])
101            z += Kf[i]
102            num_i = Qf[i] @ S
103            denom_i = max(Qf[i] @ z, 1e-6)
104            outputs.append(num_i / denom_i)
105
106        return np.array(outputs)
107
108    def explain(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,
109                tokens: list, query_idx: int = 0):
110        """Print a detailed trace for a specific query token."""
111        Qf = self._phi(Q)
112        Kf = self._phi(K)
113
114        KV = Kf.T @ V
115        K_sum = Kf.sum(axis=0)
116
117        num = Qf @ KV
118        denom = Qf @ K_sum
119
120        token = tokens[query_idx]
121        print(f"\n=== Linear Attention trace for '{token}' (row {query_idx}) ===")
122        print(f"Q[{query_idx}] = {Q[query_idx]}")
123        print(f"phi(Q[{query_idx}]) = {Qf[query_idx]}")
124        print(f"\nKV matrix (d_k x d_v):")
125        print(np.round(KV, 4))
126        print(f"\nK_sum = {K_sum}")
127        print(f"\nNumerator = phi(Q) @ KV = {np.round(num[query_idx], 4)}")
128        print(f"Denominator = phi(Q) @ K_sum = {denom[query_idx]:.4f}")
129        print(f"\nOutput = num / denom = {np.round(num[query_idx] / denom[query_idx], 4)}")
130
131        # Implicit weights
132        raw = np.array([np.dot(Qf[query_idx], Kf[j]) for j in range(len(tokens))])
133        w = raw / raw.sum()
134        print(f"\nImplicit attention weights:")
135        for j, t in enumerate(tokens):
136            bar = '#' * int(w[j] * 40)
137            print(f"  A[{token},{t}] = {w[j]:.4f} |{bar}|")
138
139
140# ── Shared Example (used in every chapter) ──
141tokens = ["The", "cat", "sat", "on", "mat"]
142
143Q = np.array([
144    [1.0, 0.0, 1.0, 0.0],   # The
145    [0.0, 2.0, 0.0, 1.0],   # cat
146    [1.0, 1.0, 1.0, 0.0],   # sat
147    [0.0, 0.0, 1.0, 1.0],   # on
148    [1.0, 0.0, 0.0, 1.0],   # mat
149])
150
151K = np.array([
152    [0.0, 1.0, 0.0, 1.0],   # The
153    [1.0, 0.0, 1.0, 0.0],   # cat
154    [1.0, 1.0, 0.0, 0.0],   # sat
155    [0.0, 0.0, 1.0, 1.0],   # on
156    [1.0, 0.0, 0.5, 0.5],   # mat
157])
158
159V = np.array([
160    [1.0, 0.0, 0.0, 0.0],   # The
161    [0.0, 1.0, 0.0, 0.0],   # cat
162    [0.0, 0.0, 1.0, 0.0],   # sat
163    [0.0, 0.0, 0.0, 1.0],   # on
164    [0.5, 0.5, 0.5, 0.5],   # mat
165])
166
167# ── Run ──
168attn = LinearAttention(d_k=4)
169weights, output = attn.forward(Q, K, V)
170
171print("Implicit Attention Weights (5x5):")
172print(np.round(weights, 4))
173
174print("\nOutput Matrix (5x4):")
175print(np.round(output, 4))
176
177# Detailed trace for "The" (token 0)
178attn.explain(Q, K, V, tokens, query_idx=0)
179
180# Causal form
181print("\n=== Causal (Recurrent) Form ===")
182causal_output = attn.forward_causal(Q, K, V)
183print("Causal Output:")
184print(np.round(causal_output, 4))

PyTorch Implementation

The production-ready PyTorch implementation uses torch.einsum\texttt{torch.einsum} for clean batched operations and supports both causal and non-causal modes.

Linear Attention \u2014 PyTorch Implementation
🐍linear_attention_torch.py
1import torch

PyTorch provides GPU-accelerated tensor operations. torch.einsum gives flexible Einstein summation for batched matrix operations.

2import torch.nn as nn

Neural network module. nn.Module is the base class; nn.Linear provides learnable weight matrices.

3import torch.nn.functional as F

Functional API. We use F.elu() for the ELU activation function (the core of our feature map).

5class LinearAttention(nn.Module)

Production-ready linear attention as a PyTorch module. Includes learnable projections W_q, W_k, W_v and supports both causal and non-causal modes.

13def __init__(self, d_model, d_k, d_v)

Constructor. Creates three learnable projection matrices.

EXECUTION STATE
⬇ input: d_model = input embedding dimension (e.g., 512)
⬇ input: d_k = query/key dimension (e.g., 64)
⬇ input: d_v = value dimension (e.g., 64)
15self.d_k = d_k

Store key dimension for potential scaling.

EXECUTION STATE
self.d_k = 64 (typical)
16self.W_q = nn.Linear(d_model, d_k, bias=False)

Query projection: maps input embeddings from d_model to d_k dimensions. No bias for consistency with standard attention.

EXECUTION STATE
W_q shape = (d_k, d_model) = e.g. (64, 512)
17self.W_k = nn.Linear(d_model, d_k, bias=False)

Key projection: same shape as W_q.

EXECUTION STATE
W_k shape = (d_k, d_model) = e.g. (64, 512)
18self.W_v = nn.Linear(d_model, d_v, bias=False)

Value projection: maps to d_v dimensions. In practice d_v often equals d_k.

EXECUTION STATE
W_v shape = (d_v, d_model) = e.g. (64, 512)
20def _phi(self, x) → torch.Tensor

Feature map using PyTorch’s built-in F.elu(). Identical to the NumPy version but runs on GPU.

EXECUTION STATE
⬇ input: x = any tensor (e.g., shape B×N×d_k)
⬆ returns = same shape, all values > 0
22return F.elu(x) + 1

F.elu(x) returns x if x≥0, else e^x - 1. Adding 1 gives: x+1 if x≥0, e^x if x<0. Always positive.

EXECUTION STATE
F.elu(x) = x if x≥0, e^x - 1 if x<0
F.elu(x) + 1 = x+1 if x≥0, e^x if x<0 (always > 0)
24def forward(self, x, causal=False) → torch.Tensor

Main entry point. Projects input x into Q, K, V, applies the feature map, then uses the associativity trick (or causal form).

EXECUTION STATE
⬇ input: x = shape (B, N, d_model) — batched input embeddings
⬇ input: causal = False = bidirectional, True = autoregressive
⬆ returns = shape (B, N, d_v) — context-enriched output
33Q = self._phi(self.W_q(x))

Project input to queries and apply feature map in one step. Feature map is applied AFTER projection (the learned weights shape the feature space).

EXECUTION STATE
self.W_q(x) = shape (B, N, d_k) — raw queries
Q = φ(raw queries) = shape (B, N, d_k) — all positive
34K = self._phi(self.W_k(x))

Project input to keys and apply feature map.

EXECUTION STATE
K = φ(raw keys) = shape (B, N, d_k) — all positive
35V = self.W_v(x)

Project input to values. NO feature map on values — they can be any real number.

EXECUTION STATE
V = shape (B, N, d_v) — unrestricted
41KV = torch.einsum('bnd,bnv->bdv', K, V)

Einstein summation: contract over n (sequence length) to get a compact d_k×d_v matrix per batch. This is the batched version of Kᵀ @ V. The n dimension is summed out, so the result is independent of sequence length!

EXECUTION STATE
einsum 'bnd,bnv->bdv' = for each batch b: KV[d,v] = Σ_n K[n,d] * V[n,v]
KV shape = (B, d_k, d_v) — e.g. (B, 64, 64) = 4096 entries regardless of N!
42Z = K.sum(dim=1)

Sum all keys along the sequence dimension. This d_k-dimensional vector is used for the denominator.

EXECUTION STATE
dim=1 = sum along dimension 1 (the N/sequence dimension)
Z shape = (B, d_k)
44num = torch.einsum('bnd,bdv->bnv', Q, KV)

Multiply each query with the KV summary. This is the batched version of φ(Q) @ KV.

EXECUTION STATE
einsum 'bnd,bdv->bnv' = for each batch b, token n: num[n,v] = Σ_d Q[n,d] * KV[d,v]
num shape = (B, N, d_v)
45denom = torch.einsum('bnd,bd->bn', Q, Z)

Dot product of each query row with the summed keys. One scalar per token.

EXECUTION STATE
einsum 'bnd,bd->bn' = for each batch b, token n: denom[n] = Σ_d Q[n,d] * Z[d]
denom shape = (B, N)
46denom = denom.clamp(min=1e-6)

Clamp minimum to prevent division by zero. .clamp() is PyTorch’s equivalent of np.maximum().

EXECUTION STATE
.clamp(min=1e-6) = replace any value < 1e-6 with 1e-6
48return num / denom.unsqueeze(-1)

Divide numerator by denominator with broadcasting. unsqueeze(-1) adds a trailing dimension.

EXECUTION STATE
.unsqueeze(-1) = reshape (B,N) → (B,N,1) for broadcasting with (B,N,d_v)
⬆ return = shape (B, N, d_v) — the final output
50def _causal_forward(self, Q, K, V) → torch.Tensor

Autoregressive version using cumulative sums. Each token only attends to itself and past tokens. Uses the recurrent form: S[i] = S[i-1] + φ(K[i]) ⊗ V[i].

EXECUTION STATE
⬇ input: Q, K, V = already φ-mapped Q and K, raw V. All shape (B, N, d)
⬆ returns = shape (B, N, d_v) — causal output
57KV_outer = torch.einsum('bnd,bnv->bndv', K, V)

Outer product of each key-value pair. Creates a d_k×d_v matrix for each token in each batch.

EXECUTION STATE
einsum 'bnd,bnv->bndv' = for each (b,n): KV_outer[d,v] = K[d] * V[v]
KV_outer shape = (B, N, d_k, d_v)
58S = KV_outer.cumsum(dim=1)

Cumulative sum along the sequence dimension. S[i] = Σ_{j≤i} φ(K[j]) ⊗ V[j]. This is the vectorized version of the for-loop recurrence.

EXECUTION STATE
.cumsum(dim=1) = running sum along dim 1 (sequence). S[n] = sum of KV_outer[0..n]
S shape = (B, N, d_k, d_v) — running KV state at each position
61z = K.cumsum(dim=1)

Cumulative key sum. z[i] = Σ_{j≤i} φ(K[j]). Used for the denominator at each position.

EXECUTION STATE
z shape = (B, N, d_k) — running key sum at each position
63num = torch.einsum('bnd,bndv->bnv', Q, S)

For each position n: dot product of φ(Q[n]) with the running state S[n].

EXECUTION STATE
num shape = (B, N, d_v)
64denom = torch.einsum('bnd,bnd->bn', Q, z)

For each position n: dot product of φ(Q[n]) with the running key sum z[n].

EXECUTION STATE
denom shape = (B, N)
67return num / denom.unsqueeze(-1)

Final causal output. Each token’s output is computed using only information from tokens at or before its position.

EXECUTION STATE
⬆ return = shape (B, N, d_v) — causal linear attention output
94 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class LinearAttention(nn.Module):
6    """
7    Linear Attention in PyTorch.
8
9    Uses ELU+1 feature map and the associativity trick
10    to achieve O(N * d^2) complexity.
11    """
12
13    def __init__(self, d_model: int, d_k: int, d_v: int):
14        super().__init__()
15        self.d_k = d_k
16        self.W_q = nn.Linear(d_model, d_k, bias=False)
17        self.W_k = nn.Linear(d_model, d_k, bias=False)
18        self.W_v = nn.Linear(d_model, d_v, bias=False)
19
20    def _phi(self, x: torch.Tensor) -> torch.Tensor:
21        """Feature map: ELU(x) + 1, always positive."""
22        return F.elu(x) + 1
23
24    def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor:
25        """
26        Args:
27            x:      Input embeddings (B, N, d_model)
28            causal: If True, use autoregressive form
29
30        Returns:
31            output: (B, N, d_v)
32        """
33        Q = self._phi(self.W_q(x))   # (B, N, d_k)
34        K = self._phi(self.W_k(x))   # (B, N, d_k)
35        V = self.W_v(x)              # (B, N, d_v)
36
37        if causal:
38            return self._causal_forward(Q, K, V)
39
40        # Non-causal: use the associativity trick
41        KV = torch.einsum('bnd,bnv->bdv', K, V)    # (B, d_k, d_v)
42        Z = K.sum(dim=1)                            # (B, d_k)
43
44        num = torch.einsum('bnd,bdv->bnv', Q, KV)  # (B, N, d_v)
45        denom = torch.einsum('bnd,bd->bn', Q, Z)    # (B, N)
46        denom = denom.clamp(min=1e-6)
47
48        return num / denom.unsqueeze(-1)
49
50    def _causal_forward(self, Q: torch.Tensor, K: torch.Tensor,
51                        V: torch.Tensor) -> torch.Tensor:
52        """Causal (autoregressive) linear attention via cumulative sum."""
53        B, N, d_k = Q.shape
54        d_v = V.shape[-1]
55
56        # Cumulative KV: S[i] = sum_{j<=i} K[j] outer V[j]
57        KV_outer = torch.einsum('bnd,bnv->bndv', K, V)  # (B, N, d_k, d_v)
58        S = KV_outer.cumsum(dim=1)                       # (B, N, d_k, d_v)
59
60        # Cumulative key sum: z[i] = sum_{j<=i} K[j]
61        z = K.cumsum(dim=1)                              # (B, N, d_k)
62
63        num = torch.einsum('bnd,bndv->bnv', Q, S)       # (B, N, d_v)
64        denom = torch.einsum('bnd,bnd->bn', Q, z)       # (B, N)
65        denom = denom.clamp(min=1e-6)
66
67        return num / denom.unsqueeze(-1)
68
69
70# ── Run with shared example ──
71tokens = ["The", "cat", "sat", "on", "mat"]
72
73Q_raw = torch.tensor([
74    [1.0, 0.0, 1.0, 0.0],
75    [0.0, 2.0, 0.0, 1.0],
76    [1.0, 1.0, 1.0, 0.0],
77    [0.0, 0.0, 1.0, 1.0],
78    [1.0, 0.0, 0.0, 1.0],
79])
80
81K_raw = torch.tensor([
82    [0.0, 1.0, 0.0, 1.0],
83    [1.0, 0.0, 1.0, 0.0],
84    [1.0, 1.0, 0.0, 0.0],
85    [0.0, 0.0, 1.0, 1.0],
86    [1.0, 0.0, 0.5, 0.5],
87])
88
89V_raw = torch.tensor([
90    [1.0, 0.0, 0.0, 0.0],
91    [0.0, 1.0, 0.0, 0.0],
92    [0.0, 0.0, 1.0, 0.0],
93    [0.0, 0.0, 0.0, 1.0],
94    [0.5, 0.5, 0.5, 0.5],
95])
96
97# Direct computation (bypass learned projections)
98def phi(x):
99    return F.elu(x) + 1
100
101Qf = phi(Q_raw)
102Kf = phi(K_raw)
103
104# Non-causal
105KV = Kf.T @ V_raw             # (4, 4)
106K_sum = Kf.sum(dim=0)         # (4,)
107num = Qf @ KV                 # (5, 4)
108denom = Qf @ K_sum            # (5,)
109output = num / denom.unsqueeze(-1)
110
111print("Output (non-causal):")
112print(output.round(decimals=4))
113
114# Causal
115S = torch.zeros(4, 4)
116z = torch.zeros(4)
117for i, t in enumerate(tokens):
118    S += torch.outer(Kf[i], V_raw[i])
119    z += Kf[i]
120    num_i = Qf[i] @ S
121    denom_i = (Qf[i] @ z).clamp(min=1e-6)
122    print(f"{t}: {(num_i / denom_i).round(decimals=4)}")

Key Takeaways

  1. The quadratic wall is real: At N=100,000N = 100{,}000, the attention matrix has 10 billion entries. This is the fundamental scalability limit that linear attention solves.
  2. The trick is algebraic, not approximate: Linear attention does not approximate softmax attention. It computes a different attention mechanism that happens to be computable in O(Nd2)O(N \cdot d^2) time by exploiting matrix associativity.
  3. The feature map replaces softmax: ϕ(x)=ELU(x)+1\phi(x) = \text{ELU}(x) + 1 maps all values to positive numbers, enabling the kernel trick that makes associative reordering valid.
  4. The key computation is ϕ(K)V\phi(K)^\top V: This d×dd \times d matrix summarizes all key-value interactions without building the N×NN \times N attention matrix.
  5. Flatter weights are the cost: Without the exponential amplification of softmax, linear attention produces more uniform weight distributions. This reduces selectivity but enables scalability.
  6. The recurrent form enables O(1) per-token inference: By maintaining a fixed-size state SRd×dS \in \mathbb{R}^{d \times d}, linear attention processes each new token in constant time — a property that standard attention with KV-cache cannot match.
  7. Linear attention spawned an entire research field: Mamba, RetNet, RWKV, and GLA all trace their lineage to the "Transformers are RNNs" insight. The principle of maintaining a compact, updateable state is the foundation of modern sub-quadratic architectures.

Exercises

Exercise 1: Verify the Associativity

Compute ϕ(Q)ϕ(K)\phi(Q) \cdot \phi(K)^\top (the 5×55 \times 5 matrix) and then multiply by V. Verify that you get the same numerator as the ϕ(Q)(ϕ(K)V)\phi(Q) \cdot (\phi(K)^\top V) computation above. Why is the first approach O(N2d)O(N^2 d) and the second O(Nd2)O(N d^2)? Count the multiply-add operations.

Exercise 2: Different Feature Maps

Replace ϕ(x)=ELU(x)+1\phi(x) = \text{ELU}(x) + 1 with ϕ(x)=ReLU(x)+ϵ\phi(x) = \text{ReLU}(x) + \epsilon (where ϵ=0.001\epsilon = 0.001). How do the attention weights change? What happens when many key values are negative? Why did the authors prefer ELU over ReLU?

Exercise 3: Causal vs Non-Causal

Run the causal (recurrent) form for all 5 tokens and compare the output for "mat" (the last token) with the non-causal output. They should match — why? Now compare the output for "The" (the first token) between causal and non-causal. Why are they different?

Exercise 4: Scaling to Long Sequences

Generate random Q, K, V matrices with N=10,000N = 10{,}000 and d=64d = 64. Time both the quadratic computation (ϕ(Q)ϕ(K)V\phi(Q) \cdot \phi(K)^\top \cdot V) and the linear computation (ϕ(Q)(ϕ(K)V)\phi(Q) \cdot (\phi(K)^\top \cdot V)). Does the measured speedup match the theoretical N/d=156×N/d = 156\times? Why might it differ?

Exercise 5: Memory Analysis

For a transformer with 32 heads and 40 layers processing N=32,768N = 32{,}768 tokens with d=128d = 128, calculate the total memory for: (a) Standard attention matrices, (b) Linear attention's KV matrices, (c) Linear attention's recurrent state. What is the memory reduction factor?


References

  1. Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020). "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention." ICML 2020.
  2. Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). "Attention Is All You Need." NeurIPS 2017.
  3. Gu, A., & Dao, T. (2023). "Mamba: Linear-Time Sequence Modeling with Selective State Spaces." arXiv:2312.00752.
  4. Sun, Y., Dong, L., Huang, S., et al. (2023). "Retentive Network: A Successor to Transformer for Large Language Models." arXiv:2307.08621.
  5. Peng, B., Alcaide, E., Anthony, Q., et al. (2023). "RWKV: Reinventing RNNs for the Transformer Era." EMNLP 2023.
  6. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022.
  7. Yang, S., Wang, B., Shen, Y., Panda, R., & Kim, Y. (2024). "Gated Linear Attention Transformers with Hardware-Efficient Training." ICML 2024.
Loading comments...