Chapter 11
18 min read
Section 12 of 17

Sliding Window Attention

Sliding Window Attention

Beltagy, Peters, & Cohan, "Longformer: The Long-Document Transformer", arXiv:2004.05150, 2020


Learning Objectives

After completing this chapter, you will be able to:

  1. Explain why full self-attention has an O(N2)O(N^2) bottleneck and why this limits practical context length.
  2. Describe the sliding window mask mathematically and implement it from scratch in NumPy and PyTorch.
  3. Walk through a complete worked example showing how window size WW controls which tokens a query can attend to.
  4. Analyse how stacking LL layers with window WW gives an effective receptive field of L×WL \times W tokens.
  5. Connect sliding window attention to production systems like Longformer, Mistral-7B, and BigBird.

The Real Problem

The Quadratic Wall

Standard self-attention (Chapter 1) computes a score between every pair of tokens. For a sequence of NN tokens, that is N2N^2 dot products, N2N^2 softmax entries, and an N×NN \times N weight matrix stored in memory. The scaling is harsh:

Sequence length NScore matrix entriesMemory (FP16)
512 (BERT)262,1440.5 MB
4,096 (GPT-3)16,777,21632 MB
16,384 (Longformer)268,435,456512 MB
131,072 (Mistral)17,179,869,18432 GB

At 16,384 tokens, a single attention head requires 512 MB just for the score matrix. With 12 heads and 12 layers, you need over 70 GB — more than any consumer GPU. This is the quadratic wall: doubling context length quadruples the cost.

The Locality Principle

But here is the critical insight: most dependencies in natural language are local. A noun is closest to its adjective. A verb is closest to its subject. Commas, articles, and prepositions bind to their immediate neighbours. Empirical studies of attention patterns in BERT and GPT models show that the majority of attention weight concentrates within a window of 100–500 tokens, even when the full context is available.

The Key Question: If most attention weight is local anyway, why compute all N2N^2 scores? What if we restrict each token to attend only to its WW nearest neighbours on each side, dropping the complexity from O(N2)O(N^2) to O(N×W)O(N \times W)?

This is exactly what sliding window attention does. For fixed WW, the cost grows linearly in NN, not quadratically — enabling context lengths of 16K, 32K, or even 128K tokens.


The Story Behind Sliding Window Attention

Longformer: From Research to Practice

The idea of restricting attention to a local window was not invented in a vacuum. Local convolutions in CNNs had proven that local connectivity is powerful enough for images. The question was whether the same principle could work for sequences.

In 2020, Iz Beltagy, Matthew Peters, and Arman Cohan at the Allen Institute for AI published Longformer. They were working on tasks that required processing entire documents — legal briefs, scientific papers, Wikipedia articles — documents that easily exceed 4,096 tokens. Standard BERT-style models simply could not process them.

Their solution had two key ingredients:

  1. Sliding window attention for the majority of tokens — each token attends to its WW nearest neighbours on each side, where W=256W = 256 or W=512W = 512 in practice.
  2. Global attention on a few special tokens (like the [CLS]\texttt{[CLS]} token or question tokens in QA tasks) that attend to the entire sequence. This hybrid approach retains the ability to capture long-range dependencies where they matter most.

The Longformer could process sequences of 4,096+ tokens with the same memory as BERT uses for 512 tokens. It achieved state-of-the-art results on long-document tasks like TriviaQA, WikiHop, and the IMDB review classification benchmark.

Historical note: The sliding window concept also appears independently in Sparse Transformers (Child et al., 2019) and later in Mistral-7B (Jiang et al., 2023), which used W=4096W = 4096 with a rolling KV-cache to handle 32K context lengths. The principle is the same; the engineering differs.

The Mathematical Definition

Symbol-by-Symbol Breakdown

Sliding window attention is standard scaled dot-product attention with one modification: a distance-based mask that blocks positions outside the window. The score function becomes:

score[i,j]={QiKjdkif ijWif ij>W\text{score}[i, j] = \begin{cases} \dfrac{Q_i \cdot K_j^\top}{\sqrt{d_k}} & \text{if } |i - j| \leq W \\[6pt] -\infty & \text{if } |i - j| > W \end{cases}

Let us unpack every symbol:

SymbolMeaningIn our example
QiQ_iQuery vector for token at position iQ2=[1,1,1,0]Q_2 = [1, 1, 1, 0] for "sat"
KjK_jKey vector for token at position jK1=[1,0,1,0]K_1 = [1, 0, 1, 0] for "cat"
dkd_kDimension of Q/K vectors4
WWWindow size (half-width)1 (±1 neighbour)
ij|i - j|Absolute distance between positions20=2>W|2 - 0| = 2 > W, so blocked
-\inftyEnsures exp(-∞) = 0 in softmaxBlocked positions get zero weight

After masking, softmax is applied row-wise as usual: Aij=softmaxj(score[i,:])A_{ij} = \text{softmax}_j(\text{score}[i, :]). Because exp()=0\exp(-\infty) = 0, blocked positions receive exactly zero attention weight. The output is then Oi=jAijVjO_i = \sum_j A_{ij} \cdot V_j, but the sum effectively runs only over the 2W+12W + 1 visible positions.

The Window Mask Matrix

The mask is a band matrix centred on the diagonal with bandwidth 2W+12W + 1:

Mij={0 (visible)if ijW1 (blocked)if ij>WM_{ij} = \begin{cases} 0 \text{ (visible)} & \text{if } |i - j| \leq W \\ 1 \text{ (blocked)} & \text{if } |i - j| > W \end{cases}

For our 5-token sequence with W=1W = 1:

The (j=0)cat (j=1)sat (j=2)on (j=3)mat (j=4)
The (i=0)0 (visible)0 (visible)1 (blocked)1 (blocked)1 (blocked)
cat (i=1)0 (visible)0 (visible)0 (visible)1 (blocked)1 (blocked)
sat (i=2)1 (blocked)0 (visible)0 (visible)0 (visible)1 (blocked)
on (i=3)1 (blocked)1 (blocked)0 (visible)0 (visible)0 (visible)
mat (i=4)1 (blocked)1 (blocked)1 (blocked)0 (visible)0 (visible)

Notice the symmetric band pattern along the diagonal. Unlike the causal mask (Chapter 3) which is triangular, the sliding window mask is symmetric — token ii can see token jj if and only if jj can see ii. This bidirectional locality is natural for encoder tasks like BERT-style understanding.


Interactive: Sliding Window Mask Explorer

Use the slider below to change the window size WW and watch how the attention mask, weights, and computational cost change. Click any row to see the weight distribution for that token. Switch between "Mask Pattern", "Attention Weights", and "vs Full Attention" views.

Loading mask explorer...

Step-by-Step Calculation

We use the same shared example as every chapter: 5 tokens ("The cat sat on mat"), dk=4d_k = 4, and W=1W = 1.

Step 1: Raw Dot Products

Compute raw[i,j]=QiKj\text{raw}[i,j] = Q_i \cdot K_j^\top for all i,ji, j. This is identical to standard attention:

Thecatsatonmat
The0.002.001.001.001.50
cat3.000.002.001.000.50
sat1.002.002.001.001.50
on1.001.000.002.001.00
mat1.001.001.001.001.50

Step 2: Scaling

Divide by dk=4=2.0\sqrt{d_k} = \sqrt{4} = 2.0:

Thecatsatonmat
The0.0001.0000.5000.5000.750
cat1.5000.0001.0000.5000.250
sat0.5001.0001.0000.5000.750
on0.5000.5000.0001.0000.500
mat0.5000.5000.5000.5000.750

Step 3: Apply Sliding Window Mask

For W=1W = 1: any position where ij>1|i - j| > 1 is set to -\infty:

Thecatsatonmat
The0.0001.000-\infty-\infty-\infty
cat1.5000.0001.000-\infty-\infty
sat-\infty1.0001.0000.500-\infty
on-\infty-\infty0.0001.0000.500
mat-\infty-\infty-\infty0.5000.750

The -\infty entries form the blocked region outside the band. Only the diagonal band of width 2W+1=32W + 1 = 3 retains real scores.

Step 4: Softmax Over Visible Tokens

Softmax is computed per row, but exp()=0\exp(-\infty) = 0 so blocked positions contribute nothing. Let us trace row by row:

Row 0 ("The"): visible scores [0.000,1.000][0.000, 1.000]

  • exp(0.000)=1.0000\exp(0.000) = 1.0000, exp(1.000)=2.7183\exp(1.000) = 2.7183
  • Sum = 3.7183
  • Weights: [0.2689,0.7311,0,0,0][0.2689, 0.7311, 0, 0, 0]

Row 1 ("cat"): visible scores [1.500,0.000,1.000][1.500, 0.000, 1.000]

  • exp(1.500)=4.4817\exp(1.500) = 4.4817, exp(0.000)=1.0000\exp(0.000) = 1.0000, exp(1.000)=2.7183\exp(1.000) = 2.7183
  • Sum = 8.2000
  • Weights: [0.5465,0.1220,0.3315,0,0][0.5465, 0.1220, 0.3315, 0, 0]

Row 2 ("sat"): visible scores [1.000,1.000,0.500][1.000, 1.000, 0.500]

  • exp(1.000)=2.7183\exp(1.000) = 2.7183, exp(1.000)=2.7183\exp(1.000) = 2.7183, exp(0.500)=1.6487\exp(0.500) = 1.6487
  • Sum = 7.0853
  • Weights: [0,0.3837,0.3837,0.2327,0][0, 0.3837, 0.3837, 0.2327, 0]

Row 3 ("on"): visible scores [0.000,1.000,0.500][0.000, 1.000, 0.500]

  • exp(0.000)=1.0000\exp(0.000) = 1.0000, exp(1.000)=2.7183\exp(1.000) = 2.7183, exp(0.500)=1.6487\exp(0.500) = 1.6487
  • Sum = 5.3670
  • Weights: [0,0,0.1863,0.5065,0.3072][0, 0, 0.1863, 0.5065, 0.3072]

Row 4 ("mat"): visible scores [0.500,0.750][0.500, 0.750]

  • exp(0.500)=1.6487\exp(0.500) = 1.6487, exp(0.750)=2.1170\exp(0.750) = 2.1170
  • Sum = 3.7657
  • Weights: [0,0,0,0.4378,0.5622][0, 0, 0, 0.4378, 0.5622]

Step 5: Weighted Sum of Values

Each output Oi=jAijVjO_i = \sum_j A_{ij} \cdot V_j, where the sum only runs over visible positions:

Row 0 ("The"):

O0=0.2689×[1,0,0,0]+0.7311×[0,1,0,0]=[0.2689,0.7311,0,0]O_0 = 0.2689 \times [1,0,0,0] + 0.7311 \times [0,1,0,0] = [0.2689, 0.7311, 0, 0]

Row 2 ("sat"):

O2=0.3837×[0,1,0,0]+0.3837×[0,0,1,0]+0.2327×[0,0,0,1]=[0,0.3837,0.3837,0.2327]O_2 = 0.3837 \times [0,1,0,0] + 0.3837 \times [0,0,1,0] + 0.2327 \times [0,0,0,1] = [0, 0.3837, 0.3837, 0.2327]

Row 3 ("on"):

O3=0.1863×[0,0,1,0]+0.5065×[0,0,0,1]+0.3072×[0.5,0.5,0.5,0.5]=[0.1536,0.1536,0.3399,0.6601]O_3 = 0.1863 \times [0,0,1,0] + 0.5065 \times [0,0,0,1] + 0.3072 \times [0.5,0.5,0.5,0.5] = [0.1536, 0.1536, 0.3399, 0.6601]

Worked Example: What "sat" Sees

Let us trace the complete pipeline for "sat" (position 2, W=1W = 1):

  1. Visible positions: 2j1|2 - j| \leq 1 gives j{1,2,3}j \in \{1, 2, 3\} = ("cat", "sat", "on"). "The" is 2 positions away, "mat" is 2 away — both blocked.
  2. Scaled scores: [1.000,1.000,0.500][1.000, 1.000, 0.500] — "cat" and "sat" have the highest relevance.
  3. Softmax: [0.3837,0.3837,0.2327][0.3837, 0.3837, 0.2327] — "cat" and "sat" split evenly, "on" gets less.
  4. Output: [0,0.3837,0.3837,0.2327][0, 0.3837, 0.3837, 0.2327] — a blend of the three visible value vectors.
What "sat" loses: In full attention, "sat" could see "The" (weight 0.1519) and "mat" (weight 0.1951). With W=1W = 1, those connections are severed. But the weights for the visible tokens are re-normalised — "cat" and "sat" each get 0.3837 instead of 0.2505. The model compensates by concentrating attention on the most locally relevant tokens.

Full Attention Weights and Output

Sliding Window Attention Weights (W=1W = 1, 5×55 \times 5)

Thecatsatonmat
The0.26890.73110.00000.00000.0000
cat0.54650.12200.33150.00000.0000
sat0.00000.38370.38370.23270.0000
on0.00000.00000.18630.50650.3072
mat0.00000.00000.00000.43780.5622

Output Matrix (5×45 \times 4)

dim-0dim-1dim-2dim-3
The0.26890.73110.00000.0000
cat0.54650.12200.33150.0000
sat0.00000.38370.38370.2327
on0.15360.15360.33990.6601
mat0.28110.28110.28110.7189

Standard vs Sliding Window Comparison

Compare with full (standard) attention from Chapter 1:

TokenStandard OutputWindow (W=1) OutputMax Difference
The[0.2254, 0.4135, 0.2964, 0.2964][0.2689, 0.7311, 0.0000, 0.0000]0.3176
cat[0.4602, 0.1475, 0.3018, 0.2058][0.5465, 0.1220, 0.3315, 0.0000]0.2058
sat[0.2495, 0.3481, 0.3481, 0.2495][0.0000, 0.3837, 0.3837, 0.2327]0.2495
on[0.2854, 0.2854, 0.2106, 0.4089][0.1536, 0.1536, 0.3399, 0.6601]0.2512
mat[0.3108, 0.3108, 0.3108, 0.3108][0.2811, 0.2811, 0.2811, 0.7189]0.4081

The differences are non-trivial because our toy sequence is very short (N=5N = 5). In real models with hundreds of tokens, the vast majority of attention weight is already local, so the sliding window output closely matches full attention for most positions.


The Receptive Field Insight

A single sliding window layer with W=1W = 1 limits each token to 3 neighbours. This seems very restrictive. But transformers have multiple layers, and this changes everything.

After LL layers, each token's output incorporates information from tokens up to L×WL \times W positions away. The mechanism works like a relay: in layer 1, "sat" aggregates information from "cat", "sat", and "on". In layer 2, "sat"'s output now contains information that "cat" gathered from "The" in layer 1 — so "sat" effectively "sees" "The" through "cat".

effective receptive field=L×W\text{effective receptive field} = L \times W positions on each side

For a model like Mistral-7B with L=32L = 32 layers and W=4096W = 4096:

receptive field=32×4096=131,072 tokens on each side\text{receptive field} = 32 \times 4096 = 131{,}072 \text{ tokens on each side}

This far exceeds the 32K context window, meaning information can propagate across the entire input.

Interactive: Layer Stacking Receptive Field

Explore how increasing the window size or number of layers expands the receptive field. Click different tokens to trace their information reach:

Loading receptive field visualizer...

Applications Across Domains

Long-Document NLP

Longformer was designed for tasks where standard BERT truncates the input. Legal document analysis, patent search, scientific paper summarisation — all require processing thousands of tokens. Longformer with W=256W = 256 and global tokens on [CLS]\texttt{[CLS]} set new state-of-the-art on WikiHop, TriviaQA, and HotpotQA.

Code Analysis

Source code has strong locality: variable references are usually within tens of lines of their declaration. A sliding window of W=512W = 512 tokens captures most local scopes (function bodies, class definitions) while ignoring distant import statements that rarely affect local logic. CodeLlama uses a similar approach to handle repository-scale context.

Genomic Sequence Modeling

DNA sequences can be millions of base pairs long. Most functional relationships in genes (promoter-gene interactions, exon-intron boundaries) operate within windows of a few thousand base pairs. Models like Enformer use local attention windows to process entire chromosomes.

Time Series Forecasting

Financial data, sensor readings, and weather measurements exhibit temporal locality: today's temperature depends mostly on yesterday's, not last month's. Sliding window attention naturally captures this with WW chosen to match the seasonality period.


Connection to Modern Systems

Longformer: Global + Local Pattern

Longformer's innovation is the hybrid attention pattern: most tokens use sliding window attention, but a few designated "global" tokens attend to the entire sequence. In classification, the [CLS]\texttt{[CLS]} token is global. In question answering, the question tokens are global. This gives O(N)O(N) complexity for most tokens while preserving the ability to aggregate global information where needed.

Mistral and Sliding Window KV-Cache

Mistral-7B (Jiang et al., 2023) uses W=4096W = 4096 sliding window attention with a rolling KV-cache. During autoregressive generation, the cache only stores the most recent WW key-value pairs per layer. Older entries are evicted. This bounds the KV-cache to O(W×L×dk)O(W \times L \times d_k) regardless of how long the generated sequence becomes — a constant memory budget during inference.

SystemWindow WLayers LEffective RangeKV-Cache Bound
Longformer256123,072N/A (encoder)
Mistral-7B4,09632131,0724,096 entries/layer
BigBird (Chapter 12)Global+Local12Full sequenceN/A (encoder)

Flash Attention with Block Sparsity

Flash Attention (Chapter 13) works by tiling the attention matrix into blocks that fit in SRAM. Sliding window attention maps perfectly onto this: blocks outside the window band are never loaded from HBM, saving both memory and compute. The Flash Attention 2 paper explicitly supports sliding window as a block-sparse pattern.


Complexity Analysis

MetricStandard AttentionSliding Window (W fixed)
Score computationO(N2dk)O(N^2 \cdot d_k)O(NWdk)O(N \cdot W \cdot d_k)
Memory (score matrix)O(N2)O(N^2)O(NW)O(N \cdot W)
Output computationO(N2dv)O(N^2 \cdot d_v)O(NWdv)O(N \cdot W \cdot d_v)
Total FLOPsO(N2d)O(N^2 \cdot d)O(NWd)O(N \cdot W \cdot d)
KV-cache (inference)O(Nd)O(N \cdot d) (grows with N)O(Wd)O(W \cdot d) (constant)

The key result: For fixed WW, sliding window attention is O(N)O(N) in both time and memory. Mistral-7B with W=4096W = 4096 and N=32,768N = 32{,}768 reduces the score matrix from 1 billion entries to 268 million — an 8× reduction.

Trade-off: The savings grow with the ratio N/WN / W. For short sequences where N2W+1N \leq 2W + 1, the window covers the entire sequence and there is no benefit. Sliding window attention is most powerful for long sequences.

Python Implementation

A complete, self-contained class with the shared example. Every line is annotated with execution state — click any line to see the exact values flowing through it.

Sliding Window Attention — NumPy Implementation
🐍sliding_window_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. Q @ K.T runs as optimized C code, not Python loops.

2import math

Standard library math module. We use math.sqrt() for the scaling factor.

4class SlidingWindowAttention

A self-contained class implementing sliding window attention. The only difference from standard attention (Chapter 1) is the window mask applied before softmax: positions with |i - j| > W are blocked.

17def __init__(self, d_k, window_size=1)

Constructor takes d_k (dimension of Q/K vectors) and window_size W. Pre-computes the scaling factor sqrt(d_k).

EXECUTION STATE
⬇ input: d_k = 4 (dimension of each query/key vector)
⬇ input: window_size = 1 (±1 neighbour on each side)
⬆ sets = self.d_k=4, self.W=1, self.scale=2.0
23self.d_k = d_k

Store the key dimension for reference.

EXECUTION STATE
self.d_k = 4
24self.W = window_size

Store window size. W=1 means each token sees itself and ±1 neighbour = 3 positions max.

EXECUTION STATE
self.W = 1
25self.scale = math.sqrt(d_k)

Pre-compute √d_k = √4 = 2.0. Dividing scores by this prevents softmax saturation.

EXECUTION STATE
self.scale = 2.0
27def _softmax(self, x) → np.ndarray

Numerically stable softmax that correctly handles -inf values from the window mask. Operates row-by-row: each row sums to 1.0.

EXECUTION STATE
⬇ input: x (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000    -inf    -inf    -inf
cat  1.500   0.000   1.000    -inf    -inf
sat   -inf   1.000   1.000   0.500    -inf
on    -inf    -inf   0.000   1.000   0.500
mat   -inf    -inf    -inf   0.500   0.750
⬆ returns = np.ndarray (5,5) — attention weights, each row sums to 1.0
29x_safe = np.where(np.isfinite(x), x, -1e9)

Replace -inf with -1e9. np.isfinite() returns False for -inf. This lets np.max() work correctly without returning -inf.

EXECUTION STATE
np.isfinite(x) = True for real numbers, False for -inf entries
-1e9 = -1000000000.0 — large enough that exp(-1e9) ≈ 0
30x_shifted = x_safe - np.max(x_safe, axis=-1, keepdims=True)

Subtract the row-wise maximum from each element. This is the numerical stability trick: exp(x - max) avoids overflow while producing identical softmax results.

EXECUTION STATE
axis=-1 = find max along last axis — each row gets its own max
keepdims=True = result has shape (5,1) not (5,), so broadcasting x(5×5) - max(5×1) works
── Row 0 (The) ── =
max = 1.000 (from 'cat' position)
x_shifted[0] = [-1.000, 0.000, -1e9, -1e9, -1e9]
── Row 2 (sat) ── =
max = 1.000 (from 'cat' and 'sat' positions)
x_shifted[2] = [-1e9, 0.000, 0.000, -0.500, -1e9]
── Row 4 (mat) ── =
max = 0.750 (self-attention)
x_shifted[4] = [-1e9, -1e9, -1e9, -0.250, 0.000]
31exp_x = np.exp(x_shifted)

Element-wise exponentiation. Blocked positions have x_shifted ≈ -1e9, so exp(-1e9) ≈ 0. The max entry per row has x_shifted = 0, so exp(0) = 1.

EXECUTION STATE
── Row 0 (The) ── =
exp_x[0] = [0.3679, 1.0000, ≈0, ≈0, ≈0]
── Row 2 (sat) ── =
exp_x[2] = [≈0, 1.0000, 1.0000, 0.6065, ≈0]
── Row 4 (mat) ── =
exp_x[4] = [≈0, ≈0, ≈0, 0.7788, 1.0000]
32return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

Divide each element by its row sum. Blocked positions get weight ≈ 0. Each row sums to exactly 1.0.

EXECUTION STATE
row sums = [1.3679, 2.3679+..., 2.6065, 1.9740, 1.7788]
⬆ return: weights (5×5) =
       The     cat     sat      on     mat
The  0.2689  0.7311  0.0000  0.0000  0.0000
cat  0.5465  0.1220  0.3315  0.0000  0.0000
sat  0.0000  0.3837  0.3837  0.2327  0.0000
on   0.0000  0.0000  0.1863  0.5065  0.3072
mat  0.0000  0.0000  0.0000  0.4378  0.5622
34def build_window_mask(self, N) → np.ndarray

Builds the boolean window mask. True means the position is BLOCKED (|i - j| > W). This is the ONLY difference from standard attention.

EXECUTION STATE
⬇ input: N = 5 (number of tokens)
self.W = 1 (window size)
⬆ returns = np.ndarray (5,5) of bool — True where |i - j| > W
36idx = np.arange(N)

Create position indices [0, 1, 2, 3, 4] for the 5 tokens.

EXECUTION STATE
idx = [0, 1, 2, 3, 4]
37dist = np.abs(idx[:, None] - idx[None, :])

Compute the absolute distance matrix. idx[:, None] has shape (5,1) and idx[None,:] has shape (1,5) — broadcasting gives a (5,5) distance matrix.

EXECUTION STATE
idx[:, None] = shape (5,1) — column vector [0,1,2,3,4]ᵀ
idx[None, :] = shape (1,5) — row vector [0,1,2,3,4]
dist (5×5) =
     The  cat  sat   on  mat
The    0    1    2    3    4
cat    1    0    1    2    3
sat    2    1    0    1    2
on     3    2    1    0    1
mat    4    3    2    1    0
38return dist > self.W

Boolean comparison: True where distance exceeds window W=1. These positions will be set to -inf before softmax.

EXECUTION STATE
⬆ return: mask (5×5) =
       The    cat    sat     on    mat
The  False  False   True   True   True
cat  False  False  False   True   True
sat   True  False  False  False   True
on    True   True  False  False  False
mat   True   True   True  False  False
40def forward(self, Q, K, V)

Full forward pass: raw scores → scale → window mask → softmax → weighted sum. Identical to standard attention except for the window mask step.

EXECUTION STATE
⬇ input: Q (5×4) =
       d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
       d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
       d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = (weights (5×5), output (5×4))
53N = Q.shape[0]

Number of tokens in the sequence.

EXECUTION STATE
N = 5
54raw_scores = Q @ K.T

Matrix multiply Q (5×4) with K transposed (4×5). Each entry [i,j] is the dot product of query i with key j.

EXECUTION STATE
K.T = K transposed from (5×4) to (4×5)
raw_scores (5×5) =
      The   cat   sat    on   mat
The  0.00  2.00  1.00  1.00  1.50
cat  3.00  0.00  2.00  1.00  0.50
sat  1.00  2.00  2.00  1.00  1.50
on   1.00  1.00  0.00  2.00  1.00
mat  1.00  1.00  1.00  1.00  1.50
55scaled_scores = raw_scores / self.scale

Divide every score by sqrt(d_k) = 2.0. This keeps variance at ~1, preventing softmax saturation.

EXECUTION STATE
self.scale = 2.0
scaled_scores (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
56mask = self.build_window_mask(N)

Build the 5×5 boolean mask. True where |i-j| > 1.

EXECUTION STATE
mask (5×5) =
       The    cat    sat     on    mat
The  False  False   True   True   True
cat  False  False  False   True   True
sat   True  False  False  False   True
on    True   True  False  False  False
mat   True   True   True  False  False
57masked_scores = scaled_scores.copy()

Copy to avoid modifying the original. The original scores are preserved.

58masked_scores[mask] = -np.inf

Boolean indexing: everywhere mask is True (|i-j| > W), set score to -inf. This is the CORE OPERATION of sliding window attention.

EXECUTION STATE
masked_scores (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000    -inf    -inf    -inf
cat  1.500   0.000   1.000    -inf    -inf
sat   -inf   1.000   1.000   0.500    -inf
on    -inf    -inf   0.000   1.000   0.500
mat   -inf    -inf    -inf   0.500   0.750
59weights = self._softmax(masked_scores)

Softmax over the masked scores. Blocked positions get weight 0.0000. Each row sums to 1.0.

EXECUTION STATE
weights (5×5) =
       The     cat     sat      on     mat
The  0.2689  0.7311  0.0000  0.0000  0.0000
cat  0.5465  0.1220  0.3315  0.0000  0.0000
sat  0.0000  0.3837  0.3837  0.2327  0.0000
on   0.0000  0.0000  0.1863  0.5065  0.3072
mat  0.0000  0.0000  0.0000  0.4378  0.5622
60output = weights @ V

Multiply weights (5×5) by V (5×4). Each output row is a weighted sum of value vectors from ONLY visible (within-window) tokens.

EXECUTION STATE
output (5×4) =
       d0      d1      d2      d3
The  0.2689  0.7311  0.0000  0.0000
cat  0.5465  0.1220  0.3315  0.0000
sat  0.0000  0.3837  0.3837  0.2327
on   0.1536  0.1536  0.3399  0.6601
mat  0.2811  0.2811  0.2811  0.7189
61return weights, output

Return both weights (for visualization) and output (fed to next layer).

90tokens = ["The", "cat", "sat", "on", "mat"]

The shared example sentence used throughout all 15 chapters. 5 tokens with d_k = 4.

EXECUTION STATE
tokens = ["The", "cat", "sat", "on", "mat"]
N = 5 tokens
92Q = np.array([...])

Query matrix (5×4). Each row is what that token is 'asking for' — the features it wants to find in the sequence.

EXECUTION STATE
Q (5×4) =
       d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
100K = np.array([...])

Key matrix (5×4). Each row is what that token 'advertises' about itself.

EXECUTION STATE
K (5×4) =
       d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
108V = np.array([...])

Value matrix (5×4). Each row is the content that token contributes when selected. Q and K decide WHO to attend to; V decides WHAT to retrieve.

EXECUTION STATE
V (5×4) =
       d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
117attn = SlidingWindowAttention(d_k=4, window_size=1)

Instantiate with d_k=4 and W=1. Sets self.scale = sqrt(4) = 2.0.

EXECUTION STATE
attn.d_k = 4
attn.W = 1
attn.scale = 2.0
118weights, output = attn.forward(Q, K, V)

Run the full sliding window attention forward pass. Returns 5×5 weight matrix and 5×4 output.

EXECUTION STATE
weights (5×5) =
       The     cat     sat      on     mat
The  0.2689  0.7311  0.0000  0.0000  0.0000
cat  0.5465  0.1220  0.3315  0.0000  0.0000
sat  0.0000  0.3837  0.3837  0.2327  0.0000
on   0.0000  0.0000  0.1863  0.5065  0.3072
mat  0.0000  0.0000  0.0000  0.4378  0.5622
output (5×4) =
       d0      d1      d2      d3
The  0.2689  0.7311  0.0000  0.0000
cat  0.5465  0.1220  0.3315  0.0000
sat  0.0000  0.3837  0.3837  0.2327
on   0.1536  0.1536  0.3399  0.6601
mat  0.2811  0.2811  0.2811  0.7189
124attn.explain(Q, K, V, tokens, query_idx=2)

Print a detailed trace for 'sat' (row 2). Shows visible/blocked tokens and the full computation.

EXECUTION STATE
query_idx = 2 → 'sat'
visible = ['cat', 'sat', 'on'] (j=1,2,3)
blocked = ['The', 'mat'] (j=0,4)
99 lines without explanation
1import numpy as np
2import math
3
4class SlidingWindowAttention:
5    """
6    Sliding Window Attention (Beltagy et al., 2020)
7
8    Each token attends only to its W nearest neighbours on
9    each side, reducing complexity from O(N^2) to O(N * W).
10    Positions outside the window receive score = -inf so
11    softmax gives them exactly zero weight.
12    """
13
14    def __init__(self, d_k: int, window_size: int = 1):
15        """
16        Args:
17            d_k:         Dimension of query/key vectors
18            window_size: W — each token sees +/-W neighbours
19        """
20        self.d_k = d_k
21        self.W = window_size
22        self.scale = math.sqrt(d_k)
23
24    def _softmax(self, x: np.ndarray) -> np.ndarray:
25        """Numerically stable softmax that handles -inf."""
26        x_safe = np.where(np.isfinite(x), x, -1e9)
27        x_shifted = x_safe - np.max(x_safe, axis=-1, keepdims=True)
28        exp_x = np.exp(x_shifted)
29        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
30
31    def build_window_mask(self, N: int) -> np.ndarray:
32        """Build boolean mask: True where |i - j| > W (blocked)."""
33        idx = np.arange(N)
34        dist = np.abs(idx[:, None] - idx[None, :])
35        return dist > self.W
36
37    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
38        """
39        Full forward pass.
40
41        Args:
42            Q: Query matrix  (N, d_k)
43            K: Key matrix    (N, d_k)
44            V: Value matrix  (N, d_v)
45
46        Returns:
47            weights: Attention weight matrix (N, N)
48            output:  Context-enriched output  (N, d_v)
49        """
50        N = Q.shape[0]
51        raw_scores = Q @ K.T
52        scaled_scores = raw_scores / self.scale
53        mask = self.build_window_mask(N)
54        masked_scores = scaled_scores.copy()
55        masked_scores[mask] = -np.inf
56        weights = self._softmax(masked_scores)
57        output = weights @ V
58        return weights, output
59
60    def explain(self, Q, K, V, tokens, query_idx=0):
61        """Print a detailed trace for a specific query token."""
62        N = Q.shape[0]
63        raw = Q @ K.T
64        scaled = raw / self.scale
65        mask = self.build_window_mask(N)
66
67        t = tokens[query_idx]
68        visible = [j for j in range(N) if not mask[query_idx, j]]
69        blocked = [j for j in range(N) if mask[query_idx, j]]
70
71        print(f"\n=== Sliding window trace for '{t}' (row {query_idx}, W={self.W}) ===")
72        print(f"Q[{query_idx}] = {Q[query_idx]}")
73        print(f"Visible: {[tokens[j] for j in visible]} (|i-j| <= {self.W})")
74        print(f"Blocked: {[tokens[j] for j in blocked]} (|i-j| > {self.W})")
75
76        vis_scores = [scaled[query_idx, j] for j in visible]
77        exps = [math.exp(s) for s in vis_scores]
78        total = sum(exps)
79        ws = [e / total for e in exps]
80
81        print(f"\nScaled scores (visible only):")
82        for idx, j in enumerate(visible):
83            print(f"  S[{t},{tokens[j]}] = {vis_scores[idx]:.4f}")
84
85        print(f"\nSoftmax weights:")
86        for idx, j in enumerate(visible):
87            bar = '#' * int(ws[idx] * 40)
88            print(f"  A[{t},{tokens[j]}] = {ws[idx]:.4f} |{bar}|")
89
90        out = sum(ws[idx] * V[j] for idx, j in enumerate(visible))
91        print(f"\nOutput O[{t}] = {np.round(out, 4)}")
92
93
94# ── Shared Example (same Q, K, V as every chapter) ──
95tokens = ["The", "cat", "sat", "on", "mat"]
96
97Q = np.array([
98    [1.0, 0.0, 1.0, 0.0],   # The
99    [0.0, 2.0, 0.0, 1.0],   # cat
100    [1.0, 1.0, 1.0, 0.0],   # sat
101    [0.0, 0.0, 1.0, 1.0],   # on
102    [1.0, 0.0, 0.0, 1.0],   # mat
103])
104
105K = np.array([
106    [0.0, 1.0, 0.0, 1.0],   # The
107    [1.0, 0.0, 1.0, 0.0],   # cat
108    [1.0, 1.0, 0.0, 0.0],   # sat
109    [0.0, 0.0, 1.0, 1.0],   # on
110    [1.0, 0.0, 0.5, 0.5],   # mat
111])
112
113V = np.array([
114    [1.0, 0.0, 0.0, 0.0],   # The
115    [0.0, 1.0, 0.0, 0.0],   # cat
116    [0.0, 0.0, 1.0, 0.0],   # sat
117    [0.0, 0.0, 0.0, 1.0],   # on
118    [0.5, 0.5, 0.5, 0.5],   # mat
119])
120
121# ── Run ──
122attn = SlidingWindowAttention(d_k=4, window_size=1)
123weights, output = attn.forward(Q, K, V)
124
125print("Sliding Window Attention Weights (W=1):")
126print(np.round(weights, 4))
127
128print("\nSliding Window Output (W=1):")
129print(np.round(output, 4))
130
131# Detailed trace for "sat"
132attn.explain(Q, K, V, tokens, query_idx=2)

PyTorch Implementation

The PyTorch version supports GPU acceleration, automatic differentiation, and integrates with standard training loops. The mask construction uses masked_fill\texttt{masked\_fill} for efficient GPU-friendly masking.

Sliding Window Attention — PyTorch Implementation
🐍sliding_window_attention_torch.py
1import torch

PyTorch provides GPU-accelerated tensor operations and autograd for backpropagation.

2import torch.nn as nn

Neural network module. nn.Module is the base class for all PyTorch models.

3import torch.nn.functional as F

Functional API: F.softmax, F.scaled_dot_product_attention.

4import math

For math.sqrt(d_k), the scaling factor.

6class SlidingWindowAttention(nn.Module)

PyTorch module for sliding window attention. Inherits nn.Module for parameter tracking, GPU support, and integration with training loops.

14def __init__(self, d_model, window_size=1)

Initialize with model dimension and window size. In production, would also create learned W_Q, W_K, W_V projections.

EXECUTION STATE
⬇ input: d_model = 4
⬇ input: window_size = 1
16self.d_model = d_model

Store model dimension.

EXECUTION STATE
self.d_model = 4
17self.window_size = window_size

Store window size W.

EXECUTION STATE
self.window_size = 1
18self.scale = math.sqrt(d_model)

Pre-compute scaling factor.

EXECUTION STATE
self.scale = 2.0 (√4)
25def build_window_mask(self, N, device) → torch.Tensor

Build boolean mask on the correct device (CPU or GPU). True means blocked (|i-j| > W).

EXECUTION STATE
⬇ input: N = 5
⬇ input: device = cpu (or cuda:0 for GPU)
27idx = torch.arange(N, device=device)

Position indices on the specified device.

EXECUTION STATE
idx = tensor([0, 1, 2, 3, 4])
28dist = (idx.unsqueeze(1) - idx.unsqueeze(0)).abs()

Compute pairwise absolute distance matrix via broadcasting. unsqueeze(1) creates a column, unsqueeze(0) a row.

EXECUTION STATE
.unsqueeze(1) = reshape (5,) → (5,1) — column vector
.unsqueeze(0) = reshape (5,) → (1,5) — row vector
dist (5×5) =
     The  cat  sat   on  mat
The    0    1    2    3    4
cat    1    0    1    2    3
sat    2    1    0    1    2
on     3    2    1    0    1
mat    4    3    2    1    0
29return dist > self.window_size

Boolean comparison. True where distance exceeds W=1.

EXECUTION STATE
⬆ return: mask (5×5) =
       The    cat    sat     on    mat
The  False  False   True   True   True
cat  False  False  False   True   True
sat   True  False  False  False   True
on    True   True  False  False  False
mat   True   True   True  False  False
31def forward(self, Q, K, V) → tuple

Forward pass: scaled dot-product with window masking.

EXECUTION STATE
⬇ input: Q = torch.Tensor (5, 4)
⬇ input: K = torch.Tensor (5, 4)
⬇ input: V = torch.Tensor (5, 4)
45N = Q.size(0)

Number of tokens.

EXECUTION STATE
N = 5
46scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

Scaled dot-product scores. K.transpose(-2,-1) swaps last two dims.

EXECUTION STATE
.transpose(-2, -1) = swap last two dimensions: (N,d) → (d,N). Equivalent to .T for 2D.
scores (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000   0.500   0.500   0.750
cat  1.500   0.000   1.000   0.500   0.250
sat  0.500   1.000   1.000   0.500   0.750
on   0.500   0.500   0.000   1.000   0.500
mat  0.500   0.500   0.500   0.500   0.750
47mask = self.build_window_mask(N, Q.device)

Build mask on the same device as Q (CPU or GPU).

48scores = scores.masked_fill(mask, float("-inf"))

PyTorch masked_fill: where mask is True, replace with -inf. Cleaner than boolean indexing.

EXECUTION STATE
.masked_fill(mask, val) = where mask is True, replace with val. Non-masked positions unchanged.
scores after masking (5×5) =
       The     cat     sat      on     mat
The  0.000   1.000    -inf    -inf    -inf
cat  1.500   0.000   1.000    -inf    -inf
sat   -inf   1.000   1.000   0.500    -inf
on    -inf    -inf   0.000   1.000   0.500
mat   -inf    -inf    -inf   0.500   0.750
49weights = F.softmax(scores, dim=-1)

Apply softmax along last dimension. -inf positions get weight 0.

EXECUTION STATE
dim=-1 = softmax along columns. Each row sums to 1.0.
weights (5×5) =
       The     cat     sat      on     mat
The  0.2689  0.7311  0.0000  0.0000  0.0000
cat  0.5465  0.1220  0.3315  0.0000  0.0000
sat  0.0000  0.3837  0.3837  0.2327  0.0000
on   0.0000  0.0000  0.1863  0.5065  0.3072
mat  0.0000  0.0000  0.0000  0.4378  0.5622
50output = torch.matmul(weights, V)

Weighted sum of value vectors from visible tokens only.

EXECUTION STATE
output (5×4) =
       d0      d1      d2      d3
The  0.2689  0.7311  0.0000  0.0000
cat  0.5465  0.1220  0.3315  0.0000
sat  0.0000  0.3837  0.3837  0.2327
on   0.1536  0.1536  0.3399  0.6601
mat  0.2811  0.2811  0.2811  0.7189
51return output, weights

Return (output, weights) tuple.

54def manual(Q, K, V, window_size=1)

Static method for running sliding window attention with pre-computed Q, K, V. Used for our shared example.

EXECUTION STATE
⬇ input: Q = torch.Tensor (5, 4)
⬇ input: window_size = 1 (default)
56d_k = K.size(-1)

Infer d_k from key matrix.

EXECUTION STATE
d_k = 4
57N = Q.size(0)

Number of tokens.

EXECUTION STATE
N = 5
58scale = math.sqrt(d_k)

Scaling factor.

EXECUTION STATE
scale = 2.0
60scores = Q @ K.transpose(-2, -1) / scale

Scaled dot-product scores.

61idx = torch.arange(N)

Position indices [0,1,2,3,4].

62dist = (idx.unsqueeze(1) - idx.unsqueeze(0)).abs()

5×5 distance matrix.

63mask = dist > window_size

Boolean mask: True where |i-j| > W.

64scores = scores.masked_fill(mask, float("-inf"))

Apply window mask.

65weights = F.softmax(scores, dim=-1)

Softmax with blocked positions zeroed.

66output = weights @ V

Final weighted sum.

67return output, weights

Return results.

71tokens = ["The", "cat", "sat", "on", "mat"]

Shared example sentence.

73Q = torch.tensor([...])

Query matrix as PyTorch tensor. Same values as NumPy version.

81K = torch.tensor([...])

Key matrix.

89V = torch.tensor([...])

Value matrix.

98output, weights = SlidingWindowAttention.manual(Q, K, V, window_size=1)

Run sliding window attention with W=1. Returns output (5×4) and weights (5×5).

EXECUTION STATE
output (5×4) =
       d0      d1      d2      d3
The  0.2689  0.7311  0.0000  0.0000
cat  0.5465  0.1220  0.3315  0.0000
sat  0.0000  0.3837  0.3837  0.2327
on   0.1536  0.1536  0.3399  0.6601
mat  0.2811  0.2811  0.2811  0.7189
105for W in [0, 1, 2, 4]:

Compare different window sizes. Shows how sparsity decreases as W grows.

LOOP TRACE · 4 iterations
W=0 (self-only)
active connections = 5/25 (each token sees only itself)
W=1 (±1 neighbour)
active connections = 13/25 (48% reduction)
W=2 (±2 neighbours)
active connections = 21/25 (16% reduction)
W=4 (full — all 5 tokens visible)
active connections = 25/25 (0% reduction = full attention)
110if torch.cuda.is_available():

GPU acceleration. Simply move model and tensors to GPU with .cuda(). The mask and softmax operations are unchanged — PyTorch handles device placement.

78 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class SlidingWindowAttention(nn.Module):
7    """
8    Sliding Window Attention in PyTorch (Beltagy et al., 2020)
9
10    Supports GPU, automatic differentiation, and batched inputs.
11    Uses masked_fill for efficient GPU-friendly masking.
12    """
13
14    def __init__(self, d_model: int, window_size: int = 1):
15        super().__init__()
16        self.d_model = d_model
17        self.window_size = window_size
18        self.scale = math.sqrt(d_model)
19
20        # In production: learned projection matrices
21        # self.W_Q = nn.Linear(d_model, d_model, bias=False)
22        # self.W_K = nn.Linear(d_model, d_model, bias=False)
23        # self.W_V = nn.Linear(d_model, d_model, bias=False)
24
25    def build_window_mask(self, N: int, device: torch.device) -> torch.Tensor:
26        """Build boolean mask: True where |i - j| > W."""
27        idx = torch.arange(N, device=device)
28        dist = (idx.unsqueeze(1) - idx.unsqueeze(0)).abs()
29        return dist > self.window_size
30
31    def forward(
32        self,
33        Q: torch.Tensor,
34        K: torch.Tensor,
35        V: torch.Tensor,
36    ) -> tuple[torch.Tensor, torch.Tensor]:
37        """
38        Args:
39            Q: (N, d_model)
40            K: (N, d_model)
41            V: (N, d_model)
42
43        Returns:
44            output:  (N, d_model)
45            weights: (N, N)
46        """
47        N = Q.size(0)
48        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
49        mask = self.build_window_mask(N, Q.device)
50        scores = scores.masked_fill(mask, float("-inf"))
51        weights = F.softmax(scores, dim=-1)
52        output = torch.matmul(weights, V)
53        return output, weights
54
55    @staticmethod
56    def manual(Q, K, V, window_size=1):
57        """Run with pre-computed Q, K, V (no learned weights)."""
58        d_k = K.size(-1)
59        N = Q.size(0)
60        scale = math.sqrt(d_k)
61
62        scores = Q @ K.transpose(-2, -1) / scale
63        idx = torch.arange(N)
64        dist = (idx.unsqueeze(1) - idx.unsqueeze(0)).abs()
65        mask = dist > window_size
66        scores = scores.masked_fill(mask, float("-inf"))
67        weights = F.softmax(scores, dim=-1)
68        output = weights @ V
69        return output, weights
70
71
72# ── Shared Example ──
73tokens = ["The", "cat", "sat", "on", "mat"]
74
75Q = torch.tensor([
76    [1.0, 0.0, 1.0, 0.0],
77    [0.0, 2.0, 0.0, 1.0],
78    [1.0, 1.0, 1.0, 0.0],
79    [0.0, 0.0, 1.0, 1.0],
80    [1.0, 0.0, 0.0, 1.0],
81])
82
83K = torch.tensor([
84    [0.0, 1.0, 0.0, 1.0],
85    [1.0, 0.0, 1.0, 0.0],
86    [1.0, 1.0, 0.0, 0.0],
87    [0.0, 0.0, 1.0, 1.0],
88    [1.0, 0.0, 0.5, 0.5],
89])
90
91V = torch.tensor([
92    [1.0, 0.0, 0.0, 0.0],
93    [0.0, 1.0, 0.0, 0.0],
94    [0.0, 0.0, 1.0, 0.0],
95    [0.0, 0.0, 0.0, 1.0],
96    [0.5, 0.5, 0.5, 0.5],
97])
98
99# ── Run ──
100output, weights = SlidingWindowAttention.manual(Q, K, V, window_size=1)
101
102print("Sliding Window Weights (W=1):")
103print(weights.round(decimals=4))
104
105print("\nSliding Window Output (W=1):")
106print(output.round(decimals=4))
107
108# Compare different window sizes
109for W in [0, 1, 2, 4]:
110    out, w = SlidingWindowAttention.manual(Q, K, V, window_size=W)
111    nonzero = (w > 0.001).sum().item()
112    print(f"W={W}: {nonzero}/25 active connections")
113
114# GPU acceleration (if available)
115if torch.cuda.is_available():
116    swa = SlidingWindowAttention(d_model=4, window_size=1).cuda()
117    out_gpu, _ = swa(Q.cuda(), K.cuda(), V.cuda())
118    print("GPU output matches:", torch.allclose(output, out_gpu.cpu(), atol=1e-4))

Key Takeaways

  1. One mask, one change: Sliding window attention is standard attention with a single modification — positions where ij>W|i - j| > W are set to -\infty before softmax.
  2. Linear complexity: For fixed WW, the cost is O(N×W)O(N \times W) — linear in sequence length NN.
  3. Locality is sufficient: Most linguistic, code, and genomic dependencies are local. The window captures the vast majority of useful attention weight.
  4. Layer stacking recovers range: After LL layers, information propagates L×WL \times W positions, enabling full-sequence understanding despite local attention.
  5. KV-cache savings: In autoregressive models like Mistral, the rolling KV-cache stores only WW entries per layer, enabling constant-memory inference.
  6. Composable with other techniques: Sliding window combines naturally with Flash Attention (block sparsity), global tokens (Longformer), and causal masking (Mistral).

Exercises

  1. Vary the window: Using the Python class above, compute attention weights for W=0W = 0 (self-only), W=2W = 2, and W=4W = 4 (full). Verify that W=4W = 4 produces the same weights as Chapter 1's standard attention.
  2. Causal + sliding window: Modify the class to combine the causal mask (j>ij > i \Rightarrow -\infty) with the sliding window mask (ij>W|i - j| > W \Rightarrow -\infty). This is what Mistral-7B uses. How many active connections remain for N=5N = 5, W=1W = 1?
  3. Receptive field proof: Prove by induction that after LL layers of window-WW attention, token ii's output depends on tokens in the range [iLW,i+LW][i - LW, i + LW].
  4. Memory calculation: A model has 32 layers, 32 heads, dk=128d_k = 128, and processes N=32,768N = 32{,}768 tokens. Calculate the total score matrix memory in FP16 for (a) full attention and (b) sliding window with W=4096W = 4096. What is the ratio?
  5. Dilated window: Instead of contiguous positions, consider a dilated window where token ii attends to positions {i2W,iW,i,i+W,i+2W}\{i - 2W, i - W, i, i + W, i + 2W\}. Implement this and compare the output to the standard sliding window.

References

  1. Beltagy, I., Peters, M. E., & Cohan, A. (2020). Longformer: The Long-Document Transformer. arXiv:2004.05150.
  2. Child, R., Gray, S., Radford, A., & Sutskever, I. (2019). Generating Long Sequences with Sparse Transformers. arXiv:1904.10509.
  3. Jiang, A. Q., Sablayrolles, A., Mensch, A., et al. (2023). Mistral 7B. arXiv:2310.06825.
  4. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
  5. Zaheer, M., Guruganesh, G., Dubey, A., et al. (2020). Big Bird: Transformers for Longer Sequences. NeurIPS 2020.
  6. Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). Attention Is All You Need. NeurIPS 2017.
Loading comments...