Chapter 16
25 min read
Section 72 of 75

Mixture of Experts (MoE)

Advanced Architectures

Introduction

Mixture of Experts (MoE) is an architecture that enables training models with trillion+ parameters while keeping compute costs manageable. This section explains MoE architecture, sparse gating, and implementation details.


2.1 Understanding Mixture of Experts

The Basic Concept

📝text
1THE SCALING PROBLEM:
2────────────────────
3
4Dense Transformer: Every token uses ALL parameters
5
6Input → [Attention] → [FFN with 100% params] → Output
7
8Scaling issue:
9• 2x parameters = 2x compute per token
10• 2x parameters = 2x memory
11• Very expensive to scale!
12
13
14MoE SOLUTION: Conditional Computation
15─────────────────────────────────────
16
17Instead of one large FFN, use multiple "expert" FFNs.
18Each token only activates a FEW experts.
19
20Input → [Attention] → [Router] → Selected Experts → Output
2122                Expert 1: 12.5%  ─┐
23                Expert 2:  0.0%   │
24                Expert 3: 87.5%  ─┼→ Weighted combination
25                Expert 4:  0.0%   │
26                Expert 5:  0.0%   │
27                Expert 6:  0.0%   │
28                Expert 7:  0.0%   │
29                Expert 8:  0.0%  ─┘

Famous MoE Models:

ModelTotal ParamsActiveExpertsNotes
Switch-Base7B200M128Top-1
Switch-Large26B800M128Top-1
Switch-XXL1.6T-2048Top-1
Mixtral 8x7B47B13B8Top-2
GPT-4 (rumored)1.8T~220B16Top-2?
DeepSeek-V2236B21B160Top-6

Key Properties:

1. Total Parameters: N × expert_params
Example: 8 experts × 1B params = 8B total params

2. Active Parameters: K × expert_params (where K << N)
Example: Top-2 routing × 1B = 2B active per token

3. Compute Efficiency:
Dense 8B model: 8B FLOPs per token
MoE 8B (top-2): 2B FLOPs per token
4x less compute for same capacity!


2.2 Router (Gating Network)

