Chapter 14
25 min read
Section 15 of 17

Differential Attention

Differential Attention

Learning Objectives

By the end of this section you will be able to:

  1. Explain why standard softmax attention leaks weight to irrelevant tokens and how this degrades long-context retrieval, hallucination rates, and in-context learning.
  2. Derive the differential attention formula (softmax(Q1K1 ⁣/ ⁣d)λsoftmax(Q2K2 ⁣/ ⁣d))V\bigl(\text{softmax}(Q_1 K_1^\top\!/\!\sqrt{d}) - \lambda \cdot \text{softmax}(Q_2 K_2^\top\!/\!\sqrt{d})\bigr)V and explain the role of every symbol.
  3. Compute the full 5×5 differential attention matrix by hand for our shared sentence “The cat sat on mat”.
  4. Implement differential attention from scratch in NumPy and PyTorch, and compare outputs against standard attention.
  5. Connect differential attention to noise-cancelling headphones, differential amplifiers, and modern systems like Flash Attention and KV-cache optimization.

The Real Problem: Attention Noise

Every attention mechanism we have studied so far shares a fundamental weakness: the softmax function cannot output exact zeros. Given a score vector s=[s1,s2,,sN]s = [s_1, s_2, \ldots, s_N], softmax produces softmax(sj)=esj/kesk\text{softmax}(s_j) = e^{s_j} / \sum_k e^{s_k}. Since ex>0e^x > 0 for all xx, every token always receives some positive weight — no matter how irrelevant it is.

In short sequences this barely matters: the irrelevant weights are small. But at scale — 32K, 64K, 128K tokens — the noise accumulates. Consider a model trying to answer a question from a 100K-token document. The answer lives in 5 tokens, but softmax still sprinkles weight across the other 99,995. Each token contributes a tiny amount of noise to the output, and 99,995 tiny noises compound into a significant distortion.

This attention noise causes three practical failures:

  1. Information retrieval degrades: the model “forgets” specific facts buried in long contexts.
  2. Hallucination increases: when the correct answer gets diluted by noise, the model fills in plausible but wrong information.
  3. In-context learning becomes brittle: shuffling the order of few-shot examples changes which noise accumulates, causing high variance in output quality.
The core limitation: Softmax is structurally incapable of producing sparse attention. Even when the model has learned that a token is irrelevant, it cannot assign it exactly zero weight.

From Engineering to AI: Differential Signaling

The solution comes from an idea that electrical engineers perfected decades ago: differential signaling.

Consider noise-cancelling headphones. They have two microphones: one inside the ear cup (signal + noise) and one outside (mostly noise). The electronics subtract the outside signal from the inside signal. Any sound that appears in both microphones — ambient noise — cancels out. Only the music that is present in the inside mic but absent from the outside mic survives.

In October 2024, Tianzhu Ye, Li Dong, and colleagues at Microsoft Research and Tsinghua University applied this exact principle to transformer attention (Ye et al., 2024). Their key insight: run the same input through two separate softmax attention maps using different projections of the query and key vectors. Both maps will capture the “noise floor” — the unavoidable weight softmax gives to irrelevant tokens. But the genuine signal will differ between the two maps because it depends on the specific projection.

When we subtract one map from the other:

  • Common-mode noise cancels: weight that both maps assign to irrelevant tokens gets subtracted away.
  • Signal survives: weight that only Map 1 assigns (and Map 2 does not) remains large after subtraction.

The result is sparser, sharper attention — without modifying the softmax function itself. The 6.8B-parameter Diff Transformer matches the performance of an 11B standard Transformer, requiring only 62.2% of the parameters (Ye et al., 2024).


The Mathematical Definition

Given input Q,KRN×dQ, K \in \mathbb{R}^{N \times d} and VRN×dvV \in \mathbb{R}^{N \times d_v}, differential attention is:

DiffAttn(Q,K,V)=(softmax ⁣(Q1K1d)λsoftmax ⁣(Q2K2d))V\text{DiffAttn}(Q, K, V) = \Bigl(\text{softmax}\!\bigl(\tfrac{Q_1 K_1^\top}{\sqrt{d}}\bigr) - \lambda \cdot \text{softmax}\!\bigl(\tfrac{Q_2 K_2^\top}{\sqrt{d}}\bigr)\Bigr) V

where:

  • [Q1;Q2]=Q[Q_1 ; Q_2] = Q — split QQ along the feature dimension into two halves, each RN×d/2\in \mathbb{R}^{N \times d/2}
  • [K1;K2]=K[K_1 ; K_2] = K — same split applied to keys
  • λ\lambda — a learnable scalar that controls the cancellation strength
  • dd — the half dimension (each sub-attention operates on d/2d/2 features)

After subtraction, negative values are clamped to zero and each row is renormalized to sum to 1:

Dij=max(A1,ijλA2,ij,  0),Adiff,i=DijDijD_{ij} = \max\bigl(A_{1,ij} - \lambda \cdot A_{2,ij},\; 0\bigr), \quad A_{\text{diff},i} = \frac{D_{i}}{\sum_j D_{ij}}

Symbol-by-Symbol Breakdown

SymbolShapeWhat It Represents
Q, KN × dFull query and key matrices (d=4 in our example)
Q₁, Q₂N × d/2First and second halves of Q (d/2 = 2)
K₁, K₂N × d/2First and second halves of K
VN × dᵥValue matrix (NOT split)
λscalarLearnable cancellation strength (0.4 in our demo)
A₁N × NFirst softmax attention map
A₂N × NSecond softmax attention map
DN × NRaw differential: A₁ − λ·A₂, clamped ≥ 0
A_diffN × NRenormalized differential attention weights

