Chapter 12
20 min read
Section 69 of 117

NTK-Aware Scaling

Long-Context Extension

The Real Problem: PI Crushes Local Resolution

Pretraining a 70B model on a 32K context costs the same as pretraining it on a 4K context — about eight times more. The attention matrix grows as O(L2)O(L^2), the activation memory grows linearly, and the rare extremely-long documents that actually teach the model long-range structure are the bottleneck. Nobody wants to redo this. So almost every modern LLM is trained on a short context and extended afterwards — Llama 2 went 4K → 32K, Llama 3 went 8K → 128K, DeepSeek-V3 went 4K → 128K.

The obvious extension method — Position Interpolation (PI) by Chen et al. (2023) — is breathtakingly simple. RoPE turns position mm into a rotation by angle mθim \theta_i on each dimension pair ii. To extend from LtrainL_{\text{train}} to Lnew=sLtrainL_{\text{new}} = s \cdot L_{\text{train}}, PI says: feed m/sm/s in instead of mm. The model now sees the same angles it always saw, just at finer spacing. Reasonable, right?

Wrong — and the reason is the only thing that matters in this section. RoPE encodes position across a spectrum of frequencies. The first dimension pair rotates almost a full revolution between adjacent tokens (θ0=1\theta_0 = 1). The last dimension pair barely moves over the entire trained context (θd/21104\theta_{d/2-1} \approx 10^{-4}). These two ends of the spectrum do different jobs. High-frequency pairs distinguish neighbours — they let attention say "this token is the next token, not the one five tokens away." Low-frequency pairs encode long-range position — they let attention say "this token came from paragraph 3, not paragraph 17."

What PI does wrong

PI divides every frequency by ss. The slow, long-range pairs were the ones that actually needed compressing — they had room. The fast, local pairs get crushed for no reason: a rotation that used to take one token now takes ss tokens. Adjacent-token resolution collapses, and downstream perplexity blows up on tasks that depend on knowing exactly which token is which — code, math, fact retrieval at long range.

NTK-Aware Scaling, proposed in mid-2023 by the pseudonymous Reddit user bloc97 and rapidly adopted across the LLaMA ecosystem, fixes this with a one-line change. Same goal as PI. Same zero training cost. But it interpolates only the dimensions that needed it.

Intuition: Position Is a Frequency Spectrum

Imagine a piano with d/2d/2 keys, one per dimension pair. Striking key ii at position mm rings a single tone at frequency θi\theta_i; the chord of all keys at position mm is the position embedding. The leftmost keys are the tweeters — they vibrate fast and distinguish tokens that are just one or two apart. The rightmost keys are the woofers — they vibrate slowly and distinguish chapters from chapters.

Now you record the piano at a sample rate that covers exactly LtrainL_{\text{train}} seconds (the trained context). You want to play it back for Lnew=sLtrainL_{\text{new}} = s \cdot L_{\text{train}} seconds without re-recording. You have two options:

  1. Slow the whole tape by ss (PI). The bass survives, but every tweeter note is now an indistinct hum — you can no longer hear adjacent-token detail. The high frequencies have been driven below the model's ability to discriminate them.
  2. Keep the tweeters at their original speed; slow only the woofers (NTK-Aware). The low frequencies were the only ones that ever needed to be stretched, because they were the only ones whose period was longer than LtrainL_{\text{train}} in the first place. The tweeters were already cycling through full revolutions every few tokens — there is nothing to stretch.
The deep insight is that only the lowest frequencies are actually under-trained at long range. The highest frequencies already cycled through LtrainL_{\text{train}} revolutions during training — they saw every phase, so extrapolating them past LtrainL_{\text{train}} sees nothing new modulo 2π2\pi. The lowest frequencies cycled through less than one revolution across the entire trained context — any new position past LtrainL_{\text{train}} is in genuinely unseen angular territory for them. So compress the woofers, leave the tweeters alone. That single sentence is NTK-Aware Scaling.

