Chapter 6
12 min read
Section 30 of 75

Position-wise Feed-Forward Networks

Feed Forward and Normalization

Introduction

While attention allows tokens to communicate with each other, the Feed-Forward Network (FFN) transforms each token's representation independently. It's the "thinking" component where the model processes and enriches the information gathered by attention.

This section covers the FFN architecture, its purpose, and implementation.


1.1 The Role of FFN in Transformers

What Attention Doesn't Do

Self-attention is powerful for gathering information, but:

  • It's essentially a linear operation (weighted sum of values)
  • No non-linear transformation of individual tokens
  • Limited capacity for complex per-token computations

What FFN Adds

The FFN provides:

  • Non-linearity: ReLU/GELU activation for complex patterns
  • Per-token transformation: Each position processed independently
  • Capacity expansion: Higher dimensional intermediate representation
  • Local processing: Complement to attention's global mixing

The Transformer Block Pattern

πŸ“text
1β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
2β”‚  Input (tokens with positions)       β”‚
3β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
4               β”‚
5               β–Ό
6β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
7β”‚  Multi-Head Self-Attention          β”‚  ← Global mixing
8β”‚  (tokens communicate)               β”‚
9β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
10               β”‚
11               β–Ό
12β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
13β”‚  Feed-Forward Network               β”‚  ← Local transformation
14β”‚  (each token transformed)           β”‚
15β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
16               β”‚
17               β–Ό
18β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
19β”‚  Output (enriched representations)   β”‚
20β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1.2 FFN Architecture

The Formula

From the original Transformer paper:

πŸ“text
1FFN(x) = max(0, xW₁ + b₁)Wβ‚‚ + bβ‚‚
2       = ReLU(xW₁ + b₁)Wβ‚‚ + bβ‚‚

More generally:

πŸ“text
1FFN(x) = Activation(Linear₁(x)) β†’ Linearβ‚‚
2       = Linearβ‚‚(Activation(Linear₁(x)))

Dimensions

The FFN expands then contracts:

πŸ“text
1Input:       [batch, seq_len, d_model]     (e.g., 512)
2            β”‚
3            β–Ό
4Linear₁:    [batch, seq_len, d_ff]        (e.g., 2048 = 4 Γ— 512)
5            β”‚
6            β–Ό
7Activation: [batch, seq_len, d_ff]        (same)
8            β”‚
9            β–Ό
10Linearβ‚‚:    [batch, seq_len, d_model]     (back to 512)

Why Expand Then Contract?

Expansion (d_model β†’ d_ff):

  • Creates higher-dimensional space
  • More neurons = more patterns can be learned
  • Acts like a "wide" hidden layer

Contraction (d_ff β†’ d_model):

  • Returns to original dimension
  • Allows residual connection (same shapes)
  • Compresses learned information

Typical Ratio

πŸ“text
1d_ff = 4 Γ— d_model
2
3Examples:
4- d_model = 512  β†’ d_ff = 2048
5- d_model = 768  β†’ d_ff = 3072
6- d_model = 1024 β†’ d_ff = 4096

Some models use different ratios (e.g., 8/3 for GLU variants).


1.3 "Position-wise" Explained

What "Position-wise" Means

The same FFN is applied to each position independently:

πŸ“text
1Position 0: FFN(xβ‚€) = yβ‚€
2Position 1: FFN(x₁) = y₁
3Position 2: FFN(xβ‚‚) = yβ‚‚
4...
5
6All use the SAME weights W₁, b₁, Wβ‚‚, bβ‚‚!

Visual Representation

πŸ“text
1Input:  [Tokenβ‚€] [Token₁] [Tokenβ‚‚] [Token₃]
2           β”‚        β”‚        β”‚        β”‚
3           β–Ό        β–Ό        β–Ό        β–Ό
4FFN:    [FFN]    [FFN]    [FFN]    [FFN]   (same weights!)
5           β”‚        β”‚        β”‚        β”‚
6           β–Ό        β–Ό        β–Ό        β–Ό
7Output: [Outβ‚€]   [Out₁]   [Outβ‚‚]   [Out₃]

Implementation Efficiency

We don't need a loopβ€”matrix multiplication handles all positions:

🐍python
1# x shape: [batch, seq_len, d_model]
2# W1 shape: [d_model, d_ff]
3
4# This applies W1 to ALL positions at once:
5hidden = x @ W1  # [batch, seq_len, d_ff]
6
7# Same for all batch items AND all positions!

Why Same Weights?

Three key reasons:

  • Parameter efficiency: One FFN serves all positions
  • Translation invariance: Same transformation regardless of position
  • Generalization: Works for any sequence length

1.4 Activation Functions

ReLU (Original Transformer)

🐍python
1def relu(x):
2    return max(0, x)
3    # or torch.relu(x)

Properties:

  • Simple and fast
  • Sparse activation (many zeros)
  • Can cause "dead neurons" (always zero)

GELU (Modern Choice)

Gaussian Error Linear Unit:

🐍python
1def gelu(x):
2    # Exact: x * Ξ¦(x) where Ξ¦ is the CDF of standard normal
3    return 0.5 * x * (1 + torch.tanh(
4        math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)
5    ))
6    # or torch.nn.functional.gelu(x)

Properties:

  • Smooth (differentiable everywhere)
  • No dead neurons
  • Slightly better performance in practice
  • Used by BERT, GPT-2, etc.

SiLU/Swish

🐍python
1def silu(x):
2    return x * torch.sigmoid(x)
3    # or torch.nn.functional.silu(x)

Used by: LLaMA, GPT-NeoX

Comparison

🐍python
1import torch
2import torch.nn.functional as F
3import matplotlib.pyplot as plt
4
5x = torch.linspace(-3, 3, 100)
6
7plt.figure(figsize=(10, 6))
8plt.plot(x, F.relu(x), label='ReLU')
9plt.plot(x, F.gelu(x), label='GELU')
10plt.plot(x, F.silu(x), label='SiLU/Swish')
11plt.legend()
12plt.grid(True)
13plt.title('Activation Functions')
14plt.savefig('activations.png')

1.5 Implementation

Basic FFN

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional
5
6
7class PositionwiseFeedForward(nn.Module):
8    """
9    Position-wise Feed-Forward Network.
10
11    Applies two linear transformations with an activation in between,
12    independently to each position in the sequence.
13
14    FFN(x) = Activation(xW₁ + b₁)Wβ‚‚ + bβ‚‚
15
16    Args:
17        d_model: Input/output dimension
18        d_ff: Hidden layer dimension (typically 4 * d_model)
19        dropout: Dropout probability
20        activation: Activation function ('relu', 'gelu', 'silu')
21
22    Example:
23        >>> ffn = PositionwiseFeedForward(d_model=512, d_ff=2048)
24        >>> x = torch.randn(2, 10, 512)  # [batch, seq_len, d_model]
25        >>> output = ffn(x)  # [batch, seq_len, d_model]
26    """
27
28    def __init__(
29        self,
30        d_model: int,
31        d_ff: int,
32        dropout: float = 0.1,
33        activation: str = "relu"
34    ):
35        super().__init__()
36
37        self.d_model = d_model
38        self.d_ff = d_ff
39
40        # Two linear transformations
41        self.linear1 = nn.Linear(d_model, d_ff)
42        self.linear2 = nn.Linear(d_ff, d_model)
43
44        # Dropout
45        self.dropout = nn.Dropout(dropout)
46
47        # Activation function
48        self.activation = self._get_activation(activation)
49
50    def _get_activation(self, activation: str):
51        """Get activation function by name."""
52        activations = {
53            "relu": F.relu,
54            "gelu": F.gelu,
55            "silu": F.silu,
56        }
57        if activation not in activations:
58            raise ValueError(f"Unknown activation: {activation}")
59        return activations[activation]
60
61    def forward(self, x: torch.Tensor) -> torch.Tensor:
62        """
63        Apply position-wise feed-forward transformation.
64
65        Args:
66            x: Input tensor [batch, seq_len, d_model]
67
68        Returns:
69            Output tensor [batch, seq_len, d_model]
70        """
71        # x: [batch, seq_len, d_model]
72
73        # First linear + activation
74        hidden = self.activation(self.linear1(x))
75        # hidden: [batch, seq_len, d_ff]
76
77        # Dropout on hidden layer
78        hidden = self.dropout(hidden)
79
80        # Second linear
81        output = self.linear2(hidden)
82        # output: [batch, seq_len, d_model]
83
84        return output
85
86    def extra_repr(self) -> str:
87        return f"d_model={self.d_model}, d_ff={self.d_ff}"
88
89
90# Test
91def test_ffn():
92    d_model = 512
93    d_ff = 2048
94    batch_size = 2
95    seq_len = 10
96
97    ffn = PositionwiseFeedForward(d_model, d_ff, dropout=0.1, activation="gelu")
98
99    x = torch.randn(batch_size, seq_len, d_model)
100    output = ffn(x)
101
102    print(f"Input shape:  {x.shape}")
103    print(f"Output shape: {output.shape}")
104    print(f"\nFFN Architecture:")
105    print(f"  Linear1: {d_model} β†’ {d_ff}")
106    print(f"  Activation: GELU")
107    print(f"  Dropout: 0.1")
108    print(f"  Linear2: {d_ff} β†’ {d_model}")
109    print(f"\nParameter count: {sum(p.numel() for p in ffn.parameters()):,}")
110
111    # Verify shapes
112    assert output.shape == x.shape
113    print("\nβœ“ FFN test passed!")
114
115
116test_ffn()

