Introduction
While attention allows tokens to communicate with each other, the Feed-Forward Network (FFN) transforms each token's representation independently. It's the "thinking" component where the model processes and enriches the information gathered by attention.
This section covers the FFN architecture, its purpose, and implementation.
1.1 The Role of FFN in Transformers
What Attention Doesn't Do
Self-attention is powerful for gathering information, but:
- It's essentially a linear operation (weighted sum of values)
- No non-linear transformation of individual tokens
- Limited capacity for complex per-token computations
What FFN Adds
The FFN provides:
- Non-linearity: ReLU/GELU activation for complex patterns
- Per-token transformation: Each position processed independently
- Capacity expansion: Higher dimensional intermediate representation
- Local processing: Complement to attention's global mixing
The Transformer Block Pattern
1βββββββββββββββββββββββββββββββββββββββ
2β Input (tokens with positions) β
3ββββββββββββββββ¬βββββββββββββββββββββββ
4 β
5 βΌ
6βββββββββββββββββββββββββββββββββββββββ
7β Multi-Head Self-Attention β β Global mixing
8β (tokens communicate) β
9ββββββββββββββββ¬βββββββββββββββββββββββ
10 β
11 βΌ
12βββββββββββββββββββββββββββββββββββββββ
13β Feed-Forward Network β β Local transformation
14β (each token transformed) β
15ββββββββββββββββ¬βββββββββββββββββββββββ
16 β
17 βΌ
18βββββββββββββββββββββββββββββββββββββββ
19β Output (enriched representations) β
20βββββββββββββββββββββββββββββββββββββββ1.2 FFN Architecture
The Formula
From the original Transformer paper:
1FFN(x) = max(0, xWβ + bβ)Wβ + bβ
2 = ReLU(xWβ + bβ)Wβ + bβMore generally:
1FFN(x) = Activation(Linearβ(x)) β Linearβ
2 = Linearβ(Activation(Linearβ(x)))Dimensions
The FFN expands then contracts:
1Input: [batch, seq_len, d_model] (e.g., 512)
2 β
3 βΌ
4Linearβ: [batch, seq_len, d_ff] (e.g., 2048 = 4 Γ 512)
5 β
6 βΌ
7Activation: [batch, seq_len, d_ff] (same)
8 β
9 βΌ
10Linearβ: [batch, seq_len, d_model] (back to 512)Why Expand Then Contract?
Expansion (d_model β d_ff):
- Creates higher-dimensional space
- More neurons = more patterns can be learned
- Acts like a "wide" hidden layer
Contraction (d_ff β d_model):
- Returns to original dimension
- Allows residual connection (same shapes)
- Compresses learned information
Typical Ratio
1d_ff = 4 Γ d_model
2
3Examples:
4- d_model = 512 β d_ff = 2048
5- d_model = 768 β d_ff = 3072
6- d_model = 1024 β d_ff = 4096Some models use different ratios (e.g., 8/3 for GLU variants).
1.3 "Position-wise" Explained
What "Position-wise" Means
The same FFN is applied to each position independently:
1Position 0: FFN(xβ) = yβ
2Position 1: FFN(xβ) = yβ
3Position 2: FFN(xβ) = yβ
4...
5
6All use the SAME weights Wβ, bβ, Wβ, bβ!Visual Representation
1Input: [Tokenβ] [Tokenβ] [Tokenβ] [Tokenβ]
2 β β β β
3 βΌ βΌ βΌ βΌ
4FFN: [FFN] [FFN] [FFN] [FFN] (same weights!)
5 β β β β
6 βΌ βΌ βΌ βΌ
7Output: [Outβ] [Outβ] [Outβ] [Outβ]Implementation Efficiency
We don't need a loopβmatrix multiplication handles all positions:
1# x shape: [batch, seq_len, d_model]
2# W1 shape: [d_model, d_ff]
3
4# This applies W1 to ALL positions at once:
5hidden = x @ W1 # [batch, seq_len, d_ff]
6
7# Same for all batch items AND all positions!Why Same Weights?
Three key reasons:
- Parameter efficiency: One FFN serves all positions
- Translation invariance: Same transformation regardless of position
- Generalization: Works for any sequence length
1.4 Activation Functions
ReLU (Original Transformer)
1def relu(x):
2 return max(0, x)
3 # or torch.relu(x)Properties:
- Simple and fast
- Sparse activation (many zeros)
- Can cause "dead neurons" (always zero)
GELU (Modern Choice)
Gaussian Error Linear Unit:
1def gelu(x):
2 # Exact: x * Ξ¦(x) where Ξ¦ is the CDF of standard normal
3 return 0.5 * x * (1 + torch.tanh(
4 math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)
5 ))
6 # or torch.nn.functional.gelu(x)Properties:
- Smooth (differentiable everywhere)
- No dead neurons
- Slightly better performance in practice
- Used by BERT, GPT-2, etc.
SiLU/Swish
1def silu(x):
2 return x * torch.sigmoid(x)
3 # or torch.nn.functional.silu(x)Used by: LLaMA, GPT-NeoX
Comparison
1import torch
2import torch.nn.functional as F
3import matplotlib.pyplot as plt
4
5x = torch.linspace(-3, 3, 100)
6
7plt.figure(figsize=(10, 6))
8plt.plot(x, F.relu(x), label='ReLU')
9plt.plot(x, F.gelu(x), label='GELU')
10plt.plot(x, F.silu(x), label='SiLU/Swish')
11plt.legend()
12plt.grid(True)
13plt.title('Activation Functions')
14plt.savefig('activations.png')1.5 Implementation
Basic FFN
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional
5
6
7class PositionwiseFeedForward(nn.Module):
8 """
9 Position-wise Feed-Forward Network.
10
11 Applies two linear transformations with an activation in between,
12 independently to each position in the sequence.
13
14 FFN(x) = Activation(xWβ + bβ)Wβ + bβ
15
16 Args:
17 d_model: Input/output dimension
18 d_ff: Hidden layer dimension (typically 4 * d_model)
19 dropout: Dropout probability
20 activation: Activation function ('relu', 'gelu', 'silu')
21
22 Example:
23 >>> ffn = PositionwiseFeedForward(d_model=512, d_ff=2048)
24 >>> x = torch.randn(2, 10, 512) # [batch, seq_len, d_model]
25 >>> output = ffn(x) # [batch, seq_len, d_model]
26 """
27
28 def __init__(
29 self,
30 d_model: int,
31 d_ff: int,
32 dropout: float = 0.1,
33 activation: str = "relu"
34 ):
35 super().__init__()
36
37 self.d_model = d_model
38 self.d_ff = d_ff
39
40 # Two linear transformations
41 self.linear1 = nn.Linear(d_model, d_ff)
42 self.linear2 = nn.Linear(d_ff, d_model)
43
44 # Dropout
45 self.dropout = nn.Dropout(dropout)
46
47 # Activation function
48 self.activation = self._get_activation(activation)
49
50 def _get_activation(self, activation: str):
51 """Get activation function by name."""
52 activations = {
53 "relu": F.relu,
54 "gelu": F.gelu,
55 "silu": F.silu,
56 }
57 if activation not in activations:
58 raise ValueError(f"Unknown activation: {activation}")
59 return activations[activation]
60
61 def forward(self, x: torch.Tensor) -> torch.Tensor:
62 """
63 Apply position-wise feed-forward transformation.
64
65 Args:
66 x: Input tensor [batch, seq_len, d_model]
67
68 Returns:
69 Output tensor [batch, seq_len, d_model]
70 """
71 # x: [batch, seq_len, d_model]
72
73 # First linear + activation
74 hidden = self.activation(self.linear1(x))
75 # hidden: [batch, seq_len, d_ff]
76
77 # Dropout on hidden layer
78 hidden = self.dropout(hidden)
79
80 # Second linear
81 output = self.linear2(hidden)
82 # output: [batch, seq_len, d_model]
83
84 return output
85
86 def extra_repr(self) -> str:
87 return f"d_model={self.d_model}, d_ff={self.d_ff}"
88
89
90# Test
91def test_ffn():
92 d_model = 512
93 d_ff = 2048
94 batch_size = 2
95 seq_len = 10
96
97 ffn = PositionwiseFeedForward(d_model, d_ff, dropout=0.1, activation="gelu")
98
99 x = torch.randn(batch_size, seq_len, d_model)
100 output = ffn(x)
101
102 print(f"Input shape: {x.shape}")
103 print(f"Output shape: {output.shape}")
104 print(f"\nFFN Architecture:")
105 print(f" Linear1: {d_model} β {d_ff}")
106 print(f" Activation: GELU")
107 print(f" Dropout: 0.1")
108 print(f" Linear2: {d_ff} β {d_model}")
109 print(f"\nParameter count: {sum(p.numel() for p in ffn.parameters()):,}")
110
111 # Verify shapes
112 assert output.shape == x.shape
113 print("\nβ FFN test passed!")
114
115
116test_ffn()Output:
1Input shape: torch.Size([2, 10, 512])
2Output shape: torch.Size([2, 10, 512])
3
4FFN Architecture:
5 Linear1: 512 β 2048
6 Activation: GELU
7 Dropout: 0.1
8 Linear2: 2048 β 512
9
10Parameter count: 4,197,376
11
12β FFN test passed!1.6 Advanced: Gated Linear Units (GLU)
GLU Variants
Modern models often use gated variants that split the hidden dimension:
1class GatedFFN(nn.Module):
2 """
3 Gated Feed-Forward Network (used in PaLM, LLaMA).
4
5 Instead of: Activation(xWβ)Wβ
6 Uses: (xWβ β Activation(xW_gate))Wβ
7
8 Where β is element-wise multiplication.
9 """
10
11 def __init__(
12 self,
13 d_model: int,
14 d_ff: int,
15 dropout: float = 0.1,
16 activation: str = "silu"
17 ):
18 super().__init__()
19
20 # Gate and up projections (can be combined for efficiency)
21 self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
22 self.up_proj = nn.Linear(d_model, d_ff, bias=False)
23 self.down_proj = nn.Linear(d_ff, d_model, bias=False)
24
25 self.dropout = nn.Dropout(dropout)
26 self.activation = F.silu if activation == "silu" else F.gelu
27
28 def forward(self, x: torch.Tensor) -> torch.Tensor:
29 """
30 Args:
31 x: [batch, seq_len, d_model]
32 Returns:
33 [batch, seq_len, d_model]
34 """
35 # Gate mechanism
36 gate = self.activation(self.gate_proj(x))
37 up = self.up_proj(x)
38
39 # Element-wise gating
40 hidden = gate * up
41 hidden = self.dropout(hidden)
42
43 # Down projection
44 output = self.down_proj(hidden)
45
46 return output
47
48
49# Test GLU variant
50def test_gated_ffn():
51 d_model = 512
52 d_ff = 2048 # Note: effective expansion is different due to gating
53
54 gated_ffn = GatedFFN(d_model, d_ff, dropout=0.0)
55
56 x = torch.randn(2, 10, d_model)
57 output = gated_ffn(x)
58
59 print(f"Input shape: {x.shape}")
60 print(f"Output shape: {output.shape}")
61 print(f"Parameter count: {sum(p.numel() for p in gated_ffn.parameters()):,}")
62
63 print("\nβ Gated FFN test passed!")
64
65
66test_gated_ffn()1.7 Understanding the FFN's Role
What Does FFN Learn?
Research suggests FFN layers act as key-value memories:
1Linear1 (keys): Pattern detectors - "Is this a verb?"
2Hidden (values): Associated information - "Verb properties"
3Linear2 (output): Combine into representationFFN vs Attention Comparison
| Aspect | Attention | FFN |
|---|---|---|
| Scope | Global (all tokens) | Local (single token) |
| Nature | Linear combination | Non-linear transformation |
| Parameters | Fewer | More (typically 2/3 of layer) |
| Role | Information routing | Information processing |
Parameter Distribution
In a typical transformer layer:
- Attention: ~1/3 of parameters
- FFN: ~2/3 of parameters
1def count_parameters(d_model, d_ff, num_heads):
2 """Count parameters in attention vs FFN."""
3 d_k = d_model // num_heads
4
5 # Attention: W_Q, W_K, W_V, W_O
6 attention_params = 4 * d_model * d_model
7
8 # FFN: W1, b1, W2, b2
9 ffn_params = 2 * d_model * d_ff + d_model + d_ff
10
11 total = attention_params + ffn_params
12
13 print(f"Attention parameters: {attention_params:,} ({attention_params/total*100:.1f}%)")
14 print(f"FFN parameters: {ffn_params:,} ({ffn_params/total*100:.1f}%)")
15 print(f"Total: {total:,}")
16
17
18count_parameters(d_model=512, d_ff=2048, num_heads=8)Output:
1Attention parameters: 1,048,576 (20.0%)
2FFN parameters: 4,197,376 (80.0%)
3Total: 5,245,9521.8 Variations and Extensions
MoE (Mixture of Experts)
Instead of one FFN, use multiple "expert" FFNs:
1class MoEFFN(nn.Module):
2 """
3 Mixture of Experts FFN (simplified).
4
5 Routes each token to top-k experts based on a learned gating function.
6 """
7
8 def __init__(
9 self,
10 d_model: int,
11 d_ff: int,
12 num_experts: int = 8,
13 top_k: int = 2
14 ):
15 super().__init__()
16
17 self.num_experts = num_experts
18 self.top_k = top_k
19
20 # Expert FFNs
21 self.experts = nn.ModuleList([
22 PositionwiseFeedForward(d_model, d_ff)
23 for _ in range(num_experts)
24 ])
25
26 # Router
27 self.router = nn.Linear(d_model, num_experts)
28
29 def forward(self, x: torch.Tensor) -> torch.Tensor:
30 """Route tokens to experts."""
31 batch, seq_len, d_model = x.shape
32
33 # Get routing weights
34 router_logits = self.router(x) # [batch, seq_len, num_experts]
35 routing_weights = F.softmax(router_logits, dim=-1)
36
37 # Select top-k experts
38 top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
39 top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
40
41 # Apply experts (simplified - actual implementation is more complex)
42 output = torch.zeros_like(x)
43 for k in range(self.top_k):
44 for e in range(self.num_experts):
45 mask = (top_k_indices[..., k] == e)
46 if mask.any():
47 expert_input = x[mask]
48 expert_output = self.experts[e](expert_input.unsqueeze(0)).squeeze(0)
49 output[mask] += top_k_weights[mask, k:k+1] * expert_output
50
51 return outputSummary
FFN Architecture
1Input [batch, seq_len, d_model]
2 β
3 βΌ
4 Linear1 (d_model β d_ff)
5 β
6 βΌ
7 Activation (ReLU/GELU)
8 β
9 βΌ
10 Dropout
11 β
12 βΌ
13 Linear2 (d_ff β d_model)
14 β
15 βΌ
16Output [batch, seq_len, d_model]Key Points
| Property | Value |
|---|---|
| Typical d_ff | 4 Γ d_model |
| Activation | GELU (modern) or ReLU (original) |
| Parameters | ~2/3 of transformer layer |
| Processing | Position-independent (same weights) |
Implementation Checklist
- Two linear layers with dimension expansion/contraction
- Activation function between layers
- Dropout on hidden layer
- Shape preservation: input and output same dimension
Exercises
Implementation Exercises
1. Implement an FFN with configurable number of hidden layers (not just 2).
2. Add layer normalization inside the FFN (some models do this).
3. Implement a "bottleneck" FFN where d_ff < d_model.
Analysis Exercises
4. Compare training with ReLU vs GELU vs SiLU. Which converges faster?
5. Visualize the activation patterns in a trained FFN. What patterns emerge?
6. Experiment with different d_ff ratios (2Γ, 4Γ, 8Γ). How does it affect performance?
Conceptual Questions
7. Why is the FFN applied "position-wise" rather than across all positions?
8. What would happen if we removed the FFN entirely? How would the model behave?
Next Section Preview
In the next section, we'll dive deep into Layer Normalizationβthe normalization technique that makes transformer training stable. We'll understand why it's preferred over batch normalization and implement it from scratch.