The name "NTK-Aware" comes from a Neural Tangent Kernel argument: in the NTK limit, the contribution of each frequency component to the kernel decays geometrically with frequency, so the high-frequency components are the ones that carry fine-grained positional information through the network and must not be distorted. The piano analogy is the engineer's version of the same claim.

The Mathematical Idea

RoPE assigns dimension-pair i[0,d/2)i \in [0, d/2) the angular frequency

θi=base2i/d\theta_i = \text{base}^{-2i/d} with base=10000\text{base} = 10000.

At position mm, the rotation angle is mθim \theta_i. The highest frequency is θ0=1\theta_0 = 1; the lowest is θd/21=base(d2)/d\theta_{d/2-1} = \text{base}^{-(d-2)/d} which is tiny (≈ 1.161041.16 \cdot 10^{-4} for d=64d=64).

Position Interpolation (PI)

Replace mm with m/sm / s. Equivalently, the new effective frequencies are

θiPI=θi/s\theta_i^{\text{PI}} = \theta_i / s for all ii.

NTK-Aware

Keep the position mm integer-valued. Replace the base with a new constant base\text{base}^{\prime} chosen so that the lowest frequency gets exactly the same compression as PI:

(base)2(d/21)/d=base2(d/21)/d1s(\text{base}^{\prime})^{-2(d/2-1)/d} = \text{base}^{-2(d/2-1)/d} \cdot \frac{1}{s}

Solve for base\text{base}^{\prime}:

base=basesd/(d2)\text{base}^{\prime} = \text{base} \cdot s^{\,d/(d-2)}

The new frequencies are then

θiNTK=(base)2i/d\theta_i^{\text{NTK}} = (\text{base}^{\prime})^{-2i/d}.

Look at the two endpoints. For i=0i = 0: θ0NTK=1\theta_0^{\text{NTK}} = 1 (unchanged from original). For i=d/21i = d/2 - 1: θd/21NTK=θd/21/s\theta_{d/2-1}^{\text{NTK}} = \theta_{d/2-1} / s (matches PI). In between, the ratio θiNTK/θi\theta_i^{\text{NTK}} / \theta_i smoothly interpolates from 1 (no compression at the high end) down to 1/s1/s (full PI compression at the low end). The in-between bowed curve is what saves local resolution.

Why this exponent? The denominator d2d-2 appears because RoPE's lowest pair is at index d/21d/2-1, so the exponent on base\text{base} for that pair is 2(d/21)/d=(d2)/d-2(d/2-1)/d = -(d-2)/d. Inverting that exponent to solve for base\text{base}^{\prime} gives sd/(d2)s^{d/(d-2)}. For dd \to \infty the exponent approaches s1=ss^1 = s; for d=64d = 64 it is s1.032s^{1.032} — barely larger than ss.

Manual Numerical Walkthrough

One worked example with d=8d = 8, s=4s = 4, Ltrain=1024L_{\text{train}} = 1024, Lnew=4096L_{\text{new}} = 4096. Small enough to do every step by hand; structurally identical to the real d=128d = 128 case.

Click to expand: full numerical comparison of Original vs PI vs NTK-Aware

Step 1 — original frequencies. Four pairs (i=0,1,2,3i = 0, 1, 2, 3): θi=100002i/8=10000i/4\theta_i = 10000^{-2i/8} = 10000^{-i/4}.

θ0=1.0\theta_0 = 1.0, θ1=100000.25=0.1\theta_1 = 10000^{-0.25} = 0.1, θ2=100000.5=0.01\theta_2 = 10000^{-0.5} = 0.01, θ3=100000.75=0.001\theta_3 = 10000^{-0.75} = 0.001.

Step 2 — angles seen during training. Max position m=1024m = 1024: mθ0=1024m\theta_0 = 1024 rad (≈ 163 revolutions), mθ1=102.4m\theta_1 = 102.4 rad (≈ 16 revs), mθ2=10.24m\theta_2 = 10.24 rad (≈ 1.6 revs), mθ3=1.024m\theta_3 = 1.024 rad (≈ 16% of one rev). Only the slowest pair didn't even finish one full revolution — that's the one in trouble at extension time.

