Chapter 15
35 min read
Section 16 of 17

Multi-Head Latent Attention (MLA)

Multi-Head Latent Attention (MLA)

DeepSeek-AI, "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model", 2024


Learning Objectives

After completing this chapter, you will be able to:

  1. Explain why the KV-cache is the dominant memory bottleneck during autoregressive inference in large language models
  2. Describe how MLA compresses keys and values through a learned low-rank bottleneck, caching only a small latent vector cKVc_{KV} instead of full KK and VV
  3. Derive the compression and decompression equations and trace each matrix multiplication through our shared example "the cat sat on the mat"
  4. Implement MLA from scratch in both NumPy and PyTorch and compare its attention weights against standard attention
  5. Analyze the trade-off between cache compression ratio and reconstruction quality, and explain why the approximation error is acceptable
  6. Connect MLA to the broader landscape of KV-cache optimization techniques including MQA (Chapter 5), GQA (Chapter 6), and quantization approaches
Why this matters: DeepSeek-V2 demonstrated that a 236B-parameter model could serve 100K-token contexts with 64×64\times less KV-cache memory than standard MHA — making long-context inference economically viable on commodity GPUs. MLA is not just an optimization; it fundamentally changes what context lengths are deployable in production.

The Real Problem

The KV-Cache Memory Wall

During autoregressive generation, a transformer must attend to every previous token at every layer. To avoid recomputing keys and values for all past tokens on each step, models store them in a KV-cache. This cache grows linearly with sequence length and number of layers.

For a model with LL layers, HH heads, head dimension dhd_h, and sequence length NN, the total KV-cache size is:

KV-cache=2×L×H×N×dh×sizeof(float)\text{KV-cache} = 2 \times L \times H \times N \times d_h \times \text{sizeof(float)}

The factor of 2 accounts for storing both KK and VV. For DeepSeek-V2 with L=60L = 60, H=128H = 128, dh=128d_h = 128, and N=128,000N = 128{,}000:

2×60×128×128,000×128×2 bytes37.5 GB2 \times 60 \times 128 \times 128{,}000 \times 128 \times 2 \text{ bytes} \approx 37.5 \text{ GB}

That is 37.5 GB of GPU memory consumed per sequence just for the cache — often exceeding the memory needed for the model weights themselves. When serving multiple concurrent users, this quickly exhausts even high-end GPU memory (80 GB on an A100). The KV-cache has become the single largest barrier to deploying long-context LLMs at scale.

Why Not GQA or MQA?

We explored two earlier approaches to reduce KV-cache size. Multi-Query Attention (Chapter 5) shares a single K,V head across all query heads — reducing cache by H×H\times. Grouped-Query Attention (Chapter 6) uses GG groups, reducing by H/GH/G.

MethodCache per token per layerReductionLimitation
MHA2×H×dh2 \times H \times d_h1× (baseline)Full cost
GQA2×G×dh2 \times G \times d_hH/GH/G\u00d7Reduced head diversity
MQA2×dh2 \times d_hHH\u00d7Single KV head — quality drops
MLAdCd_C (one latent)2Hdh/dC2Hd_h / d_C\u00d7Learned compression — minimal quality loss

MQA and GQA reduce cache by sharing heads, but they still cache full-dimensional vectors. MLA takes a fundamentally different approach: compress the information itself. Instead of deciding which heads share K and V, MLA asks: can we learn a low-rank representation that captures the essential information in K and V with far fewer numbers?


From Intuition to Mathematics

The Autoencoder Analogy

Think of a photograph. A 1080p image has 6.2 million pixel values. Yet a JPEG compresses it to perhaps 200 KB with minimal visual quality loss. JPEG works because natural images have enormous redundancy — nearby pixels are correlated, color channels are correlated, and most energy concentrates in low spatial frequencies.

The key vectors in a transformer have the same property. Across the dmodeld_{\text{model}}-dimensional space, the vectors that K and V actually use tend to lie near a low-dimensional subspace. High-dimensional vectors produced by learned linear projections are rarely fully rank — much of their information is redundant.

MLA is essentially a learned linear autoencoder for the KV-cache:

  1. Encoder (down-projection WDKVW_{DKV}): compress the full-dimension input to a small latent cKVc_{KV}
  2. Bottleneck (the cache): store only cKVc_{KV}, which has dCdmodeld_C \ll d_{\text{model}} dimensions
  3. Decoder (up-projections WUK,WUVW_{UK}, W_{UV}): reconstruct approximate KK' and VV' from the latent

The down-projection and up-projections are learned end-to-end during training. The model discovers the optimal low-rank subspace that preserves attention quality while minimizing cache size.

The MLA Idea

Here is the key insight that makes MLA powerful: K and V for the same token share much of the same information. After all, both are linear projections of the same hidden state hth_t. Rather than compressing K and V separately (which would save 2\u00d7), MLA compresses them jointly through a single shared latent — saving (dk+dv)/dC(d_k + d_v) / d_C overall.

In standard attention, the hidden state hth_t is projected into K and V separately:

Kt=htWK,Vt=htWVK_t = h_t \cdot W^K, \quad V_t = h_t \cdot W^V

In MLA, both pass through a shared bottleneck first:

cKV=htWDKV,Kt=cKVWUK,Vt=cKVWUVc_{KV} = h_t \cdot W_{DKV}, \quad K'_t = c_{KV} \cdot W_{UK}, \quad V'_t = c_{KV} \cdot W_{UV}

Only cKVc_{KV} enters the cache. At inference time, KK' and VV' are reconstructed from cKVc_{KV} using the frozen up-projection matrices. The reconstruction happens on-the-fly and adds negligible compute compared to the enormous memory savings.


The Mathematical Definition

Symbol-by-Symbol Breakdown

Compression (encoder):

cKV=hWDKVc_{KV} = h \cdot W_{DKV}

SymbolShapeMeaning
hh(N,dmodel)(N, d_{\text{model}})Hidden state input (or K matrix in our example)
WDKVW_{DKV}(dmodel,dC)(d_{\text{model}}, d_C)Learned down-projection (encoder weights)
cKVc_{KV}(N,dC)(N, d_C)Compressed latent — THIS is cached
dCd_CscalarLatent dimension, typically 512 (vs 32,768 for full MHA)

