Chapter 12
25 min read
Section 70 of 117

YaRN: Frequency-Domain Interpolation

Long-Context Extension

The Real Problem: NTK Alone Is Not Enough

The previous section ended with a victory and a warning. NTK-aware scaling fixed the most obvious failure of Position Interpolation — it stopped pretending that all RoPE dimensions live on the same timescale. By scaling the RoPE base bbsd/(d2)b \to b \cdot s^{d / (d-2)} instead of compressing every angle uniformly, NTK kept the highest-frequency rotations roughly intact while spreading the slowest ones across the new context window. The result was clean: at s=4s = 4, you could push Llama-2 from 4K to 16K tokens with almost no loss of perplexity, no fine-tuning, no retraining.

Then people tried s=8s = 8, and s=16s = 16, and the cracks reappeared. Around 32K context, NTK-aware Llama-2 started forgetting the start of documents. By 64K, the loss curve looked like a polite version of the original PI catastrophe — not divergent, just quietly broken. Information retrieval at long range worked for the first ~20K tokens, then degraded smoothly into noise. Why?

The diagnosis took a careful look at what NTK actually does to each dim-pair. NTK rescales the base, which means it applies a single multiplicative correction to the whole spectrum. But the spectrum is highly non-uniform. RoPE assigns dimension pair i[0,d/2)i \in [0, d/2) the angular frequency θi=b2i/d\theta_i = b^{-2i/d}, so the wavelengths λi=2π/θi\lambda_i = 2\pi / \theta_i span three full orders of magnitude. For Llama-2 with d=128d = 128:

Dim pair iθᵢλᵢ (positions)Rotations in 4K context
0 (fastest)1.006.28~652
100.3219.7~208
200.1062.6~65
300.032198.6~20.6
400.010630.3~6.5
500.00322,000~2.05
600.00106,323~0.65
63 (slowest)1.19·10⁻⁴52,738~0.08

Look at that last column. Dim-pair 0 saw 652 full rotationsduring pretraining — it has been around the unit circle so many times that the model has memorised every angle. Dim-pair 63 saw less than one tenth of a rotation — the model literally never observed what its sin and cos look like past their first few percent of arc. These two ends of the spectrum need completely different treatment, and any method that applies a single global rule to both is going to leak one or the other.

The structural insight: Position Interpolation compresses every dim uniformly — it ruins the fast dims (whose angles the model learned precisely) to fix the slow dims (whose angles the model never observed). NTK-aware scaling fixes this for moderate ss, but at extreme scale factors its single corrective factor is still a compromise. The fast dims get overstretched, the slow dims get undercompressed, and the band edges develop discontinuities the model has never seen.

YaRN — Yet another RoPE extensioN, Peng et al. 2023 — is the paper that finally said: stop trying to find a single scaling rule. Use three rules, one per band. Combined with a small attention-temperature correction, YaRN is the method that powers DeepSeek-V2's and DeepSeek-V3's 128K context windows, Qwen2-72B's 32K extension, and most production long-context Llama variants. It is also the rare technique whose algorithmic description fits on half a page.

Intuition: A Clock With Hands of Many Speeds

Imagine RoPE as a fancy clock with d/2d/2 hands instead of three. The fastest hand sweeps the dial in 6.28 positions; the slowest hand takes 50,000 positions to complete one orbit. Every token in the sequence advances every hand by its own fraction of a revolution, and the resulting configuration of hands is what the attention mechanism reads to know "how far apart are these two tokens?".

Now you want to extend the clock's range from 4,000 positions to 32,000. What should you do with each hand?

  1. The fast hands (high frequency). These hands completed hundreds of full revolutions during the original 4K-token training, so the model has seen every conceivable configuration of them. Compressing them — making them tick slower so they only complete the same number of revolutions in 32K — would destroy precise short-range information. The model knows what a 7-token gap looks like; do not lie to it. Solution: extrapolate — keep these hands ticking at the original rate. The model will see new configurations beyond 4K, but it has learned the geometry well enough to generalise.
  2. The slow hands (low frequency). These hands completed only a fraction of a revolution during 4K-token training. The model has no idea what their configuration looks like at, say, the 90° mark, because it never observed them past their initial 30° arc. If you let them keep ticking at the original rate, by token 32,000 they will have swept into regions of angle space that have never appeared in training data — pure extrapolation, which language models are famously bad at. Solution: interpolate — slow these hands down by a factor of ss so they sweep the same small arc across the 32K context that they used to sweep across 4K. The model will still be in familiar angle space.
  3. The medium hands. Some hands completed a few revolutions during training — enough to know the basic shape of sin and cos but not enough to nail every angle. A sharp cutoff between "extrapolate" and "interpolate" would create a discontinuity at the band boundary; the model would suddenly see two adjacent dim-pairs behave according to totally different rules. Solution: smoothly blend — for hands in the transition zone, mix the extrapolation and interpolation strategies in proportion to how many rotations they observed.

That is the entire YaRN algorithm. Three bands, one ramp, one line of math per dim. Plus one extra ingredient — the attention temperature — which we will get to in a moment, because rescaling frequencies changes the typical magnitude of QKQ \cdot K^\top dot products and the softmax needs a small nudge to compensate.