Step 3 — PI frequencies. Divide every θ\theta by s=4s = 4: θ0PI=0.25\theta_0^{\text{PI}} = 0.25, θ1PI=0.025\theta_1^{\text{PI}} = 0.025, θ2PI=0.0025\theta_2^{\text{PI}} = 0.0025, θ3PI=0.00025\theta_3^{\text{PI}} = 0.00025.

Step 4 — NTK-Aware base. base=1000048/6=1000041.333\text{base}^{\prime} = 10000 \cdot 4^{8/6} = 10000 \cdot 4^{1.333\ldots}. Compute 41.333=e1.333ln4=e1.3331.386=e1.8486.354^{1.333} = e^{1.333 \ln 4} = e^{1.333 \cdot 1.386} = e^{1.848} \approx 6.35. So base63,496\text{base}^{\prime} \approx 63{,}496.

Step 5 — NTK frequencies. θiNTK=63496i/4\theta_i^{\text{NTK}} = 63496^{-i/4}:

θ0NTK=1.0\theta_0^{\text{NTK}} = 1.0 (unchanged — preserved), θ1NTK=634960.250.0631\theta_1^{\text{NTK}} = 63496^{-0.25} \approx 0.0631, θ2NTK=634960.50.00397\theta_2^{\text{NTK}} = 63496^{-0.5} \approx 0.00397, θ3NTK=634960.750.000251θ3/4=0.00025\theta_3^{\text{NTK}} = 63496^{-0.75} \approx 0.000251 \approx \theta_3 / 4 = 0.00025 (matches PI, as designed).

Step 6 — angle at the new boundary m=4096m = 4096.

pair ioriginal m·θPI m·θNTK m·θwhat model saw at L_train
0 (tweeter)4096 rad1024 rad4096 rad1024 rad
1409.6 rad102.4 rad258.5 rad102.4 rad
240.96 rad10.24 rad16.27 rad10.24 rad
3 (woofer)4.096 rad1.024 rad1.024 rad1.024 rad

Reading the table.

  • Pair 3 (woofer). Original would push it to 4.096 rad — past the 1.024 rad the model has ever been calibrated on. PI brings it back to exactly 1.024 — safe. NTK brings it to exactly 1.024 — same as PI, by construction. Both are correct here.
  • Pair 0 (tweeter). Original would push it to 4096 rad. PI brings it down to 1024 rad. NTK leaves it at 4096 rad. Is NTK wrong? No — at θ_0 = 1.0, the model has already cycled through 1024/2π1631024 / 2\pi \approx 163 full revolutions in training. Past 1024 there are no "new angles" — every angle modulo 2π2\pi has been seen many times. So extrapolation in the tweeter is free.
  • Pairs 1 and 2 (mids). NTK's 258.5 and 16.27 fall between PI's 102.4/10.24 and the original's 409.6/40.96. They are out-of-distribution compared to L_train, but only mildly so — and the high cycles count means they are still in a phase-covered regime. The bow between the endpoints is exactly the soft-extrapolation pattern that preserves local resolution.

The key number. Look at adjacent-token distinguishability: how much does pair 0's angle change between position mm and m+1m+1? Original/NTK: 1.0 rad — a full radian per token, easily distinguishable. PI: 0.25 rad — four times smaller. Adjacent tokens that the model used to distinguish cleanly are now four times closer in angle space. With s = 8 or 16 the loss of resolution becomes quantitative, not just theoretical.

Visualizing the Frequency Stretch

Move the sliders. The x-axis is dimension-pair index ii from 00 (tweeters, left) to d/21d/2-1 (woofers, right). The y-axis is the rotation angle each pair sees at the new boundary position LnewL_{\text{new}}, plotted on log scale. The dashed amber line is what the model saw at training; if a curve sits above it, that pair is being asked to extrapolate.

Loading NTK-Aware frequency visualizer…