Decompression (decoder):

K=cKVWUK,V=cKVWUVK' = c_{KV} \cdot W_{UK}, \qquad V' = c_{KV} \cdot W_{UV}

SymbolShapeMeaning
WUKW_{UK}(dC,dmodel)(d_C, d_{\text{model}})Learned up-projection for key reconstruction
WUVW_{UV}(dC,dmodel)(d_C, d_{\text{model}})Learned up-projection for value reconstruction
KK'(N,dmodel)(N, d_{\text{model}})Reconstructed keys (≈ original K)
VV'(N,dmodel)(N, d_{\text{model}})Reconstructed values (≈ original V)

Attention (standard, using reconstructed K' and V'):

Attention(Q,K,V)=softmax ⁣(QK ⁣dk)V\text{Attention}(Q, K', V') = \text{softmax}\!\left(\frac{Q \cdot K'^{\!\top}}{\sqrt{d_k}}\right) \cdot V'

What the Formulas Say in Plain English

  1. Compress: Take each token's full key/value vector and squeeze it through a narrow bottleneck, producing a tiny latent vector cKVc_{KV}. This is like summarizing a paragraph into a tweet — you lose some nuance but keep the essential meaning.
  2. Cache: Store only the tweet-sized latents. For our example, that's 2 floats per token instead of 8 — a 4\u00d7 reduction. At DeepSeek-V2 scale, it's 512 vs 32,768 — a 64\u00d7 reduction.
  3. Decompress: When a new query arrives and needs to attend to past tokens, expand the cached latents back to full-dimension keys and values using learned up-projections. The expansion is imperfect — some information was lost — but the model has learned to preserve the information that matters most for attention.
  4. Attend: Run standard scaled dot-product attention with the reconstructed K' and V', producing context-enriched output.

Why One Latent for Both K and V?

A natural question is: why not compress K and V into separate latents? DeepSeek-V2 discovered that using a single shared latent for both works remarkably well. The reason is that K and V are both linear projections of the same hidden state hth_t, so they share a common information core. A single latent captures this shared structure efficiently. Using separate latents would double the cache size without proportional quality gain.


Step-by-Step Calculation

Using our shared example: tokens = ["The", "cat", "sat", "on", "mat"] with d=4d = 4 and dC=2d_C = 2. The projection matrices use a pattern where WDKVW_{DKV} maps dimensions 0,2 to latent 0 and dimensions 1,3 to latent 1 (weight 0.7 each), and WUKW_{UK} reverses this mapping.

Step 1: Compress K to Latent cKVc_{KV}

Each token's 4D key vector is projected down to a 2D latent via cKV=KWDKVc_{KV} = K \cdot W_{DKV}:

cKV[The]=[0,1,0,1]WDKV=[0×0.7+0×0.7,  1×0.7+1×0.7]=[0.000,  1.400]c_{KV}[\text{The}] = [0, 1, 0, 1] \cdot W_{DKV} = [0 \times 0.7 + 0 \times 0.7,\; 1 \times 0.7 + 1 \times 0.7] = [0.000,\; 1.400]

cKV[cat]=[1,0,1,0]WDKV=[1×0.7+1×0.7,  0×0.7+0×0.7]=[1.400,  0.000]c_{KV}[\text{cat}] = [1, 0, 1, 0] \cdot W_{DKV} = [1 \times 0.7 + 1 \times 0.7,\; 0 \times 0.7 + 0 \times 0.7] = [1.400,\; 0.000]

cKV[sat]=[1,1,0,0]WDKV=[1×0.7+0×0.7,  1×0.7+0×0.7]=[0.700,  0.700]c_{KV}[\text{sat}] = [1, 1, 0, 0] \cdot W_{DKV} = [1 \times 0.7 + 0 \times 0.7,\; 1 \times 0.7 + 0 \times 0.7] = [0.700,\; 0.700]

cKV[on]=[0,0,1,1]WDKV=[0×0.7+1×0.7,  0×0.7+1×0.7]=[0.700,  0.700]c_{KV}[\text{on}] = [0, 0, 1, 1] \cdot W_{DKV} = [0 \times 0.7 + 1 \times 0.7,\; 0 \times 0.7 + 1 \times 0.7] = [0.700,\; 0.700]

cKV[mat]=[1,0,0.5,0.5]WDKV=[1×0.7+0.5×0.7,  0×0.7+0.5×0.7]=[1.050,  0.350]c_{KV}[\text{mat}] = [1, 0, 0.5, 0.5] \cdot W_{DKV} = [1 \times 0.7 + 0.5 \times 0.7,\; 0 \times 0.7 + 0.5 \times 0.7] = [1.050,\; 0.350]

Tokenc0c_0c1c_1
The0.0001.400
cat1.4000.000
sat0.7000.700
on0.7000.700
mat1.0500.350
Critical observation: "sat" and "on" map to the same latent [0.700, 0.700], even though their original K vectors were different: sat = [1,1,0,0] vs on = [0,0,1,1]. The bottleneck cannot distinguish between them because our projection sums dims (0+2) and (1+3). This is the information loss that comes with compression. A learned projection would minimize this for the specific data distribution encountered during training.

Step 2: Decompress to K' and V'

Reconstruct keys from the cached latent: K=cKVWUKK' = c_{KV} \cdot W_{UK}.

K[The]=[0.000,1.400]WUK=[0.000,  0.980,  0.000,  0.980]K'[\text{The}] = [0.000, 1.400] \cdot W_{UK} = [0.000,\; 0.980,\; 0.000,\; 0.980]

Compare with original: K[The]=[0.0,1.0,0.0,1.0]K[\text{The}] = [0.0, 1.0, 0.0, 1.0]. Close! The 0.98 values are slightly below 1.0 due to the 0.7 \u00d7 0.7 = 0.49 round-trip scaling applied twice (sum of two dims).

TokenOriginal KReconstructed K′
The[0.0, 1.0, 0.0, 1.0][0.000, 0.980, 0.000, 0.980]
cat[1.0, 0.0, 1.0, 0.0][0.980, 0.000, 0.980, 0.000]
sat[1.0, 1.0, 0.0, 0.0][0.490, 0.490, 0.490, 0.490]
on[0.0, 0.0, 1.0, 1.0][0.490, 0.490, 0.490, 0.490]
mat[1.0, 0.0, 0.5, 0.5][0.735, 0.245, 0.735, 0.245]

Notice: "The" and "cat" reconstruct well because their K vectors are axis-aligned (each dimension is either 0 or 1 in a pattern matching the projection). "sat" and "on" collapse to the same reconstruction because their information was merged in the bottleneck.

Step 3: Scaled Dot-Product Attention

Compute S=QK ⁣/4S = Q \cdot K'^{\!\top} / \sqrt{4} using the reconstructed keys. Here is the detailed computation for "cat" (row 1):

Q[cat]=[0,2,0,1]Q[\text{cat}] = [0, 2, 0, 1]

Q[cat]K[The] ⁣=0×0+2×0.98+0×0+1×0.98=2.940Q[\text{cat}] \cdot K'[\text{The}]^{\!\top} = 0 \times 0 + 2 \times 0.98 + 0 \times 0 + 1 \times 0.98 = 2.940 → scaled: 2.940/2=1.4702.940 / 2 = 1.470

Q[cat]K[cat] ⁣=0×0.98+2×0+0×0.98+1×0=0.000Q[\text{cat}] \cdot K'[\text{cat}]^{\!\top} = 0 \times 0.98 + 2 \times 0 + 0 \times 0.98 + 1 \times 0 = 0.000 → scaled: 0.000/2=0.0000.000 / 2 = 0.000

Q[cat]K[sat] ⁣=0×0.49+2×0.49+0×0.49+1×0.49=1.470Q[\text{cat}] \cdot K'[\text{sat}]^{\!\top} = 0 \times 0.49 + 2 \times 0.49 + 0 \times 0.49 + 1 \times 0.49 = 1.470 → scaled: 1.470/2=0.7351.470 / 2 = 0.735

Q[cat]K[on] ⁣=1.470Q[\text{cat}] \cdot K'[\text{on}]^{\!\top} = 1.470 → scaled: 0.7350.735 (same as "sat" because K'[sat] = K'[on])

Q[cat]K[mat] ⁣=0×0.735+2×0.245+0×0.735+1×0.245=0.735Q[\text{cat}] \cdot K'[\text{mat}]^{\!\top} = 0 \times 0.735 + 2 \times 0.245 + 0 \times 0.735 + 1 \times 0.245 = 0.735 → scaled: 0.735/2=0.36750.735 / 2 = 0.3675

Softmax on [1.470, 0.000, 0.735, 0.735, 0.3675]:

max=1.470\max = 1.470, shifted = [0.000, -1.470, -0.735, -0.735, -1.1025]

e0.000=1.000,  e1.470=0.230,  e0.735=0.480,  e0.735=0.480,  e1.103=0.332e^{0.000} = 1.000, \; e^{-1.470} = 0.230, \; e^{-0.735} = 0.480, \; e^{-0.735} = 0.480, \; e^{-1.103} = 0.332

sum=2.521\text{sum} = 2.521 A[cat]=[0.397,  0.091,  0.190,  0.190,  0.132]A[\text{cat}] = [0.397,\; 0.091,\; 0.190,\; 0.190,\; 0.132]

Comparison with standard attention: In standard attention (Chapter 1), cat's weights were [0.403, 0.090, 0.244, 0.148, 0.115]. In MLA, "sat" and "on" receive equal weight (0.190) because the bottleneck merged their key information. Standard attention gives "sat" 0.244 and "on" 0.148 — it can distinguish them. This illustrates the precision-vs-memory trade-off at the heart of MLA.

Step 4: Weighted Sum of V'

The final output for "cat": O[cat]=A[cat]VO[\text{cat}] = A[\text{cat}] \cdot V'

O[cat]=0.397×[0,0.98,0,0.98]+0.091×[0.98,0,0.98,0]+0.190×[0.49,0.49,0.49,0.49]O[\text{cat}] = 0.397 \times [0, 0.98, 0, 0.98] + 0.091 \times [0.98, 0, 0.98, 0] + 0.190 \times [0.49, 0.49, 0.49, 0.49]

+  0.190×[0.49,0.49,0.49,0.49]+0.132×[0.735,0.245,0.735,0.245]\quad\quad\quad + \; 0.190 \times [0.49, 0.49, 0.49, 0.49] + 0.132 \times [0.735, 0.245, 0.735, 0.245]

=[0.373,  0.607,  0.373,  0.607]= [0.373,\; 0.607,\; 0.373,\; 0.607]


Interactive: MLA Cache Explorer

Use the slider below to change the latent dimension dCd_C and observe how compression ratio, reconstruction quality, and attention weights change. Switch between the Pipeline view (showing compress/decompress), Cache Size view (comparing with other mechanisms), and Heatmap view (comparing standard vs MLA attention weights).

Loading MLA Cache Explorer...

Full Attention Weights and Output

Standard vs MLA Comparison

Standard Attention Weights (from Chapter 1):

Thecatsatonmat
The0.10950.29760.18050.18050.2318
cat0.40260.08980.24420.14810.1153
sat0.15190.25050.25050.15190.1951
on0.19030.19030.11540.31370.1903
mat0.18920.18920.18920.18920.2430

MLA Attention Weights (dC=2d_C = 2):

Thecatsatonmat
The0.11090.29560.18110.18110.2313
cat0.39670.09120.19020.19020.1317
sat0.15080.24610.19270.19270.2178
on0.20000.20000.20000.20000.2000
mat0.20000.20000.20000.20000.2000

Interpreting the Differences

  • "The" row: Very similar weights. MLA preserves the structure well because "The"'s key was axis-aligned and reconstructed cleanly.
  • "cat" row: The dominant pattern is preserved (strongest attention to "The"), but "sat" and "on" now receive equal weight (0.190 vs 0.244/0.148 in standard) because the bottleneck merged them.
  • "on" and "mat" rows: MLA produces perfectly uniform weights (0.200). These tokens had query patterns that, after the K-compression, see all keys as equally similar. In a trained model with learned projections, this uniformity would not occur.

MLA Output (5×45 \times 4):

dim-0dim-1dim-2dim-3
The0.63720.34280.63720.3428
cat0.37260.60740.37260.6074
sat0.59010.38990.59010.3899
on0.53900.44100.53900.4410
mat0.53900.44100.53900.4410

Notice that dim-0 = dim-2 and dim-1 = dim-3 in every row. This is a direct consequence of our symmetric projection design (W_DKV pairs dims 0,2 and 1,3). A learned projection would break this symmetry and produce distinct values in all four dimensions.


The Approximation Trade-off

What Information Is Lost?

When dC<dmodeld_C < d_{\text{model}}, the compression is lossy. Mathematically, the composed mapping KKWDKVWUK=KK \mapsto K \cdot W_{DKV} \cdot W_{UK} = K' is a rank-dCd_C approximation of K. Any information in K that lies outside the dCd_C-dimensional subspace spanned by WDKVW_{DKV} is lost.

In our example with dC=2d_C = 2, we can represent at most a 2D subspace of the 4D key space. Tokens whose K vectors differ only in the "invisible" directions (those not captured by WDKVW_{DKV}) become indistinguishable after compression. This is exactly what happened with "sat" and "on".

Why the Loss Is Acceptable

  1. Low-rank structure in practice: In trained transformers, K and V matrices exhibit strong low-rank structure. DeepSeek-V2 found that 512 dimensions capture 99%+ of the variance in a 32,768-dimensional KV space (Zhu et al., 2024).
  2. End-to-end learning: The projections are not fixed like PCA — they are learned jointly with the rest of the model. The model learns to put important information into the directions that WDKVW_{DKV} preserves, and to tolerate lossy reconstruction of less critical dimensions.
  3. Graceful degradation: As dCd_C increases, MLA smoothly approaches standard attention. DeepSeek-V2 uses dC=512d_C = 512, which provides near-lossless attention quality at 64\u00d7 compression.
  4. Redundancy across layers: Information lost in one layer's compression can be recovered from other layers. The model learns to distribute information robustly across layers, similar to how error-correcting codes work.

Applications Across Domains

DomainHow MLA HelpsImpact
Long-context LLMs128K-token context with 37× less KV-cache memoryMakes long-document analysis economically viable
Multi-turn chatEach conversation turn adds tokens; MLA keeps cache compactMore concurrent users per GPU
Code generationLarge codebases (100K+ tokens) require long contextEnables whole-repository understanding
Retrieval-augmented generationRetrieved passages expand context significantlyMore passages in context without OOM
Vision transformersHigh-resolution images produce many patches (tokens)Higher resolution attention with bounded memory
Scientific sequence modelingProtein sequences and genomic data can be very longEnables attention over full gene sequences

Connection to Modern Systems

DeepSeek-V2 Architecture Details

DeepSeek-V2 (Zhu et al., 2024) uses MLA with these specific parameters:

ParameterValueMeaning
dmodeld_{\text{model}}5,120Hidden dimension
nhn_h128Number of attention heads
dhd_h128Per-head dimension
dCd_C (KV compression)512Compressed KV latent dimension
dC(Q)d_C^{(Q)} (Q compression)1,536Compressed Q latent dimension
Standard KV cache32,768 floats/token/layer2×128×1282 \times 128 \times 128
MLA KV cache512 floats/token/layerdC=512d_C = 512 only
Compression ratio64×64\times32,768 / 512

A unique feature of DeepSeek-V2 is that it also compresses Q through a latent, though with a larger dimension (dC(Q)=1536d_C^{(Q)} = 1536) since queries don't need to be cached (they are used only once per generation step).

RoPE Integration in MLA

A technical challenge with MLA is that Rotary Position Embedding (RoPE, Chapter 8) is typically applied to K after projection. But in MLA, K is reconstructed from cKVc_{KV} which was computed before position information was available. DeepSeek-V2 solves this by maintaining a small additional per-head RoPE key ktropek_t^{\text{rope}} of dimension dR=64d_R = 64 that is not compressed. The total cache per token per layer is then dC+dR=512+64=576d_C + d_R = 512 + 64 = 576 floats, still vastly smaller than the standard 32,768.

Flash Attention Compatibility

Once K' and V' are reconstructed from cKVc_{KV}, the attention computation is standard scaled dot-product attention. This means MLA is fully compatible with Flash Attention (Chapter 13) — the IO-aware tiling algorithm applies unchanged to the reconstructed matrices. The decompression step (two matrix multiplications) adds only O(NdCd)O(N \cdot d_C \cdot d) compute, which is negligible compared to the O(N2d)O(N^2 d) attention itself for long sequences.


Complexity Analysis

OperationTime ComplexityMemory
Compress: KWDKVK \cdot W_{DKV}O(NddC)O(N \cdot d \cdot d_C)O(NdC)O(N \cdot d_C)
Decompress K: cKVWUKc_{KV} \cdot W_{UK}O(NdCd)O(N \cdot d_C \cdot d)O(Nd)O(N \cdot d)
Decompress V: cKVWUVc_{KV} \cdot W_{UV}O(NdCd)O(N \cdot d_C \cdot d)O(Nd)O(N \cdot d)
Attention: QKT,  αVQK'^T, \; \alpha V'O(N2d)O(N^2 d)O(N2+Nd)O(N^2 + Nd)
KV-cache per tokenO(dC)O(d_C) vs O(d)O(d) standard

The key insight is that the decompression cost (O(NdCd)O(N \cdot d_C \cdot d)) is dominated by the attention cost (O(N2d)O(N^2 d)) for any sequence longer than dCd_C tokens. Since dC=512d_C = 512 and typical sequences are 4K–128K tokens, MLA adds negligible overhead while providing enormous memory savings.


Python Implementation

The complete MLA implementation as a Python class. Click any line to see its execution trace with exact values from our "the cat sat on the mat" example.

Multi-Head Latent Attention \u2014 NumPy Implementation
🐍mla_attention.py
1import numpy as np

NumPy provides vectorized array operations. All matrix multiplications (Q @ K′.T, weights @ V′) execute as optimized C code.

2import math

math.sqrt is used for the scaling factor √d_k.

4class MultiHeadLatentAttention:

Encapsulates the MLA mechanism from DeepSeek-V2. Manages three learned projection matrices: W_DKV (down-projection), W_UK and W_UV (up-projections for K and V reconstruction).

14def __init__(self, d_model, d_c):

Constructor: d_model is the full embedding dimension (4 in our example), d_c is the compressed latent dimension (2). The compression ratio is (d_k + d_v) / d_c.

EXECUTION STATE
⬇ input: d_model = 4 — full dimension of Q, K, V vectors
⬇ input: d_c = 2 — compressed latent dimension (d_c ≪ d_model)
15self.d_model = d_model

Store the full dimension size for later use in scaling.

EXECUTION STATE
self.d_model = 4
16self.d_c = d_c

Store the compressed dimension. This controls how many floats are cached per token.

EXECUTION STATE
self.d_c = 2
17self.scale = math.sqrt(d_model)

Precompute √4 = 2.0. Used in attention scoring to prevent dot products from growing too large.

EXECUTION STATE
self.scale = 2.0 (√4)
18self.W_DKV = np.zeros((d_model, d_c))

Initialize the down-projection matrix W_DKV as (4×2) zeros. This matrix compresses 4D vectors into 2D latent vectors.

EXECUTION STATE
self.W_DKV shape = (4, 2) — maps d_model → d_c
19for i in range(d_model):

Loop over all 4 rows of W_DKV to set up the projection pattern.

LOOP TRACE · 4 iterations
i=0
W_DKV[0, 0%2=0] = 0.7 — dim 0 maps to latent 0
i=1
W_DKV[1, 1%2=1] = 0.7 — dim 1 maps to latent 1
i=2
W_DKV[2, 2%2=0] = 0.7 — dim 2 maps to latent 0
i=3
W_DKV[3, 3%2=1] = 0.7 — dim 3 maps to latent 1
20j = i % d_c

Each input dimension maps to latent dimension (i mod d_c). Dims 0,2 → latent 0; dims 1,3 → latent 1.

EXECUTION STATE
pattern = dim0→c0, dim1→c1, dim2→c0, dim3→c1
21self.W_DKV[i, j] = 0.7

Set the projection weight to 0.7. In practice, these are learned via backpropagation.

EXECUTION STATE
W_DKV (4×2) =
     c0   c1
d0  0.7  0.0
d1  0.0  0.7
d2  0.7  0.0
d3  0.0  0.7
22self.W_UK = np.zeros((d_c, d_model))

Initialize the up-projection matrix for K reconstruction. Shape (2×4): maps latent back to full dimension.

EXECUTION STATE
self.W_UK shape = (2, 4) — maps d_c → d_model
25self.W_UK[i, j] = 0.7

W_UK mirrors W_DKVᵀ pattern. Latent 0 → dims 0,2; latent 1 → dims 1,3.

EXECUTION STATE
W_UK (2×4) =
     d0   d1   d2   d3
c0  0.7  0.0  0.7  0.0
c1  0.0  0.7  0.0  0.7
26self.W_UV = np.zeros((d_c, d_model))

Initialize up-projection for V reconstruction. Same structure as W_UK in this demo. In practice, W_UK ≠ W_UV so K′ and V′ differ.

EXECUTION STATE
W_UV (2×4) =
     d0   d1   d2   d3
c0  0.7  0.0  0.7  0.0
c1  0.0  0.7  0.0  0.7
32def compress(self, K) -> np.ndarray:

Down-project K from (N, d_model) to (N, d_c). This is the compression step — the output c_KV is what gets cached.

EXECUTION STATE
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬆ returns = np.ndarray (5, 2) — compressed latent c_KV
34return K @ self.W_DKV

Matrix multiply K(5×4) @ W_DKV(4×2) = c_KV(5×2). Each token’s 4D key is compressed to 2D.

EXECUTION STATE
⬆ return: c_KV (5×2) =
      c0      c1
The  0.0000  1.4000
cat  1.4000  0.0000
sat  0.7000  0.7000
on   0.7000  0.7000
mat  1.0500  0.3500
36def decompress_K(self, c_KV) -> np.ndarray:

Up-project c_KV from (N, d_c) back to (N, d_model). Reconstructs approximate keys for attention computation.

EXECUTION STATE
⬇ input: c_KV (5×2) =
      c0      c1
The  0.0000  1.4000
cat  1.4000  0.0000
sat  0.7000  0.7000
on   0.7000  0.7000
mat  1.0500  0.3500
⬆ returns = np.ndarray (5, 4) — reconstructed K′
38return c_KV @ self.W_UK

c_KV(5×2) @ W_UK(2×4) = K′(5×4). Note: K′ ≈ K but not identical — information was lost in the bottleneck.

EXECUTION STATE
⬆ return: K′ (5×4) =
       d0      d1      d2      d3
The  0.0000  0.9800  0.0000  0.9800
cat  0.9800  0.0000  0.9800  0.0000
sat  0.4900  0.4900  0.4900  0.4900
on   0.4900  0.4900  0.4900  0.4900
mat  0.7350  0.2450  0.7350  0.2450
40def decompress_V(self, c_KV) -> np.ndarray:

Same structure as decompress_K but uses W_UV. In a trained model, W_UV ≠ W_UK so V′ differs from K′.

EXECUTION STATE
⬇ input: c_KV (5×2) = (same latent as above)
⬆ returns = np.ndarray (5, 4) — reconstructed V′
42return c_KV @ self.W_UV

c_KV(5×2) @ W_UV(2×4) = V′(5×4). In this demo W_UV = W_UK, so V′ = K′.

EXECUTION STATE
⬆ return: V′ (5×4) =
       d0      d1      d2      d3
The  0.0000  0.9800  0.0000  0.9800
cat  0.9800  0.0000  0.9800  0.0000
sat  0.4900  0.4900  0.4900  0.4900
on   0.4900  0.4900  0.4900  0.4900
mat  0.7350  0.2450  0.7350  0.2450
44def _softmax(self, x) -> np.ndarray:

Numerically stable softmax: subtract row max before exp to prevent overflow.

EXECUTION STATE
⬇ input: x = np.ndarray — scaled score matrix (5×5)
46x_shifted = x - np.max(x, axis=-1, keepdims=True)

Subtract each row’s maximum. This prevents exp() overflow while giving identical softmax output.

EXECUTION STATE
axis=-1 = Operate along the LAST axis (columns within each row). For a 5×5 matrix, find the max of each row independently.
keepdims=True = Keep reduced axis as size-1 dimension. Returns shape (5,1) not (5,), so broadcasting x(5×5) - max(5×1) works correctly.
47exp_x = np.exp(x_shifted)

Exponentiate the shifted scores. All values are ≤0, so exp_x ∈ (0, 1].

48return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

Divide each row by its sum to get a probability distribution. Each row sums to 1.0.

EXECUTION STATE
axis=-1 = Sum along the last axis (each row independently)
keepdims=True = Shape (5,1) for broadcasting: exp_x(5×5) / sum(5×1)
50def forward(self, Q, K, V):

Full MLA forward pass. Takes Q, K, V matrices and returns attention weights, output, and the cached latent.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  1.0  0.0  1.0  0.0
⬆ returns = (weights, output, c_KV) — attention weights (5×5), output (5×4), cached latent (5×2)
55c_KV = self.compress(K)

COMPRESS: K(5×4) → c_KV(5×2). Each token’s 4-float key becomes a 2-float latent. THIS is what gets cached.

EXECUTION STATE
c_KV (5×2) — CACHED =
      c0      c1
The  0.0000  1.4000
cat  1.4000  0.0000
sat  0.7000  0.7000
on   0.7000  0.7000
mat  1.0500  0.3500
56K_prime = self.decompress_K(c_KV)

DECOMPRESS K: c_KV(5×2) → K′(5×4). Reconstructed from cache at inference time.

EXECUTION STATE
K′ (5×4) =
       d0      d1      d2      d3
The  0.0000  0.9800  0.0000  0.9800
cat  0.9800  0.0000  0.9800  0.0000
sat  0.4900  0.4900  0.4900  0.4900
on   0.4900  0.4900  0.4900  0.4900
mat  0.7350  0.2450  0.7350  0.2450
57V_prime = self.decompress_V(c_KV)

DECOMPRESS V: c_KV(5×2) → V′(5×4). Same latent produces both K′ and V′ via different projections.

EXECUTION STATE
V′ (5×4) =
       d0      d1      d2      d3
The  0.0000  0.9800  0.0000  0.9800
cat  0.9800  0.0000  0.9800  0.0000
sat  0.4900  0.4900  0.4900  0.4900
on   0.4900  0.4900  0.4900  0.4900
mat  0.7350  0.2450  0.7350  0.2450
58scores = Q @ K_prime.T / self.scale

Standard scaled dot-product: Q(5×4) @ K′ᵀ(4×5) / 2.0 = scores(5×5). Each entry measures query-key similarity.

EXECUTION STATE
.T = NumPy transpose — K′(5×4) becomes K′ᵀ(4×5)
self.scale = 2.0 (√4)
scores (5×5) =
         The     cat     sat      on     mat
The   0.0000  0.9800  0.4900  0.4900  0.7350
cat   1.4700  0.0000  0.7350  0.7350  0.3675
sat   0.4900  0.9800  0.7350  0.7350  0.8575
on    0.4900  0.4900  0.4900  0.4900  0.4900
mat   0.4900  0.4900  0.4900  0.4900  0.4900
59weights = self._softmax(scores)

Apply softmax row-wise. Converts raw scores to probability distributions.

EXECUTION STATE
weights (5×5) =
        The     cat     sat      on     mat
The  0.1109  0.2956  0.1811  0.1811  0.2313
cat  0.3967  0.0912  0.1902  0.1902  0.1317
sat  0.1508  0.2461  0.1927  0.1927  0.2178
on   0.2000  0.2000  0.2000  0.2000  0.2000
mat  0.2000  0.2000  0.2000  0.2000  0.2000
60output = weights @ V_prime

Weighted sum of reconstructed values. weights(5×5) @ V′(5×4) = output(5×4).

EXECUTION STATE
output (5×4) =
       d0      d1      d2      d3
The  0.6372  0.3428  0.6372  0.3428
cat  0.3726  0.6074  0.3726  0.6074
sat  0.5901  0.3899  0.5901  0.3899
on   0.5390  0.4410  0.5390  0.4410
mat  0.5390  0.4410  0.5390  0.4410
61return weights, output, c_KV

Return all three: weights for analysis, output for the next layer, c_KV for caching.

EXECUTION STATE
⬆ return: weights = (5, 5) attention probability matrix
⬆ return: output = (5, 4) context-enriched representations
⬆ return: c_KV = (5, 2) cached latent — only 2 floats per token!
74 lines without explanation
1import numpy as np
2import math
3
4class MultiHeadLatentAttention:
5    """
6    Multi-Head Latent Attention (MLA) — DeepSeek-V2, 2024
7
8    Core idea: compress K and V into a low-rank latent c_KV
9    before caching. Only c_KV is stored in the KV-cache.
10    K' and V' are reconstructed from c_KV at inference time.
11
12    Cache savings: d_C floats per token instead of (d_k + d_v).
13    """
14
15    def __init__(self, d_model: int, d_c: int):
16        self.d_model = d_model
17        self.d_c = d_c
18        self.scale = math.sqrt(d_model)
19        self.W_DKV = np.zeros((d_model, d_c))
20        for i in range(d_model):
21            j = i % d_c
22            self.W_DKV[i, j] = 0.7
23        self.W_UK = np.zeros((d_c, d_model))
24        for i in range(d_c):
25            for j in range(d_model):
26                if j % d_c == i:
27                    self.W_UK[i, j] = 0.7
28        self.W_UV = np.zeros((d_c, d_model))
29        for i in range(d_c):
30            for j in range(d_model):
31                if j % d_c == i:
32                    self.W_UV[i, j] = 0.7
33
34    def compress(self, K: np.ndarray) -> np.ndarray:
35        """Down-project K to low-rank latent c_KV."""
36        return K @ self.W_DKV
37
38    def decompress_K(self, c_KV: np.ndarray) -> np.ndarray:
39        """Up-project latent back to key space."""
40        return c_KV @ self.W_UK
41
42    def decompress_V(self, c_KV: np.ndarray) -> np.ndarray:
43        """Up-project latent back to value space."""
44        return c_KV @ self.W_UV
45
46    def _softmax(self, x: np.ndarray) -> np.ndarray:
47        """Numerically stable softmax along last axis."""
48        x_shifted = x - np.max(x, axis=-1, keepdims=True)
49        exp_x = np.exp(x_shifted)
50        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
51
52    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
53        """
54        Full MLA forward pass.
55
56        Returns: (weights, output, c_KV)
57        """
58        c_KV = self.compress(K)
59        K_prime = self.decompress_K(c_KV)
60        V_prime = self.decompress_V(c_KV)
61        scores = Q @ K_prime.T / self.scale
62        weights = self._softmax(scores)
63        output = weights @ V_prime
64        return weights, output, c_KV
65
66
67tokens = ["The", "cat", "sat", "on", "mat"]
68
69Q = np.array([
70    [1.0, 0.0, 1.0, 0.0],
71    [0.0, 2.0, 0.0, 1.0],
72    [1.0, 1.0, 1.0, 0.0],
73    [0.0, 0.0, 1.0, 1.0],
74    [1.0, 0.0, 0.0, 1.0],
75])
76K = np.array([
77    [0.0, 1.0, 0.0, 1.0],
78    [1.0, 0.0, 1.0, 0.0],
79    [1.0, 1.0, 0.0, 0.0],
80    [0.0, 0.0, 1.0, 1.0],
81    [1.0, 0.0, 0.5, 0.5],
82])
83V = np.array([
84    [1.0, 0.0, 0.0, 0.0],
85    [0.0, 1.0, 0.0, 0.0],
86    [0.0, 0.0, 1.0, 0.0],
87    [0.0, 0.0, 0.0, 1.0],
88    [1.0, 0.0, 1.0, 0.0],
89])
90
91mla = MultiHeadLatentAttention(d_model=4, d_c=2)
92weights, output, c_KV = mla.forward(Q, K, V)
93
94print("Latent c_KV (CACHED):")
95for i, t in enumerate(tokens):
96    print(f"  {t}: {c_KV[i]}")
97
98print("\nAttention weights:")
99for i, t in enumerate(tokens):
100    print(f"  {t}: {weights[i].round(4)}")
101
102print("\nOutput:")
103for i, t in enumerate(tokens):
104    print(f"  {t}: {output[i].round(4)}")
105
106print(f"\nCache: {mla.d_c} floats/token vs {2*mla.d_model} standard")

PyTorch Implementation

The same mechanism in PyTorch, using nn.Linear\texttt{nn.Linear} for learnable projections. In production, the weights of W_DKV, W_UK, and W_UV are learned end-to-end via backpropagation.

Multi-Head Latent Attention \u2014 PyTorch Implementation
🐍mla_pytorch.py
1import torch

PyTorch tensor library for GPU-accelerated matrix operations.

2import torch.nn as nn

Neural network module. nn.Linear provides learnable weight matrices with automatic gradient tracking.

3import math

Standard math library for sqrt.

5class MultiHeadLatentAttention(nn.Module):

Inherits from nn.Module so PyTorch can track parameters, manage devices (CPU/GPU), and compute gradients automatically.

12def __init__(self, d_model, d_c):

Constructor. d_model=4, d_c=2 in our example.

EXECUTION STATE
⬇ input: d_model = 4 — full embedding dimension
⬇ input: d_c = 2 — compressed latent dimension
13super().__init__()

Initialize nn.Module parent class. Required for parameter registration and hooks.

17self.scale = math.sqrt(d_model)

Precompute √d = √4 = 2.0 for attention scaling.

EXECUTION STATE
self.scale = 2.0
20self.W_DKV = nn.Linear(d_model, d_c, bias=False)

Learnable down-projection: (4 → 2). nn.Linear stores a (d_c, d_model) = (2, 4) weight matrix internally. bias=False means no additive bias.

EXECUTION STATE
W_DKV.weight shape = (2, 4) — nn.Linear stores transposed
bias=False = No bias term. MLA projections are pure linear maps.
21self.W_UK = nn.Linear(d_c, d_model, bias=False)

Learnable up-projection for K reconstruction: (2 → 4).

EXECUTION STATE
W_UK.weight shape = (4, 2)
22self.W_UV = nn.Linear(d_c, d_model, bias=False)

Learnable up-projection for V reconstruction: (2 → 4). Separate from W_UK so K′ ≠ V′.

EXECUTION STATE
W_UV.weight shape = (4, 2)
24def forward(self, Q, K, V) -> tuple:

The forward pass. PyTorch calls this when you do mla(Q, K, V).

EXECUTION STATE
⬇ input: Q = torch.Tensor (B, N, d_model) = (1, 5, 4)
⬇ input: K = torch.Tensor (B, N, d_model) = (1, 5, 4)
⬇ input: V = torch.Tensor (B, N, d_model) = (1, 5, 4)
26c_KV = self.W_DKV(K)

COMPRESS: K(1,5,4) → c_KV(1,5,2). nn.Linear applies the learned projection. This latent is what gets cached during inference.

EXECUTION STATE
c_KV shape = (1, 5, 2) — only 2 floats per token cached!
29K_prime = self.W_UK(c_KV)

DECOMPRESS K: c_KV(1,5,2) → K′(1,5,4). Reconstructed from the cached latent.

EXECUTION STATE
K_prime shape = (1, 5, 4)
30V_prime = self.W_UV(c_KV)

DECOMPRESS V: c_KV(1,5,2) → V′(1,5,4). Same latent, different up-projection.

EXECUTION STATE
V_prime shape = (1, 5, 4)
33scores = Q @ K_prime.transpose(-2, -1) / self.scale

Scaled dot-product scores. .transpose(-2, -1) swaps last two dims: (1,5,4) → (1,4,5).

EXECUTION STATE
.transpose(-2, -1) = Swap axes -2 and -1. K′(B,N,d) → K′ᵀ(B,d,N) for matrix multiplication.
self.scale = 2.0 — prevents dot products from growing too large
scores shape = (1, 5, 5)
34weights = torch.softmax(scores, dim=-1)

Apply softmax along the last dimension (each query row independently).

EXECUTION STATE
dim=-1 = Softmax along last axis. Each row becomes a probability distribution summing to 1.
weights shape = (1, 5, 5)
35output = weights @ V_prime

Weighted sum of reconstructed values. weights(1,5,5) @ V′(1,5,4) = output(1,5,4).

EXECUTION STATE
output shape = (1, 5, 4)
36return output, weights, c_KV

Return output for the next layer, weights for analysis, and c_KV for caching.

EXECUTION STATE
⬆ return: output = (1, 5, 4) — context-enriched representations
⬆ return: c_KV = (1, 5, 2) — this is the only thing stored in KV-cache
34 lines without explanation
1import torch
2import torch.nn as nn
3import math
4
5class MultiHeadLatentAttention(nn.Module):
6    """
7    MLA in PyTorch — production-style with nn.Linear layers.
8    The down/up projections are learnable parameters.
9    """
10
11    def __init__(self, d_model: int, d_c: int):
12        super().__init__()
13        self.d_model = d_model
14        self.d_c = d_c
15        self.scale = math.sqrt(d_model)
16
17        # Learnable projections
18        self.W_DKV = nn.Linear(d_model, d_c, bias=False)
19        self.W_UK  = nn.Linear(d_c, d_model, bias=False)
20        self.W_UV  = nn.Linear(d_c, d_model, bias=False)
21
22    def forward(self, Q: torch.Tensor, K: torch.Tensor,
23                V: torch.Tensor) -> tuple:
24        # Compress K,V -> shared latent (THIS is cached)
25        c_KV = self.W_DKV(K)          # (B, N, d_c)
26
27        # Decompress at inference time
28        K_prime = self.W_UK(c_KV)     # (B, N, d_model)
29        V_prime = self.W_UV(c_KV)     # (B, N, d_model)
30
31        # Standard scaled dot-product attention
32        scores = Q @ K_prime.transpose(-2, -1) / self.scale
33        weights = torch.softmax(scores, dim=-1)
34        output = weights @ V_prime     # (B, N, d_model)
35        return output, weights, c_KV
36
37
38# Run with our shared example
39Q_t = torch.tensor([[[1.,0.,1.,0.],[0.,2.,0.,1.],
40    [1.,1.,1.,0.],[0.,0.,1.,1.],[1.,0.,0.,1.]]])
41K_t = torch.tensor([[[0.,1.,0.,1.],[1.,0.,1.,0.],
42    [1.,1.,0.,0.],[0.,0.,1.,1.],[1.,0.,0.5,0.5]]])
43V_t = torch.tensor([[[1.,0.,0.,0.],[0.,1.,0.,0.],
44    [0.,0.,1.,0.],[0.,0.,0.,1.],[1.,0.,1.,0.]]])
45
46mla = MultiHeadLatentAttention(d_model=4, d_c=2)
47with torch.no_grad():
48    output, weights, c_KV = mla(Q_t, K_t, V_t)
49
50print(f"c_KV shape: {c_KV.shape}")    # (1, 5, 2)
51print(f"Output shape: {output.shape}") # (1, 5, 4)
52print(f"Cache savings: {2*4}/{2} = {2*4//2}x smaller")

Key Takeaways

  1. MLA compresses the KV-cache through a learned bottleneck. Instead of caching full K and V vectors, it caches a tiny latent cKVc_{KV} and reconstructs K', V' on demand.
  2. One shared latent serves both K and V reconstruction. Since K and V are both projections of the same hidden state, a single compressed representation captures their shared information efficiently.
  3. The compression ratio is (dk+dv)/dC(d_k + d_v) / d_C. In DeepSeek-V2, this is 64\u00d7, reducing 37 GB of KV-cache to under 600 MB per sequence.
  4. MLA is orthogonal to Flash Attention. After decompression, standard attention proceeds normally, so Flash Attention's IO-aware tiling applies directly.
  5. The trade-off is reconstruction fidelity vs cache size. Larger dCd_C gives better reconstruction but larger cache. The model learns to put important information into the preserved subspace.
  6. MLA fundamentally differs from MQA/GQA. MQA/GQA share heads; MLA compresses dimensions. MLA can achieve higher compression ratios with less quality loss because it targets redundancy within vectors, not across heads.

Exercises

Exercise 1: Change dCd_C to 3

Modify the Python code to use dC=3d_C = 3 instead of 2. What happens to the reconstruction of "sat" and "on"? Do their latent representations become distinct? What is the new compression ratio?

Exercise 2: Reconstruction Error

Compute the Frobenius norm reconstruction error KKF/KF\|K - K'\|_F / \|K\|_F for dC=1,2,3,4d_C = 1, 2, 3, 4. Plot the results. At what dCd_C does the error reach zero? Why?

Exercise 3: Separate K and V Compression

Modify MLA to compress K and V through separate latents (each of dimension dCd_C), so the cache stores 2dC2 d_C floats per token. Compare the attention weights against the shared-latent version. Is the quality improvement worth the doubled cache?

Exercise 4: SVD-Optimal Projection

Compute the SVD of the K matrix: K=UΣVTK = U \Sigma V^T. Use the top-2 right singular vectors as WDKVW_{DKV}. How does this "optimal" linear projection compare to our simple pattern projection in terms of reconstruction error?

Exercise 5: Scaling to Real Dimensions

Calculate the KV-cache memory savings for a model with dmodel=4096d_{\text{model}} = 4096, nh=32n_h = 32 heads, dh=128d_h = 128, L=40L = 40 layers, and N=32,000N = 32{,}000 tokens. Compare MHA, GQA (G=4), MQA, and MLA (dC=256d_C = 256) in GB at FP16 precision.


References

  1. Zhu, A., et al. "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model." arXiv:2405.04434, 2024.
  2. Liu, A., et al. "DeepSeek-V3 Technical Report." arXiv:2412.19437, 2024.
  3. Shazeer, N. "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150, 2019.
  4. Ainslie, J., et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." arXiv:2305.13245, 2023.
  5. Vaswani, A., et al. "Attention Is All You Need." NeurIPS, 2017.
  6. Dao, T., et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS, 2022.
  7. Su, J., et al. "RoFormer: Enhanced Transformer with Rotary Position Embedding." arXiv:2104.09864, 2021.
Loading comments...