Chapter 12
14 min read
Section 13 of 17

Sparse Attention \u2014 BigBird

Sparse Attention \u2014 BigBird

Learning Objectives

By the end of this chapter, you will:

  1. Understand why sliding window attention alone fails for tasks requiring long-range dependencies, and how BigBird solves this with a three-pillar sparse attention pattern.
  2. Learn the graph-theoretic result that makes BigBird a universal approximator equivalent to full attention, yet with O(N)O(N) complexity instead of O(N2)O(N^2).
  3. Master the mathematical formulation of the BigBird sparse mask, including local window, global tokens, and random edges.
  4. Compute a complete worked example by hand using “The cat sat on mat”, verifying every matrix value.
  5. Implement BigBird sparse attention from scratch in both NumPy and PyTorch, ready to plug into real transformer architectures.
Why This Matters: BigBird was the first attention variant to prove a theoretical equivalence to full attention while operating in linear time. It powers Google's long-document NLP systems and is the direct ancestor of LongT5, Pegasus-X, and the sparse attention patterns in modern models like Gemini. Understanding BigBird teaches you how to think about attention as a graph problem — a perspective that unlocks all subsequent efficient attention research.

The Real Problem: When Local Windows Are Not Enough

In Chapter 11, we saw that Sliding Window Attention reduces complexity from O(N2)O(N^2) to O(N×W)O(N \times W) by restricting each token to its ±W\pm W nearest neighbors. This works beautifully for local syntax — a verb is usually near its subject, an adjective near its noun. But real language is not purely local.

Consider a 10,000-token legal document. The opening paragraph states: “The Seller agrees to indemnify the Buyer.” Nine thousand tokens later, a clause reads: “In breach of the foregoing obligation, penalties shall apply.” The word “foregoing” at position 9,500 must attend to “indemnify” at position 42. With a window of W=256W = 256, information would need to propagate through 9458/256=37\lceil 9458 / 256 \rceil = 37 layers before these tokens can interact — by which point the signal has decayed beyond recovery.

Why Sliding Window Fails

Think of the attention mask as a graph where tokens are nodes and allowed attention connections are edges. In sliding window attention with W=1W = 1, the graph looks like a chain:

The ↔ cat ↔ sat ↔ on ↔ mat

The diameter of this graph (the longest shortest path between any two nodes) is N1N - 1. For our 5-token example, that means “The” and “mat” need 4 hops (layers) to exchange information. For N=4,096N = 4{,}096 tokens with W=1W = 1, the diameter is 4,095 — requiring thousands of transformer layers for end-to-end communication.

Attention PatternGraph DiameterHops for N=4096
Full attention11
Sliding window (W=1)N − 14,095
Sliding window (W=256)⌈(N−1)/W⌉16
BigBird (local + global)22

The Graph-Theory Insight

Zaheer et al. (NeurIPS 2020) made a crucial observation: if you add even a single global token that connects to every other token, the graph diameter drops to 2, regardless of sequence length. Any token can reach any other token in at most 2 hops: go to the global token, then from the global token to the target.

This is the same principle behind hub-and-spoke networks in airline routing. You don't need direct flights between every pair of cities. A few major hubs (Chicago, Atlanta, Dallas) let you reach any destination in at most 2 flights.

The Analogy: Full attention is like having a direct flight between every pair of 4,096 cities — 4096216.8M4096^2 \approx 16.8M routes. BigBird is like keeping local bus routes (window), adding a few airport hubs (global tokens), and adding occasional charter flights (random edges) — only 2M\sim 2M routes but you can still reach anywhere in 2 hops.

BigBird's Three-Pillar Solution

BigBird constructs its sparse attention mask by combining three independent patterns, each serving a distinct purpose:

Pillar 1: Local Window Attention

Identical to Chapter 11's sliding window. Token ii attends to all tokens jj where ijW|i - j| \leq W. This captures the syntactic dependencies that are overwhelmingly local: subject-verb agreement, adjective-noun modification, preposition attachment.

In our example with W=1W = 1, “sat” (position 2) attends to “cat” (1), “sat” (2), and “on” (3). This captures the core phrase structure “cat sat on”.

Pillar 2: Global Token Attention

A small set of designated tokens (typically the [CLS] token or the first few tokens) are made “global”: they attend to every token, and every token attends to them. These global tokens serve as information aggregators — they absorb information from the entire sequence and make it available to all other tokens.

Formally, if gg is a global token index:

  • Row gg is all active: token gg attends to every position.
  • Column gg is all active: every token attends to position gg.

In our example, “The” (index 0) is the global token. This means “mat” (index 4) can access information about “The” directly — even though they are 4 positions apart, well beyond the W=1W = 1 window.

Pillar 3: Random Attention

Each token additionally attends to RR randomly selected positions that are not already covered by local or global patterns. This is motivated by the Erd\u0151s–R\u00e9nyi random graph theory: a random graph with RlogNR \geq \log N random edges per node is connected with high probability.

Random edges serve as “shortcuts” through the graph, reducing the effective diameter even further and ensuring that information can flow between any two tokens through unexpected paths. Even with just R=3R = 3 random edges per token, the attention graph becomes a small-world network where multi-layer propagation reaches everywhere efficiently.

The Theoretical Guarantee

Theorem (Zaheer et al., 2020): The BigBird sparse attention pattern with local window, global tokens, and random connections is a universal approximator of sequence-to-sequence functions. Any function computable by a full-attention transformer can also be computed by a BigBird transformer — with the same depth and width, using only O(N)O(N) attention operations per layer instead of O(N2)O(N^2).

This is a remarkable theoretical result. It means that the quadratic cost of full attention is not an inherent requirement of the transformer architecture — it is an artifact of the dense attention pattern. The expressivity of transformers comes from the multi-layer composition of attention and feedforward operations, not from having every token attend to every other token in a single layer.


Mathematical Formulation

Mask Construction