Three things are visible at every (d, s):

  • The grey original curve sits above the amber training line everywhere — that's the problem in one picture.
  • The red PI curve sits exactly on the amber line for every pair. Safe — but no headroom for the high-frequency end, which is where local information lives.
  • The cyan NTK-Aware curve touches the amber line at the woofer end and rises smoothly toward the original at the tweeter end. The bowed shape is the entire point.
Crank ss up to 16 or 32 and notice how the NTK curve still keeps its highest-frequency pair almost untouched. That is why NTK-Aware works at extreme extension ratios where PI starts to fail catastrophically. It is also why bare NTK-Aware is not enough at extreme ratios — for that we need YaRN, in the next section.

Plain Python: Three Ways to Stretch RoPE

No tensors yet. Just three pure-Python functions that compute the per-pair frequencies for the original RoPE, for Position Interpolation, and for NTK-Aware Scaling. The whole comparison fits in twenty lines.

The three methods, side by side
🐍ntk_aware_freqs.py
1Three constants set the entire problem

BASE = 10000 is the constant Vaswani picked for sinusoidal PE and RoPE inherited. D is one head's dimension — typically 64 or 128 in modern LLMs. L_TRAIN is the maximum sequence length the model has actually seen rotation angles for; L_NEW is the larger context we want. Everything in this section is about the ratio s = L_NEW / L_TRAIN: how do we make positions in [L_TRAIN, L_NEW] safe for a model that has never seen them?

EXECUTION STATE
BASE = 10000 (RoPE convention)
D = 64 (head dim, even)
L_TRAIN = 4096
L_NEW = 32768
s = 32768 / 4096 = 8.0
8Original RoPE frequencies

RoPE splits the head dimension into d/2 pairs. Pair i rotates at angular frequency θ_i = base^(-2i/d). With base = 10000 and d = 64, pair i = 0 has θ_0 = 1.0 (it advances by 1 radian per position step — fast) and pair i = 31 has θ_31 = 10000^(-62/64) ≈ 1.1·10⁻⁴ (it barely moves over thousands of positions — slow). High-i pairs encode long-range distance; low-i pairs encode adjacent-token detail.

EXECUTION STATE
theta_0 = 10000^0 = 1.0
theta_1 = 10000^(-2/64) = 0.749
theta_31 = 10000^(-62/64) ≈ 1.16e-4
14Position Interpolation (the naive fix)

PI compresses positions by s, so position m becomes m/s before going into RoPE. Equivalently, divide every θ_i by s. This guarantees the boundary position L_NEW lands at the same rotation angle the model saw at L_TRAIN — but it does so by squashing every frequency equally. The slow, long-range pairs were fine (they had room). The fast, short-range pairs are now eight times slower, which means adjacent tokens that the model used to distinguish at one full rotation now sit a tiny fraction of a rotation apart. Local resolution dies.

EXECUTION STATE
PI theta_0 = 1.0 / 8 = 0.125
PI theta_31 = 1.16e-4 / 8 ≈ 1.45e-5
20NTK-Aware Scaling — stretch the base, not the positions

Same goal as PI (lowest-frequency pair must land inside the trained range at L_NEW) but a different lever. Instead of dividing every θ, we replace the base. The new base is chosen so that the LOWEST-frequency pair (i = d/2 - 1) gets compressed by exactly s — same as PI — but every higher-frequency pair gets compressed less and less. The highest-frequency pair (i = 0) is left almost untouched: with d = 64 the new base shifts it by less than 0.001%. That is the whole trick — preserve fast local rotation, only interpolate the slow long-range rotation.

EXECUTION STATE
new_base = 10000 · 8^(64/62) ≈ 83576
NTK theta_0 = 83576^0 = 1.0 (unchanged!)
NTK theta_31 = 83576^(-62/64) ≈ 1.45e-5 (= PI)
31What the model 'saw' at L_train vs what each method shows at L_new

At training time, the highest-freq pair rotates by L_TRAIN · θ_0 = 4096 radians at the last position. The lowest-freq pair rotates by only 4096 · 1.16e-4 ≈ 0.476 radians — less than one revolution across the entire trained context. Extension must keep every pair's L_NEW-rotation inside its L_TRAIN-rotation envelope, otherwise the model sees angles it has never been calibrated on.