Top-K Routing Implementation

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Tuple, Optional
5import math
6
7
8class TopKRouter(nn.Module):
9    """
10    Top-K routing for Mixture of Experts.
11
12    Selects top-K experts for each token based on learned gating.
13    """
14
15    def __init__(
16        self,
17        d_model: int,
18        num_experts: int,
19        top_k: int = 2,
20        noise_std: float = 0.1,
21        capacity_factor: float = 1.25
22    ):
23        """
24        Initialize router.
25
26        Args:
27            d_model: Input dimension
28            num_experts: Number of experts
29            top_k: Number of experts to route to
30            noise_std: Standard deviation of noise during training
31            capacity_factor: Expert capacity multiplier
32        """
33        super().__init__()
34
35        self.num_experts = num_experts
36        self.top_k = top_k
37        self.noise_std = noise_std
38        self.capacity_factor = capacity_factor
39
40        # Gating network (simple linear)
41        self.gate = nn.Linear(d_model, num_experts, bias=False)
42
43    def forward(
44        self,
45        hidden_states: torch.Tensor
46    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
47        """
48        Compute routing weights.
49
50        Args:
51            hidden_states: [batch_size, seq_len, d_model]
52
53        Returns:
54            router_weights: [batch_size, seq_len, top_k]
55            selected_experts: [batch_size, seq_len, top_k]
56            router_logits: [batch_size, seq_len, num_experts]
57        """
58        batch_size, seq_len, d_model = hidden_states.shape
59
60        # Compute router logits
61        router_logits = self.gate(hidden_states)  # [B, S, E]
62
63        # Add noise during training for exploration
64        if self.training and self.noise_std > 0:
65            noise = torch.randn_like(router_logits) * self.noise_std
66            router_logits = router_logits + noise
67
68        # Get top-k experts
69        router_weights, selected_experts = torch.topk(
70            router_logits, self.top_k, dim=-1
71        )
72
73        # Normalize weights (softmax over selected experts)
74        router_weights = F.softmax(router_weights, dim=-1)
75
76        return router_weights, selected_experts, router_logits
77
78
79def demonstrate_routing():
80    """
81    Demonstrate router behavior.
82    """
83    print("Router Demonstration")
84    print("=" * 60)
85
86    # Setup
87    batch_size = 2
88    seq_len = 8
89    d_model = 64
90    num_experts = 4
91    top_k = 2
92
93    # Random input
94    torch.manual_seed(42)
95    hidden_states = torch.randn(batch_size, seq_len, d_model)
96
97    # Top-K Router
98    router = TopKRouter(d_model, num_experts, top_k)
99    weights, experts, logits = router(hidden_states)
100
101    print(f"Input shape: [{batch_size}, {seq_len}, {d_model}]")
102    print(f"Num experts: {num_experts}, Top-K: {top_k}")
103    print()
104
105    print("Sample routing decisions (first 4 tokens):")
106    print("-" * 50)
107    for i in range(4):
108        exp = experts[0, i].tolist()
109        wt = weights[0, i].tolist()
110        print(f"Token {i}: Expert {exp[0]} ({wt[0]:.2f}) + Expert {exp[1]} ({wt[1]:.2f})")
111
112    # Expert load distribution
113    expert_counts = torch.zeros(num_experts)
114    for e in range(num_experts):
115        expert_counts[e] = (experts == e).sum().item()
116
117    print(f"\nExpert load distribution:")
118    for e in range(num_experts):
119        bar = "█" * int(expert_counts[e] / 2)
120        print(f"  Expert {e}: {expert_counts[e]:3.0f} tokens  {bar}")
121
122
123demonstrate_routing()

2.3 Expert Layer Implementation

Sparse MoE Layer

🐍python
1class Expert(nn.Module):
2    """
3    Single expert (standard FFN).
4    """
5
6    def __init__(
7        self,
8        d_model: int,
9        d_ff: int,
10        dropout: float = 0.0
11    ):
12        super().__init__()
13
14        self.fc1 = nn.Linear(d_model, d_ff)
15        self.fc2 = nn.Linear(d_ff, d_model)
16        self.dropout = nn.Dropout(dropout)
17        self.activation = nn.GELU()
18
19    def forward(self, x: torch.Tensor) -> torch.Tensor:
20        x = self.fc1(x)
21        x = self.activation(x)
22        x = self.dropout(x)
23        x = self.fc2(x)
24        return x
25
26
27class SparseMoELayer(nn.Module):
28    """
29    Sparse Mixture of Experts layer.
30
31    Replaces standard FFN in transformer.
32    """
33
34    def __init__(
35        self,
36        d_model: int,
37        d_ff: int,
38        num_experts: int = 8,
39        top_k: int = 2,
40        dropout: float = 0.0,
41        noise_std: float = 0.1
42    ):
43        """
44        Initialize MoE layer.
45
46        Args:
47            d_model: Model dimension
48            d_ff: FFN intermediate dimension
49            num_experts: Number of expert FFNs
50            top_k: Number of experts per token
51            dropout: Dropout rate
52            noise_std: Router noise
53        """
54        super().__init__()
55
56        self.num_experts = num_experts
57        self.top_k = top_k
58        self.d_model = d_model
59
60        # Router
61        self.router = TopKRouter(d_model, num_experts, top_k, noise_std)
62
63        # Experts
64        self.experts = nn.ModuleList([
65            Expert(d_model, d_ff, dropout)
66            for _ in range(num_experts)
67        ])
68
69    def forward(
70        self,
71        hidden_states: torch.Tensor
72    ) -> Tuple[torch.Tensor, torch.Tensor]:
73        """
74        Forward pass through MoE.
75
76        Args:
77            hidden_states: [batch_size, seq_len, d_model]
78
79        Returns:
80            output: [batch_size, seq_len, d_model]
81            router_logits: [batch_size, seq_len, num_experts]
82        """
83        batch_size, seq_len, d_model = hidden_states.shape
84
85        # Get routing decisions
86        router_weights, selected_experts, router_logits = self.router(hidden_states)
87
88        # Initialize output
89        output = torch.zeros_like(hidden_states)
90
91        # Process each expert
92        # This is a simple implementation; real implementations batch tokens per expert
93        for expert_idx in range(self.num_experts):
94            # Find tokens assigned to this expert
95            expert_mask = (selected_experts == expert_idx)  # [B, S, K]
96
97            if expert_mask.any():
98                # Get indices where this expert is selected
99                # This is inefficient but clear
100                for k in range(self.top_k):
101                    mask = expert_mask[:, :, k]  # [B, S]
102
103                    if mask.any():
104                        # Get tokens for this expert
105                        expert_input = hidden_states[mask]  # [num_tokens, d_model]
106
107                        # Process through expert
108                        expert_output = self.experts[expert_idx](expert_input)
109
110                        # Get corresponding weights
111                        expert_weights = router_weights[:, :, k][mask]  # [num_tokens]
112
113                        # Weighted contribution
114                        weighted_output = expert_output * expert_weights.unsqueeze(-1)
115
116                        # Add to output
117                        output[mask] += weighted_output
118
119        return output, router_logits
120
121
122def test_moe_layer():
123    """
124    Test MoE layer.
125    """
126    print("Testing MoE Layer")
127    print("=" * 60)
128
129    batch_size = 4
130    seq_len = 16
131    d_model = 256
132    d_ff = 512
133    num_experts = 4
134    top_k = 2
135
136    # Input
137    x = torch.randn(batch_size, seq_len, d_model)
138
139    # MoE layer
140    moe = SparseMoELayer(d_model, d_ff, num_experts, top_k)
141
142    # Forward
143    output, router_logits = moe(x)
144
145    print(f"Input shape: {x.shape}")
146    print(f"Output shape: {output.shape}")
147    print(f"Router logits shape: {router_logits.shape}")
148
149    # Compare params with dense FFN
150    dense_params = d_model * d_ff + d_ff + d_ff * d_model + d_model
151    moe_params = num_experts * (d_model * d_ff + d_ff + d_ff * d_model + d_model)
152    router_params = d_model * num_experts
153
154    print(f"\nParameter comparison:")
155    print(f"  Dense FFN: {dense_params:,}")
156    print(f"  MoE ({num_experts} experts): {moe_params + router_params:,}")
157    print(f"  Ratio: {(moe_params + router_params) / dense_params:.1f}x more params")
158    print(f"  Active per token: {2 * dense_params:,} (top-{top_k})")
159
160
161test_moe_layer()

2.4 Load Balancing Loss

Preventing Expert Collapse

🐍python
1class MoEAuxiliaryLoss(nn.Module):
2    """
3    Auxiliary losses for MoE training.
4
5    Prevents:
6    1. Router collapse (all tokens to few experts)
7    2. Expert imbalance (some experts never trained)
8    """
9
10    def __init__(
11        self,
12        num_experts: int,
13        top_k: int,
14        aux_loss_weight: float = 0.01
15    ):
16        """
17        Initialize auxiliary loss.
18
19        Args:
20            num_experts: Number of experts
21            top_k: Number of selected experts per token
22            aux_loss_weight: Weight for auxiliary loss
23        """
24        super().__init__()
25        self.num_experts = num_experts
26        self.top_k = top_k
27        self.aux_loss_weight = aux_loss_weight
28
29    def load_balancing_loss(
30        self,
31        router_logits: torch.Tensor
32    ) -> torch.Tensor:
33        """
34        Load balancing loss from Switch Transformer.
35
36        Encourages uniform expert utilization.
37
38        Args:
39            router_logits: [batch_size, seq_len, num_experts]
40
41        Returns:
42            Auxiliary loss scalar
43        """
44        # Get routing probabilities
45        router_probs = F.softmax(router_logits, dim=-1)  # [B, S, E]
46
47        # Fraction of tokens routed to each expert
48        # f_i = mean probability of routing to expert i
49        expert_usage = router_probs.mean(dim=[0, 1])  # [E]
50
51        # Fraction of router probability allocated to each expert
52        # P_i = mean of max probability among tokens for expert i
53        # Simplified: fraction of probability mass going to each expert
54        expert_load = router_probs.sum(dim=[0, 1]) / router_probs.sum()  # [E]
55
56        # Loss: sum of f_i * P_i * num_experts
57        # Minimized when uniform: f_i = P_i = 1/num_experts
58        aux_loss = (expert_usage * expert_load).sum() * self.num_experts
59
60        return aux_loss * self.aux_loss_weight
61
62    def router_z_loss(
63        self,
64        router_logits: torch.Tensor
65    ) -> torch.Tensor:
66        """
67        Router z-loss to prevent router logits from growing too large.
68
69        From ST-MoE paper.
70
71        Args:
72            router_logits: [batch_size, seq_len, num_experts]
73
74        Returns:
75            Z-loss scalar
76        """
77        # Log-sum-exp of router logits
78        log_z = torch.logsumexp(router_logits, dim=-1)  # [B, S]
79
80        # Loss is mean of squared log_z
81        z_loss = log_z.pow(2).mean()
82
83        return z_loss * 0.001  # Small weight
84
85    def forward(
86        self,
87        router_logits: torch.Tensor
88    ) -> Tuple[torch.Tensor, dict]:
89        """
90        Compute total auxiliary loss.
91
92        Args:
93            router_logits: Router logits from MoE layers
94
95        Returns:
96            Total aux loss, dict of individual losses
97        """
98        load_loss = self.load_balancing_loss(router_logits)
99        z_loss = self.router_z_loss(router_logits)
100
101        total_loss = load_loss + z_loss
102
103        return total_loss, {
104            'load_balancing_loss': load_loss.item(),
105            'router_z_loss': z_loss.item()
106        }

Without Load Balancing:

📝text
1Problem: Router learns to always select same experts
2
3Training progression:
4Epoch 1: Expert distribution [30%, 20%, 25%, 25%]
5Epoch 2: Expert distribution [45%, 15%, 25%, 15%]
6Epoch 3: Expert distribution [70%, 10%, 15%, 5%]
7Epoch 4: Expert distribution [95%, 2%, 2%, 1%]
8
9Result: Only 1 expert is trained effectively!
10This is called "expert collapse" or "winner take all"

With Load Balancing Loss:

📝text
1Loss = CrossEntropy + α * LoadBalancingLoss
2
3LoadBalancingLoss encourages uniform distribution:
4- Penalizes when some experts get too much traffic
5- Penalizes when some experts get too little
6
7Training with aux loss:
8Epoch 1: Expert distribution [30%, 20%, 25%, 25%]
9Epoch 2: Expert distribution [28%, 24%, 24%, 24%]
10Epoch 3: Expert distribution [26%, 25%, 25%, 24%]
11Epoch 4: Expert distribution [25%, 25%, 25%, 25%]
12
13Result: All experts trained equally!

2.5 Complete MoE Transformer Block

Putting It All Together

🐍python
1class MoETransformerBlock(nn.Module):
2    """
3    Transformer block with MoE FFN.
4    """
5
6    def __init__(
7        self,
8        d_model: int,
9        num_heads: int,
10        d_ff: int,
11        num_experts: int = 8,
12        top_k: int = 2,
13        dropout: float = 0.1,
14        use_moe: bool = True
15    ):
16        """
17        Initialize MoE transformer block.
18
19        Args:
20            d_model: Model dimension
21            num_heads: Number of attention heads
22            d_ff: FFN intermediate dimension
23            num_experts: Number of experts
24            top_k: Experts per token
25            dropout: Dropout rate
26            use_moe: Whether to use MoE (vs dense FFN)
27        """
28        super().__init__()
29
30        self.use_moe = use_moe
31
32        # Self-attention
33        self.self_attn = nn.MultiheadAttention(
34            d_model, num_heads, dropout=dropout, batch_first=True
35        )
36        self.norm1 = nn.LayerNorm(d_model)
37        self.dropout1 = nn.Dropout(dropout)
38
39        # FFN (MoE or dense)
40        if use_moe:
41            self.ffn = SparseMoELayer(
42                d_model, d_ff, num_experts, top_k, dropout
43            )
44        else:
45            self.ffn = nn.Sequential(
46                nn.Linear(d_model, d_ff),
47                nn.GELU(),
48                nn.Dropout(dropout),
49                nn.Linear(d_ff, d_model)
50            )
51
52        self.norm2 = nn.LayerNorm(d_model)
53        self.dropout2 = nn.Dropout(dropout)
54
55    def forward(
56        self,
57        x: torch.Tensor,
58        attn_mask: Optional[torch.Tensor] = None
59    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
60        """
61        Forward pass.
62
63        Args:
64            x: Input [batch_size, seq_len, d_model]
65            attn_mask: Optional attention mask
66
67        Returns:
68            output: [batch_size, seq_len, d_model]
69            router_logits: Optional router logits for aux loss
70        """
71        # Self-attention with residual
72        residual = x
73        x = self.norm1(x)
74        x, _ = self.self_attn(x, x, x, attn_mask=attn_mask)
75        x = self.dropout1(x) + residual
76
77        # FFN with residual
78        residual = x
79        x = self.norm2(x)
80
81        if self.use_moe:
82            x, router_logits = self.ffn(x)
83        else:
84            x = self.ffn(x)
85            router_logits = None
86
87        x = self.dropout2(x) + residual
88
89        return x, router_logits
90
91
92class MoETransformer(nn.Module):
93    """
94    Transformer with MoE layers.
95
96    Typically, MoE is applied to every other layer or every 4th layer.
97    """
98
99    def __init__(
100        self,
101        vocab_size: int,
102        d_model: int = 512,
103        num_heads: int = 8,
104        num_layers: int = 12,
105        d_ff: int = 2048,
106        num_experts: int = 8,
107        top_k: int = 2,
108        moe_frequency: int = 2,  # Apply MoE every N layers
109        dropout: float = 0.1,
110        max_seq_len: int = 512
111    ):
112        """
113        Initialize MoE Transformer.
114
115        Args:
116            vocab_size: Vocabulary size
117            d_model: Model dimension
118            num_heads: Number of attention heads
119            num_layers: Number of transformer layers
120            d_ff: FFN intermediate dimension
121            num_experts: Number of experts
122            top_k: Experts per token
123            moe_frequency: Apply MoE every N layers
124            dropout: Dropout rate
125            max_seq_len: Maximum sequence length
126        """
127        super().__init__()
128
129        self.d_model = d_model
130        self.num_experts = num_experts
131
132        # Embeddings
133        self.token_embedding = nn.Embedding(vocab_size, d_model)
134        self.position_embedding = nn.Embedding(max_seq_len, d_model)
135        self.dropout = nn.Dropout(dropout)
136
137        # Transformer layers (alternating dense and MoE)
138        self.layers = nn.ModuleList()
139        for i in range(num_layers):
140            use_moe = (i % moe_frequency == moe_frequency - 1)  # Every Nth layer
141            self.layers.append(
142                MoETransformerBlock(
143                    d_model, num_heads, d_ff, num_experts, top_k, dropout, use_moe
144                )
145            )
146
147        self.norm = nn.LayerNorm(d_model)
148        self.output = nn.Linear(d_model, vocab_size)
149
150        # Auxiliary loss
151        self.aux_loss_fn = MoEAuxiliaryLoss(num_experts, top_k)
152
153    def forward(
154        self,
155        input_ids: torch.Tensor,
156        attention_mask: Optional[torch.Tensor] = None
157    ) -> Tuple[torch.Tensor, torch.Tensor]:
158        """
159        Forward pass.
160
161        Args:
162            input_ids: [batch_size, seq_len]
163            attention_mask: Optional mask
164
165        Returns:
166            logits: [batch_size, seq_len, vocab_size]
167            aux_loss: Auxiliary loss for MoE
168        """
169        batch_size, seq_len = input_ids.shape
170
171        # Embeddings
172        positions = torch.arange(seq_len, device=input_ids.device)
173        x = self.token_embedding(input_ids) + self.position_embedding(positions)
174        x = self.dropout(x)
175
176        # Collect router logits for aux loss
177        all_router_logits = []
178
179        # Transformer layers
180        for layer in self.layers:
181            x, router_logits = layer(x, attention_mask)
182            if router_logits is not None:
183                all_router_logits.append(router_logits)
184
185        x = self.norm(x)
186        logits = self.output(x)
187
188        # Compute auxiliary loss
189        if all_router_logits:
190            # Concatenate and compute loss
191            combined_logits = torch.cat(all_router_logits, dim=0)
192            aux_loss, _ = self.aux_loss_fn(combined_logits)
193        else:
194            aux_loss = torch.tensor(0.0, device=input_ids.device)
195
196        return logits, aux_loss

Summary

MoE Key Points:

AspectDense ModelMoE Model
ParametersAll activeSparse (top-K active)
FLOPs per tokenProportional to params~Fixed regardless of params
Memory (inference)Full modelFull model
Memory (training)Activations + gradientsSame + routing overhead
ScalingExpensiveEfficient

Design Decisions:

1. Number of experts: 8-128 (sweet spot: 8-16)
2. Top-K: 1-2 (top-2 most common)
3. MoE frequency: Every layer or every 2nd/4th layer
4. Expert size: Same as dense FFN or smaller
5. Aux loss weight: 0.01-0.1

Exercises:

1. Implement Top-1 routing and compare with Top-2.
2. Visualize expert utilization during training.
3. Train a small MoE model on text generation and compare with dense.
4. Implement Expert Choice routing and compare load balance.
5. Measure inference latency of MoE vs dense with same active parameters.

Next Section: In Section 3, we'll explore modern position encoding techniques including Rotary Position Embeddings (RoPE) and ALiBi.

Loading comments...