Output:

πŸ“text
1Input shape:  torch.Size([2, 10, 512])
2Output shape: torch.Size([2, 10, 512])
3
4FFN Architecture:
5  Linear1: 512 β†’ 2048
6  Activation: GELU
7  Dropout: 0.1
8  Linear2: 2048 β†’ 512
9
10Parameter count: 4,197,376
11
12βœ“ FFN test passed!

1.6 Advanced: Gated Linear Units (GLU)

GLU Variants

Modern models often use gated variants that split the hidden dimension:

🐍python
1class GatedFFN(nn.Module):
2    """
3    Gated Feed-Forward Network (used in PaLM, LLaMA).
4
5    Instead of: Activation(xW₁)Wβ‚‚
6    Uses:      (xW₁ βŠ™ Activation(xW_gate))Wβ‚‚
7
8    Where βŠ™ is element-wise multiplication.
9    """
10
11    def __init__(
12        self,
13        d_model: int,
14        d_ff: int,
15        dropout: float = 0.1,
16        activation: str = "silu"
17    ):
18        super().__init__()
19
20        # Gate and up projections (can be combined for efficiency)
21        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
22        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
23        self.down_proj = nn.Linear(d_ff, d_model, bias=False)
24
25        self.dropout = nn.Dropout(dropout)
26        self.activation = F.silu if activation == "silu" else F.gelu
27
28    def forward(self, x: torch.Tensor) -> torch.Tensor:
29        """
30        Args:
31            x: [batch, seq_len, d_model]
32        Returns:
33            [batch, seq_len, d_model]
34        """
35        # Gate mechanism
36        gate = self.activation(self.gate_proj(x))
37        up = self.up_proj(x)
38
39        # Element-wise gating
40        hidden = gate * up
41        hidden = self.dropout(hidden)
42
43        # Down projection
44        output = self.down_proj(hidden)
45
46        return output
47
48
49# Test GLU variant
50def test_gated_ffn():
51    d_model = 512
52    d_ff = 2048  # Note: effective expansion is different due to gating
53
54    gated_ffn = GatedFFN(d_model, d_ff, dropout=0.0)
55
56    x = torch.randn(2, 10, d_model)
57    output = gated_ffn(x)
58
59    print(f"Input shape: {x.shape}")
60    print(f"Output shape: {output.shape}")
61    print(f"Parameter count: {sum(p.numel() for p in gated_ffn.parameters()):,}")
62
63    print("\nβœ“ Gated FFN test passed!")
64
65
66test_gated_ffn()

1.7 Understanding the FFN's Role

What Does FFN Learn?

Research suggests FFN layers act as key-value memories:

πŸ“text
1Linear1 (keys):   Pattern detectors - "Is this a verb?"
2Hidden (values):  Associated information - "Verb properties"
3Linear2 (output): Combine into representation

FFN vs Attention Comparison

AspectAttentionFFN
ScopeGlobal (all tokens)Local (single token)
NatureLinear combinationNon-linear transformation
ParametersFewerMore (typically 2/3 of layer)
RoleInformation routingInformation processing

Parameter Distribution

In a typical transformer layer:

  • Attention: ~1/3 of parameters
  • FFN: ~2/3 of parameters