The Lambda Parameter

The scalar λ\lambda controls how aggressively Map 2 is subtracted from Map 1.

  • λ=0\lambda = 0: no cancellation — identical to standard attention on half the dimensions.
  • λ=1\lambda = 1: maximum cancellation — if both maps agree, the weight goes to zero.
  • 0<λ<10 < \lambda < 1: partial cancellation — the paper uses λinit0.4\lambda_{\text{init}} \approx 0.4 and lets the model learn the optimal value per layer.

Lambda Reparameterization

In the full Diff Transformer architecture, λ\lambda is not a single scalar but is reparameterized through learnable vectors for stable training:

λ=exp(λq1λk1)exp(λq2λk2)+λinit\lambda = \exp(\lambda_{q_1} \cdot \lambda_{k_1}) - \exp(\lambda_{q_2} \cdot \lambda_{k_2}) + \lambda_{\text{init}}

where λq1,λk1,λq2,λk2Rd/2\lambda_{q_1}, \lambda_{k_1}, \lambda_{q_2}, \lambda_{k_2} \in \mathbb{R}^{d/2} are learnable vectors. The initialization depends on layer index l[1,L]l \in [1, L]:

λinit=0.80.6exp(0.3(l1))\lambda_{\text{init}} = 0.8 - 0.6 \cdot \exp(-0.3 \cdot (l - 1))

Early layers (l=1l = 1) get λinit0.2\lambda_{\text{init}} \approx 0.2 (mild cancellation), while later layers get λinit0.8\lambda_{\text{init}} \to 0.8 (strong cancellation). This matches the intuition that early layers need broad context while later layers need precise retrieval.

Each head's output is then scaled: hˉi=(1λinit)LN(hi)\bar{h}_i = (1 - \lambda_{\text{init}}) \cdot \text{LN}(h_i), where LN is GroupNorm. The factor (1λinit)(1 - \lambda_{\text{init}}) aligns gradient magnitudes with standard transformers, ensuring training stability and enabling hyperparameter reuse.


Step-by-Step Calculation

Let us trace the full differential attention computation for our shared sentence “The cat sat on mat” with λ=0.4\lambda = 0.4. We walk through all four steps for every token.

Step 1: Split Q and K into Two Halves

Split QQ (5×4) into Q1Q_1 = columns [0,1] and Q2Q_2 = columns [2,3]:

d₀d₁
The1.00.0
cat0.02.0
sat1.01.0
on0.00.0
mat1.00.0

Q1Q_1 (above) and Q2Q_2 (below):

d₂d₃
The1.00.0
cat0.01.0
sat1.00.0
on1.01.0
mat0.01.0

Similarly, K1K_1 = K[:,0:2] and K2K_2 = K[:,2:4].

Step 2: Two Separate Attention Maps

Map 1: A1=softmax(Q1K1/2)A_1 = \text{softmax}(Q_1 K_1^\top / \sqrt{2})

Thecatsatonmat
The0.12370.25090.25090.12370.2509
cat0.36640.08910.36640.08910.0891
sat0.18110.18110.36730.08930.1811
on0.20000.20000.20000.20000.2000
mat0.12370.25090.25090.12370.2509

Map 2: A2=softmax(Q2K2/2)A_2 = \text{softmax}(Q_2 K_2^\top / \sqrt{2})

Thecatsatonmat
The0.13370.27110.13370.27110.1904
cat0.27110.13370.13370.27110.1904
sat0.13370.27110.13370.27110.1904
on0.18110.18110.08930.36730.1811
mat0.27110.13370.13370.27110.1904

Notice that both maps spread weight across all tokens — neither produces any zeros. This is the fundamental softmax limitation. But they disagree on which tokens are important: for the “cat” query, A1A_1 gives 36.6% to “The” while A2A_2 gives 27.1%. The differential will amplify this disagreement.

Step 3: Differential Map

Compute D=A10.4A2D = A_1 - 0.4 \cdot A_2, showing the “cat” row (row 1) element by element:

KeyA₁[cat,j]0.4·A₂[cat,j]D[cat,j]Clamped
The0.36640.1084+0.25790.2579
cat0.08910.0535+0.03560.0356
sat0.36640.0535+0.31290.3129
on0.08910.1084−0.01940.0000 ←
mat0.08910.0762+0.01290.0129

The arrow marks the critical event: “on” receives a negative differential because A2A_2 weighted it more than A1A_1 did. This means both maps agree it is noise, so clamping it to zero eliminates it entirely.

After clamping, the row sum is 0.6194. Renormalizing gives:

Adiff[cat]=[0.4164,  0.0575,  0.5052,  0.0000,  0.0209]A_{\text{diff}}[\text{cat}] = [0.4164,\; 0.0575,\; 0.5052,\; 0.0000,\; 0.0209]
The noise cancellation in action: “sat” now receives 50.5% of cat's attention (vs 36.6% in A1A_1 and 13.4% in A2A_2). “on” drops from 8.9%/27.1% to exactly 0%. The signal was amplified; the noise was eliminated.

The same noise cancellation occurs in other rows. For “sat” (row 2), the “on” token is also clamped to zero, and “sat” receives 50.7% of its own attention.

Step 4: Output

The final output is output=AdiffV\text{output} = A_{\text{diff}} \cdot V. Each row is a weighted combination of value vectors using noise-cancelled weights. For “cat”:

ocat=0.4164VThe+0.0575Vcat+0.5052Vsat+0Von+0.0209Vmato_{\text{cat}} = 0.4164 \cdot V_{\text{The}} + 0.0575 \cdot V_{\text{cat}} + 0.5052 \cdot V_{\text{sat}} + 0 \cdot V_{\text{on}} + 0.0209 \cdot V_{\text{mat}}
=[0.4269,  0.0679,  0.5156,  0.0104]= [0.4269,\; 0.0679,\; 0.5156,\; 0.0104]

Notice that “on”'s value vector contributes exactly zero to cat's output. In standard attention this is impossible.


Full Attention Weight Comparison

Differential Attention Weights (5×55 \times 5)

Thecatsatonmat
The0.11700.23740.32900.02540.2912
cat0.41640.05750.50520.00000.0209
sat0.20620.11740.50690.00000.1695
on0.21260.21260.27380.08840.2126
mat0.02540.32900.32900.02540.2912

Standard Attention Weights (from Chapter 1)

Thecatsatonmat
The0.10950.29760.18050.18050.2318
cat0.40260.08980.24420.14810.1153
sat0.15190.25050.25050.15190.1951
on0.19030.19030.11540.31370.1903
mat0.18920.18920.18920.18920.2430

Key differences:

  • Sparsity: Differential attention has two exact zeros (cat→on and sat→on). Standard attention has no zeros anywhere.
  • Sharpness: The maximum weight in differential attention is 0.5069 (sat attending to sat), versus 0.4026 (cat attending to The) in standard attention.
  • Signal amplification: cat→sat jumps from 0.2442 (standard) to 0.5052 (differential) — a 2.07× increase.

Output Comparison (5×45 \times 4)

TokenDiff d₀Diff d₁Diff d₂Diff d₃Std d₀Std d₁Std d₂Std d₃
The0.26260.38300.47460.17100.22540.41350.29640.2964
cat0.42690.06790.51560.01040.46020.14750.30180.2058
sat0.29090.20210.59170.08480.24950.34810.34810.2495
on0.31890.31890.38010.19470.28540.28540.21060.4089
mat0.17100.47460.47460.17100.31080.31080.31080.3108

The differential output vectors are more polarized. For “cat”, dimension 3 (the “on” direction) drops from 0.2058 to 0.0104, while dimension 2 (“sat”) increases from 0.3018 to 0.5156. The representation is more decisive.


Interactive: Differential Attention Explorer

Drag the λ\lambda slider to see how noise cancellation strengthens or weakens in real time. Click any row to inspect that token's attention distribution across all four views: Map 1, Map 2, Differential, and Standard.

Loading differential attention visualizer…

Python (NumPy) Implementation

A complete, runnable implementation of differential attention. Every line is annotated with the exact values computed for our shared example. Click any line to see its execution state.

Differential Attention — NumPy Implementation
🐍differential_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. Q1 @ K1.T executes as optimized C code, not Python loops.

2import math

Python’s math module gives us sqrt() for the scaling factor √d.

EXAMPLE
math.sqrt(2) → 1.4142
4class DifferentialAttention

Encapsulates the full differential attention pipeline: split → two softmax maps → subtract → clamp → renormalize → weighted sum. One instance handles any (N, d_k) input.

10def __init__(self, d_k, lam) — constructor

Initialize with full key dimension d_k=4 and lambda=0.4. Each sub-attention uses d_k/2 = 2 dimensions.

EXECUTION STATE
⬇ input: d_k = 4 (full key dimension)
⬇ input: lam = 0.4 (noise cancellation strength)
11self.d_half = d_k // 2

Split dimension in half. Q and K will each be sliced into two (N, 2) matrices.

EXECUTION STATE
d_k // 2 = 4 // 2 = 2
self.d_half = 2
12self.scale = math.sqrt(self.d_half)

Scaling factor √d_half = √2 ≈ 1.4142. Prevents dot products from growing too large before softmax.

EXECUTION STATE
math.sqrt(2) = 1.4142
self.scale = 1.4142
13self.lam = lam

Store lambda. In training this would be a learnable parameter; here fixed at 0.4 for demonstration.

EXECUTION STATE
self.lam = 0.4
15def _softmax(self, x) → np.ndarray

Numerically stable row-wise softmax. Subtracts row max before exp() to prevent overflow. Called twice: once for S1, once for S2.

EXECUTION STATE
⬇ input: x (5×5) = score matrix (e.g. S1 on first call)
⬆ returns = np.ndarray (5, 5) — each row sums to 1.0
16x_shifted = x - np.max(x, axis=-1, keepdims=True)

Subtract each row’s maximum. This prevents exp() overflow while producing identical softmax output. Showing values for S1 (first call):

