Beltagy, Peters, & Cohan, "Longformer: The Long-Document Transformer", arXiv:2004.05150, 2020
Learning Objectives
After completing this chapter, you will be able to:
Explain why full self-attention has an O(N2) bottleneck and why this limits practical context length.
Describe the sliding window mask mathematically and implement it from scratch in NumPy and PyTorch.
Walk through a complete worked example showing how window size W controls which tokens a query can attend to.
Analyse how stacking L layers with window W gives an effective receptive field of L×W tokens.
Connect sliding window attention to production systems like Longformer, Mistral-7B, and BigBird.
The Real Problem
The Quadratic Wall
Standard self-attention (Chapter 1) computes a score between every pair of tokens. For a sequence of N tokens, that is N2 dot products, N2 softmax entries, and an N×N weight matrix stored in memory. The scaling is harsh:
Sequence length N
Score matrix entries
Memory (FP16)
512 (BERT)
262,144
0.5 MB
4,096 (GPT-3)
16,777,216
32 MB
16,384 (Longformer)
268,435,456
512 MB
131,072 (Mistral)
17,179,869,184
32 GB
At 16,384 tokens, a single attention head requires 512 MB just for the score matrix. With 12 heads and 12 layers, you need over 70 GB — more than any consumer GPU. This is the quadratic wall: doubling context length quadruples the cost.
The Locality Principle
But here is the critical insight: most dependencies in natural language are local. A noun is closest to its adjective. A verb is closest to its subject. Commas, articles, and prepositions bind to their immediate neighbours. Empirical studies of attention patterns in BERT and GPT models show that the majority of attention weight concentrates within a window of 100–500 tokens, even when the full context is available.
The Key Question: If most attention weight is local anyway, why compute all N2 scores? What if we restrict each token to attend only to its W nearest neighbours on each side, dropping the complexity from O(N2) to O(N×W)?
This is exactly what sliding window attention does. For fixed W, the cost grows linearly in N, not quadratically — enabling context lengths of 16K, 32K, or even 128K tokens.
The Story Behind Sliding Window Attention
Longformer: From Research to Practice
The idea of restricting attention to a local window was not invented in a vacuum. Local convolutions in CNNs had proven that local connectivity is powerful enough for images. The question was whether the same principle could work for sequences.
In 2020, Iz Beltagy, Matthew Peters, and Arman Cohan at the Allen Institute for AI published Longformer. They were working on tasks that required processing entire documents — legal briefs, scientific papers, Wikipedia articles — documents that easily exceed 4,096 tokens. Standard BERT-style models simply could not process them.
Their solution had two key ingredients:
Sliding window attention for the majority of tokens — each token attends to its W nearest neighbours on each side, where W=256 or W=512 in practice.
Global attention on a few special tokens (like the [CLS] token or question tokens in QA tasks) that attend to the entire sequence. This hybrid approach retains the ability to capture long-range dependencies where they matter most.
The Longformer could process sequences of 4,096+ tokens with the same memory as BERT uses for 512 tokens. It achieved state-of-the-art results on long-document tasks like TriviaQA, WikiHop, and the IMDB review classification benchmark.
Historical note: The sliding window concept also appears independently in Sparse Transformers (Child et al., 2019) and later in Mistral-7B (Jiang et al., 2023), which used W=4096 with a rolling KV-cache to handle 32K context lengths. The principle is the same; the engineering differs.
The Mathematical Definition
Symbol-by-Symbol Breakdown
Sliding window attention is standard scaled dot-product attention with one modification: a distance-based mask that blocks positions outside the window. The score function becomes:
After masking, softmax is applied row-wise as usual: Aij=softmaxj(score[i,:]). Because exp(−∞)=0, blocked positions receive exactly zero attention weight. The output is then Oi=∑jAij⋅Vj, but the sum effectively runs only over the 2W+1 visible positions.
The Window Mask Matrix
The mask is a band matrix centred on the diagonal with bandwidth 2W+1:
Notice the symmetric band pattern along the diagonal. Unlike the causal mask (Chapter 3) which is triangular, the sliding window mask is symmetric — token i can see token j if and only if j can see i. This bidirectional locality is natural for encoder tasks like BERT-style understanding.
Interactive: Sliding Window Mask Explorer
Use the slider below to change the window size W and watch how the attention mask, weights, and computational cost change. Click any row to see the weight distribution for that token. Switch between "Mask Pattern", "Attention Weights", and "vs Full Attention" views.
Loading mask explorer...
Step-by-Step Calculation
We use the same shared example as every chapter: 5 tokens ("The cat sat on mat"), dk=4, and W=1.
Step 1: Raw Dot Products
Compute raw[i,j]=Qi⋅Kj⊤ for all i,j. This is identical to standard attention:
The
cat
sat
on
mat
The
0.00
2.00
1.00
1.00
1.50
cat
3.00
0.00
2.00
1.00
0.50
sat
1.00
2.00
2.00
1.00
1.50
on
1.00
1.00
0.00
2.00
1.00
mat
1.00
1.00
1.00
1.00
1.50
Step 2: Scaling
Divide by dk=4=2.0:
The
cat
sat
on
mat
The
0.000
1.000
0.500
0.500
0.750
cat
1.500
0.000
1.000
0.500
0.250
sat
0.500
1.000
1.000
0.500
0.750
on
0.500
0.500
0.000
1.000
0.500
mat
0.500
0.500
0.500
0.500
0.750
Step 3: Apply Sliding Window Mask
For W=1: any position where ∣i−j∣>1 is set to −∞:
The
cat
sat
on
mat
The
0.000
1.000
−∞
−∞
−∞
cat
1.500
0.000
1.000
−∞
−∞
sat
−∞
1.000
1.000
0.500
−∞
on
−∞
−∞
0.000
1.000
0.500
mat
−∞
−∞
−∞
0.500
0.750
The −∞ entries form the blocked region outside the band. Only the diagonal band of width 2W+1=3 retains real scores.
Step 4: Softmax Over Visible Tokens
Softmax is computed per row, but exp(−∞)=0 so blocked positions contribute nothing. Let us trace row by row:
Let us trace the complete pipeline for "sat" (position 2, W=1):
Visible positions:∣2−j∣≤1 gives j∈{1,2,3} = ("cat", "sat", "on"). "The" is 2 positions away, "mat" is 2 away — both blocked.
Scaled scores:[1.000,1.000,0.500] — "cat" and "sat" have the highest relevance.
Softmax:[0.3837,0.3837,0.2327] — "cat" and "sat" split evenly, "on" gets less.
Output:[0,0.3837,0.3837,0.2327] — a blend of the three visible value vectors.
What "sat" loses: In full attention, "sat" could see "The" (weight 0.1519) and "mat" (weight 0.1951). With W=1, those connections are severed. But the weights for the visible tokens are re-normalised — "cat" and "sat" each get 0.3837 instead of 0.2505. The model compensates by concentrating attention on the most locally relevant tokens.
Full Attention Weights and Output
Sliding Window Attention Weights (W=1, 5×5)
The
cat
sat
on
mat
The
0.2689
0.7311
0.0000
0.0000
0.0000
cat
0.5465
0.1220
0.3315
0.0000
0.0000
sat
0.0000
0.3837
0.3837
0.2327
0.0000
on
0.0000
0.0000
0.1863
0.5065
0.3072
mat
0.0000
0.0000
0.0000
0.4378
0.5622
Output Matrix (5×4)
dim-0
dim-1
dim-2
dim-3
The
0.2689
0.7311
0.0000
0.0000
cat
0.5465
0.1220
0.3315
0.0000
sat
0.0000
0.3837
0.3837
0.2327
on
0.1536
0.1536
0.3399
0.6601
mat
0.2811
0.2811
0.2811
0.7189
Standard vs Sliding Window Comparison
Compare with full (standard) attention from Chapter 1:
Token
Standard Output
Window (W=1) Output
Max Difference
The
[0.2254, 0.4135, 0.2964, 0.2964]
[0.2689, 0.7311, 0.0000, 0.0000]
0.3176
cat
[0.4602, 0.1475, 0.3018, 0.2058]
[0.5465, 0.1220, 0.3315, 0.0000]
0.2058
sat
[0.2495, 0.3481, 0.3481, 0.2495]
[0.0000, 0.3837, 0.3837, 0.2327]
0.2495
on
[0.2854, 0.2854, 0.2106, 0.4089]
[0.1536, 0.1536, 0.3399, 0.6601]
0.2512
mat
[0.3108, 0.3108, 0.3108, 0.3108]
[0.2811, 0.2811, 0.2811, 0.7189]
0.4081
The differences are non-trivial because our toy sequence is very short (N=5). In real models with hundreds of tokens, the vast majority of attention weight is already local, so the sliding window output closely matches full attention for most positions.
The Receptive Field Insight
A single sliding window layer with W=1 limits each token to 3 neighbours. This seems very restrictive. But transformers have multiple layers, and this changes everything.
After L layers, each token's output incorporates information from tokens up to L×W positions away. The mechanism works like a relay: in layer 1, "sat" aggregates information from "cat", "sat", and "on". In layer 2, "sat"'s output now contains information that "cat" gathered from "The" in layer 1 — so "sat" effectively "sees" "The" through "cat".
effective receptive field=L×W positions on each side
For a model like Mistral-7B with L=32 layers and W=4096:
receptive field=32×4096=131,072 tokens on each side
This far exceeds the 32K context window, meaning information can propagate across the entire input.
Interactive: Layer Stacking Receptive Field
Explore how increasing the window size or number of layers expands the receptive field. Click different tokens to trace their information reach:
Loading receptive field visualizer...
Applications Across Domains
Long-Document NLP
Longformer was designed for tasks where standard BERT truncates the input. Legal document analysis, patent search, scientific paper summarisation — all require processing thousands of tokens. Longformer with W=256 and global tokens on [CLS] set new state-of-the-art on WikiHop, TriviaQA, and HotpotQA.
Code Analysis
Source code has strong locality: variable references are usually within tens of lines of their declaration. A sliding window of W=512 tokens captures most local scopes (function bodies, class definitions) while ignoring distant import statements that rarely affect local logic. CodeLlama uses a similar approach to handle repository-scale context.
Genomic Sequence Modeling
DNA sequences can be millions of base pairs long. Most functional relationships in genes (promoter-gene interactions, exon-intron boundaries) operate within windows of a few thousand base pairs. Models like Enformer use local attention windows to process entire chromosomes.
Time Series Forecasting
Financial data, sensor readings, and weather measurements exhibit temporal locality: today's temperature depends mostly on yesterday's, not last month's. Sliding window attention naturally captures this with W chosen to match the seasonality period.
Connection to Modern Systems
Longformer: Global + Local Pattern
Longformer's innovation is the hybrid attention pattern: most tokens use sliding window attention, but a few designated "global" tokens attend to the entire sequence. In classification, the [CLS] token is global. In question answering, the question tokens are global. This gives O(N) complexity for most tokens while preserving the ability to aggregate global information where needed.
Mistral and Sliding Window KV-Cache
Mistral-7B (Jiang et al., 2023) uses W=4096 sliding window attention with a rolling KV-cache. During autoregressive generation, the cache only stores the most recent W key-value pairs per layer. Older entries are evicted. This bounds the KV-cache to O(W×L×dk) regardless of how long the generated sequence becomes — a constant memory budget during inference.
System
Window W
Layers L
Effective Range
KV-Cache Bound
Longformer
256
12
3,072
N/A (encoder)
Mistral-7B
4,096
32
131,072
4,096 entries/layer
BigBird (Chapter 12)
Global+Local
12
Full sequence
N/A (encoder)
Flash Attention with Block Sparsity
Flash Attention (Chapter 13) works by tiling the attention matrix into blocks that fit in SRAM. Sliding window attention maps perfectly onto this: blocks outside the window band are never loaded from HBM, saving both memory and compute. The Flash Attention 2 paper explicitly supports sliding window as a block-sparse pattern.
Complexity Analysis
Metric
Standard Attention
Sliding Window (W fixed)
Score computation
O(N2⋅dk)
O(N⋅W⋅dk)
Memory (score matrix)
O(N2)
O(N⋅W)
Output computation
O(N2⋅dv)
O(N⋅W⋅dv)
Total FLOPs
O(N2⋅d)
O(N⋅W⋅d)
KV-cache (inference)
O(N⋅d) (grows with N)
O(W⋅d) (constant)
The key result: For fixed W, sliding window attention is O(N) in both time and memory. Mistral-7B with W=4096 and N=32,768 reduces the score matrix from 1 billion entries to 268 million — an 8× reduction.
Trade-off: The savings grow with the ratio N/W. For short sequences where N≤2W+1, the window covers the entire sequence and there is no benefit. Sliding window attention is most powerful for long sequences.
Python Implementation
A complete, self-contained class with the shared example. Every line is annotated with execution state — click any line to see the exact values flowing through it.
Sliding Window Attention — NumPy Implementation
🐍sliding_window_attention.py
Explanation(33)
Code(132)
1import numpy as np
NumPy provides vectorized matrix operations. Q @ K.T runs as optimized C code, not Python loops.
2import math
Standard library math module. We use math.sqrt() for the scaling factor.
4class SlidingWindowAttention
A self-contained class implementing sliding window attention. The only difference from standard attention (Chapter 1) is the window mask applied before softmax: positions with |i - j| > W are blocked.
17def __init__(self, d_k, window_size=1)
Constructor takes d_k (dimension of Q/K vectors) and window_size W. Pre-computes the scaling factor sqrt(d_k).
EXECUTION STATE
⬇ input: d_k = 4 (dimension of each query/key vector)
⬇ input: window_size = 1 (±1 neighbour on each side)
⬆ sets = self.d_k=4, self.W=1, self.scale=2.0
23self.d_k = d_k
Store the key dimension for reference.
EXECUTION STATE
self.d_k = 4
24self.W = window_size
Store window size. W=1 means each token sees itself and ±1 neighbour = 3 positions max.
EXECUTION STATE
self.W = 1
25self.scale = math.sqrt(d_k)
Pre-compute √d_k = √4 = 2.0. Dividing scores by this prevents softmax saturation.
EXECUTION STATE
self.scale = 2.0
27def _softmax(self, x) → np.ndarray
Numerically stable softmax that correctly handles -inf values from the window mask. Operates row-by-row: each row sums to 1.0.
EXECUTION STATE
⬇ input: x (5×5) =
The cat sat on mat
The 0.000 1.000 -inf -inf -inf
cat 1.500 0.000 1.000 -inf -inf
sat -inf 1.000 1.000 0.500 -inf
on -inf -inf 0.000 1.000 0.500
mat -inf -inf -inf 0.500 0.750
⬆ returns = np.ndarray (5,5) — attention weights, each row sums to 1.0
29x_safe = np.where(np.isfinite(x), x, -1e9)
Replace -inf with -1e9. np.isfinite() returns False for -inf. This lets np.max() work correctly without returning -inf.
EXECUTION STATE
np.isfinite(x) = True for real numbers, False for -inf entries
-1e9 = -1000000000.0 — large enough that exp(-1e9) ≈ 0
Subtract the row-wise maximum from each element. This is the numerical stability trick: exp(x - max) avoids overflow while producing identical softmax results.
EXECUTION STATE
axis=-1 = find max along last axis — each row gets its own max
keepdims=True = result has shape (5,1) not (5,), so broadcasting x(5×5) - max(5×1) works
── Row 0 (The) ── =
max = 1.000 (from 'cat' position)
x_shifted[0] = [-1.000, 0.000, -1e9, -1e9, -1e9]
── Row 2 (sat) ── =
max = 1.000 (from 'cat' and 'sat' positions)
x_shifted[2] = [-1e9, 0.000, 0.000, -0.500, -1e9]
── Row 4 (mat) ── =
max = 0.750 (self-attention)
x_shifted[4] = [-1e9, -1e9, -1e9, -0.250, 0.000]
31exp_x = np.exp(x_shifted)
Element-wise exponentiation. Blocked positions have x_shifted ≈ -1e9, so exp(-1e9) ≈ 0. The max entry per row has x_shifted = 0, so exp(0) = 1.
The cat sat on mat
The 0.2689 0.7311 0.0000 0.0000 0.0000
cat 0.5465 0.1220 0.3315 0.0000 0.0000
sat 0.0000 0.3837 0.3837 0.2327 0.0000
on 0.0000 0.0000 0.1863 0.5065 0.3072
mat 0.0000 0.0000 0.0000 0.4378 0.5622
34def build_window_mask(self, N) → np.ndarray
Builds the boolean window mask. True means the position is BLOCKED (|i - j| > W). This is the ONLY difference from standard attention.
EXECUTION STATE
⬇ input: N = 5 (number of tokens)
self.W = 1 (window size)
⬆ returns = np.ndarray (5,5) of bool — True where |i - j| > W
36idx = np.arange(N)
Create position indices [0, 1, 2, 3, 4] for the 5 tokens.
EXECUTION STATE
idx = [0, 1, 2, 3, 4]
37dist = np.abs(idx[:, None] - idx[None, :])
Compute the absolute distance matrix. idx[:, None] has shape (5,1) and idx[None,:] has shape (1,5) — broadcasting gives a (5,5) distance matrix.
The cat sat on mat
The 0 1 2 3 4
cat 1 0 1 2 3
sat 2 1 0 1 2
on 3 2 1 0 1
mat 4 3 2 1 0
38return dist > self.W
Boolean comparison: True where distance exceeds window W=1. These positions will be set to -inf before softmax.
EXECUTION STATE
⬆ return: mask (5×5) =
The cat sat on mat
The False False True True True
cat False False False True True
sat True False False False True
on True True False False False
mat True True True False False
40def forward(self, Q, K, V)
Full forward pass: raw scores → scale → window mask → softmax → weighted sum. Identical to standard attention except for the window mask step.
EXECUTION STATE
⬇ input: Q (5×4) =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
⬇ input: K (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
⬇ input: V (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
⬆ returns = (weights (5×5), output (5×4))
53N = Q.shape[0]
Number of tokens in the sequence.
EXECUTION STATE
N = 5
54raw_scores = Q @ K.T
Matrix multiply Q (5×4) with K transposed (4×5). Each entry [i,j] is the dot product of query i with key j.
EXECUTION STATE
K.T = K transposed from (5×4) to (4×5)
raw_scores (5×5) =
The cat sat on mat
The 0.00 2.00 1.00 1.00 1.50
cat 3.00 0.00 2.00 1.00 0.50
sat 1.00 2.00 2.00 1.00 1.50
on 1.00 1.00 0.00 2.00 1.00
mat 1.00 1.00 1.00 1.00 1.50
55scaled_scores = raw_scores / self.scale
Divide every score by sqrt(d_k) = 2.0. This keeps variance at ~1, preventing softmax saturation.
EXECUTION STATE
self.scale = 2.0
scaled_scores (5×5) =
The cat sat on mat
The 0.000 1.000 0.500 0.500 0.750
cat 1.500 0.000 1.000 0.500 0.250
sat 0.500 1.000 1.000 0.500 0.750
on 0.500 0.500 0.000 1.000 0.500
mat 0.500 0.500 0.500 0.500 0.750
56mask = self.build_window_mask(N)
Build the 5×5 boolean mask. True where |i-j| > 1.
EXECUTION STATE
mask (5×5) =
The cat sat on mat
The False False True True True
cat False False False True True
sat True False False False True
on True True False False False
mat True True True False False
57masked_scores = scaled_scores.copy()
Copy to avoid modifying the original. The original scores are preserved.
58masked_scores[mask] = -np.inf
Boolean indexing: everywhere mask is True (|i-j| > W), set score to -inf. This is the CORE OPERATION of sliding window attention.
EXECUTION STATE
masked_scores (5×5) =
The cat sat on mat
The 0.000 1.000 -inf -inf -inf
cat 1.500 0.000 1.000 -inf -inf
sat -inf 1.000 1.000 0.500 -inf
on -inf -inf 0.000 1.000 0.500
mat -inf -inf -inf 0.500 0.750
59weights = self._softmax(masked_scores)
Softmax over the masked scores. Blocked positions get weight 0.0000. Each row sums to 1.0.
EXECUTION STATE
weights (5×5) =
The cat sat on mat
The 0.2689 0.7311 0.0000 0.0000 0.0000
cat 0.5465 0.1220 0.3315 0.0000 0.0000
sat 0.0000 0.3837 0.3837 0.2327 0.0000
on 0.0000 0.0000 0.1863 0.5065 0.3072
mat 0.0000 0.0000 0.0000 0.4378 0.5622
60output = weights @ V
Multiply weights (5×5) by V (5×4). Each output row is a weighted sum of value vectors from ONLY visible (within-window) tokens.
EXECUTION STATE
output (5×4) =
d0 d1 d2 d3
The 0.2689 0.7311 0.0000 0.0000
cat 0.5465 0.1220 0.3315 0.0000
sat 0.0000 0.3837 0.3837 0.2327
on 0.1536 0.1536 0.3399 0.6601
mat 0.2811 0.2811 0.2811 0.7189
61return weights, output
Return both weights (for visualization) and output (fed to next layer).
90tokens = ["The", "cat", "sat", "on", "mat"]
The shared example sentence used throughout all 15 chapters. 5 tokens with d_k = 4.
EXECUTION STATE
tokens = ["The", "cat", "sat", "on", "mat"]
N = 5 tokens
92Q = np.array([...])
Query matrix (5×4). Each row is what that token is 'asking for' — the features it wants to find in the sequence.
EXECUTION STATE
Q (5×4) =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
100K = np.array([...])
Key matrix (5×4). Each row is what that token 'advertises' about itself.
EXECUTION STATE
K (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
108V = np.array([...])
Value matrix (5×4). Each row is the content that token contributes when selected. Q and K decide WHO to attend to; V decides WHAT to retrieve.
EXECUTION STATE
V (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
Instantiate with d_k=4 and W=1. Sets self.scale = sqrt(4) = 2.0.
EXECUTION STATE
attn.d_k = 4
attn.W = 1
attn.scale = 2.0
118weights, output = attn.forward(Q, K, V)
Run the full sliding window attention forward pass. Returns 5×5 weight matrix and 5×4 output.
EXECUTION STATE
weights (5×5) =
The cat sat on mat
The 0.2689 0.7311 0.0000 0.0000 0.0000
cat 0.5465 0.1220 0.3315 0.0000 0.0000
sat 0.0000 0.3837 0.3837 0.2327 0.0000
on 0.0000 0.0000 0.1863 0.5065 0.3072
mat 0.0000 0.0000 0.0000 0.4378 0.5622
output (5×4) =
d0 d1 d2 d3
The 0.2689 0.7311 0.0000 0.0000
cat 0.5465 0.1220 0.3315 0.0000
sat 0.0000 0.3837 0.3837 0.2327
on 0.1536 0.1536 0.3399 0.6601
mat 0.2811 0.2811 0.2811 0.7189
124attn.explain(Q, K, V, tokens, query_idx=2)
Print a detailed trace for 'sat' (row 2). Shows visible/blocked tokens and the full computation.
EXECUTION STATE
query_idx = 2 → 'sat'
visible = ['cat', 'sat', 'on'] (j=1,2,3)
blocked = ['The', 'mat'] (j=0,4)
99 lines without explanation
1import numpy as np
2import math
34classSlidingWindowAttention:5"""
6 Sliding Window Attention (Beltagy et al., 2020)
78 Each token attends only to its W nearest neighbours on
9 each side, reducing complexity from O(N^2) to O(N * W).
10 Positions outside the window receive score = -inf so
11 softmax gives them exactly zero weight.
12 """1314def__init__(self, d_k:int, window_size:int=1):15"""
16 Args:
17 d_k: Dimension of query/key vectors
18 window_size: W — each token sees +/-W neighbours
19 """20 self.d_k = d_k
21 self.W = window_size
22 self.scale = math.sqrt(d_k)2324def_softmax(self, x: np.ndarray)-> np.ndarray:25"""Numerically stable softmax that handles -inf."""26 x_safe = np.where(np.isfinite(x), x,-1e9)27 x_shifted = x_safe - np.max(x_safe, axis=-1, keepdims=True)28 exp_x = np.exp(x_shifted)29return exp_x / np.sum(exp_x, axis=-1, keepdims=True)3031defbuild_window_mask(self, N:int)-> np.ndarray:32"""Build boolean mask: True where |i - j| > W (blocked)."""33 idx = np.arange(N)34 dist = np.abs(idx[:,None]- idx[None,:])35return dist > self.W
3637defforward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):38"""
39 Full forward pass.
4041 Args:
42 Q: Query matrix (N, d_k)
43 K: Key matrix (N, d_k)
44 V: Value matrix (N, d_v)
4546 Returns:
47 weights: Attention weight matrix (N, N)
48 output: Context-enriched output (N, d_v)
49 """50 N = Q.shape[0]51 raw_scores = Q @ K.T
52 scaled_scores = raw_scores / self.scale
53 mask = self.build_window_mask(N)54 masked_scores = scaled_scores.copy()55 masked_scores[mask]=-np.inf
56 weights = self._softmax(masked_scores)57 output = weights @ V
58return weights, output
5960defexplain(self, Q, K, V, tokens, query_idx=0):61"""Print a detailed trace for a specific query token."""62 N = Q.shape[0]63 raw = Q @ K.T
64 scaled = raw / self.scale
65 mask = self.build_window_mask(N)6667 t = tokens[query_idx]68 visible =[j for j inrange(N)ifnot mask[query_idx, j]]69 blocked =[j for j inrange(N)if mask[query_idx, j]]7071print(f"\n=== Sliding window trace for '{t}' (row {query_idx}, W={self.W}) ===")72print(f"Q[{query_idx}] = {Q[query_idx]}")73print(f"Visible: {[tokens[j]for j in visible]} (|i-j| <= {self.W})")74print(f"Blocked: {[tokens[j]for j in blocked]} (|i-j| > {self.W})")7576 vis_scores =[scaled[query_idx, j]for j in visible]77 exps =[math.exp(s)for s in vis_scores]78 total =sum(exps)79 ws =[e / total for e in exps]8081print(f"\nScaled scores (visible only):")82for idx, j inenumerate(visible):83print(f" S[{t},{tokens[j]}] = {vis_scores[idx]:.4f}")8485print(f"\nSoftmax weights:")86for idx, j inenumerate(visible):87 bar ='#'*int(ws[idx]*40)88print(f" A[{t},{tokens[j]}] = {ws[idx]:.4f} |{bar}|")8990 out =sum(ws[idx]* V[j]for idx, j inenumerate(visible))91print(f"\nOutput O[{t}] = {np.round(out,4)}")929394# ── Shared Example (same Q, K, V as every chapter) ──95tokens =["The","cat","sat","on","mat"]9697Q = np.array([98[1.0,0.0,1.0,0.0],# The99[0.0,2.0,0.0,1.0],# cat100[1.0,1.0,1.0,0.0],# sat101[0.0,0.0,1.0,1.0],# on102[1.0,0.0,0.0,1.0],# mat103])104105K = np.array([106[0.0,1.0,0.0,1.0],# The107[1.0,0.0,1.0,0.0],# cat108[1.0,1.0,0.0,0.0],# sat109[0.0,0.0,1.0,1.0],# on110[1.0,0.0,0.5,0.5],# mat111])112113V = np.array([114[1.0,0.0,0.0,0.0],# The115[0.0,1.0,0.0,0.0],# cat116[0.0,0.0,1.0,0.0],# sat117[0.0,0.0,0.0,1.0],# on118[0.5,0.5,0.5,0.5],# mat119])120121# ── Run ──122attn = SlidingWindowAttention(d_k=4, window_size=1)123weights, output = attn.forward(Q, K, V)124125print("Sliding Window Attention Weights (W=1):")126print(np.round(weights,4))127128print("\nSliding Window Output (W=1):")129print(np.round(output,4))130131# Detailed trace for "sat"132attn.explain(Q, K, V, tokens, query_idx=2)
PyTorch Implementation
The PyTorch version supports GPU acceleration, automatic differentiation, and integrates with standard training loops. The mask construction uses masked_fill for efficient GPU-friendly masking.
Sliding Window Attention — PyTorch Implementation
🐍sliding_window_attention_torch.py
Explanation(40)
Code(118)
1import torch
PyTorch provides GPU-accelerated tensor operations and autograd for backpropagation.
2import torch.nn as nn
Neural network module. nn.Module is the base class for all PyTorch models.
The cat sat on mat
The 0 1 2 3 4
cat 1 0 1 2 3
sat 2 1 0 1 2
on 3 2 1 0 1
mat 4 3 2 1 0
29return dist > self.window_size
Boolean comparison. True where distance exceeds W=1.
EXECUTION STATE
⬆ return: mask (5×5) =
The cat sat on mat
The False False True True True
cat False False False True True
sat True False False False True
on True True False False False
mat True True True False False
31def forward(self, Q, K, V) → tuple
Forward pass: scaled dot-product with window masking.
Scaled dot-product scores. K.transpose(-2,-1) swaps last two dims.
EXECUTION STATE
.transpose(-2, -1) = swap last two dimensions: (N,d) → (d,N). Equivalent to .T for 2D.
scores (5×5) =
The cat sat on mat
The 0.000 1.000 0.500 0.500 0.750
cat 1.500 0.000 1.000 0.500 0.250
sat 0.500 1.000 1.000 0.500 0.750
on 0.500 0.500 0.000 1.000 0.500
mat 0.500 0.500 0.500 0.500 0.750
PyTorch masked_fill: where mask is True, replace with -inf. Cleaner than boolean indexing.
EXECUTION STATE
.masked_fill(mask, val) = where mask is True, replace with val. Non-masked positions unchanged.
scores after masking (5×5) =
The cat sat on mat
The 0.000 1.000 -inf -inf -inf
cat 1.500 0.000 1.000 -inf -inf
sat -inf 1.000 1.000 0.500 -inf
on -inf -inf 0.000 1.000 0.500
mat -inf -inf -inf 0.500 0.750
49weights = F.softmax(scores, dim=-1)
Apply softmax along last dimension. -inf positions get weight 0.
EXECUTION STATE
dim=-1 = softmax along columns. Each row sums to 1.0.
weights (5×5) =
The cat sat on mat
The 0.2689 0.7311 0.0000 0.0000 0.0000
cat 0.5465 0.1220 0.3315 0.0000 0.0000
sat 0.0000 0.3837 0.3837 0.2327 0.0000
on 0.0000 0.0000 0.1863 0.5065 0.3072
mat 0.0000 0.0000 0.0000 0.4378 0.5622
50output = torch.matmul(weights, V)
Weighted sum of value vectors from visible tokens only.
EXECUTION STATE
output (5×4) =
d0 d1 d2 d3
The 0.2689 0.7311 0.0000 0.0000
cat 0.5465 0.1220 0.3315 0.0000
sat 0.0000 0.3837 0.3837 0.2327
on 0.1536 0.1536 0.3399 0.6601
mat 0.2811 0.2811 0.2811 0.7189
51return output, weights
Return (output, weights) tuple.
54def manual(Q, K, V, window_size=1)
Static method for running sliding window attention with pre-computed Q, K, V. Used for our shared example.
One mask, one change: Sliding window attention is standard attention with a single modification — positions where ∣i−j∣>W are set to −∞ before softmax.
Linear complexity: For fixed W, the cost is O(N×W) — linear in sequence length N.
Locality is sufficient: Most linguistic, code, and genomic dependencies are local. The window captures the vast majority of useful attention weight.
Layer stacking recovers range: After L layers, information propagates L×W positions, enabling full-sequence understanding despite local attention.
KV-cache savings: In autoregressive models like Mistral, the rolling KV-cache stores only W entries per layer, enabling constant-memory inference.
Composable with other techniques: Sliding window combines naturally with Flash Attention (block sparsity), global tokens (Longformer), and causal masking (Mistral).
Exercises
Vary the window: Using the Python class above, compute attention weights for W=0 (self-only), W=2, and W=4 (full). Verify that W=4 produces the same weights as Chapter 1's standard attention.
Causal + sliding window: Modify the class to combine the causal mask (j>i⇒−∞) with the sliding window mask (∣i−j∣>W⇒−∞). This is what Mistral-7B uses. How many active connections remain for N=5, W=1?
Receptive field proof: Prove by induction that after L layers of window-W attention, token i's output depends on tokens in the range [i−LW,i+LW].
Memory calculation: A model has 32 layers, 32 heads, dk=128, and processes N=32,768 tokens. Calculate the total score matrix memory in FP16 for (a) full attention and (b) sliding window with W=4096. What is the ratio?
Dilated window: Instead of contiguous positions, consider a dilated window where token i attends to positions {i−2W,i−W,i,i+W,i+2W}. Implement this and compare the output to the standard sliding window.
References
Beltagy, I., Peters, M. E., & Cohan, A. (2020). Longformer: The Long-Document Transformer. arXiv:2004.05150.
Child, R., Gray, S., Radford, A., & Sutskever, I. (2019). Generating Long Sequences with Sparse Transformers. arXiv:1904.10509.
Jiang, A. Q., Sablayrolles, A., Mensch, A., et al. (2023). Mistral 7B. arXiv:2310.06825.
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
Zaheer, M., Guruganesh, G., Dubey, A., et al. (2020). Big Bird: Transformers for Longer Sequences. NeurIPS 2020.
Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). Attention Is All You Need. NeurIPS 2017.