EXECUTION STATE
θ_0 at L_train = 4096 rad ≈ 651.7 revolutions
θ_31 at L_train = 0.476 rad
36Reading the boundary table

PI brings every pair down to its L_TRAIN value (good). NTK keeps the highest pair at its UNEXTENDED value (4096 · 1.0 -> 32768 · 1.0 = 32768 rad, way past L_TRAIN's 4096) but the high-frequency pair was already so far past one revolution at L_train that 'staying calibrated' is meaningless — the angle is taken mod 2π and the model saw every phase. The lowest pair lands at the SAME value as PI (0.476 rad) — that is the design constraint. The bowed shape in between is what saves local resolution.

EXECUTION STATE
PI angle, lowest pair = 32768 · 1.45e-5 ≈ 0.476 rad
NTK angle, lowest pair = 32768 · 1.45e-5 ≈ 0.476 rad (matches!)
44 lines without explanation
1import math
2
3BASE = 10000          # original RoPE base frequency
4D    = 64             # head dimension (must be even)
5L_TRAIN = 4096        # context length the model was trained on
6L_NEW   = 32768       # target context length after extension
7s = L_NEW / L_TRAIN   # extension factor (8x here)
8
9# 1) Original RoPE frequencies for one head, no extension.
10def rope_freqs(d, base):
11    # pair index i = 0..d/2-1; theta_i = base^(-2i/d)
12    return [base ** (-2 * i / d) for i in range(d // 2)]
13
14# 2) Position Interpolation (PI). Compress positions: m' = m / s.
15#    Equivalently, divide every frequency by s.
16def pi_freqs(d, base, s):
17    return [theta / s for theta in rope_freqs(d, base)]
18
19# 3) NTK-Aware Scaling. Stretch the base instead of the positions.
20#    Solve: (new_base)^(-2 (d/2 - 1) / d) = base^(-2 (d/2 - 1) / d) / s
21#    => new_base = base * s^(d / (d - 2))
22def ntk_freqs(d, base, s):
23    new_base = base * (s ** (d / (d - 2)))
24    return rope_freqs(d, new_base)
25
26orig = rope_freqs(D, BASE)
27pi   = pi_freqs(D, BASE, s)
28ntk  = ntk_freqs(D, BASE, s)
29
30print(f"highest-freq pair (i=0):")
31print(f"  orig theta = {orig[0]:.6f}, PI = {pi[0]:.6f}, NTK = {ntk[0]:.6f}")
32print(f"lowest-freq pair  (i=d/2-1 = {D//2-1}):")
33print(f"  orig theta = {orig[-1]:.3e}, PI = {pi[-1]:.3e}, NTK = {ntk[-1]:.3e}")
34
35# What rotation angle does each pair see at the new boundary position?
36def max_angle(thetas, m):
37    return [m * t for t in thetas]
38
39orig_at_train = max_angle(orig, L_TRAIN)
40pi_at_new     = max_angle(pi,   L_NEW)
41ntk_at_new    = max_angle(ntk,  L_NEW)
42
43print()
44print("Boundary-position rotation angles:")
45print(f"  orig at L_train, highest pair:   {orig_at_train[0]:8.1f} rad")
46print(f"  PI   at L_new,   highest pair:   {pi_at_new[0]:8.1f} rad")
47print(f"  NTK  at L_new,   highest pair:   {ntk_at_new[0]:8.1f} rad")
48print(f"  orig at L_train, lowest pair:    {orig_at_train[-1]:.4f} rad")
49print(f"  PI   at L_new,   lowest pair:    {pi_at_new[-1]:.4f} rad")
50print(f"  NTK  at L_new,   lowest pair:    {ntk_at_new[-1]:.4f} rad")

Running this prints exactly the numbers from the walkthrough table above (scaled up to d = 64): the highest-frequency pair is left alone by NTK and crushed by 8× by PI; the lowest-frequency pair is landed on the same value by both methods.

PyTorch: Drop-In NTK-Aware RoPE

Production RoPE precomputes two tables — a cosine table and a sine table, both of shape (Lmax,d/2)(L_{\text{max}}, d/2) — and reads from them during attention. Extension means rebuilding those tables. PI changes the positions column; NTK-Aware changes the base constant. Everything else — the rotation kernel, the attention math, the KV cache — stays identical. This is why both methods are described as training-free: zero changes to weights, zero changes to the rest of the forward pass.

NTK-Aware RoPE as a one-line patch to the cos/sin table builder
🐍ntk_aware_rope.py
3Precompute the cos/sin tables once per (d, max_pos, base)

Every transformer that uses RoPE stores two tensors of shape (max_pos, d/2): the cosine and sine of m · θ_i for every position m and every dimension pair i. These tables are NOT learned — they are pure functions of (d, max_pos, base). Building them at startup costs microseconds. The entire NTK-Aware vs PI vs original distinction lives inside these few lines.

EXECUTION STATE
pair_idx = tensor([0, 1, ..., 31]), shape (32,)
inv_freq = θ_i = 1 / base^(2i/d), shape (32,)
positions = tensor([0, 1, ..., max_pos-1])
angles =
shape (max_pos, 32) — one angle per (m, i)
12rotate_half: the trick that makes RoPE a 2D rotation per pair

RoPE acts on (q_2i, q_2i+1) like a 2D rotation by angle m · θ_i. The standard rewrite is q·cos(θ) + R(q)·sin(θ) where R swaps the two coordinates of each pair and negates one. This view lets the whole batched op use only elementwise multiplies — no matmul. The cos/sin tables are repeat_interleaved so each pair (q_2i, q_2i+1) sees the SAME (cos m·θ_i, sin m·θ_i).

22Position Interpolation in code

PI does NOT change inv_freq. It changes the positions tensor — replaces torch.arange(max_pos) with torch.arange(max_pos) / s. So when we precompute angles = m/s · θ_i, the m=L_NEW row of the table equals the m=L_TRAIN row of the original table. The model sees no new angles; it just sees them spaced s× more finely. That density is the whole problem — adjacent tokens m and m+1 are now only 1/s of a step apart in angle, but the model has never been trained to distinguish positions that close.

EXECUTION STATE
positions / s = [0, 1/8, 2/8, ..., 32767/8]
row L_NEW-1 of cos table = = row L_TRAIN-1 of original (by design)
29NTK-Aware in code — one line of math, that's it

Replace base with base · s^(d/(d-2)). For d=64, s=8 → new_base ≈ 83576. Hand it to the same build_rope_cos_sin function and you are done. Positions stay integer-valued (no fractional indices), so the high-frequency pairs continue to use the same θ_0 ≈ 1.0 they were trained with. Only the SLOW pairs see their θ deflated. This is also why NTK-Aware is described as a 'training-free' extension — no fine-tune required to test it, and it costs ZERO new compute at inference.

EXECUTION STATE
new_base = 10000 · 8^(64/62) ≈ 83576
shape of cos table =
(32768, 32) — same as PI
34Drop-in inference: nothing else changes

Notice the apply_rope call is byte-identical between PI, NTK-Aware, and original. The only difference is which cos/sin tensors we feed it. This is why NTK-Aware shipped in a community llama.cpp PR within days of bloc97's Reddit post — it is a one-line change to the table builder. Production stacks (vLLM, SGLang, transformers) all expose a `rope_scaling = {'type': 'dynamic_ntk', 'factor': s}` knob that performs exactly this base substitution at attention time.

EXECUTION STATE
q (input) =
(2, 8, 32768, 64) — batch, heads, seq, dim
q_rope (output) =
(2, 8, 32768, 64) — same shape
46 lines without explanation
1import torch
2
3def build_rope_cos_sin(d, max_pos, base):
4    """Precompute cos/sin tables for RoPE — exactly as in Llama/DeepSeek."""
5    half = d // 2
6    pair_idx  = torch.arange(half, dtype=torch.float32)   # [0, 1, ..., half-1]
7    inv_freq  = 1.0 / (base ** (2 * pair_idx / d))        # theta_i
8    positions = torch.arange(max_pos, dtype=torch.float32)
9    # Outer product: angles[m, i] = m * theta_i
10    angles = positions[:, None] * inv_freq[None, :]       # (max_pos, half)
11    return angles.cos(), angles.sin()                     # both (max_pos, half)
12
13def rotate_half(x):
14    # Split last dim in half, rotate the pair: (x1, x2) -> (-x2, x1)
15    x1, x2 = x[..., ::2], x[..., 1::2]
16    return torch.stack((-x2, x1), dim=-1).flatten(-2)
17
18def apply_rope(q, cos, sin):
19    # q: (..., seq, d). cos/sin: (seq, d/2) -> broadcast over last dim by repeating.
20    cos_full = cos.repeat_interleave(2, dim=-1)            # (seq, d)
21    sin_full = sin.repeat_interleave(2, dim=-1)            # (seq, d)
22    return q * cos_full + rotate_half(q) * sin_full
23
24# ---- The two extension methods ----
25D, L_TRAIN, L_NEW = 64, 4096, 32768
26s = L_NEW / L_TRAIN
27
28# (a) Position Interpolation: same base, but feed positions m / s into the table.
29def pi_cos_sin(d, max_pos, base, s):
30    cos, sin = build_rope_cos_sin(d, max_pos, base)
31    # Resample so that the m-th entry corresponds to true position m/s.
32    # In practice: build for max_pos new positions with positions = arange / s.
33    half = d // 2
34    inv_freq  = 1.0 / (base ** (2 * torch.arange(half) / d))
35    positions = torch.arange(max_pos) / s                  # KEY: divide
36    angles    = positions[:, None] * inv_freq[None, :]
37    return angles.cos(), angles.sin()
38
39# (b) NTK-Aware: new base; positions stay integer-valued.
40def ntk_cos_sin(d, max_pos, base, s):
41    new_base = base * (s ** (d / (d - 2)))                 # KEY: stretch base
42    return build_rope_cos_sin(d, max_pos, new_base)
43
44# Plug into a forward pass exactly like before:
45q = torch.randn(2, 8, L_NEW, D)                            # (batch, heads, seq, d)
46cos_ntk, sin_ntk = ntk_cos_sin(D, L_NEW, base=10000, s=s)
47q_rope = apply_rope(q, cos_ntk, sin_ntk)
48
49print("cos/sin table:", cos_ntk.shape)                     # (32768, 32)
50print("q after RoPE:", q_rope.shape)                       # (2, 8, 32768, 64)
51print("new base used:", 10000 * (s ** (D / (D - 2))))

In Hugging Face Transformers, this exact substitution is selected by setting \text{rope\_scaling} = \{ \text{type}: \text{"dynamic_ntk",factor:s}, \text{factor}: s \} in the model config. In vLLM and SGLang it is --rope-scaling-type dynamic-ntk --rope-scaling-factor 8. Same math, same shapes, three minutes to enable on any RoPE-based LLM.

At Massive Scale: Why NTK Alone Is Not Enough

NTK-Aware was a breakthrough because for the first time we could push Llama-1 7B to 8K and 16K context with almost no perplexity regression — and zero training cost. But at extreme extension ratios (s16s \geq 16) and on modern d=128d = 128 heads, two problems appear that bare NTK-Aware cannot fix.

Problem 1: The middle band still gets distorted

The bowed curve in the visualizer is smooth, but it does not match the truth of where extrapolation is safe and where it is not. Frequencies whose period is just barely longer than LtrainL_{\text{train}} need PI-style compression (they have never seen a full revolution). Frequencies whose period is much shorter than LtrainL_{\text{train}} need no compression at all (they cycled through every phase many times). NTK-Aware interpolates between these two regimes continuously, which means the mid-band pairs — the ones at the crossover — get the wrong treatment.

Problem 2: Attention logit scale changes

Stretching the base widens the inner product distribution of the rotated query and key. Softmax temperature interacts with this shift; at large ss the attention distribution becomes more uniform (over-smoothed) than what the model was trained on. This shows up as degraded retrieval quality at very long context even when perplexity looks healthy.

The standard remedy: YaRN

YaRN, covered next section, addresses both:

  1. It thresholds the frequencies into three bands — fully extrapolate the high-freq band (NTK-style), fully interpolate the low-freq band (PI-style), and linearly ramp between them in the middle. This is a piecewise version of NTK-Aware that does not over-stretch the mids.
  2. It rescales the attention logits by a small constant to compensate for the softmax temperature shift — typically 1+0.1lns\sqrt{1 + 0.1 \ln s}.

DeepSeek-V3's 128K context is built with YaRN on top of an original 4K-trained MLA RoPE, with s=32s = 32. The middle-band correction matters enormously at that ratio — bare NTK-Aware at s=32s = 32 would lose retrieval performance on the "needle in a haystack" tests that 128K context exists to pass.

MethodTrain costHigh-freq pairsLow-freq pairsTypical max sUsed by
PI (Chen 2023)0 or short FTcompressed by scompressed by s≈ 4×Llama 1 long-ctx
NTK-Aware (bloc97 2023)0 (training-free)unchangedcompressed by s≈ 8×early CodeLlama, llama.cpp
YaRN (Peng 2023)short fine-tuneunchangedcompressed by s≈ 32×DeepSeek-V3, many open LLMs

Engineering Reality and Gotchas

  1. Pre-compute once, share across the batch. The cos/sin tables depend only on (d, max_pos, base). Build them once at model load, store as buffers, and broadcast against any (batch, heads) shape. Re-building them per forward pass is a common performance bug.
  2. Dynamic NTK at inference. A useful variant — "Dynamic NTK" — adjusts ss per forward pass based on the current sequence length: st=max(1,Lt/Ltrain)s_t = \max(1, L_t / L_{\text{train}}). For short prompts (LtLtrainL_t \leq L_{\text{train}}) you do nothing; for long prompts you stretch only as much as needed. The trade-off is that the cos/sin table must be rebuilt whenever the sequence crosses LtrainL_{\text{train}} in a streaming setting — a one-time cost.
  3. Combines with quantisation cleanly. Because the cos/sin tables are just data, they quantise to FP8/INT8 with no extra accuracy loss beyond what RoPE already had. NTK-Aware extension on a quantised model is essentially free.
  4. The PI–NTK boundary depends on d. For d=16d = 16 the exponent d/(d2)=1.14d/(d-2) = 1.14 is meaningfully larger than 1, so the new base inflates more aggressively. For d=128d = 128 it is 1.016 — almost identical to a pure ss stretch. NTK-Aware behaves more like a re-derived PI on big heads and more like an aggressive tweeter-preserver on tiny heads.
  5. Always test on retrieval, not just perplexity. A well-known failure mode is that bare NTK-Aware at s=16s = 16 can improve perplexity on long-form text (because the low frequencies are now correctly spaced) while degrading needle-in-haystack accuracy (because the mid-band pairs are subtly miscalibrated). Run BABILong, RULER, or NIH before declaring victory.
  6. Short fine-tune helps a lot, even though it's "training-free". The original NTK-Aware post billed it as zero-training, and it does work zero-shot. But the standard production recipe today is: apply NTK-Aware or YaRN, then fine-tune for 1\sim 1B tokens at the new context length. This recovers the small perplexity gap that any frequency-domain rescaling introduces. DeepSeek-V3 explicitly does this in the YaRN stage.
NTK-Aware Scaling is, on the surface, a one-line code change: replace the RoPE base. Under the surface, it is the first technique that treated position as a frequency spectrum and asked which frequencies actually need adjustment. Every long-context method that followed — YaRN, LongRoPE, the dynamic variants — is a refinement of the same idea. If you understand which dimension pairs are tweeters and which are woofers, you understand modern long-context extension.
Loading comments...