The BigBird attention mask M{0,1}N×NM \in \{0, 1\}^{N \times N} is defined as the union of three binary masks:

M=MlocalMglobalMrandomM = M_{\text{local}} \cup M_{\text{global}} \cup M_{\text{random}}

where:

  1. Local mask: Mlocal[i,j]=1    ijWM_{\text{local}}[i, j] = 1 \iff |i - j| \leq W
  2. Global mask: Mglobal[i,j]=1    iG or jGM_{\text{global}}[i, j] = 1 \iff i \in \mathcal{G} \text{ or } j \in \mathcal{G}, where G\mathcal{G} is the set of global token indices
  3. Random mask: For each row ii, select RR positions uniformly at random from the inactive set {j:Mlocal[i,j]=0 and Mglobal[i,j]=0}\{j : M_{\text{local}}[i,j] = 0 \text{ and } M_{\text{global}}[i,j] = 0\}

The Sparse Attention Equation

Given the combined mask MM, the BigBird attention output is:

S[i,j]=Q[i]K[j]dkS[i,j] = \frac{Q[i] \cdot K[j]^\top}{\sqrt{d_k}} — raw scaled dot-product score

S~[i,j]={S[i,j]if M[i,j]=1if M[i,j]=0\tilde{S}[i,j] = \begin{cases} S[i,j] & \text{if } M[i,j] = 1 \\ -\infty & \text{if } M[i,j] = 0 \end{cases} — masked score

A[i,j]=exp(S~[i,j])k:M[i,k]=1exp(S~[i,k])A[i,j] = \frac{\exp(\tilde{S}[i,j])}{\sum_{k : M[i,k]=1} \exp(\tilde{S}[i,k])} — attention weight (softmax over active positions only)

out[i]=j:M[i,j]=1A[i,j]V[j]\text{out}[i] = \sum_{j : M[i,j]=1} A[i,j] \cdot V[j] — weighted sum of value vectors

The key insight: setting blocked scores to -\infty before softmax causes exp()=0\exp(-\infty) = 0, which means blocked positions contribute zero weight. The softmax denominator only sums over active positions, so the remaining weights still sum to 1.

Complexity Analysis

ComponentConnections per tokenTotal connectionsAsymptotic
Local window2W + 1N \times (2W + 1)O(NW)
Global tokens|\mathcal{G}|2 \times N \times |\mathcal{G}|O(NG)
Random edgesRN \times RO(NR)
Total BigBird\leq 2W + 1 + G + R\leq N(2W + 1 + 2G + R)O(N)
Full attentionNN^2O(N^2)

For fixed WW, GG, and RR, the total number of active connections is O(N)O(N) — linear in sequence length. In practice, BigBird typically uses W=64W = 64 or W=256W = 256, G=2G = 2 global tokens, and R=3R = 3 random edges. For N=4,096N = 4{,}096:

  • Full attention: 40962=16,777,2164096^2 = 16{,}777{,}216 connections
  • BigBird (W=256,G=2,R=3W = 256, G = 2, R = 3): 4096×(513+4+3)=2,129,9204096 \times (513 + 4 + 3) = 2{,}129{,}920 connections — an 87% reduction

Interactive: BigBird Mask Explorer

Experiment with BigBird's three attention pillars below. Adjust the window size, change the global token, and toggle random edges to see how the sparse mask, attention weights, and output change in real time.

Loading BigBird visualizer...
Try this: Set window W=0, then observe that only the global token row/column is active. Now increase W to 2 — notice how most of the matrix becomes active, approaching full attention. The sweet spot for real models is typically W=64256W = 64\text{--}256 with N=4096+N = 4096\text{+}.

Step-by-Step Calculation

Let us trace the complete BigBird sparse attention computation on our shared example “The cat sat on mat” with W=1W = 1, global token = “The” (index 0), and R=0R = 0 (no random edges).

Step 1: Compute Raw Scores

First, compute all 25 pairwise scores S=QK/dkS = Q \cdot K^\top / \sqrt{d_k} exactly as in full attention. Each score measures similarity between a query token and a key token:

KTheK_{\text{The}}KcatK_{\text{cat}}KsatK_{\text{sat}}KonK_{\text{on}}KmatK_{\text{mat}}
QTheQ_{\text{The}}0.00001.00000.50000.50000.7500
QcatQ_{\text{cat}}1.50000.00001.00000.50000.2500
QsatQ_{\text{sat}}0.50001.00001.00000.50000.7500
QonQ_{\text{on}}0.50000.50000.00001.00000.5000
QmatQ_{\text{mat}}0.50000.50000.50000.50000.7500

For example, S[cat,The]=[0,2,0,1][0,1,0,1]4=0+2+0+12=1.5S[\text{cat}, \text{The}] = \frac{[0, 2, 0, 1] \cdot [0, 1, 0, 1]}{\sqrt{4}} = \frac{0 + 2 + 0 + 1}{2} = 1.5

Step 2: Build the BigBird Mask

For each position pair (i,j)(i, j), check if ANY condition holds:

Query ↓ / Key →The (0)cat (1)sat (2)on (3)mat (4)Active
The (global)G+LG+LGGG5/5
catG+LLL3/5
satGLLL4/5
onGLLL4/5
matGLL3/5

G = global (row 0 or column 0), L = local (ij1|i-j| \leq 1), × = blocked. Total: 19/25 connections active (76%), versus 25/25 for full attention.

Step 3: Apply Mask to Scores

Blocked positions are set to -\infty:

Thecatsatonmat
The0.00001.00000.50000.50000.7500
cat1.50000.00001.0000−∞−∞
sat0.50001.00001.00000.5000−∞
on0.5000−∞0.00001.00000.5000
mat0.5000−∞−∞0.50000.7500

Step 4: Row-Wise Softmax

Now we apply softmax to each row independently. The -\infty positions produce exp()=0\exp(-\infty) = 0, so only active positions contribute. Let us trace the most interesting row: “mat” (row 4).

Row 4 (“mat”) — Why this row is interesting: Without the global token, “mat” would only see “on” and itself (window W=1W = 1). With the global hub “The”, it gains access to a document-level summary, letting it borrow context from the very start of the sentence.

Active positions for “mat”: {\{The (j=0, global), on (j=3, local), mat (j=4, local)}\}

Active scores: [0.5000,  0.5000,  0.7500][0.5000,\; 0.5000,\; 0.7500]

Subtract row max (0.75000.7500): [0.2500,  0.2500,  0.0000][-0.2500,\; -0.2500,\; 0.0000]

Exponentiate: [0.7788,  0.7788,  1.0000][0.7788,\; 0.7788,\; 1.0000], sum = 2.55762.5576

Normalize: A[mat]=[0.3045,  0,  0,  0.3045,  0.3910]A[\text{mat}] = [0.3045,\; 0,\; 0,\; 0.3045,\; 0.3910]

Notice that “mat” distributes its attention across three sources: 30.45% to the global hub “The”, 30.45% to its neighbor “on”, and 39.10% to itself. The global connection lets “mat” access global context that would be invisible with only a local window.

Complete softmax for all 5 rows is shown in the interactive Python code below. Click any line in the CodeExplanation panel to see the exact intermediate values for each row.

Step 5: Compute Output

out[mat]=0.3045×V[The]+0.3045×V[on]+0.3910×V[mat]\text{out}[\text{mat}] = 0.3045 \times V[\text{The}] + 0.3045 \times V[\text{on}] + 0.3910 \times V[\text{mat}]

=0.3045×[1,0,0,0]+0.3045×[0,0,0,1]+0.3910×[0.5,0.5,0.5,0.5]= 0.3045 \times [1, 0, 0, 0] + 0.3045 \times [0, 0, 0, 1] + 0.3910 \times [0.5, 0.5, 0.5, 0.5]

=[0.3045,0,0,0]+[0,0,0,0.3045]+[0.1955,0.1955,0.1955,0.1955]= [0.3045, 0, 0, 0] + [0, 0, 0, 0.3045] + [0.1955, 0.1955, 0.1955, 0.1955]

=[0.5000,  0.1955,  0.1955,  0.5000]= [0.5000,\; 0.1955,\; 0.1955,\; 0.5000]


Full Attention Weight and Output Matrices

Attention Weights — BigBird Sparse (5×55 \times 5)

Thecatsatonmat
The0.10950.29760.18050.18050.2318
cat0.54650.12200.33150.00000.0000
sat0.18880.31120.31120.18880.0000
on0.23500.00000.14250.38750.2350
mat0.30450.00000.00000.30450.3910

Key observations: Row 0 (“The”, global) has nonzero weights for all 5 tokens. Rows 1 and 4 (“cat” and “mat”) have only 3 nonzero weights each. The zeros correspond to blocked positions.

Output Matrix — Sparse Attention (5×45 \times 4)

dim-0dim-1dim-2dim-3
The0.22540.41350.29640.2964
cat0.54650.12200.33150.0000
sat0.18880.31120.31120.1888
on0.35250.11750.26000.5050
mat0.50000.19550.19550.5000

Comparison: Sparse vs Full Attention

How much does the sparse mask change the output compared to full quadratic attention?

TokenBigBird OutputFull Attention OutputL2 Error
The[0.2254, 0.4135, 0.2964, 0.2964][0.2254, 0.4135, 0.2964, 0.2964]0.0000
cat[0.5465, 0.1220, 0.3315, 0.0000][0.4602, 0.1475, 0.3018, 0.2058]0.2363
sat[0.1888, 0.3112, 0.3112, 0.1888][0.2495, 0.3481, 0.3481, 0.2495]0.0859
on[0.3525, 0.1175, 0.2600, 0.5050][0.2854, 0.2854, 0.2106, 0.4089]0.2003
mat[0.5000, 0.1955, 0.1955, 0.5000][0.3108, 0.3108, 0.3108, 0.3108]0.2678

“The” has zero error because it is the global token and sees all 5 positions in both sparse and full attention — identical weights, identical output. Other tokens show small L2 errors because the sparse mask redistributes their attention weights among fewer positions. In practice, with learned Q/K/V projections and multiple layers, the model compensates for this redistribution during training.


Applications Across Domains

DomainLong Sequence ProblemBigBird ApplicationResult
GenomicsDNA sequences of 32K+ basesPredicting promoter-gene interactions across distant lociSOTA on chromatin profiling benchmarks
Legal NLPContracts with 10K+ tokensCross-referencing clauses separated by thousands of wordsBetter than full attention with 87% less memory
Scientific PapersFull-text papers (8K+ tokens)Abstract-to-conclusion citation analysisPowers Google’s LongT5 and Pegasus-X
Code UnderstandingLarge codebases (16K+ tokens)Import statements referencing distant function definitionsUsed in Code Llama’s long-context mode
Medical RecordsPatient histories over yearsConnecting diagnosis notes to treatment outcomes months apartImproved clinical prediction accuracy

Connection to Modern Systems

SystemBigBird’s InfluenceKey Difference
Longformer (Beltagy 2020)Same local + global pattern, independent parallel workLongformer uses dilated sliding window; BigBird adds random edges
LongT5 (Guo et al. 2022)Direct descendant of BigBirdAdds TGlobal (Transient Global) with learned global projections
Flash Attention (Dao 2022)Orthogonal optimization: same math, faster IOFlash Attention is about hardware tiling, not sparsity pattern
Mistral / Sliding WindowAdopted BigBird’s local window patternMistral uses pure sliding window without global/random
Gemini (Google 2024)Uses sparse attention patterns for long contextCombines multiple sparsity strategies including BigBird-like patterns

Key distinction: BigBird defines which positions attend to each other (the sparsity pattern). Flash Attention optimizes how the computation is executed on GPU hardware. They are complementary: you can use Flash Attention to efficiently compute BigBird's sparse attention pattern.


Python Implementation

The full BigBird sparse attention implementation with the shared example. Click any line to see the execution state, variable values, and row-by-row softmax computation.

BigBird Sparse Attention \u2014 NumPy Implementation
🐍sparse_bigbird_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. Q @ K.T runs as optimized C code, not Python loops.

2import math

Python’s math module gives us math.sqrt() for the scaling factor √d_k.

4class BigBirdSparseAttention

Encapsulates BigBird’s three-pillar sparse attention: local window + global tokens + random edges. All state (window size, global indices, random seed) is stored as instance attributes.

18def __init__(self, d_k, W, global_indices, R, seed)

Constructor takes the head dimension d_k, window half-size W, list of global token positions, number of random edges R per token, and a random seed for reproducibility.

EXECUTION STATE
⬇ input: d_k = 4 (each token is a 4-dimensional vector)
⬇ input: W = 1 (attend to ±1 neighbors on each side)
⬇ input: global_indices = [0] (token 'The' at index 0 is the global hub)
⬇ input: R = 0 (no random edges in this demo)
⬇ input: seed = 42 (reproducible randomness)
27self.d_k = d_k

Store head dimension. Used only for reference; the actual scaling uses self.scale.

EXECUTION STATE
self.d_k = 4
28self.W = W

Window half-size. Token i attends to positions [i−W, i+W]. W=1 means ±1 neighbor = 3 tokens per window.

EXECUTION STATE
self.W = 1
29self.global_indices = global_indices or [0]

Positions of global hub tokens. Default: index 0. Global tokens attend to ALL positions and ALL tokens attend to them.

EXECUTION STATE
self.global_indices = [0]
30self.R = R

Number of random connections per token. Each token will randomly attend to R additional positions not already covered by local or global.

EXECUTION STATE
self.R = 0
31self.scale = math.sqrt(d_k)

Precompute √d_k = √4 = 2.0. Dividing scores by this prevents softmax saturation when d_k is large.

EXECUTION STATE
self.scale = 2.0 (√4)
32self.rng = np.random.RandomState(seed)

Create a seeded random number generator. Using RandomState instead of global np.random ensures reproducible random edges across runs.

EXECUTION STATE
self.rng = RandomState(42)
34def _build_mask(self, N) → np.ndarray

Constructs the N×N boolean BigBird mask by combining local window, global tokens, and random edges. Returns True where attention is allowed, False where it is blocked.

EXECUTION STATE
⬇ input: N = 5 (number of tokens in 'The cat sat on mat')
⬆ returns = np.ndarray (5, 5) boolean — True = attend, False = block
36mask = np.zeros((N, N), dtype=bool)

Initialize a 5×5 mask of all False (all blocked). We will selectively set positions to True.

EXECUTION STATE
mask (5×5) =
      The    cat    sat     on    mat
The  False  False  False  False  False
cat  False  False  False  False  False
sat  False  False  False  False  False
on   False  False  False  False  False
mat  False  False  False  False  False
38for i in range(N): — local window

For each query token i, enable attention to positions within [i−W, i+W]. This is Pillar 1: local window attention.

LOOP TRACE · 5 iterations
i=0 (The)
lo = max(0, 0−1) = 0
hi = min(5, 0+1+1) = 2
mask[0, 0:2] = True = The attends to [The, cat]
i=1 (cat)
lo = max(0, 1−1) = 0
hi = min(5, 1+1+1) = 3
mask[1, 0:3] = True = cat attends to [The, cat, sat]
i=2 (sat)
lo = max(0, 2−1) = 1
hi = min(5, 2+1+1) = 4
mask[2, 1:4] = True = sat attends to [cat, sat, on]
i=3 (on)
lo = max(0, 3−1) = 2
hi = min(5, 3+1+1) = 5
mask[3, 2:5] = True = on attends to [sat, on, mat]
i=4 (mat)
lo = max(0, 4−1) = 3
hi = min(5, 4+1+1) = 5
mask[4, 3:5] = True = mat attends to [on, mat]
43for g in self.global_indices: — global tokens

Pillar 2: For each global token index g, enable its entire row (g attends to all) and entire column (all attend to g). This creates information hub nodes.

LOOP TRACE · 1 iterations
g=0 (The)
mask[0, :] = True = 'The' now attends to ALL 5 tokens
mask[:, 0] = True = ALL 5 tokens now attend to 'The'
48for i in range(N): — random edges

Pillar 3: For each token, find positions not yet active and randomly pick R of them. With R=0, this loop adds nothing. With R=1, it adds one random edge per token.

LOOP TRACE · 5 iterations
i=0 (The)
inactive positions = [] (The is global — all already active)
added = none (already fully connected)
i=1 (cat)
inactive positions = [3, 4] (on, mat not in local or global)
added (R=0) = none
i=2 (sat)
inactive positions = [4] (mat not in local or global)
added (R=0) = none
i=3 (on)
inactive positions = [1] (cat not in local or global)
added (R=0) = none
i=4 (mat)
inactive positions = [1, 2] (cat, sat not in local or global)
added (R=0) = none
57return mask

Returns the complete 5×5 boolean mask. True means the score is kept; False means it is replaced with −∞ before softmax.

EXECUTION STATE
⬆ return: mask (5×5) =
      The    cat    sat     on    mat
The   True   True   True   True   True
cat   True   True   True  False  False
sat   True   True   True   True  False
on    True  False   True   True   True
mat   True  False  False   True   True
59def _softmax(self, x) → np.ndarray

Numerically stable softmax: subtract row-wise max before exp() to prevent overflow. Positions set to −∞ become 0 after exp(−∞) = 0.

EXECUTION STATE
⬇ input: x (5×5) = Masked score matrix with −∞ at blocked positions
⬆ returns = np.ndarray (5, 5) — row-normalized probabilities (rows sum to 1.0)
61x_safe = np.where(np.isfinite(x), x, -1e9)

Replace true −∞ with a very large negative number (−1e9). This avoids NaN in the max computation while keeping the same softmax result (exp(−1e9) ≈ 0).

62x_shifted = x_safe - np.max(x_safe, axis=-1, keepdims=True)

Subtract the row maximum for numerical stability. The largest value in each row becomes 0, preventing exp() overflow.

EXECUTION STATE
axis=-1 = Operate along columns within each row. For a 5×5 matrix, find the max of each row independently.
keepdims=True = Keep shape as (5,1) not (5,), so broadcasting x(5×5) − max(5×1) works correctly.
63exp_x = np.where(np.isfinite(x), np.exp(x_shifted), 0.0)

Exponentiate shifted scores where the original was finite; set blocked positions to 0.0. This ensures blocked positions contribute nothing to the softmax sum.

64return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

Divide each row’s exp values by the row sum to get probabilities. Each row sums to 1.0. Blocked positions remain 0.

EXECUTION STATE
axis=-1 = Sum along columns within each row.
keepdims=True = Keep shape (5,1) for broadcasting division.
66def forward(self, Q, K, V)

Main forward pass. Takes Q, K, V matrices, builds the sparse mask, computes masked attention, and returns weights + output.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = (weights, output, mask) — all as np.ndarray
76N = Q.shape[0]

Number of tokens = number of rows in Q.

EXECUTION STATE
N = 5
77mask = self._build_mask(N)

Calls _build_mask(5) to construct the 5×5 boolean BigBird mask with local + global + random patterns.

EXECUTION STATE
mask (5×5) =
      The    cat    sat     on    mat
The   True   True   True   True   True
cat   True   True   True  False  False
sat   True   True   True   True  False
on    True  False   True   True   True
mat   True  False  False   True   True
79scores = Q @ K.T / self.scale

Compute ALL pairwise scaled dot-product scores. Q(5×4) @ K.T(4×5) = 5×5 score matrix, then divide by √d_k = 2.0.

EXECUTION STATE
@ = Python matrix multiplication operator (matmul)
K.T = Transpose of K: shape changes from (5,4) to (4,5)
self.scale = 2.0 (√4)
scores (5×5) =
        The     cat     sat      on     mat
The   0.0000  1.0000  0.5000  0.5000  0.7500
cat   1.5000  0.0000  1.0000  0.5000  0.2500
sat   0.5000  1.0000  1.0000  0.5000  0.7500
on    0.5000  0.5000  0.0000  1.0000  0.5000
mat   0.5000  0.5000  0.5000  0.5000  0.7500
81masked = np.where(mask, scores, -np.inf)

Apply the sparse mask: keep scores where mask is True, replace with −∞ where mask is False. After softmax, −∞ → exp(−∞) = 0, so blocked tokens get zero weight.

EXECUTION STATE
masked (5×5) =
         The      cat      sat       on      mat
The    0.0000   1.0000   0.5000   0.5000   0.7500
cat    1.5000   0.0000   1.0000     -inf     -inf
sat    0.5000   1.0000   1.0000   0.5000     -inf
on     0.5000     -inf   0.0000   1.0000   0.5000
mat    0.5000     -inf     -inf   0.5000   0.7500
83weights = self._softmax(masked)

Apply row-wise softmax over only the active (non-−∞) positions. Each row sums to 1.0, and blocked positions get weight 0.

EXECUTION STATE
── Row 0 (The) ── =
active scores = [0.0000, 1.0000, 0.5000, 0.5000, 0.7500] (all 5 active — global row)
max = 1.0000
shifted = [-1.0000, 0.0000, -0.5000, -0.5000, -0.2500]
exp = [0.3679, 1.0000, 0.6065, 0.6065, 0.7788]
sum = 3.3597
weights[The] = [0.1095, 0.2976, 0.1805, 0.1805, 0.2318]
── Row 1 (cat) ── =
active scores = [1.5000, 0.0000, 1.0000] (3 active: The, cat, sat)
max = 1.5000
shifted = [0.0000, -1.5000, -0.5000]
exp = [1.0000, 0.2231, 0.6065]
sum = 1.8297
weights[cat] = [0.5465, 0.1220, 0.3315, 0.0000, 0.0000]
── Row 2 (sat) ── =
active scores = [0.5000, 1.0000, 1.0000, 0.5000] (4 active: The, cat, sat, on)
max = 1.0000
shifted = [-0.5000, 0.0000, 0.0000, -0.5000]
exp = [0.6065, 1.0000, 1.0000, 0.6065]
sum = 3.2131
weights[sat] = [0.1888, 0.3112, 0.3112, 0.1888, 0.0000]
── Row 3 (on) ── =
active scores = [0.5000, 0.0000, 1.0000, 0.5000] (4 active: The, sat, on, mat)
max = 1.0000
shifted = [-0.5000, -1.0000, 0.0000, -0.5000]
exp = [0.6065, 0.3679, 1.0000, 0.6065]
sum = 2.5809
weights[on] = [0.2350, 0.0000, 0.1425, 0.3875, 0.2350]
── Row 4 (mat) ── =
active scores = [0.5000, 0.5000, 0.7500] (3 active: The, on, mat)
max = 0.7500
shifted = [-0.2500, -0.2500, 0.0000]
exp = [0.7788, 0.7788, 1.0000]
sum = 2.5576
weights[mat] = [0.3045, 0.0000, 0.0000, 0.3045, 0.3910]
85output = weights @ V

Multiply attention weights (5×5) by value matrix V (5×4) to get the 5×4 output. Each row is a weighted combination of V vectors from the active positions.

EXECUTION STATE
output (5×4) =
        dim0     dim1     dim2     dim3
The   0.2254   0.4135   0.2964   0.2964
cat   0.5465   0.1220   0.3315   0.0000
sat   0.1888   0.3112   0.3112   0.1888
on    0.3525   0.1175   0.2600   0.5050
mat   0.5000   0.1955   0.1955   0.5000
87return weights, output, mask

Return all three results: the sparse attention weight matrix, the output context vectors, and the boolean mask used.

EXECUTION STATE
⬆ return: weights = shape (5, 5) — sparse attention probabilities
⬆ return: output = shape (5, 4) — context-aware token representations
⬆ return: mask = shape (5, 5) — boolean mask (19/25 active)
160 lines without explanation
1import numpy as np
2import math
3
4class BigBirdSparseAttention:
5    """
6    BigBird Sparse Attention (Zaheer et al., NeurIPS 2020)
7
8    Combines three attention patterns for O(N) complexity:
9      1. Local window — each token attends to W neighbors each side
10      2. Global tokens — designated hub tokens attend to/from all
11      3. Random edges  — each token randomly attends to R others
12
13    Theoretical result: this sparse pattern is a universal
14    approximator equivalent to full quadratic attention.
15    """
16
17    def __init__(self, d_k: int, W: int = 1,
18                 global_indices: list = None,
19                 R: int = 0, seed: int = 42):
20        """
21        Args:
22            d_k:  Dimension of each head (for scaling)
23            W:    Window half-size (attend to ±W neighbors)
24            global_indices: List of global hub token positions
25            R:    Number of random connections per token
26            seed: Random seed for reproducible random edges
27        """
28        self.d_k = d_k
29        self.W = W
30        self.global_indices = global_indices or [0]
31        self.R = R
32        self.scale = math.sqrt(d_k)
33        self.rng = np.random.RandomState(seed)
34
35    def _build_mask(self, N: int) -> np.ndarray:
36        """Build the BigBird sparse mask (N x N boolean)."""
37        mask = np.zeros((N, N), dtype=bool)
38
39        # 1) Local window: |i - j| <= W
40        for i in range(N):
41            lo = max(0, i - self.W)
42            hi = min(N, i + self.W + 1)
43            mask[i, lo:hi] = True
44
45        # 2) Global tokens: full row and column
46        for g in self.global_indices:
47            mask[g, :] = True   # global attends to all
48            mask[:, g] = True   # all attend to global
49
50        # 3) Random edges: each token gets R random targets
51        for i in range(N):
52            inactive = np.where(~mask[i])[0]
53            if len(inactive) > 0 and self.R > 0:
54                picks = self.rng.choice(
55                    inactive,
56                    size=min(self.R, len(inactive)),
57                    replace=False
58                )
59                mask[i, picks] = True
60
61        return mask
62
63    def _softmax(self, x: np.ndarray) -> np.ndarray:
64        """Numerically stable softmax along last axis."""
65        x_safe = np.where(np.isfinite(x), x, -1e9)
66        x_shifted = x_safe - np.max(x_safe, axis=-1, keepdims=True)
67        exp_x = np.where(np.isfinite(x), np.exp(x_shifted), 0.0)
68        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
69
70    def forward(self, Q: np.ndarray, K: np.ndarray,
71                V: np.ndarray):
72        """
73        Args:
74            Q: Query matrix  (N, d_k)
75            K: Key matrix    (N, d_k)
76            V: Value matrix  (N, d_k)
77        Returns:
78            weights: Sparse attention weights (N, N)
79            output:  Context vectors (N, d_k)
80            mask:    Boolean mask used (N, N)
81        """
82        N = Q.shape[0]
83        mask = self._build_mask(N)
84
85        # Scaled dot-product scores
86        scores = Q @ K.T / self.scale
87
88        # Apply sparse mask: blocked positions = -inf
89        masked = np.where(mask, scores, -np.inf)
90
91        # Softmax over active positions only
92        weights = self._softmax(masked)
93
94        # Weighted sum of values
95        output = weights @ V
96
97        return weights, output, mask
98
99    def explain(self, Q, K, V, tokens, query_idx=0):
100        """Print step-by-step trace for one query token."""
101        weights, output, mask = self.forward(Q, K, V)
102        token = tokens[query_idx]
103        active = np.where(mask[query_idx])[0]
104
105        print(f"\n=== BigBird trace: '{token}' (idx {query_idx}) ===")
106        print(f"W={self.W}, global={self.global_indices}, R={self.R}")
107        print(f"Active connections: {len(active)}/{len(tokens)}")
108        print()
109
110        for j in range(len(tokens)):
111            status = "ACTIVE" if mask[query_idx, j] else "blocked"
112            w = weights[query_idx, j]
113            bar = "#" * int(w * 40)
114            reason = []
115            if abs(query_idx - j) <= self.W:
116                reason.append("local")
117            if j in self.global_indices:
118                reason.append("global-col")
119            if query_idx in self.global_indices:
120                reason.append("global-row")
121            label = "+".join(reason) if reason else "random/none"
122            print(f"  [{status:7s}] {tokens[j]:4s} "
123                  f"w={w:.4f} |{bar}| ({label})")
124
125        print(f"\nOutput[{token}] = "
126              f"{np.round(output[query_idx], 4)}")
127
128
129# ── Shared Example ("the cat sat on the mat") ──
130tokens = ["The", "cat", "sat", "on", "mat"]
131
132Q = np.array([
133    [1.0, 0.0, 1.0, 0.0],   # The
134    [0.0, 2.0, 0.0, 1.0],   # cat
135    [1.0, 1.0, 1.0, 0.0],   # sat
136    [0.0, 0.0, 1.0, 1.0],   # on
137    [1.0, 0.0, 0.0, 1.0],   # mat
138])
139
140K = np.array([
141    [0.0, 1.0, 0.0, 1.0],   # The
142    [1.0, 0.0, 1.0, 0.0],   # cat
143    [1.0, 1.0, 0.0, 0.0],   # sat
144    [0.0, 0.0, 1.0, 1.0],   # on
145    [1.0, 0.0, 0.5, 0.5],   # mat
146])
147
148V = np.array([
149    [1.0, 0.0, 0.0, 0.0],   # The
150    [0.0, 1.0, 0.0, 0.0],   # cat
151    [0.0, 0.0, 1.0, 0.0],   # sat
152    [0.0, 0.0, 0.0, 1.0],   # on
153    [0.5, 0.5, 0.5, 0.5],   # mat
154])
155
156# BigBird with W=1, global token at index 0, no random edges
157bigbird = BigBirdSparseAttention(
158    d_k=4, W=1, global_indices=[0], R=0, seed=42
159)
160weights, output, mask = bigbird.forward(Q, K, V)
161
162print("=== BigBird Sparse Attention (W=1, global=[0]) ===")
163print("\nAttention Weights:")
164header = "      " + "  ".join(f"{t:>7s}" for t in tokens)
165print(header)
166for i, t in enumerate(tokens):
167    row = "  ".join(f"{weights[i,j]:7.4f}" for j in range(5))
168    print(f"{t:3s}   {row}")
169
170print("\nOutput:")
171print("       dim0     dim1     dim2     dim3")
172for i, t in enumerate(tokens):
173    row = "  ".join(f"{output[i,j]:8.4f}" for j in range(4))
174    print(f"{t:3s}  {row}")
175
176# Trace for each token
177for idx in range(len(tokens)):
178    bigbird.explain(Q, K, V, tokens, query_idx=idx)
179
180# With random edges
181bigbird_r = BigBirdSparseAttention(
182    d_k=4, W=1, global_indices=[0], R=1, seed=42
183)
184w_r, out_r, mask_r = bigbird_r.forward(Q, K, V)
185print("\n=== With Random Edges (R=1) ===")
186print("Mask:")
187for i, t in enumerate(tokens):
188    row = " ".join("Y" if mask_r[i,j] else "." for j in range(5))
189    print(f"  {t}: {row}")

PyTorch Implementation

The equivalent PyTorch implementation using masked_fill\texttt{masked\_fill} and F.softmax\texttt{F.softmax}. This version is GPU-ready and integrates with PyTorch's autograd for training.

BigBird Sparse Attention \u2014 PyTorch Implementation
🐍sparse_bigbird_attention_torch.py
1import torch

PyTorch provides GPU-accelerated tensor operations and automatic differentiation for training.

2import torch.nn as nn

nn.Module is the base class for all neural network components in PyTorch.

3import torch.nn.functional as F

Provides stateless functions like F.softmax that operate on tensors without learnable parameters.

4import math

Used for math.sqrt(d_k) to compute the scaling factor.

6class BigBirdSparseAttention(nn.Module)

Inherits from nn.Module so it integrates with PyTorch’s training loop, parameter management, and device handling.

14def __init__(self, d_k, W, global_indices, R)

Initialize BigBird with head dimension, window size, global token list, and random edge count. Same parameters as the NumPy version.

EXECUTION STATE
⬇ input: d_k = 4
⬇ input: W = 1
⬇ input: global_indices = [0]
⬇ input: R = 0
16super().__init__()

Call nn.Module’s constructor to properly register this as a PyTorch module.

22self.scale = math.sqrt(d_k)

Precompute √4 = 2.0 for score scaling.

EXECUTION STATE
self.scale = 2.0
24def _build_mask(self, N, device) → torch.Tensor

Build the sparse mask as a boolean tensor on the same device (CPU/GPU) as the input data.

EXECUTION STATE
⬇ input: N = 5
⬇ input: device = cpu (or cuda:0 if GPU available)
⬆ returns = torch.Tensor (5, 5) boolean
27mask = torch.zeros(N, N, dtype=torch.bool, device=device)

Create a 5×5 boolean tensor of all False on the target device.

EXECUTION STATE
mask = 5×5 tensor of False
65scores = Q @ K.T / self.scale

Same scaled dot-product as NumPy. Q(5×4) @ K.T(4×5) / 2.0 = 5×5 score matrix.

EXECUTION STATE
K.T = PyTorch .T property transposes the last two dimensions
scores (5×5) = Same values as NumPy version
67scores = scores.masked_fill(~mask, float('-inf'))

PyTorch’s masked_fill: where ~mask is True (blocked positions), replace scores with −∞. This is the efficient PyTorch way to apply the sparse mask.

EXECUTION STATE
~mask = Inverted mask: True where attention is blocked
float('-inf') = Negative infinity — becomes 0 after softmax
69weights = F.softmax(scores, dim=-1)

PyTorch’s built-in softmax along the last dimension (columns within each row). Handles −∞ correctly by producing 0.

EXECUTION STATE
dim=-1 = Softmax along the last axis (key dimension). Each row independently sums to 1.0.
71weights = weights.nan_to_num(0.0)

Safety: if any row was ALL −∞ (impossible in BigBird with global tokens, but defensive), replace NaN with 0.

73output = weights @ V

Weighted sum of value vectors. weights(5×5) @ V(5×4) = output(5×4).

EXECUTION STATE
output (5×4) = Same values as NumPy version
75return weights, output

Return the sparse attention weights and the context-enriched output vectors.

EXECUTION STATE
⬆ return: weights = shape (5, 5)
⬆ return: output = shape (5, 4)
110 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class BigBirdSparseAttention(nn.Module):
7    """
8    BigBird Sparse Attention in PyTorch.
9
10    Combines local window + global tokens + random edges
11    into a single sparse attention mask.
12    """
13
14    def __init__(self, d_k: int, W: int = 1,
15                 global_indices: list = None,
16                 R: int = 0):
17        super().__init__()
18        self.d_k = d_k
19        self.W = W
20        self.global_indices = global_indices or [0]
21        self.R = R
22        self.scale = math.sqrt(d_k)
23
24    def _build_mask(self, N: int,
25                    device: torch.device) -> torch.Tensor:
26        """Build BigBird sparse mask as a boolean tensor."""
27        mask = torch.zeros(N, N, dtype=torch.bool,
28                           device=device)
29
30        # 1) Local window
31        for i in range(N):
32            lo = max(0, i - self.W)
33            hi = min(N, i + self.W + 1)
34            mask[i, lo:hi] = True
35
36        # 2) Global tokens
37        for g in self.global_indices:
38            mask[g, :] = True
39            mask[:, g] = True
40
41        # 3) Random edges
42        if self.R > 0:
43            for i in range(N):
44                inactive = (~mask[i]).nonzero(as_tuple=True)[0]
45                if len(inactive) > 0:
46                    perm = torch.randperm(len(inactive))
47                    picks = inactive[perm[:self.R]]
48                    mask[i, picks] = True
49
50        return mask
51
52    def forward(self, Q: torch.Tensor, K: torch.Tensor,
53                V: torch.Tensor) -> tuple:
54        """
55        Args:
56            Q: (N, d_k) query matrix
57            K: (N, d_k) key matrix
58            V: (N, d_k) value matrix
59        Returns:
60            weights: (N, N) sparse attention weights
61            output:  (N, d_k) context vectors
62        """
63        N = Q.shape[0]
64        mask = self._build_mask(N, Q.device)
65
66        # Scaled dot-product scores
67        scores = Q @ K.T / self.scale
68
69        # Apply mask: blocked = -inf
70        scores = scores.masked_fill(~mask, float("-inf"))
71
72        # Softmax over active positions
73        weights = F.softmax(scores, dim=-1)
74
75        # Handle any NaN from all-masked rows
76        weights = weights.nan_to_num(0.0)
77
78        # Weighted sum
79        output = weights @ V
80
81        return weights, output
82
83
84# ── Shared Example ──
85tokens = ["The", "cat", "sat", "on", "mat"]
86
87Q = torch.tensor([
88    [1.0, 0.0, 1.0, 0.0],
89    [0.0, 2.0, 0.0, 1.0],
90    [1.0, 1.0, 1.0, 0.0],
91    [0.0, 0.0, 1.0, 1.0],
92    [1.0, 0.0, 0.0, 1.0],
93])
94
95K = torch.tensor([
96    [0.0, 1.0, 0.0, 1.0],
97    [1.0, 0.0, 1.0, 0.0],
98    [1.0, 1.0, 0.0, 0.0],
99    [0.0, 0.0, 1.0, 1.0],
100    [1.0, 0.0, 0.5, 0.5],
101])
102
103V = torch.tensor([
104    [1.0, 0.0, 0.0, 0.0],
105    [0.0, 1.0, 0.0, 0.0],
106    [0.0, 0.0, 1.0, 0.0],
107    [0.0, 0.0, 0.0, 1.0],
108    [0.5, 0.5, 0.5, 0.5],
109])
110
111bigbird = BigBirdSparseAttention(
112    d_k=4, W=1, global_indices=[0], R=0
113)
114
115with torch.no_grad():
116    weights, output = bigbird(Q, K, V)
117
118print("BigBird Attention Weights:")
119for i, t in enumerate(tokens):
120    row = " ".join(f"{weights[i,j]:.4f}" for j in range(5))
121    print(f"  {t}: {row}")
122
123print("\nOutput:")
124for i, t in enumerate(tokens):
125    row = " ".join(f"{output[i,j]:.4f}" for j in range(4))
126    print(f"  {t}: {row}")

Key Takeaways

  1. Three pillars: BigBird combines local window (syntax), global tokens (information hubs), and random edges (shortcuts) into one sparse mask.
  2. Linear complexity: Total connections are O(N×(2W+G+R))O(N \times (2W + G + R)), which is O(N)O(N) for fixed W,G,RW, G, R. An 87% reduction at N=4096N = 4096 compared to full attention.
  3. Universal approximation: BigBird is provably as expressive as full attention — the quadratic cost was never necessary for the transformer's representational power.
  4. Graph diameter = 2: Global tokens reduce the shortest path between any two tokens to 2 hops, regardless of sequence length.
  5. Practical impact: Enables processing of 4K–16K+ token sequences (documents, genomes, codebases) where full attention would OOM.
  6. Complementary to Flash Attention: BigBird defines the pattern (which positions interact), Flash Attention optimizes the computation (how to compute it fast on GPUs). Use both together.

Exercises

  1. Window size exploration: Recompute the BigBird attention weights with W=0W = 0 (no local window, only global). How does the output for “mat” change? What happens to tokens that are not adjacent to the global token?
  2. Multiple global tokens: Make both “The” (index 0) and “mat” (index 4) global tokens. How many connections are active now? Compute the new attention weights for “sat”.
  3. Random edges: Add R=1R = 1 random edge per token. Suppose the random connections are: cat→mat, sat→mat, on→cat, mat→cat. Recompute the weights for “cat” now that it can see “mat”.
  4. Scaling analysis: For a sequence of length N=8,192N = 8{,}192 with W=128W = 128, G=2G = 2, R=3R = 3, how many attention connections does BigBird use? What percentage of full attention is that?
  5. Implementation challenge: Modify the PyTorch implementation to support batched inputs with shape (B,N,dk)(B, N, d_k) and multi-head attention with shape (B,H,N,dk)(B, H, N, d_k).

References

  1. Zaheer, M., Guruganesh, G., Dubey, A., Ainslie, J., Alberti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q., Yang, L., & Ahmed, A. (2020). Big Bird: Transformers for Longer Sequences. Advances in Neural Information Processing Systems (NeurIPS) 33, pp. 17283–17297.
  2. Beltagy, I., Peters, M. E., & Cohan, A. (2020). Longformer: The Long-Document Transformer. arXiv preprint arXiv:2004.05150.
  3. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
  4. Guo, M., Ainslie, J., Uthus, D., Ontanon, S., Ni, J., Sung, Y.-H., & Yang, Y. (2022). LongT5: Efficient Text-To-Text Transformer for Long Sequences. Findings of NAACL 2022.
  5. Erd\u0151s, P. & R\u00e9nyi, A. (1959). On Random Graphs I. Publicationes Mathematicae Debrecen, 6, 290–297.
Loading comments...