Why was this not obvious in 2023? Because the community had been thinking of RoPE as a unified positional encoding rather than a bank of independent oscillators. Once you squint at it as a spectrum — many simultaneous frequencies, each with its own training history — the "treat each band by its rotation count" idea becomes the natural answer. YaRN's contribution was not a deep theorem; it was reframing.

The Math: Splitting the Spectrum

Start with the standard RoPE frequency schedule. For a head of dimension dd, dim-pair i[0,d/2)i \in [0, d/2), base frequency bb (almost always 10,000), the angular frequency is θi=b2i/d\theta_i = b^{-2i/d} and the wavelength in token positions is λi=2π/θi\lambda_i = 2\pi / \theta_i.

Given a pretraining context length LtrainL_{\text{train}} and a target scale factor s=Ltarget/Ltrains = L_{\text{target}} / L_{\text{train}}, define the per-dim rotation count

ri=Ltrain/λir_i = L_{\text{train}} / \lambda_i,

which counts how many full revolutions dim-pair ii made during pretraining. This is the scalar that classifies each dim into one of three bands.

The ramp γ(r)

YaRN introduces a linear ramp γ:R0[0,1]\gamma : \mathbb{R}_{\ge 0} \to [0, 1] parameterised by two thresholds α<β\alpha < \beta:

γ(r)={0rα(rα)/(βα)α<r<β1rβ\gamma(r) = \begin{cases} 0 & r \le \alpha \\ (r - \alpha) / (\beta - \alpha) & \alpha < r < \beta \\ 1 & r \ge \beta \end{cases}

Read this as: γ=1\gamma = 1 means "the dim is in the high-frequency band, keep its rotation rate unchanged"; γ=0\gamma = 0 means "the dim is in the low-frequency band, compress it by a factor of ss"; intermediate values blend continuously. The paper's recommended defaults are α=1\alpha = 1 (the rotation count below which a dim never completed a full revolution and is treated as "slow") and β=32\beta = 32 (the count above which the model has seen so many revolutions it can safely extrapolate). DeepSeek-V3 uses these defaults unchanged.

The blended frequency

The YaRN-rescaled angular frequency for dim-pair ii is

θiYaRN=γ(ri)θi+(1γ(ri))θis\theta_i^{\text{YaRN}} = \gamma(r_i) \cdot \theta_i + (1 - \gamma(r_i)) \cdot \frac{\theta_i}{s}.

At the band edges this collapses to the familiar special cases:

  • High-freq dim (riβr_i \ge \beta): θiYaRN=θi\theta_i^{\text{YaRN}} = \theta_i — pure extrapolation, identical to the original RoPE.
  • Low-freq dim (riαr_i \le \alpha): θiYaRN=θi/s\theta_i^{\text{YaRN}} = \theta_i / s — pure interpolation, identical to Position Interpolation on that dim alone.
  • Transition dim: a convex combination of the two strategies, weighted by the position of rir_i within [α,β][\alpha, \beta].

Notice how compact this is. The whole frequency-domain intervention is a single line of math per dim, and it degenerates to either of its predecessors at the boundary. That is exactly the kind of strict generalisation that lets you adopt YaRN as a drop-in replacement: a no-scale call (s=1s = 1) recovers vanilla RoPE bit-for-bit.

The Temperature Fix: Why Attention Scores Need a Nudge

If you stop here and run inference, you will find that perplexity at long context is better than PI and better than NTK — but still not as good as the original short-context perplexity. The residual gap puzzled the YaRN authors until they noticed something subtle about the magnitude of attention dot-products after rescaling.

RoPE represents position by rotating Q and K vectors. The post-rotation dot product qiR(tatb)kjq_i^\top R(t_a - t_b) k_j depends on the relative offset between two positions. When you stretch the wavelengths (by lowering some θi\theta_i), the typical rotation angle for a given offset gets smaller. A smaller rotation means the rotated vector is closer to its un-rotated version, which means the typical dot-product magnitude shrinks. The softmax sees lower-amplitude logits and produces a smoother, less selective attention distribution — exactly the symptom of "the model is forgetting what to focus on".

The fix is a single multiplicative scalar applied to pre-softmax attention scores:

A=1tQKdheadwith1t=0.1ln(s)+1.A = \frac{1}{\sqrt{t}} \cdot \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \quad\text{with}\quad \frac{1}{\sqrt{t}} = 0.1 \cdot \ln(s) + 1.

This formula was derived empirically — fit a one-parameter curve to the post-rescaling perplexity gap across s[2,64]s \in [2, 64] — and verified to generalise across model sizes and families. The implementation trick is even cleaner than the math: you fold the factor into the cos/sin tables once at module init, so the attention layer sees no extra multiply at runtime. The cost of YaRN inference is therefore exactly the same as vanilla RoPE inference.

At s=8s = 8: 1/t=0.1ln(8)+11.2081/\sqrt{t} = 0.1 \cdot \ln(8) + 1 \approx 1.208 — a 20.8% boost in logit magnitude. At s=16s = 16: ≈ 1.277. At s=32s = 32 (DeepSeek-V3's 32× extension from 4K to 128K): ≈ 1.347. These are small numbers with outsized effects on long-context perplexity — exactly the kind of detail you cannot afford to skip in production.

Manual Numerical Walkthrough

Click to expand: a complete YaRN step on a d=8 toy head with s=4

Take a tiny head with d=8d = 8 (so 4 dim-pairs), base b=10000b = 10000, pretraining context Ltrain=16L_{\text{train}} = 16 tokens, target Ltarget=64L_{\text{target}} = 64 (so s=4s = 4). Use the paper defaults α=1\alpha = 1, β=32\beta = 32.

Step 1: compute θᵢ. θi=100002i/8=10i\theta_i = 10000^{-2i/8} = 10^{-i} (because 100001/4=10110000^{-1/4} = 10^{-1}).

iθᵢλᵢ = 2π/θᵢrᵢ = 16/λᵢband
01.06.2832.546blend
10.162.830.2546interpolate
20.01628.30.02546interpolate
30.0016,2830.002546interpolate

Step 2: compute γ(rᵢ). Only r0=2.546r_0 = 2.546 falls in the transition band (α,β)=(1,32)(\alpha, \beta) = (1, 32); the other three are all below α\alpha.

  • γ(2.546)=(2.5461)/(321)=1.546/310.0499\gamma(2.546) = (2.546 - 1) / (32 - 1) = 1.546 / 31 \approx 0.0499
  • γ(0.2546)=0\gamma(0.2546) = 0 (below α)
  • γ(0.02546)=0\gamma(0.02546) = 0
  • γ(0.002546)=0\gamma(0.002546) = 0

Step 3: compute θᵢ' via the YaRN blend.

iγᵢγ·θᵢ(1-γ)·θᵢ/sθᵢ' (YaRN)θᵢ' (PI baseline)
00.04990.04990(0.9501)·(0.25) = 0.23750.28740.2500
10.000.02500.02500.0250
20.000.002500.002500.00250
30.000.0002500.0002500.000250

Step 4: notice what happened. For dim-pair i=0i = 0 — the fastest, the one the model has seen many full rotations of — YaRN sets θ0YaRN0.287\theta_0^{\text{YaRN}} \approx 0.287, while PI compresses it all the way to 0.250.25. That extra 15%\sim 15\% in angular speed preserves the short-range positional sharpness that PI would have destroyed. For dim-pairs 1–3 (all in the low-freq band) YaRN and PI agree exactly: divide by s=4s = 4.

Step 5: the attention temperature. 1/t=0.1ln(4)+1=0.11.3863+11.13861/\sqrt{t} = 0.1 \cdot \ln(4) + 1 = 0.1 \cdot 1.3863 + 1 \approx 1.1386. So every pre-softmax logit will be multiplied by ~1.14 — a mild sharpening of the attention distribution that compensates for the slight dot-product attenuation caused by the slowed rotations.

Step 6: the cost of all this. Four θ computations, four γ evaluations, four blends, one log. On a real head with d=128d = 128 the cost is 64 of each — call it a microsecond on a single CPU thread, zero on the GPU. The expensive part of switching to YaRN is not the math; it is the continued pretraining on long documents that follows, which we will discuss in the massive-scale section below.

Visualizing the Three Bands

The math is simple but the geometry is hard to picture without a live diagram. The widget below plots two things at once:

  • Left panel: a bar per dim-pair, height equal to log10(λi/Ltrain)\log_{10}(\lambda_i / L_{\text{train}}) — negative for fast dims (many rotations), positive for slow dims (subrotation). Bars are coloured by YaRN band: green for extrapolate, amber for blend, blue for interpolate.
  • Right panel: the ramp γ(r)\gamma(r) with the active α\alpha and β\beta marked. Every dim-pair is plotted as a dot at its (ri,γi)(r_i, \gamma_i) coordinate so you can see which band each dim landed in.

Try moving the ss slider from 2 to 32. Notice how the band counts in the legend change — at small ss, most dims stay green (extrapolate); at large ss, more dims drift blue (interpolate). The transition band stays narrow at any ss because it is gated by α,β\alpha, \beta, not by ss. The attention-temperature value at the bottom of the right panel updates live to match.

Loading YaRN frequency explorer…
Things to look for: Move dd from 16 to 128 — the spectrum gets wider and more bars enter the extrapolate band. Drop β\beta to 8 — the transition zone widens dramatically and more dims get partial compression. Push ss to 32 (DeepSeek-V3 territory) — the slowest dims' θ drops by a factor of 32, meaning a token at position 100,000 looks rotationally similar to a token at position 3,125 in the original model. That is exactly the trade-off YaRN is making: spatial compression of slow rotations, preservation of fast rotations.

Plain Python: YaRN From Scratch

Before we touch PyTorch, here is the algorithm in pure Python. No tensors, no broadcasting tricks — just a loop over dim-pairs and the four equations from the math section. If you understand this snippet, you understand YaRN.

YaRN scaling, plain Python
🐍yarn_plain.py
30YarnParams — the four numbers that define a YaRN run

YaRN is parameterised by exactly four numbers per head: the head dimension d, the RoPE base (10000 for almost every modern LLM), the pretraining context L_train (4096 for Llama-2, 32768 for DeepSeek-V2), and the desired scale factor s = L_target / L_train. The thresholds α and β bound the transition region; the paper's defaults α = 1, β = 32 work for every model the authors tested and are what DeepSeek-V3 also uses.

EXECUTION STATE
d = 128 (Llama head dim)
base = 10000.0 (modern RoPE convention)
L_train = 4096 (pretraining context)
s = 8.0 (so target context = 32,768 tokens)
alpha = 1.0 (low-freq cutoff: dims with <1 rotation during pretrain → compress)
beta = 32.0 (high-freq cutoff: dims with >32 rotations → keep as-is)
39γ(r) — the ramp function

This is the only piece of YaRN that is genuinely a free parameter. The ramp γ is 0 in the low-frequency region, 1 in the high-frequency region, and linear in between. γ = 1 means 'keep θᵢ unchanged' (extrapolate); γ = 0 means 'shrink θᵢ to θᵢ/s' (interpolate). The discontinuity-free transition is what fixes the band-edge artifacts that plagued earlier methods like NTK-by-parts.

EXECUTION STATE
r = L_train / λᵢ — how many full rotations dim-pair i made during pretrain
return = γ ∈ [0, 1] — interpolation weight (0 = compress, 1 = keep)
45yarn_thetas — the meat of the algorithm

We loop over every dimension-pair i ∈ [0, d/2). For each, we compute the original RoPE θᵢ, its wavelength λᵢ, the rotation count rᵢ during pretrain, the ramp γ, and finally the YaRN-blended θᵢ'. The whole thing is O(d/2) and runs once per checkpoint — cheaper than a single softmax of a single forward pass. The expensive part of context extension is the continued training that follows, not this calculation.

48Compute the original RoPE frequency for dim-pair i

This is the standard RoPE definition: dimension-pairs get geometrically spaced frequencies on a log scale. Pair 0 is the fastest (θ₀ = 1, period 2π ≈ 6.3 positions); the last pair is the slowest (θ ≈ 10⁻³ to 10⁻⁴ for d = 128, meaning wavelengths of tens of thousands of positions). The geometric spacing means every doubling of i halves the frequency.

LOOP TRACE · 3 iterations
i = 0 (fastest)
theta = 10000^(-0/128) = 1.0
lambda = 2π ≈ 6.28 positions
r = 4096 / 6.28 ≈ 651.9 rotations
i = 32 (mid)
theta = 10000^(-64/128) = 0.01
lambda = 2π / 0.01 ≈ 628 positions
r = 4096 / 628 ≈ 6.52 rotations
i = 63 (slowest)
theta = 10000^(-126/128) ≈ 1.19·10⁻⁴
lambda = ≈ 52,738 positions
r = 4096 / 52738 ≈ 0.078 rotations — saw < 1/10 of a rotation!
49Wavelength in token positions

λᵢ = 2π / θᵢ is the position distance over which dim-pair i completes one full sin/cos rotation. This is the most important quantity in the whole section. A small λᵢ (relative to L_train) means the model saw many full rotations during pretraining and 'knows' the curve well; a large λᵢ (relative to L_train) means the model saw only a partial arc and is essentially guessing the rest.

EXECUTION STATE
lam = wavelength of dim-pair i, in token positions
50Rotation count rᵢ — the YaRN classifier

r = L_train / λ tells you, for dim-pair i, how many full rotations the model observed during pretraining. This is the single number that decides which YaRN band the dim falls into. With L_train = 4096 and the Llama-style frequency schedule, the dims naturally split: pairs 0–10 have r > 100 (high-freq), pairs 50–63 have r < 1 (low-freq), and the in-between dims sit in the transition band.

EXECUTION STATE
r = 0.078 for i=63 (slow), 651.9 for i=0 (fast)
52The YaRN blend — extrapolate, interpolate, or mix

This single line is YaRN. γ chooses how much of the original θ to keep vs how much to shrink. When γ = 1 (high-freq), θ' = θ — we extrapolate, because the model knows these rotations well. When γ = 0 (low-freq), θ' = θ/s — we interpolate, because the model only saw the linear part of these slow rotations anyway. In between, the blend is smooth. Compare this to plain Position Interpolation, which sets θ' = θ/s for ALL i — destroying high-frequency information that the model actually relies on for short-range attention.

LOOP TRACE · 3 iterations
i = 0 (extrapolate, γ = 1)
γ = 1.0 (r = 651.9 ≥ β = 32)
theta_y = 1.0 * 1.0 + 0.0 * (1.0/8) = 1.0 — unchanged
i = 22 (blend, γ ≈ 0.78)
r = ≈ 25.1 (in transition band)
γ = (25.1 - 1) / (32 - 1) ≈ 0.78
theta_y = 0.78·θ + 0.22·(θ/8) — mostly preserved
i = 63 (interpolate, γ = 0)
γ = 0.0 (r = 0.078 ≤ α = 1)
theta_y = 0.0·θ + 1.0·(θ/8) = θ/8 — compressed
57The famous attention temperature

This empirical formula is the second half of YaRN. The authors derived it from a small ablation: at s = 8, multiplying pre-softmax attention logits by 1.208 recovers most of the lost perplexity. The rule generalises: 1/√t = 0.1·ln(s) + 1 works across s ∈ [2, 64] for every model in the paper. Implementation: either fold it into the RoPE rotation magnitudes (scaling cos/sin amplitudes by 1/√t), or multiply Q and K by √(1/√t) before the dot-product. Both are mathematically equivalent and cost zero extra FLOPs.

EXECUTION STATE
s = 8.0
1/√t = 0.1·ln(8) + 1 = 1.2079
63Try it on a Llama-style head

Running this on d = 128, s = 8 prints θ₀ = 1.0 (unchanged — extrapolate band) and the last θ ≈ 1.5·10⁻⁵ (was 1.19·10⁻⁴, now divided by 8 — interpolate band). Most of the spectrum in between is partially preserved by the ramp. The reader can verify the attention-temperature factor matches the value our interactive explorer prints in the right panel.

EXECUTION STATE
thetas[0] = 1.0 — fastest dim untouched
thetas[-1] = ≈ 1.49·10⁻⁵ — slowest dim divided by 8
attention_temperature(8) = ≈ 1.2079
60 lines without explanation
1"""
2A from-scratch YaRN rescaler for RoPE frequencies.
3
4We compute, for every dimension-pair i of a transformer head:
5
6  theta_i  = base ** (-2*i / d)         # original RoPE angular freq
7  lambda_i = 2*pi / theta_i              # wavelength, in token positions
8  r_i      = L_train / lambda_i          # # of full rotations during pretrain
9
10YaRN classifies i by r_i:
11
12  r_i >= beta   → 'extrapolate':  keep theta_i  (high-frequency dim, sharp)
13  r_i <= alpha  → 'interpolate':  use theta_i / s (low-frequency dim, slow)
14  alpha < r_i < beta → smoothly blend via ramp gamma(r_i) in [0, 1]
15
16The blended frequency is
17
18  theta_i_yarn = gamma(r_i) * theta_i + (1 - gamma(r_i)) * (theta_i / s)
19
20And YaRN also rescales attention scores by an extra temperature factor
21
22  1 / sqrt(t) = 0.1 * ln(s) + 1
23
24so that the typical dot-product magnitude after rescaling matches what
25the model saw during pretraining.
26"""
27from __future__ import annotations
28import math
29from dataclasses import dataclass
30
31@dataclass
32class YarnParams:
33    d: int                 # head dimension (must be even)
34    base: float = 10000.0  # RoPE base frequency
35    L_train: int = 4096    # original context length the model was pretrained on
36    s: float = 8.0         # target context = L_train * s
37    alpha: float = 1.0     # low-frequency threshold (rotations during pretrain)
38    beta: float = 32.0     # high-frequency threshold
39
40def gamma(r: float, alpha: float, beta: float) -> float:
41    """Linear ramp from 0 (interpolate) at r=alpha to 1 (extrapolate) at r=beta."""
42    if r <= alpha: return 0.0
43    if r >= beta:  return 1.0
44    return (r - alpha) / (beta - alpha)
45
46def yarn_thetas(p: YarnParams) -> list[float]:
47    """Return per-dim-pair angular frequencies after YaRN rescaling."""
48    thetas = []
49    for i in range(p.d // 2):
50        theta   = p.base ** (-2.0 * i / p.d)
51        lam     = 2.0 * math.pi / theta
52        r       = p.L_train / lam
53        g       = gamma(r, p.alpha, p.beta)
54        # Blend extrapolation (theta) with interpolation (theta / s).
55        theta_y = g * theta + (1.0 - g) * (theta / p.s)
56        thetas.append(theta_y)
57    return thetas
58
59def attention_temperature(s: float) -> float:
60    """The 'magic' YaRN temperature: 1/sqrt(t) = 0.1*ln(s) + 1.
61       Applied as a multiplicative factor on pre-softmax logits."""
62    return 0.1 * math.log(s) + 1.0
63
64# Example: a Llama-style head with d=128, extending 4K → 32K context.
65p = YarnParams(d=128, base=10000.0, L_train=4096, s=8.0, alpha=1.0, beta=32.0)
66thetas = yarn_thetas(p)
67print(f"first theta (fastest dim) = {thetas[0]:.4f}")
68print(f"last  theta (slowest dim) = {thetas[-1]:.6f}")
69print(f"attention 1/sqrt(t)       = {attention_temperature(p.s):.4f}")

Run this and you get back a list of 64 new RoPE frequencies, plus a single attention temperature scalar. That is the entire output of the YaRN procedure. Every downstream step — building the cos/sin tables, rotating Q and K, computing attention — is identical to vanilla RoPE.

PyTorch: A Drop-In RoPE Replacement

The plain-Python version is for understanding; the PyTorch version below is what actually ships. It is structured as a single nn.Modulenn.Module that you can paste into any Llama/Mistral/Qwen/DeepSeek codebase to replace the existing RotaryEmbeddingRotaryEmbedding class — no attention kernel changes required.

YaRN as a drop-in rotary embedding
🐍yarn_rotary.py
9Module signature — what the caller has to supply

Four required inputs (head_dim, base, L_train, yarn_scale) and three tunables (α, β, max_seq_len). In every modern LLM trainer this module replaces the existing RotaryEmbedding line for line — you do not need to touch the attention kernel itself. The yarn_scale=1.0 default is a no-op so you can ship this module as a strict superset of vanilla RoPE.

EXECUTION STATE
head_dim = e.g. 128 for Llama-2 7B, 256 for DeepSeek-V3
L_train = the context length of the BASE checkpoint (4096 for Llama-2)
yarn_scale = L_target / L_train (8 to go 4K → 32K)
max_seq_len = if set, precompute cos/sin tables up to this length
22Original RoPE frequencies as a tensor

We allocate an FP32 tensor of shape [head_dim/2] holding the geometric frequency schedule. Keeping these in FP32 is critical for numerical stability — the slow dims have θ ≈ 10⁻⁴ and the rotation counts at long context can overflow if cached in BF16. Note we only need θ at the dim-PAIR granularity; the same θ is reused for the 'sin half' and the 'cos half' of every pair.

EXECUTION STATE
i = tensor([0, 2, 4, ..., head_dim-2]) — even indices only
theta = shape [head_dim/2], FP32, range [10⁻⁴, 1.0] for head_dim=128
27Compute wavelengths and rotation counts vectorised

We compute λ and r per dim-pair in two broadcasted ops. r is the quantity the ramp γ acts on. For Llama-2 with L_train = 4096, r ranges from ~652 (pair 0) down to ~0.08 (pair 63) — three full orders of magnitude. That spread is exactly why YaRN's per-dim treatment matters: a single global scaling factor cannot do the right thing for both ends of this range.

EXECUTION STATE
lam = shape [head_dim/2], wavelengths in position units
r = shape [head_dim/2], rotations during pretrain
31Ramp γ as a clamped linear function — five characters of PyTorch

torch.clamp gives us the piecewise-linear ramp without any branching. For dims where r ≤ α the clamp returns 0; where r ≥ β it returns 1; in between it gives the linear value. This vectorisation matters when you have 64 dim-pairs and need to recompute γ at every config-change in a notebook — it stays a single GPU op rather than a Python loop.

EXECUTION STATE
gamma = shape [head_dim/2], values in [0, 1]
34The YaRN θ blend — one line that defines the algorithm

Same equation as the plain-Python version, but vectorised: an element-wise blend of θ (extrapolate) and θ/s (interpolate) weighted by γ. The result, inv_freq, is exactly the tensor that the standard RoPE module expects — which is why YaRN is a drop-in: from here on, every downstream op (the outer product with positions, the cos/sin cache, the q/k rotation in attention) is identical to vanilla RoPE.

EXECUTION STATE
inv_freq = shape [head_dim/2], the YaRN-rescaled angular frequencies
37Attention temperature, computed once at module init

We compute the 1/√t factor exactly once and store it as a Python float. The max(yarn_scale, 1.0) guard makes the no-op case (yarn_scale = 1) numerically clean: log(1) = 0 so attn_scale = 1. For yarn_scale = 8 we get attn_scale ≈ 1.2079, for yarn_scale = 16 we get ≈ 1.277. These values are tiny but matter: without them, perplexity at 32K context degrades by 5–15% on every model the authors tested.

EXECUTION STATE
self.attn_scale = ≈ 1.2079 for yarn_scale=8
41_build_cache — the standard RoPE cos/sin table

For T positions and d/2 frequencies, we build a [T, d/2] table of angles via an outer product, then apply cos and sin. This is the exact computation every RoPE implementation does. The only YaRN-specific touch is folding attn_scale into the table here — by baking 1/√t into cos and sin, we save one element-wise multiply per attention call without changing any downstream code.

EXECUTION STATE
freqs =
shape [T, head_dim/2] — angles for every (position, dim-pair) cell
cos_cached =
freqs.cos() * 1.2079 — folded YaRN temperature
sin_cached =
freqs.sin() * 1.2079 — same trick
51Forward — the same rotation every RoPE module does

The forward pass is identical to a standard RoPE module: split q and k into adjacent pairs (q1, q2), pull T rows of the cos/sin cache, and apply the 2-D rotation matrix per pair. The whole YaRN treatment is invisible here — it lives entirely in how inv_freq and cos_cached were initialised. That is why the cost of YaRN at inference is exactly zero compared to vanilla RoPE.

EXECUTION STATE
q = shape [batch, heads, T, head_dim], BF16 typically
cos =
shape [T, head_dim/2], lookup into the precomputed cache
58The 2-D rotation, element-wise

Each pair (q1, q2) is rotated by angle θᵢ·t, where t is the absolute position and θᵢ is the YaRN-rescaled frequency. We stack the rotated outputs into the last dim and flatten back to [..., head_dim]. The operation costs 4 multiplies and 2 adds per element — at scale this is ~1% of the attention layer's wallclock and negligible compared to the GEMMs. The shape-preserving behaviour is what lets YaRN compose with any attention kernel: FlashAttention, xFormers, vLLM's PagedAttention, all just see normal rotated q/k tensors.

EXECUTION STATE
q_rot = shape [batch, heads, T, head_dim] — same shape, RoPE-rotated
61 lines without explanation
1"""
2A PyTorch drop-in replacement for the standard RoPE module, with YaRN
3scaling baked in. Compatible with Llama, Mistral, Qwen, and DeepSeek-V2/V3
4attention layers — pass yarn_scale=8.0 to extend a 4K-context checkpoint
5to 32K context with a few thousand steps of continued pretraining.
6"""
7import math
8import torch
9from torch import nn, Tensor
10
11class YarnRotaryEmbedding(nn.Module):
12    def __init__(
13        self,
14        head_dim: int,
15        base: float = 10000.0,
16        L_train: int = 4096,
17        yarn_scale: float = 1.0,
18        alpha: float = 1.0,
19        beta: float = 32.0,
20        max_seq_len: int | None = None,
21    ):
22        super().__init__()
23        assert head_dim % 2 == 0
24        self.head_dim = head_dim
25        self.yarn_scale = yarn_scale
26        # 1. Original RoPE angular frequencies, one per dim-pair.
27        i = torch.arange(0, head_dim, 2, dtype=torch.float32)
28        theta = base ** (-i / head_dim)                       # [head_dim/2]
29
30        # 2. Per-pair rotation count during pretraining.
31        lam = 2 * math.pi / theta                             # wavelength
32        r   = L_train / lam                                   # rotations
33
34        # 3. Ramp γ(r) — torch.clamp gives the piecewise-linear shape.
35        gamma = ((r - alpha) / (beta - alpha)).clamp(0.0, 1.0)
36
37        # 4. YaRN blend: keep θ for fast dims, shrink θ/s for slow ones.
38        inv_freq = gamma * theta + (1.0 - gamma) * (theta / yarn_scale)
39        self.register_buffer("inv_freq", inv_freq, persistent=False)
40
41        # 5. Attention temperature factor — 1 if no scaling, > 1 otherwise.
42        self.attn_scale = 0.1 * math.log(max(yarn_scale, 1.0)) + 1.0
43
44        # 6. Precompute cos/sin for the target context if requested.
45        if max_seq_len is not None:
46            self._build_cache(max_seq_len)
47
48    def _build_cache(self, T: int) -> None:
49        t = torch.arange(T, dtype=torch.float32)              # [T]
50        freqs = torch.outer(t, self.inv_freq)                 # [T, head_dim/2]
51        # Folded by 1/√t so we do not need a second multiply in attention.
52        cos = freqs.cos() * self.attn_scale
53        sin = freqs.sin() * self.attn_scale
54        self.register_buffer("cos_cached", cos, persistent=False)
55        self.register_buffer("sin_cached", sin, persistent=False)
56
57    def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
58        """q, k: [batch, n_heads, T, head_dim].
59        Returns the same shapes with RoPE applied and the YaRN
60        temperature folded into the cos/sin tables.
61        """
62        T = q.shape[-2]
63        cos = self.cos_cached[:T].to(q.dtype)                 # [T, head_dim/2]
64        sin = self.sin_cached[:T].to(q.dtype)
65        # Split each pair: x = (x1, x2) → rotate as (x1·cos - x2·sin, x1·sin + x2·cos).
66        q1, q2 = q[..., 0::2], q[..., 1::2]
67        k1, k2 = k[..., 0::2], k[..., 1::2]
68        q_rot = torch.stack((q1 * cos - q2 * sin, q1 * sin + q2 * cos), dim=-1)
69        k_rot = torch.stack((k1 * cos - k2 * sin, k1 * sin + k2 * cos), dim=-1)
70        return q_rot.flatten(-2), k_rot.flatten(-2)

The two key implementation choices to notice: (1) we keep θ\theta and the rotation count rr in FP32, because the slowest dims carry values around 10410^{-4} that would round catastrophically in BF16; (2) we fold the YaRN attention temperature 1/t1/\sqrt{t} directly into the precomputed cos/sin tables, which means the attention kernel does not need to know YaRN exists. FlashAttention, xFormers, PagedAttention, all just see standard rotated Q/K tensors and run unmodified.

At Massive Scale: DeepSeek-V3 from 4K to 128K

Here is how YaRN actually plays out on a 671 B-parameter MoE being trained for a 128K-token context. DeepSeek-V3's published recipe (paper Section 4.3) is a two-phase context extension after the main pretraining run:

PhaseContextScale sTokens trainedNotes
Pretraining4,0961 (no YaRN)14.8 TStandard RoPE on 14.8 T tokens
Extension Phase 132,7688.0~10 BYaRN turned on; α = 1, β = 32
Extension Phase 2131,07232.0~10 BSame YaRN params, bigger s

Three things to notice about this recipe. First, the extension tokens are about 10310^{-3} of the pretraining tokens — YaRN is cheap because the underlying frequency change is small and the model only needs to adapt to a new dot-product distribution, not relearn language. Second, both extension phases use the same α,β\alpha, \beta — only ss changes. Third, going from 32K to 128K (another 4× scale) requires another ~10 B tokens, not 4× more, because each subsequent stretch is mostly re-fitting the attention temperature rather than learning new long-range patterns.

On a 2,048-GPU H800 cluster, each phase is roughly 36 hours of wallclock. That means a base 4K-context model becomes a production 128K-context model in about three days of training, on top of two months of pretraining. YaRN is in the rare category of techniques where the implementation cost and the training cost are both negligible compared to their benefit — most of what you are paying for is the 128K-context data pipeline, not the algorithmic change.

What changes at scale that did not appear in the toy example

  1. FP32 cos/sin tables become a memory line item. At 128K context and dhead=128d_{\text{head}} = 128, the cos/sin cache is 131072644131072 \cdot 64 \cdot 4 bytes ≈ 32 MB. Per head per device. For DeepSeek-V3 with 128 heads per layer and 61 layers, that is 250 GB of cache if you built it naively. The fix: share one cos/sin cache across heads (legal because every head uses the same RoPE frequency schedule) and reuse across layers. The actual memory cost drops to 32 MB — trivial.
  2. The KV cache grows linearly with context. A single 128K-token sequence holds 13107212864261131072 \cdot 128 \cdot 64 \cdot 2 \cdot 61 bytes (FP8 K and V, 64 KV heads after Multi-Head Latent Attention compression, 61 layers) ≈ ~30 GB per sequence. This dwarfs the cos/sin cost but is independent of YaRN — it is the cost of having long context at all. YaRN enables you to use that KV cache by making the model attend correctly at long range; the KV cache itself is what the next chapter is about.
  3. Long-context training data is the real bottleneck. Pretraining corpora are heavy on web pages and dialogues, both of which are short. To fine-tune YaRN to 128K context you need a curated stream of long documents — books, technical manuals, code repos, multi-document QA pairs — which is expensive to collect at scale. DeepSeek-V3's extension phases use a heavy upsample of long-document data; ~70% of extension tokens are from sequences longer than 16K.
  4. Attention quality degrades super-linearly with ss. At s=8s = 8 the model recovers near-pretraining perplexity after ~10 B tokens; at s=32s = 32, the same 10 B tokens leave a small but measurable gap. The cure is more tokens, but the law looks like tokens neededs1.5\text{tokens needed} \propto s^{1.5}, which means each doubling of context costs ~2.8× more extension training. This is the dominant practical limit on how far you can push a single YaRN run.

Engineering Reality: What YaRN Costs You

YaRN is conceptually clean, but every production team that ships it discovers the same handful of gotchas. Here is the short list before you bring up your own long-context extension run.

Gotcha 1: tokenizer-specific rotation counts. The rotation count ri=Ltrain/λir_i = L_{\text{train}} / \lambda_i is measured in tokens, not characters or bytes. If your new long-context dataset has a very different token-per-byte ratio (e.g. heavy code vs heavy English prose), the effectiveLtrainL_{\text{train}} in semantic units shifts. Most teams handle this by setting LtrainL_{\text{train}} from the pretraining config rather than recomputing — which is fine, as long as the new corpus shares a tokenizer with the pretraining corpus.
Gotcha 2: α and β are not as universal as the paper claims. α=1,β=32\alpha = 1, \beta = 32 work for Llama-2, Mistral, and Qwen because those models share adhead=128d_{\text{head}} = 128 and base 10000. DeepSeek-V3 with dhead=128d_{\text{head}} = 128 and MLA-compressed Q/K behaves differently in the transition band; the DeepSeek-V3 paper notes a manual sweep that landed on the defaults but foundβ=28\beta = 28 slightly better for s16s \ge 16. Always sweep if you change head dimension.
Gotcha 3: the temperature factor must apply during training too. A surprisingly common bug is to add the 1/t1/\sqrt{t} factor at inference but forget it during the YaRN extension fine-tune. The model then learns weights that compensate for the wrong logit scale, and you get worse perplexity than with no temperature fix at all. Bake the factor into the cos/sin cache at module init and you cannot mismatch it later.
Gotcha 4: FlashAttention numerics at extreme s. At s=32s = 32 the slowest dim's rotation rate drops to ~3·10⁻⁶. If you store the cos/sin cache in BF16, the slow dims' angles round to zero past position ~30,000 and the model effectively loses all long-range positional information. The fix is to keep the cache in FP32 and only downcast at the multiply against Q/K, which costs an extra memory load but rescues numerical stability. DeepSeek-V3's production kernels do exactly this.

Why ship YaRN over the alternatives

MethodTrain tokens for 8× ext.PPL gap vs baseImplementation
Position Interpolation (PI)~50 B+2.1%1 line
NTK-aware scaling~0 (zero-shot ok)+0.8% at s=4, +6% at s=161 line
ALiBi (alternative encoding)Full retrain+0.4%different attention
YaRN~10 B+0.3% at s=81 module, 50 lines
LongRoPE (DeepSeek-V3+)~5 B+0.2% at s=32YaRN + per-dim ramp search

The table makes the choice obvious: YaRN gives you the best-quality long-context model for the smallest extension budget, with an implementation footprint of one module. Even LongRoPE, the most recent improvement, is just YaRN with a learned γ\gamma ramp instead of the linear default — same algorithm, more tuning. Every production long-context LLM shipped since late 2023 — DeepSeek-V2 (128K), DeepSeek-V3 (128K), Qwen2-72B (32K), Llama-3.1 (128K via a related approach), Yi-200K, Mistral-Large — is sitting on top of either YaRN or a small variant of it.

The takeaway: YaRN is the cheap-and-correct answer to long context. It costs three days of training, ten billion tokens of long-document data, and a fifty-line module change. In return you get a 32× context window with sub-1% perplexity degradation. The reason every modern frontier model uses it (or a near-variant) is not that there is no alternative — there are several — but that none of them are this cheap, this clean, or this empirically robust across model scales.
Loading comments...