🐍python
1def count_parameters(d_model, d_ff, num_heads):
2    """Count parameters in attention vs FFN."""
3    d_k = d_model // num_heads
4
5    # Attention: W_Q, W_K, W_V, W_O
6    attention_params = 4 * d_model * d_model
7
8    # FFN: W1, b1, W2, b2
9    ffn_params = 2 * d_model * d_ff + d_model + d_ff
10
11    total = attention_params + ffn_params
12
13    print(f"Attention parameters: {attention_params:,} ({attention_params/total*100:.1f}%)")
14    print(f"FFN parameters: {ffn_params:,} ({ffn_params/total*100:.1f}%)")
15    print(f"Total: {total:,}")
16
17
18count_parameters(d_model=512, d_ff=2048, num_heads=8)

Output:

πŸ“text
1Attention parameters: 1,048,576 (20.0%)
2FFN parameters: 4,197,376 (80.0%)
3Total: 5,245,952

1.8 Variations and Extensions

MoE (Mixture of Experts)

Instead of one FFN, use multiple "expert" FFNs:

🐍python
1class MoEFFN(nn.Module):
2    """
3    Mixture of Experts FFN (simplified).
4
5    Routes each token to top-k experts based on a learned gating function.
6    """
7
8    def __init__(
9        self,
10        d_model: int,
11        d_ff: int,
12        num_experts: int = 8,
13        top_k: int = 2
14    ):
15        super().__init__()
16
17        self.num_experts = num_experts
18        self.top_k = top_k
19
20        # Expert FFNs
21        self.experts = nn.ModuleList([
22            PositionwiseFeedForward(d_model, d_ff)
23            for _ in range(num_experts)
24        ])
25
26        # Router
27        self.router = nn.Linear(d_model, num_experts)
28
29    def forward(self, x: torch.Tensor) -> torch.Tensor:
30        """Route tokens to experts."""
31        batch, seq_len, d_model = x.shape
32
33        # Get routing weights
34        router_logits = self.router(x)  # [batch, seq_len, num_experts]
35        routing_weights = F.softmax(router_logits, dim=-1)
36
37        # Select top-k experts
38        top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
39        top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
40
41        # Apply experts (simplified - actual implementation is more complex)
42        output = torch.zeros_like(x)
43        for k in range(self.top_k):
44            for e in range(self.num_experts):
45                mask = (top_k_indices[..., k] == e)
46                if mask.any():
47                    expert_input = x[mask]
48                    expert_output = self.experts[e](expert_input.unsqueeze(0)).squeeze(0)
49                    output[mask] += top_k_weights[mask, k:k+1] * expert_output
50
51        return output

Summary

FFN Architecture

πŸ“text
1Input [batch, seq_len, d_model]
2           β”‚
3           β–Ό
4     Linear1 (d_model β†’ d_ff)
5           β”‚
6           β–Ό
7     Activation (ReLU/GELU)
8           β”‚
9           β–Ό
10        Dropout
11           β”‚
12           β–Ό
13     Linear2 (d_ff β†’ d_model)
14           β”‚
15           β–Ό
16Output [batch, seq_len, d_model]

Key Points

PropertyValue
Typical d_ff4 Γ— d_model
ActivationGELU (modern) or ReLU (original)
Parameters~2/3 of transformer layer
ProcessingPosition-independent (same weights)

Implementation Checklist

  • Two linear layers with dimension expansion/contraction
  • Activation function between layers
  • Dropout on hidden layer
  • Shape preservation: input and output same dimension

Exercises

Implementation Exercises

1. Implement an FFN with configurable number of hidden layers (not just 2).

2. Add layer normalization inside the FFN (some models do this).

3. Implement a "bottleneck" FFN where d_ff < d_model.

Analysis Exercises

4. Compare training with ReLU vs GELU vs SiLU. Which converges faster?

5. Visualize the activation patterns in a trained FFN. What patterns emerge?

6. Experiment with different d_ff ratios (2Γ—, 4Γ—, 8Γ—). How does it affect performance?

Conceptual Questions

7. Why is the FFN applied "position-wise" rather than across all positions?

8. What would happen if we removed the FFN entirely? How would the model behave?


Next Section Preview

In the next section, we'll dive deep into Layer Normalizationβ€”the normalization technique that makes transformer training stable. We'll understand why it's preferred over batch normalization and implement it from scratch.