EXECUTION STATE
axis=-1 = operate along LAST axis — find the max within each row independently
keepdims=True = return shape (5,1) not (5,) so broadcasting x(5×5) − max(5×1) works
── Row 0 (The) ── =
x = [0.0000, 0.7071, 0.7071, 0.0000, 0.7071]
max(x) = 0.7071
x_shifted = [-0.7071, 0.0000, 0.0000, -0.7071, 0.0000]
── Row 1 (cat) ── =
x = [1.4142, 0.0000, 1.4142, 0.0000, 0.0000]
max(x) = 1.4142
x_shifted = [0.0000, -1.4142, 0.0000, -1.4142, -1.4142]
── Row 2 (sat) ── =
x = [0.7071, 0.7071, 1.4142, 0.0000, 0.7071]
max(x) = 1.4142
x_shifted = [-0.7071, -0.7071, 0.0000, -1.4142, -0.7071]
── Row 3 (on) ── =
x = [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
max(x) = 0.0000
x_shifted = [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
── Row 4 (mat) ── =
x = [0.0000, 0.7071, 0.7071, 0.0000, 0.7071]
max(x) = 0.7071
x_shifted = [-0.7071, 0.0000, 0.0000, -0.7071, 0.0000]
17e = np.exp(x_shifted)

Exponentiate each shifted value. Negative shifted values become fractions < 1; zero stays at 1.0.

EXECUTION STATE
── Row 0 (The) ── =
exp(x_shifted) = [0.4931, 1.0000, 1.0000, 0.4931, 1.0000]
── Row 1 (cat) ── =
exp(x_shifted) = [1.0000, 0.2431, 1.0000, 0.2431, 0.2431]
── Row 2 (sat) ── =
exp(x_shifted) = [0.4931, 0.4931, 1.0000, 0.2431, 0.4931]
── Row 3 (on) ── =
exp(x_shifted) = [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]
── Row 4 (mat) ── =
exp(x_shifted) = [0.4931, 1.0000, 1.0000, 0.4931, 1.0000]
18return e / np.sum(e, axis=-1, keepdims=True)

Divide each row by its sum to get a probability distribution. Every row now sums to 1.0.

EXECUTION STATE
axis=-1 = sum each row independently
keepdims=True = shape (5,1) for broadcasting e(5×5) / sum(5×1)
── Row 0 (The) ── =
sum = 3.9861
⬆ return = [0.1237, 0.2509, 0.2509, 0.1237, 0.2509]
── Row 1 (cat) ── =
sum = 2.7294
⬆ return = [0.3664, 0.0891, 0.3664, 0.0891, 0.0891]
── Row 2 (sat) ── =
sum = 2.7223
⬆ return = [0.1811, 0.1811, 0.3673, 0.0893, 0.1811]
── Row 3 (on) ── =
sum = 5.0000
⬆ return = [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]
── Row 4 (mat) ── =
sum = 3.9861
⬆ return = [0.1237, 0.2509, 0.2509, 0.1237, 0.2509]
20def forward(self, Q, K, V) → np.ndarray

Main pipeline: split Q/K → two attention maps → subtract → clamp → renormalize → weighted sum of V.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = np.ndarray (5, 4) — context-aware representation per token
23Q1, Q2 = Q[:, :d_half], Q[:, d_half:]

Split Q along the feature dimension into two halves. Q1 gets dims [0,1], Q2 gets dims [2,3]. These drive two independent attention maps.

EXECUTION STATE
Q[:, :2] → Q1 (5×2) =
      d0   d1
The  1.0  0.0
cat  0.0  2.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
Q[:, 2:] → Q2 (5×2) =
      d2   d3
The  1.0  0.0
cat  0.0  1.0
sat  1.0  0.0
on   1.0  1.0
mat  0.0  1.0
24K1, K2 = K[:, :d_half], K[:, d_half:]

Split K the same way. K1 and K2 are the keys for the two independent attention computations.

EXECUTION STATE
K[:, :2] → K1 (5×2) =
      d0   d1
The  0.0  1.0
cat  1.0  0.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
K[:, 2:] → K2 (5×2) =
      d2   d3
The  0.0  1.0
cat  1.0  0.0
sat  0.0  0.0
on   1.0  1.0
mat  0.5  0.5
27S1 = Q1 @ K1.T / self.scale

Compute scaled dot-product scores for the first sub-attention. Q1(5×2) × K1ᵀ(2×5) → (5×5), then divide by √2.

EXECUTION STATE
@ = matrix multiply operator
K1.T (2×5) = K1 transposed so columns become rows
self.scale = 1.4142 (√2)
S1 (5×5) =
        The     cat     sat      on     mat
The  0.0000  0.7071  0.7071  0.0000  0.7071
cat  1.4142  0.0000  1.4142  0.0000  0.0000
sat  0.7071  0.7071  1.4142  0.0000  0.7071
on   0.0000  0.0000  0.0000  0.0000  0.0000
mat  0.0000  0.7071  0.7071  0.0000  0.7071
28A1 = self._softmax(S1)

Apply row-wise softmax to S1. This is the first attention map — it captures one view of token relevance.

EXECUTION STATE
A1 (5×5) =
        The     cat     sat      on     mat
The  0.1237  0.2509  0.2509  0.1237  0.2509
cat  0.3664  0.0891  0.3664  0.0891  0.0891
sat  0.1811  0.1811  0.3673  0.0893  0.1811
on   0.2000  0.2000  0.2000  0.2000  0.2000
mat  0.1237  0.2509  0.2509  0.1237  0.2509
30S2 = Q2 @ K2.T / self.scale

Same operation on the second half. Q2(5×2) × K2ᵀ(2×5) → (5×5), scaled by √2. Different dims capture a different similarity signal.

EXECUTION STATE
S2 (5×5) =
        The     cat     sat      on     mat
The  0.0000  0.7071  0.0000  0.7071  0.3536
cat  0.7071  0.0000  0.0000  0.7071  0.3536
sat  0.0000  0.7071  0.0000  0.7071  0.3536
on   0.7071  0.7071  0.0000  1.4142  0.7071
mat  0.7071  0.0000  0.0000  0.7071  0.3536
31A2 = self._softmax(S2)

Second attention map. Compare A2 with A1 — they agree on some tokens (noise) and disagree on others (signal).

EXECUTION STATE
A2 (5×5) =
        The     cat     sat      on     mat
The  0.1337  0.2711  0.1337  0.2711  0.1904
cat  0.2711  0.1337  0.1337  0.2711  0.1904
sat  0.1337  0.2711  0.1337  0.2711  0.1904
on   0.1811  0.1811  0.0893  0.3673  0.1811
mat  0.2711  0.1337  0.1337  0.2711  0.1904
34diff = A1 - self.lam * A2

The core differential step. Subtract λ·A2 from A1. Where both maps agree (common-mode noise), the value shrinks. Where A1 is large but A2 is small, the value stays large — genuine signal survives.

EXECUTION STATE
self.lam = 0.4
── Row 0 (The) ── =
A1[0] = [0.1237, 0.2509, 0.2509, 0.1237, 0.2509]
0.4·A2[0] = [0.0535, 0.1084, 0.0535, 0.1084, 0.0762]
diff[0] = [+0.0702, +0.1424, +0.1974, +0.0152, +0.1747]
── Row 1 (cat) ── =
A1[1] = [0.3664, 0.0891, 0.3664, 0.0891, 0.0891]
0.4·A2[1] = [0.1084, 0.0535, 0.0535, 0.1084, 0.0762]
diff[1] = [+0.2579, +0.0356, +0.3129, −0.0194, +0.0129]
── Row 2 (sat) ── =
diff[2] = [+0.1276, +0.0727, +0.3139, −0.0191, +0.1050]
── Row 3 (on) ── =
diff[3] = [+0.1276, +0.1276, +0.1643, +0.0531, +0.1276]
── Row 4 (mat) ── =
diff[4] = [+0.0152, +0.1974, +0.1974, +0.0152, +0.1747]
⚠ Negatives = cat→on = −0.0194, sat→on = −0.0191 (noise detected!)
35diff = np.maximum(diff, 0)

Clamp negative values to zero. A negative differential means the second map weighted that token MORE than the first — pure noise. Zeroing it out eliminates that noise completely.

EXECUTION STATE
diff (5×5) after clamp =
        The     cat     sat      on     mat
The  0.0702  0.1424  0.1974  0.0152  0.1747
cat  0.2579  0.0356  0.3129  0.0000  0.0129
sat  0.1276  0.0727  0.3139  0.0000  0.1050
on   0.1276  0.1276  0.1643  0.0531  0.1276
mat  0.0152  0.1974  0.1974  0.0152  0.1747
zeroed entries = cat→on and sat→on are now 0 (noise fully cancelled)
36row_sums = diff.sum(axis=-1, keepdims=True)

Sum each row. Since we subtracted and clamped, row sums are less than 1.0. We need to renormalize.

EXECUTION STATE
axis=-1 = sum along last axis (within each row)
keepdims=True = shape (5,1) for broadcasting
row_sums = The=0.6000, cat=0.6194, sat=0.6191, on=0.6000, mat=0.6000
37row_sums = np.where(row_sums == 0, 1, row_sums)

Safety guard: if every entry in a row got clamped to zero, the sum would be 0 and division would produce NaN. Replace 0 with 1 to prevent this. No rows are zero in our example.

EXECUTION STATE
np.where(condition, if_true, if_false) = element-wise conditional
result = no change — all sums > 0
38A_diff = diff / row_sums

Renormalize each row to sum to 1.0. This is the final differential attention weight matrix.

EXECUTION STATE
A_diff (5×5) =
        The     cat     sat      on     mat
The  0.1170  0.2374  0.3290  0.0254  0.2912
cat  0.4164  0.0575  0.5052  0.0000  0.0209
sat  0.2062  0.1174  0.5069  0.0000  0.1695
on   0.2126  0.2126  0.2738  0.0884  0.2126
mat  0.0254  0.3290  0.3290  0.0254  0.2912
41return A_diff @ V

Multiply differential attention weights by the value matrix. Each output row is a weighted combination of value vectors, with noise-cancelled weights.

EXECUTION STATE
⬆ return: A_diff @ V (5×4) =
        d0      d1      d2      d3
The  0.2626  0.3830  0.4746  0.1710
cat  0.4269  0.0679  0.5156  0.0104
sat  0.2909  0.2021  0.5917  0.0848
on   0.3189  0.3189  0.3801  0.1947
mat  0.1710  0.4746  0.4746  0.1710
45tokens = ["The", "cat", "sat", "on", "mat"]

Our shared 5-token sentence used across all 15 chapters for consistent comparison.

47Q = np.array([...]) — Query matrix

5×4 query matrix. Dims [0,1] will drive A1, dims [2,3] will drive A2.

EXECUTION STATE
Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
54K = np.array([...]) — Key matrix

5×4 key matrix. Same split: dims [0,1] pair with Q1, dims [2,3] pair with Q2.

EXECUTION STATE
K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
61V = np.array([...]) — Value matrix

5×4 value matrix. Near-identity with mat=[0.5,...] as a blended token. V is NOT split — the full 4-dim values are weighted by A_diff.

EXECUTION STATE
V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
69attn = DifferentialAttention(d_k=4, lam=0.4)

Create instance. d_k=4 means each sub-attention uses 2 dimensions. λ=0.4 is the paper’s default.

EXECUTION STATE
attn.d_half = 2
attn.scale = 1.4142
attn.lam = 0.4
70output = attn.forward(Q, K, V)

Run the full pipeline. Internally: split → S1,A1 → S2,A2 → diff → clamp → renorm → A_diff @ V.

EXECUTION STATE
output (5×4) =
        d0      d1      d2      d3
The  0.2626  0.3830  0.4746  0.1710
cat  0.4269  0.0679  0.5156  0.0104
sat  0.2909  0.2021  0.5917  0.0848
on   0.3189  0.3189  0.3801  0.1947
mat  0.1710  0.4746  0.4746  0.1710
72for i, t in enumerate(tokens): — print output

Loop through all 5 tokens and print their output vectors.

LOOP TRACE · 5 iterations
i=0, t='The'
output[0] = [0.2626, 0.3830, 0.4746, 0.1710]
i=1, t='cat'
output[1] = [0.4269, 0.0679, 0.5156, 0.0104]
i=2, t='sat'
output[2] = [0.2909, 0.2021, 0.5917, 0.0848]
i=3, t='on'
output[3] = [0.3189, 0.3189, 0.3801, 0.1947]
i=4, t='mat'
output[4] = [0.1710, 0.4746, 0.4746, 0.1710]
42 lines without explanation
1import numpy as np
2import math
3
4class DifferentialAttention:
5    """
6    Differential Attention (Ye et al., 2024).
7    Computes two softmax maps and subtracts to cancel noise.
8    """
9
10    def __init__(self, d_k: int, lam: float = 0.4):
11        self.d_half = d_k // 2
12        self.scale = math.sqrt(self.d_half)
13        self.lam = lam
14
15    def _softmax(self, x: np.ndarray) -> np.ndarray:
16        x_shifted = x - np.max(x, axis=-1, keepdims=True)
17        e = np.exp(x_shifted)
18        return e / np.sum(e, axis=-1, keepdims=True)
19
20    def forward(self, Q: np.ndarray, K: np.ndarray,
21                V: np.ndarray) -> np.ndarray:
22        # Step 1: Split Q and K into two halves
23        Q1, Q2 = Q[:, :self.d_half], Q[:, self.d_half:]
24        K1, K2 = K[:, :self.d_half], K[:, self.d_half:]
25
26        # Step 2: Two separate attention maps
27        S1 = Q1 @ K1.T / self.scale
28        A1 = self._softmax(S1)
29
30        S2 = Q2 @ K2.T / self.scale
31        A2 = self._softmax(S2)
32
33        # Step 3: Differential — subtract, clamp, renormalize
34        diff = A1 - self.lam * A2
35        diff = np.maximum(diff, 0)
36        row_sums = diff.sum(axis=-1, keepdims=True)
37        row_sums = np.where(row_sums == 0, 1, row_sums)
38        A_diff = diff / row_sums
39
40        # Step 4: Weighted sum of values
41        return A_diff @ V
42
43
44# === Shared example: "The cat sat on mat" ===
45tokens = ["The", "cat", "sat", "on", "mat"]
46
47Q = np.array([
48    [1.0, 0.0, 1.0, 0.0],   # The
49    [0.0, 2.0, 0.0, 1.0],   # cat
50    [1.0, 1.0, 1.0, 0.0],   # sat
51    [0.0, 0.0, 1.0, 1.0],   # on
52    [1.0, 0.0, 0.0, 1.0],   # mat
53])
54K = np.array([
55    [0.0, 1.0, 0.0, 1.0],   # The
56    [1.0, 0.0, 1.0, 0.0],   # cat
57    [1.0, 1.0, 0.0, 0.0],   # sat
58    [0.0, 0.0, 1.0, 1.0],   # on
59    [1.0, 0.0, 0.5, 0.5],   # mat
60])
61V = np.array([
62    [1.0, 0.0, 0.0, 0.0],   # The
63    [0.0, 1.0, 0.0, 0.0],   # cat
64    [0.0, 0.0, 1.0, 0.0],   # sat
65    [0.0, 0.0, 0.0, 1.0],   # on
66    [0.5, 0.5, 0.5, 0.5],   # mat
67])
68
69attn = DifferentialAttention(d_k=4, lam=0.4)
70output = attn.forward(Q, K, V)
71
72for i, t in enumerate(tokens):
73    print(f"{t}: [{', '.join(f'{v:.4f}' for v in output[i])}]")

PyTorch Implementation

The PyTorch version makes λ\lambda a learnable nn.Parameter. During training, backpropagation adjusts λ\lambda alongside the projection weights, learning the optimal cancellation strength per layer.

Differential Attention — PyTorch Implementation
🐍differential_attention_pytorch.py
1import torch

PyTorch tensor library with GPU acceleration and automatic differentiation.

2import torch.nn as nn

Neural network module. nn.Module is the base class; nn.Parameter makes tensors learnable.

3import torch.nn.functional as F

Stateless functions: F.softmax computes softmax without storing weights.

4import math

For math.sqrt(2) = 1.4142 scaling factor.

6class DifferentialAttentionPT(nn.Module)

PyTorch module. Key difference from NumPy: self.lam is an nn.Parameter so gradients flow through it during training.

12def __init__(self, d_k, lam_init)

Constructor. lam_init=0.4 becomes the initial value for the learnable λ parameter.

EXECUTION STATE
⬇ input: d_k = 4
⬇ input: lam_init = 0.4
16self.lam = nn.Parameter(torch.tensor(lam_init))

nn.Parameter wraps a tensor so PyTorch’s optimizer can update it via backprop. During training, λ learns the optimal noise cancellation strength for each layer.

EXECUTION STATE
self.lam = Parameter(tensor(0.4000), requires_grad=True)
18def forward(self, Q, K, V) → torch.Tensor

Forward pass. Same algorithm as NumPy but uses torch operations for GPU support and autograd.

EXECUTION STATE
⬇ input: Q = torch.Tensor shape (5, 4)
⬇ input: K = torch.Tensor shape (5, 4)
⬇ input: V = torch.Tensor shape (5, 4)
20Q1 = Q[..., :self.d_half]

Ellipsis ... handles arbitrary leading batch dimensions. For shape (5,4), Q[..., :2] = Q[:, :2]. For batched (B,N,4), it slices only the last dim.

EXECUTION STATE
... = matches any number of leading dims — enables batched inference
Q1 shape = (5, 2)
25A1 = F.softmax(Q1 @ K1.transpose(-2, -1) / self.scale, dim=-1)

All in one line: dot product, scale, softmax. transpose(-2, -1) swaps the last two dims (works for any batch shape).

EXECUTION STATE
.transpose(-2, -1) = swap last two dims: (5,2) → (2,5)
dim=-1 = softmax along last axis (each row independently)
A1 shape = (5, 5)
28diff = A1 - self.lam * A2

Differential step. Since self.lam is an nn.Parameter, PyTorch tracks this operation in the computation graph for backpropagation.

EXECUTION STATE
self.lam = 0.4 (learnable)
diff shape = (5, 5)
29diff = torch.clamp(diff, min=0)

Clamp negatives to zero. torch.clamp is differentiable: gradient is 1 where diff > 0, and 0 where diff ≤ 0 (like ReLU).

35return A_diff @ V

Final weighted sum. Output shape (5, 4) — identical to the NumPy result.

EXECUTION STATE
⬆ return shape = (5, 4)
⬆ return: output =
        d0      d1      d2      d3
The  0.2626  0.3830  0.4746  0.1710
cat  0.4269  0.0679  0.5156  0.0104
sat  0.2909  0.2021  0.5917  0.0848
on   0.3189  0.3189  0.3801  0.1947
mat  0.1710  0.4746  0.4746  0.1710
63model = DifferentialAttentionPT(d_k=4, lam_init=0.4)

Instantiate. model.lam is a learnable Parameter; in this demo we use torch.no_grad() so it stays at 0.4.

64with torch.no_grad():

Disable gradient tracking for inference. Saves memory and compute since we are not training.

67for i, t in enumerate(tokens): — print output

Print results. Identical to NumPy output, confirming both implementations match.

LOOP TRACE · 5 iterations
i=0, t='The'
output[0] = [0.2626, 0.3830, 0.4746, 0.1710]
i=1, t='cat'
output[1] = [0.4269, 0.0679, 0.5156, 0.0104]
i=2, t='sat'
output[2] = [0.2909, 0.2021, 0.5917, 0.0848]
i=3, t='on'
output[3] = [0.3189, 0.3189, 0.3801, 0.1947]
i=4, t='mat'
output[4] = [0.1710, 0.4746, 0.4746, 0.1710]
52 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class DifferentialAttentionPT(nn.Module):
7    """
8    PyTorch Differential Attention (Ye et al., 2024).
9    Lambda is a learnable nn.Parameter updated by backprop.
10    """
11
12    def __init__(self, d_k: int, lam_init: float = 0.4):
13        super().__init__()
14        self.d_half = d_k // 2
15        self.scale = math.sqrt(self.d_half)
16        self.lam = nn.Parameter(torch.tensor(lam_init))
17
18    def forward(self, Q: torch.Tensor, K: torch.Tensor,
19                V: torch.Tensor) -> torch.Tensor:
20        Q1 = Q[..., :self.d_half]
21        Q2 = Q[..., self.d_half:]
22        K1 = K[..., :self.d_half]
23        K2 = K[..., self.d_half:]
24
25        A1 = F.softmax(Q1 @ K1.transpose(-2, -1) / self.scale, dim=-1)
26        A2 = F.softmax(Q2 @ K2.transpose(-2, -1) / self.scale, dim=-1)
27
28        diff = A1 - self.lam * A2
29        diff = torch.clamp(diff, min=0)
30        row_sums = diff.sum(dim=-1, keepdim=True)
31        row_sums = torch.where(row_sums == 0,
32                               torch.ones_like(row_sums), row_sums)
33        A_diff = diff / row_sums
34
35        return A_diff @ V
36
37
38# === Shared example ===
39tokens = ["The", "cat", "sat", "on", "mat"]
40
41Q = torch.tensor([
42    [1.0, 0.0, 1.0, 0.0],
43    [0.0, 2.0, 0.0, 1.0],
44    [1.0, 1.0, 1.0, 0.0],
45    [0.0, 0.0, 1.0, 1.0],
46    [1.0, 0.0, 0.0, 1.0],
47])
48K = torch.tensor([
49    [0.0, 1.0, 0.0, 1.0],
50    [1.0, 0.0, 1.0, 0.0],
51    [1.0, 1.0, 0.0, 0.0],
52    [0.0, 0.0, 1.0, 1.0],
53    [1.0, 0.0, 0.5, 0.5],
54])
55V = torch.tensor([
56    [1.0, 0.0, 0.0, 0.0],
57    [0.0, 1.0, 0.0, 0.0],
58    [0.0, 0.0, 1.0, 0.0],
59    [0.0, 0.0, 0.0, 1.0],
60    [0.5, 0.5, 0.5, 0.5],
61])
62
63model = DifferentialAttentionPT(d_k=4, lam_init=0.4)
64with torch.no_grad():
65    output = model(Q, K, V)
66
67for i, t in enumerate(tokens):
68    print(f"{t}: [{', '.join(f'{v:.4f}' for v in output[i])}]")

Both implementations produce identical output, confirming the math is framework-independent. The PyTorch version additionally supports:

FeatureNumPyPyTorch
GPU acceleration✓ (.cuda())
Automatic differentiation✓ (autograd)
Learnable λ✓ (nn.Parameter)
Batched inferenceManual✓ (... indexing)
Mixed precision✓ (torch.float16)

Connection to Modern Systems

Multi-Head Differential Attention

In the full Diff Transformer, multi-head attention works as:

MultiHead(X)=Concat(hˉ1,,hˉh)WO\text{MultiHead}(X) = \text{Concat}(\bar{h}_1, \ldots, \bar{h}_h) W^O

where hˉi=(1λinit)GroupNorm(DiffAttni(X))\bar{h}_i = (1 - \lambda_{\text{init}}) \cdot \text{GroupNorm}(\text{DiffAttn}_i(X)). To match the parameter count of a standard Transformer with hh heads, the Diff Transformer uses h/2h/2 heads, each operating on 2dhead2d_{\text{head}} dimensions (split into the two sub-attentions).

Flash Attention Compatibility

Differential attention requires two separate softmax computations instead of one. Each softmax can independently use Flash Attention's IO-aware tiling (Chapter 13), so the memory savings carry over directly. The only additional cost is computing the subtraction and renormalization, which are element-wise operations — negligible compared to the matrix multiplications.

KV-Cache in Inference

For autoregressive generation, differential attention requires caching K1,K2K_1, K_2 separately instead of a single KK. Since each half has d/2d/2 dimensions, the total cache size is identical to standard attention's KK cache. Combining with GQA (Chapter 6) or MLA (Chapter 15) further reduces the cache footprint.

Applications Beyond Language

  • Computer vision: In Vision Transformers (ViT), image patches often attend uniformly to background regions. Differential attention would suppress this background noise, sharpening attention on discriminative regions.
  • Scientific computing: Protein structure prediction and molecular dynamics use long sequences where noise accumulation is a primary bottleneck.
  • Code generation: When generating code from a long specification, the model needs to retrieve specific requirements without distraction from boilerplate text.

Complexity Analysis

OperationTimeMemory
Q₁K₁ᵀ + softmaxO(N² · d/2)O(N²)
Q₂K₂ᵀ + softmaxO(N² · d/2)O(N²)
A₁ − λ·A₂ + clamp + renormO(N²)O(N²)
A_diff · VO(N² · dᵥ)O(N · dᵥ)
TotalO(N² · d)O(N²)

The asymptotic complexity is identical to standard attention: O(N2d)O(N^2 d) time and O(N2)O(N^2) memory. The two half-dimension softmax operations have the same total FLOP count as one full-dimension softmax. The additional element-wise subtraction and renormalization are O(N2)O(N^2) — dominated by the matrix multiplications.

The actual wall-clock overhead is less than 5% (Ye et al., 2024), because the two softmax operations can be parallelized and the element-wise ops fuse into existing GPU kernels.


Key Takeaways

  1. Softmax always leaks weight to irrelevant tokens because ex>0e^x > 0 for all xx. In long contexts, this noise accumulates and degrades retrieval, increases hallucination, and makes in-context learning brittle.
  2. Differential attention subtracts two softmax maps to cancel common-mode noise. Genuine signal (present in one map but not the other) is amplified.
  3. Negative differentials are clamped to zero, producing genuinely sparse attention weights — something standard softmax cannot achieve.
  4. Lambda controls cancellation strength and is learned per layer via a reparameterized scalar. Early layers use mild cancellation; later layers use aggressive cancellation.
  5. No asymptotic overhead: O(N2d)O(N^2 d) time and O(N2)O(N^2) memory, identical to standard attention. Compatible with Flash Attention and KV-cache optimization.
  6. 6.8B Diff Transformer matches 11B standard Transformer in language modeling, demonstrating that noise cancellation is more parameter-efficient than simply scaling up.

Exercises

Exercise 1: Lambda Extremes

Compute the differential attention weights for “cat” (row 1) with λ=0\lambda = 0 and λ=1\lambda = 1. How does the sparsity change? At what value of λ\lambda does the first zero appear in cat's attention weights?

Exercise 2: Gradient Through Lambda

Using the PyTorch implementation, remove the torch.no_grad() context and compute loss = output.sum(). Call loss.backward(). What is model.lam.grad? Explain its sign: should gradient descent increase or decrease λ\lambda?

Exercise 3: Three-Way Split

Extend the mechanism to a three-way split: A1λ1A2λ2A3A_1 - \lambda_1 A_2 - \lambda_2 A_3. Implement this in NumPy and compute the attention weights for our shared example with λ1=0.3\lambda_1 = 0.3 and λ2=0.2\lambda_2 = 0.2. Does the three-way split produce more sparsity than the two-way?

Exercise 4: Causal Differential Attention

Combine differential attention with causal masking (Chapter 3). Apply the causal mask to both A1A_1 and A2A_2 before subtraction. Does the mask interact differently with differential attention compared to standard attention?

Exercise 5: Scaling Laws

The paper reports that Diff Transformer achieves 6.8B = 11B parameter efficiency. Using the complexity table above, calculate the FLOP savings when running a 6.8B Diff Transformer versus an 11B standard Transformer on a 4096-token sequence with dmodel=4096d_{\text{model}} = 4096.


References

  1. Ye, T., Dong, L., Xia, Y., Sun, Y., Zhu, Y., Huang, G., & Wei, F. (2024). “Differential Transformer.” arXiv:2410.05258. Published at ICLR 2025.
  2. Vaswani, A., et al. (2017). “Attention Is All You Need.” NeurIPS 2017.
  3. Dao, T., et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022.
  4. Shazeer, N. (2019). “Fast Transformer Decoding: One Write-Head is All You Need.” arXiv:1911.02150.
  5. Ainslie, J., et al. (2023). “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” EMNLP 2023.
Loading comments...