Introduction
Before we dive into transformers, it's essential to understand the journey that led to their creation. This section traces the evolution of sequence modeling in deep learning, from the earliest recurrent neural networks to the challenges that eventually motivated the development of attention mechanisms and transformers.
Understanding this history will help you appreciate why transformers represent such a fundamental breakthrough, and why they've become the dominant architecture for nearly all modern AI systems.
The Road to Transformers
20 years of innovation in sequence modeling
LSTM
Hochreiter & Schmidhuber
Gating mechanisms to control information flow, addressing vanishing gradients
Seq2Seq
Sutskever et al.
Encoder-decoder architecture for machine translation
GRU
Cho et al.
Simplified gating with fewer parameters, comparable performance
Bahdanau Attention
Bahdanau, Cho & Bengio
First attention mechanism: direct connections between encoder-decoder
Luong Attention
Luong et al.
Simplified dot-product attention, global vs local variants
Transformer
Vaswani et al.
"Attention Is All You Need" - removes recurrence entirely
Key insight: Each innovation addressed a specific limitation of its predecessor. LSTMs fixed vanishing gradients, attention fixed the bottleneck, and Transformers removed sequential processing entirely.
Why this backstory matters
1.1 The Sequence Modeling Problem
What is Sequence Modeling?
Sequence modeling refers to the task of processing and generating sequential data—data where the order of elements matters. Examples include:
- Natural Language: Sentences are sequences of words where order changes meaning
- "The dog bit the man" vs "The man bit the dog"
- Time Series: Stock prices, weather readings, sensor data
- Audio: Speech as sequences of acoustic features
- Video: Sequences of image frames
- DNA: Sequences of nucleotides (A, T, G, C)
Key Challenges in Sequence Modeling
- Variable Length: Sequences can have different lengths
- Order Dependency: The meaning depends on element positions
- Long-Range Dependencies: Elements far apart may be related
- Context Sensitivity: The meaning of an element depends on surrounding elements
Typical failure modes
Word Representations Before Transformers
Before diving into RNNs, it's worth understanding how words were represented. This context explains why contextual representations (from RNNs and later Transformers) were such a breakthrough.
Static Word Embeddings
| Method | Year | Key Idea | Limitation |
|---|---|---|---|
| One-hot | Classic | Sparse vector, 1 at word index | No similarity info, huge vectors |
| Word2Vec | 2013 | Predict neighbors, dense vectors | One vector per word (no context) |
| GloVe | 2014 | Global co-occurrence statistics | Same: one vector per word |
| FastText | 2016 | Subword n-grams | Better OOV, still static |
The fundamental problem: In Word2Vec and GloVe, the word "bank" has exactly ONE vector, whether it means a financial institution or a river bank. Context is ignored.
1# Word2Vec/GloVe: Static embeddings
2embeddings["bank"] # Same vector for ALL uses of "bank"
3
4# Sentences with different meanings:
5"I went to the bank to deposit money" # financial
6"The river bank was steep" # geographical
7"Bank left in the turn" # aviation/driving
8
9# All three use THE SAME embedding for "bank"!The Shift to Contextual Embeddings
The realization that word meaning depends on context led to:
- RNN-based representations: Hidden states encode context (but limited by sequential processing)
- ELMo (2018): BiLSTM-based contextual embeddings—same word gets different vectors in different sentences
- Transformers (BERT, GPT): Attention-based contextual embeddings—each token's representation depends on ALL other tokens
The key insight: A word's meaning isn't fixed—it emerges from its context. Transformers make computing context-dependent representations efficient and parallelizable.
Real-World Failure Examples
These aren't hypothetical—these are the kinds of errors that plagued pre-Transformer NLP systems:
| Task | Failure Example | Why RNNs Struggle |
|---|---|---|
| Translation | "The old man the boats" → mistranslated as elderly person | Garden-path sentences require reanalysis after "man" is seen as a verb, but RNNs commit early |
| Coreference | "The trophy doesn't fit in the suitcase because it is too big" → wrong "it" resolution | Must track which noun "it" refers to across 10+ tokens |
| Summarization | First paragraph of article ignored in summary | By the end, information from the beginning has decayed |
| Question Answering | Answer found in first sentence, question asks about it at the end | Gradient from answer position can't reach the evidence |
| Sentiment | "I thought the movie would be bad, but it was actually amazing" → negative | Initial negative words dominate if final context is lost |
The pattern: Whenever the "answer" or critical information is far from where it's needed, RNNs fail. This isn't a bug in specific implementations—it's a fundamental architectural limitation.
1.2 Recurrent Neural Networks (RNNs)
The RNN Architecture
RNNs were the first neural architecture designed specifically for sequences. The key idea: maintain a hidden state that gets updated as you process each element.
Where:
- is the hidden state at time
- is the input at time
- , are weight matrices
- is an activation function (typically tanh)
Micro Example: two time steps
What to notice
Visual Representation
RNN Unrolled Through Time
Each hidden state ht depends on the previous state ht-1
The Sequential Processing Bottleneck
Critical Limitation: RNNs process sequences one element at a time, in order.
- To compute , you must first compute and
- No parallelization possible during training or inference
- Training time scales linearly with sequence length
- Cannot leverage modern GPU parallelism effectively
For a sequence of length :
- RNN: sequential operations
- Transformer (spoiler): parallel operations
Mental model: telephone game
1.3 The Vanishing and Exploding Gradient Problem
What Happens During Backpropagation?
When training RNNs, gradients must flow backward through time. For a sequence of length :
This involves multiplying many Jacobian matrices together.
Gradient Flow Through Time (BPTT)
Watch how gradients change as they flow backward through the network
Vanishing Gradient Problem
With |Wh| = 0.5, gradients shrink exponentially: 1.0 → 0.5 → 0.25 → ... → ≈0.
Effect: Early layers learn extremely slowly or not at all. Long-range dependencies are forgotten.
Why Does This Happen Mechanically?
Recall the RNN recurrence:
During backpropagation, the derivative at each step is:
Two problems immediately appear:
Problem 1: The tanh Derivative
The derivative of tanh is:
This derivative is always ≤ 1 and often much smaller:
| Activation Value | tanh(x) | Derivative tanh'(x) |
|---|---|---|
| Saturated high | 0.99 | 1 − 0.99² = 0.02 |
| Moderately high | 0.80 | 1 − 0.64 = 0.36 |
| Near zero | 0.20 | 1 − 0.04 = 0.96 |
Saturation kills gradients
Problem 2: Weight Matrix Magnitudes
Even small deviations from 1 create exponential effects over time:
| Multiplier | After 30 steps | After 50 steps | After 100 steps |
|---|---|---|---|
| 0.042 | 0.005 | ≈ 0 | |
| 0.215 | 0.077 | 0.006 | |
| 4.32 | 11.47 | 131.5 | |
| 17.45 | 117.39 | 13,780 |
The core insight
Vanishing Gradients: Symptoms
If the weight matrices have eigenvalues < 1, gradients shrink exponentially:
- Loss decreases extremely slowly or plateaus early
- Gradients become literally 0.0 in early time steps
- Early positions receive almost no learning signal
- Long-term dependencies are impossible to learn
Real-world example: In "The cat that sat on the mat that was red was happy", the verb "was" must agree with "cat" (singular). But with vanishing gradients, the model cannot learn this dependency—the gradient from "was" never reaches "cat".
The subtle danger
Exploding Gradients: Symptoms
If eigenvalues > 1, gradients grow exponentially:
- Loss suddenly becomes NaN or infinity
- Weights become enormously large (1e⁸, 1e¹⁵, ...)
- Training is unstable—loss oscillates wildly
- Model outputs become nonsensical after a few steps
This happens because the weight update becomes massive:
Numerical Deep Dive
See It In Code: Gradient Decay Visualization
Here's a minimal PyTorch example that demonstrates gradient vanishing in a simple RNN. Run this to see how gradients decay:
1import torch
2import torch.nn as nn
3import matplotlib.pyplot as plt
4
5class SimpleRNN(nn.Module):
6 def __init__(self, hidden_size=64):
7 super().__init__()
8 self.hidden_size = hidden_size
9 self.rnn = nn.RNN(input_size=1, hidden_size=hidden_size, batch_first=True)
10 self.fc = nn.Linear(hidden_size, 1)
11
12 def forward(self, x, hidden=None):
13 out, hidden = self.rnn(x, hidden)
14 return self.fc(out[:, -1, :]), hidden
15
16# Create a sequence and track gradients at each position
17seq_length = 100
18model = SimpleRNN()
19x = torch.randn(1, seq_length, 1, requires_grad=True)
20
21# Forward pass
22output, _ = model(x)
23loss = output.sum()
24
25# Backward pass - capture gradients
26loss.backward()
27
28# Measure gradient magnitude at each time step
29# (This is a simplified proxy - actual gradient flow is more complex)
30grad_norms = []
31for t in range(seq_length):
32 # Gradient of loss w.r.t. input at position t
33 grad_norm = x.grad[0, t, 0].abs().item()
34 grad_norms.append(grad_norm)
35
36# Plot the decay
37plt.figure(figsize=(10, 4))
38plt.plot(range(seq_length), grad_norms)
39plt.xlabel('Position in sequence (earlier → later)')
40plt.ylabel('Gradient magnitude')
41plt.title('Gradient Decay in RNN: Earlier positions receive weaker gradients')
42plt.yscale('log') # Log scale to see the exponential decay
43plt.grid(True, alpha=0.3)
44plt.savefig('gradient_decay.png')
45print(f"Gradient at position 0: {grad_norms[0]:.6f}")
46print(f"Gradient at position 99: {grad_norms[-1]:.6f}")
47print(f"Decay ratio: {grad_norms[0]/grad_norms[-1]:.2f}x")Expected output: The gradient at position 0 will be orders of magnitude smaller than at position 99, demonstrating the vanishing gradient problem visually.
Try this experiment
nn.LSTM) and compare the gradient decay. You'll see that LSTMs maintain more uniform gradients—but the sequential bottleneck remains.Practical Debugging: Detecting Gradient Problems
Here's how to diagnose gradient issues in your own RNN training:
Symptoms Checklist
| Symptom | Likely Cause | Quick Fix |
|---|---|---|
| Loss stuck at initial value | Vanishing gradients | Reduce sequence length, check initialization |
| Loss becomes NaN after N steps | Exploding gradients | Add gradient clipping, reduce learning rate |
| Early tokens never learned | Vanishing + truncated BPTT | Increase BPTT window, use LSTM/GRU |
| Training unstable (oscillates) | Exploding gradients | Gradient clipping, smaller learning rate |
| Works on short sequences, fails on long | Vanishing gradients | Switch architecture (Transformer) |
Monitoring Code
1def monitor_gradients(model, threshold_low=1e-7, threshold_high=1e3):
2 """Call after loss.backward() to check gradient health."""
3 total_norm = 0
4 for name, param in model.named_parameters():
5 if param.grad is not None:
6 param_norm = param.grad.data.norm(2).item()
7 total_norm += param_norm ** 2
8
9 # Flag problematic gradients
10 if param_norm < threshold_low:
11 print(f"⚠️ VANISHING: {name} grad norm = {param_norm:.2e}")
12 elif param_norm > threshold_high:
13 print(f"🔥 EXPLODING: {name} grad norm = {param_norm:.2e}")
14
15 total_norm = total_norm ** 0.5
16 print(f"Total gradient norm: {total_norm:.4f}")
17 return total_norm
18
19# Usage in training loop:
20loss.backward()
21grad_norm = monitor_gradients(model)
22if grad_norm > 10:
23 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
24optimizer.step()TensorBoard integration
Why Transformers Don't Have This Problem
Transformers fundamentally avoid this issue because they don't multiply through time:
- No recurrence: No chaining of derivatives over T steps
- Direct connections: Each attention layer connects tokens directly via
- Depth vs. sequence: A 24-layer Transformer has 24 gradient steps, not 512 like an RNN processing 512 tokens
- Residual connections: Skip connections allow gradients to flow unchanged
This is one of the key reasons Transformers can handle sequences of thousands of tokens while RNNs struggle with hundreds.
Mitigations you should know
Weight Initialization: The First Line of Defense
Proper initialization can delay (but not prevent) gradient problems. The goal: keep the eigenvalues of close to 1.
Initialization Strategies for RNNs
| Strategy | How It Works | Best For |
|---|---|---|
| Xavier/Glorot | Scale by √(2 / (fan_in + fan_out)) | Feedforward layers, not RNNs |
| Orthogonal | W^T W = I (eigenvalues = 1) | RNN hidden-to-hidden weights |
| Identity | W = I (start as identity) | Simple RNNs, preserves gradients |
| LSTM bias init | Set forget gate bias to 1-2 | Encourages remembering early |
Why Orthogonal Initialization Works
For a matrix to preserve gradient norms, we want:
Orthogonal matrices satisfy , which means . This preserves gradient norms exactly—at least initially.
1import torch
2import torch.nn as nn
3
4# Method 1: PyTorch built-in
5rnn = nn.RNN(input_size=64, hidden_size=128)
6nn.init.orthogonal_(rnn.weight_hh_l0) # Hidden-to-hidden weights
7
8# Method 2: Manual orthogonal initialization
9def orthogonal_init(shape):
10 """Create an orthogonal matrix of given shape."""
11 flat_shape = (shape[0], max(1, shape[1] if len(shape) > 1 else 1))
12 a = torch.randn(flat_shape)
13 q, r = torch.linalg.qr(a)
14 # Fix the signs of the diagonal of r
15 d = torch.diag(r)
16 ph = d.sign()
17 q *= ph
18 return q[:shape[0], :shape[1]] if len(shape) > 1 else q[:shape[0], 0]
19
20# Method 3: LSTM forget gate bias trick
21lstm = nn.LSTM(input_size=64, hidden_size=128)
22# Set forget gate bias to 1.0 (encourages remembering)
23# Bias is stored as [input_gate, forget_gate, cell_gate, output_gate]
24nn.init.constant_(lstm.bias_ih_l0[128:256], 1.0)
25nn.init.constant_(lstm.bias_hh_l0[128:256], 1.0)The catch
Truncated Backpropagation Through Time (TBPTT)
In practice, full backpropagation through an entire sequence is often infeasible. Truncated BPTT is a workaround that limits how far back gradients flow.
How It Works
Instead of backpropagating through all T time steps, we split the sequence into chunks and backpropagate only within each chunk:
| Parameter | Typical Value | Effect |
|---|---|---|
| k₁ (forward steps) | 35-100 | How many tokens to process before an update |
| k₂ (backward steps) | 35-100 | How far back gradients flow |
| k₁ = k₂ | Common choice | Simpler implementation, balanced updates |
The Trade-off
| k₂ Value | Memory | Long-Range Learning | Training Speed |
|---|---|---|---|
| Small (20-35) | Low | Poor - can't learn distant dependencies | Fast |
| Medium (50-100) | Moderate | Okay - captures paragraph-level context | Moderate |
| Large (200+) | High | Better - but still truncated | Slow |
| Full sequence | Very High | Best (in theory) - but gradients vanish anyway | Very Slow |
The fundamental limitation
1# Typical PyTorch truncated BPTT pattern
2k1, k2 = 50, 50 # forward and backward window sizes
3hidden = None
4
5for i in range(0, seq_len, k1):
6 # Forward k1 steps
7 chunk = sequence[i:i+k1]
8 output, hidden = rnn(chunk, hidden)
9
10 # Compute loss for this chunk
11 loss = criterion(output, targets[i:i+k1])
12
13 # Backprop only k2 steps (hidden is detached)
14 loss.backward()
15 optimizer.step()
16 optimizer.zero_grad()
17
18 # CRITICAL: Detach hidden state to truncate gradient flow
19 hidden = hidden.detach() # Gradient stops here!When to use larger windows
1.4 LSTMs and GRUs: Addressing Vanishing Gradients
Long Short-Term Memory (LSTM)
Hochreiter & Schmidhuber (1997) introduced LSTMs with a key innovation: gating mechanisms that control information flow.
LSTM Components
- Cell State (): The "memory highway" that can carry information unchanged
- Forget Gate (): Decides what to remove from cell state
- Input Gate (): Decides what new information to store
- Output Gate (): Decides what to output
Why LSTMs Help
The cell state acts as a gradient highway:
- When , gradients flow through unchanged
- Mitigates vanishing gradient problem
- Can maintain information over hundreds of time steps
Gated Recurrent Units (GRUs)
Cho et al. (2014) proposed a simplified variant:
- Fewer parameters than LSTM
- Often comparable performance
- Faster to train
Gate intuition
LSTM vs GRU: quick guide
| Use case | Pick LSTM when... | Pick GRU when... |
|---|---|---|
| Long context | You need stronger control over what to forget/keep | Context is short-to-medium |
| Compute budget | You can afford more parameters | You need faster training/inference |
| Data size | Plenty of data to fit extra params | Data is smaller or noisier |
1.5 Remaining Limitations
Despite their improvements, LSTMs and GRUs still suffer from fundamental limitations:
1. Sequential Processing
Both architectures still process sequences step-by-step:
- Cannot parallelize across time steps
- Training time still in sequence length
- GPU utilization is poor
2. Limited Long-Range Modeling
Even with gates, information must still pass through many steps:
- Effective context window is limited (typically 100-300 tokens)
- Very long documents remain challenging
- Information gets diluted over distance
3. Fixed Computation Per Step
Every step uses the same computation regardless of:
- How relevant the current token is
- How much context is needed
- Whether the task requires local or global information
4. Difficulty with Bidirectional Context
Understanding text often requires both past and future context. Consider:
"The bank was steep" - Is this a river bank or a financial institution? You need future context ("steep") to disambiguate.
Bidirectional RNNs (BiRNNs)
The solution was to run two RNNs: one forward, one backward, then concatenate their hidden states:
| Property | Forward RNN | BiRNN | Transformer |
|---|---|---|---|
| Sees past context | Yes | Yes | Yes |
| Sees future context | No | Yes | Yes |
| Computation cost | O(n) | O(2n) | O(n²) but parallel |
| Parallelizable | No | No | Yes |
| Direct long-range links | No | No | Yes |
Why BiRNNs still aren't enough:
- Still sequential: You must wait for both passes to complete
- Separate passes: Forward and backward don't directly interact during processing
- No direct token connections: To relate token 1 to token 100, information still flows through all intermediate states in both directions
- Double memory: Hidden states from both directions must be stored
Where BiRNNs are still used
Why this still fell short
1.6 The Encoder-Decoder Bottleneck
Before attention mechanisms, the dominant approach for sequence-to-sequence tasks (like machine translation) was the encoder-decoder architecture. While revolutionary, it had a critical flaw.
The Seq2Seq Architecture (Sutskever et al., 2014)
The basic idea was elegant: use one RNN (the encoder) to read the input sequence and compress it into a fixed-size vector, then use another RNN (the decoder) to generate the output sequence from that vector.
The Encoder-Decoder Bottleneck
All source information must squeeze through a single fixed-size vector
Why This Is a Problem
Imagine trying to translate a complex legal document or a technical paper. The entire meaning of potentially thousands of words must be compressed into a vector of, say, 256 or 512 dimensions. This creates several issues:
- Information Compression: Early tokens in long sequences get "overwritten" as the encoder processes more tokens. By the time you reach the end, the beginning is a faint echo.
- Fixed Capacity: The context vector has the same size regardless of whether you're encoding "Hello" or "War and Peace." Clearly, more information requires more capacity.
- No Selective Access: When the decoder generates word 50, it might need to specifically look at word 3 of the input. But it can only see the blended context vector—it cannot selectively attend to specific parts of the input.
Empirical Evidence
Researchers observed a sharp performance cliff:
| Sentence Length | BLEU Score (approx.) | Quality |
|---|---|---|
| 5-10 words | 35-40 | Good translations |
| 15-20 words | 28-32 | Acceptable |
| 25-30 words | 20-25 | Noticeable errors |
| 40+ words | < 15 | Often incoherent |
The Bottleneck Insight: It's not that RNNs are bad at processing sequences—it's that forcing all information through a single fixed-size vector is fundamentally limiting. What if the decoder could look back at any encoder state, not just the final compressed one?
This motivated attention
1.7 The Path to Attention
The Core Insight
The key limitation of RNNs: indirect access to distant elements.
To relate the first and last word of a sentence:
- RNN: Information must pass through ALL intermediate hidden states
- What we want: Direct connection between ANY two positions
Bahdanau Attention (2014)
The first attention mechanism for sequence-to-sequence models:
Key insight: Create shortcuts in the computation graph!
1Target word to generate: "bank" (river sense)
2Encoder states: [h_river, h_flowing, h_fast]
3Alignment scores (dot products): [2.1, 0.2, 0.1]
4Softmax -> attention: [0.82, 0.09, 0.09]
5Context vector: 0.82*h_river + 0.09*h_flowing + 0.09*h_fastWhat the weights mean
Luong Attention (2015): Simplifying the Score
Luong et al. proposed several simplifications that would directly influence the Transformer's design:
Scoring Functions
| Name | Formula | Complexity |
|---|---|---|
| Bahdanau (additive) | v^T tanh(W₁h_j + W₂s_i) | O(d²) per pair |
| Luong (dot) | h_j^T s_i | O(d) per pair |
| Luong (general) | h_j^T W s_i | O(d²) per pair |
| Luong (concat) | v^T tanh(W[h_j; s_i]) | O(d²) per pair |
The key insight: Simple dot-product attention () works nearly as well as the more complex additive variant—and it's much faster because it's just a matrix multiply.
Global vs. Local Attention
- Global attention: Attend to ALL encoder positions (like Bahdanau)
- Local attention: Attend only to a window around an aligned position
Local attention was an attempt to reduce the cost, foreshadowing later work on efficient Transformers (Longformer, BigBird).
Direct link to Transformers
Benefits of Attention
- Direct connections: Any position can attend to any other
- Interpretability: Attention weights show what the model "looks at"
- Variable context: Different queries access different information
Limitations of Early Attention
- Still used RNNs as the backbone
- Attention was a "add-on" to sequential processing
- The question emerged: What if we use ONLY attention?
Alternative Approaches (The Road Not Taken)
While attention was gaining traction, researchers explored other ways to solve the parallelization problem. These approaches are worth knowing because they influenced Transformer design and remain useful in specific domains.
1. Convolutional Models for Sequences
CNNs process data in parallel by design. Several architectures adapted them for sequences:
| Model | Year | Key Idea | Limitation |
|---|---|---|---|
| WaveNet | 2016 | Dilated causal convolutions for audio | Very deep for long range |
| ByteNet | 2016 | Encoder-decoder with dilated convs | Still O(log n) depth for range n |
| ConvSeq2Seq | 2017 | Fully convolutional translation | Required many layers for long context |
Why CNNs didn't win: While parallelizable, they need many stacked layers to capture long-range dependencies (receptive field grows logarithmically with depth). A 512-token dependency requires ~9 layers of dilated convolutions vs. a single attention layer.
2. Memory Networks & Neural Turing Machines
These architectures added explicit external memory that the model could read from and write to:
- Memory Networks (Weston, 2014): External memory + attention-based retrieval
- Neural Turing Machines (Graves, 2014): Differentiable tape with read/write heads
- Differentiable Neural Computers (Graves, 2016): More sophisticated memory management
Why they didn't win: Complex to train, hard to scale, and the memory addressing mechanisms introduced their own bottlenecks. However, the idea of attending to stored representations directly influenced the key-value attention in Transformers.
Legacy of these approaches
1.8 The Stage is Set
By 2017, the NLP community recognized:
- Sequential processing is a bottleneck - We need parallelization
- Long-range dependencies are hard - Direct connections help
- Attention mechanisms work - They capture relationships effectively
- Hardware is evolving - GPUs excel at parallel matrix operations
The Hardware Reality: GPU Utilization
The shift to Transformers wasn't just about model quality—it was about hardware efficiency. Here's what researchers were observing:
| Model | GPU Utilization | Tokens/sec (V100) | WMT Training Time |
|---|---|---|---|
| LSTM (seq) | 15-25% | ~8,000 | 2-3 weeks |
| LSTM (optimized) | 30-40% | ~15,000 | 1-2 weeks |
| ConvSeq2Seq | 50-60% | ~25,000 | 4-5 days |
| Transformer | 80-95% | ~50,000+ | 12-36 hours |
Why Transformers utilize GPUs better:
- Matrix multiplication: Attention is essentially batched matrix multiplies (, then softmax, then multiply by V)—exactly what GPUs are optimized for
- No sequential dependencies: All positions can be computed in parallel within a layer
- Predictable memory access: Dense operations with regular access patterns, unlike RNN's scattered hidden state updates
- Tensor cores: Modern GPUs have specialized hardware for the exact operations Transformers need
The economic argument
The question was: Can we build a sequence model using only attention, no recurrence?
The answer came in June 2017 with "Attention Is All You Need" - the Transformer architecture.
The Paper That Changed Everything
"Attention Is All You Need" (Vaswani et al., 2017) was published at NeurIPS 2017 by a team at Google Brain and Google Research. The title itself was provocative—a direct challenge to the RNN orthodoxy.
The Authors
The paper had eight authors, several of whom have become highly influential:
- Ashish Vaswani - Lead author, now at Essential AI
- Noam Shazeer - Co-inventor of many key techniques, co-founded Character.AI
- Niki Parmar - Research scientist, contributed to architecture design
- Jakob Uszkoreit - Co-founded Inceptive (RNA design)
- Llion Jones - Co-founded Sakana AI
- Aidan Gomez - Co-founded Cohere
- Łukasz Kaiser - Key contributor to Tensor2Tensor, now at OpenAI
- Illia Polosukhin - Co-founded NEAR Protocol
The diaspora effect
Initial Reception
The paper's reception was mixed at first:
| Reaction | Argument |
|---|---|
| Skeptical | "RNNs have inductive biases for sequences—you need recurrence for temporal modeling" |
| Skeptical | "O(n²) attention won't scale to long sequences" |
| Impressed | "The translation results are state-of-the-art with much less training time" |
| Curious | "If this works for translation, what else might it work for?" |
Within a year, BERT (2018) and GPT (2018) would answer that last question definitively: everything.
The provocative title: By claiming "Attention Is All You Need," the authors were making a bold statement: the inductive biases we thought were necessary (recurrence, convolution) were actually holding us back. The simplest architecture—just attention and feedforward layers—worked best.
Design checklist for the next step
Summary
| Architecture | Long-Range | Parallelizable | Key Innovation |
|---|---|---|---|
| Vanilla RNN | Poor | No | Hidden state |
| LSTM | Better | No | Gating mechanisms |
| GRU | Better | No | Simplified gates |
| Transformer | Excellent | Yes | Self-attention |
Key Takeaways
- RNNs introduced sequential hidden states but suffer from vanishing gradients and slow training
- LSTMs/GRUs added gating mechanisms that help with gradient flow but don't solve the parallelization problem
- Attention mechanisms enable direct connections between any positions in a sequence
- The transformer removes recurrence entirely, using only attention for sequence modeling
Exercises
Conceptual Questions
- Why can't RNNs be parallelized across time steps during training?
- Explain intuitively why gradients vanish when multiplying many small numbers together.
- How does the cell state in an LSTM act as a "gradient highway"?
- What is the computational complexity of processing a sequence of length with an RNN vs. a Transformer?
Thought Experiments
- Consider the sentence: "The trophy doesn't fit in the suitcase because it is too big."
- What does "it" refer to?
- Why is this hard for an RNN to model?
- If you had to design a neural network for sequence modeling from scratch, knowing what you know now, what properties would you want it to have?
Hands-on Drills
- Implement a two-layer RNN in NumPy/PyTorch and log hidden states over a toy sequence; visualize where information decays.
- Add gradient clipping to that RNN and compare loss curves with and without clipping.
- Write a tiny Bahdanau attention over a 3-token source and 2-token target; print the attention weights to see which source positions are favored.
Further Reading
- Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation.
- Cho, K., et al. (2014). Learning phrase representations using RNN encoder-decoder.
- Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate.
- Pascanu, R., Mikolov, T., & Bengio, Y. (2013). On the difficulty of training recurrent neural networks.
In the next section, we'll dive into the Transformer architecture itself, understanding how it eliminates recurrence entirely and achieves state-of-the-art performance across virtually all sequence modeling tasks.