Introduction
Mixture of Experts (MoE) is an architecture that enables training models with trillion+ parameters while keeping compute costs manageable. This section explains MoE architecture, sparse gating, and implementation details.
2.1 Understanding Mixture of Experts
The Basic Concept
1THE SCALING PROBLEM:
2────────────────────
3
4Dense Transformer: Every token uses ALL parameters
5
6Input → [Attention] → [FFN with 100% params] → Output
7
8Scaling issue:
9• 2x parameters = 2x compute per token
10• 2x parameters = 2x memory
11• Very expensive to scale!
12
13
14MoE SOLUTION: Conditional Computation
15─────────────────────────────────────
16
17Instead of one large FFN, use multiple "expert" FFNs.
18Each token only activates a FEW experts.
19
20Input → [Attention] → [Router] → Selected Experts → Output
21 ↓
22 Expert 1: 12.5% ─┐
23 Expert 2: 0.0% │
24 Expert 3: 87.5% ─┼→ Weighted combination
25 Expert 4: 0.0% │
26 Expert 5: 0.0% │
27 Expert 6: 0.0% │
28 Expert 7: 0.0% │
29 Expert 8: 0.0% ─┘Famous MoE Models:
| Model | Total Params | Active | Experts | Notes |
|---|---|---|---|---|
| Switch-Base | 7B | 200M | 128 | Top-1 |
| Switch-Large | 26B | 800M | 128 | Top-1 |
| Switch-XXL | 1.6T | - | 2048 | Top-1 |
| Mixtral 8x7B | 47B | 13B | 8 | Top-2 |
| GPT-4 (rumored) | 1.8T | ~220B | 16 | Top-2? |
| DeepSeek-V2 | 236B | 21B | 160 | Top-6 |
Key Properties:
1. Total Parameters: N × expert_params
Example: 8 experts × 1B params = 8B total params
2. Active Parameters: K × expert_params (where K << N)
Example: Top-2 routing × 1B = 2B active per token
3. Compute Efficiency:
Dense 8B model: 8B FLOPs per token
MoE 8B (top-2): 2B FLOPs per token
4x less compute for same capacity!
2.2 Router (Gating Network)
Top-K Routing Implementation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Tuple, Optional
5import math
6
7
8class TopKRouter(nn.Module):
9 """
10 Top-K routing for Mixture of Experts.
11
12 Selects top-K experts for each token based on learned gating.
13 """
14
15 def __init__(
16 self,
17 d_model: int,
18 num_experts: int,
19 top_k: int = 2,
20 noise_std: float = 0.1,
21 capacity_factor: float = 1.25
22 ):
23 """
24 Initialize router.
25
26 Args:
27 d_model: Input dimension
28 num_experts: Number of experts
29 top_k: Number of experts to route to
30 noise_std: Standard deviation of noise during training
31 capacity_factor: Expert capacity multiplier
32 """
33 super().__init__()
34
35 self.num_experts = num_experts
36 self.top_k = top_k
37 self.noise_std = noise_std
38 self.capacity_factor = capacity_factor
39
40 # Gating network (simple linear)
41 self.gate = nn.Linear(d_model, num_experts, bias=False)
42
43 def forward(
44 self,
45 hidden_states: torch.Tensor
46 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
47 """
48 Compute routing weights.
49
50 Args:
51 hidden_states: [batch_size, seq_len, d_model]
52
53 Returns:
54 router_weights: [batch_size, seq_len, top_k]
55 selected_experts: [batch_size, seq_len, top_k]
56 router_logits: [batch_size, seq_len, num_experts]
57 """
58 batch_size, seq_len, d_model = hidden_states.shape
59
60 # Compute router logits
61 router_logits = self.gate(hidden_states) # [B, S, E]
62
63 # Add noise during training for exploration
64 if self.training and self.noise_std > 0:
65 noise = torch.randn_like(router_logits) * self.noise_std
66 router_logits = router_logits + noise
67
68 # Get top-k experts
69 router_weights, selected_experts = torch.topk(
70 router_logits, self.top_k, dim=-1
71 )
72
73 # Normalize weights (softmax over selected experts)
74 router_weights = F.softmax(router_weights, dim=-1)
75
76 return router_weights, selected_experts, router_logits
77
78
79def demonstrate_routing():
80 """
81 Demonstrate router behavior.
82 """
83 print("Router Demonstration")
84 print("=" * 60)
85
86 # Setup
87 batch_size = 2
88 seq_len = 8
89 d_model = 64
90 num_experts = 4
91 top_k = 2
92
93 # Random input
94 torch.manual_seed(42)
95 hidden_states = torch.randn(batch_size, seq_len, d_model)
96
97 # Top-K Router
98 router = TopKRouter(d_model, num_experts, top_k)
99 weights, experts, logits = router(hidden_states)
100
101 print(f"Input shape: [{batch_size}, {seq_len}, {d_model}]")
102 print(f"Num experts: {num_experts}, Top-K: {top_k}")
103 print()
104
105 print("Sample routing decisions (first 4 tokens):")
106 print("-" * 50)
107 for i in range(4):
108 exp = experts[0, i].tolist()
109 wt = weights[0, i].tolist()
110 print(f"Token {i}: Expert {exp[0]} ({wt[0]:.2f}) + Expert {exp[1]} ({wt[1]:.2f})")
111
112 # Expert load distribution
113 expert_counts = torch.zeros(num_experts)
114 for e in range(num_experts):
115 expert_counts[e] = (experts == e).sum().item()
116
117 print(f"\nExpert load distribution:")
118 for e in range(num_experts):
119 bar = "█" * int(expert_counts[e] / 2)
120 print(f" Expert {e}: {expert_counts[e]:3.0f} tokens {bar}")
121
122
123demonstrate_routing()2.3 Expert Layer Implementation
Sparse MoE Layer
1class Expert(nn.Module):
2 """
3 Single expert (standard FFN).
4 """
5
6 def __init__(
7 self,
8 d_model: int,
9 d_ff: int,
10 dropout: float = 0.0
11 ):
12 super().__init__()
13
14 self.fc1 = nn.Linear(d_model, d_ff)
15 self.fc2 = nn.Linear(d_ff, d_model)
16 self.dropout = nn.Dropout(dropout)
17 self.activation = nn.GELU()
18
19 def forward(self, x: torch.Tensor) -> torch.Tensor:
20 x = self.fc1(x)
21 x = self.activation(x)
22 x = self.dropout(x)
23 x = self.fc2(x)
24 return x
25
26
27class SparseMoELayer(nn.Module):
28 """
29 Sparse Mixture of Experts layer.
30
31 Replaces standard FFN in transformer.
32 """
33
34 def __init__(
35 self,
36 d_model: int,
37 d_ff: int,
38 num_experts: int = 8,
39 top_k: int = 2,
40 dropout: float = 0.0,
41 noise_std: float = 0.1
42 ):
43 """
44 Initialize MoE layer.
45
46 Args:
47 d_model: Model dimension
48 d_ff: FFN intermediate dimension
49 num_experts: Number of expert FFNs
50 top_k: Number of experts per token
51 dropout: Dropout rate
52 noise_std: Router noise
53 """
54 super().__init__()
55
56 self.num_experts = num_experts
57 self.top_k = top_k
58 self.d_model = d_model
59
60 # Router
61 self.router = TopKRouter(d_model, num_experts, top_k, noise_std)
62
63 # Experts
64 self.experts = nn.ModuleList([
65 Expert(d_model, d_ff, dropout)
66 for _ in range(num_experts)
67 ])
68
69 def forward(
70 self,
71 hidden_states: torch.Tensor
72 ) -> Tuple[torch.Tensor, torch.Tensor]:
73 """
74 Forward pass through MoE.
75
76 Args:
77 hidden_states: [batch_size, seq_len, d_model]
78
79 Returns:
80 output: [batch_size, seq_len, d_model]
81 router_logits: [batch_size, seq_len, num_experts]
82 """
83 batch_size, seq_len, d_model = hidden_states.shape
84
85 # Get routing decisions
86 router_weights, selected_experts, router_logits = self.router(hidden_states)
87
88 # Initialize output
89 output = torch.zeros_like(hidden_states)
90
91 # Process each expert
92 # This is a simple implementation; real implementations batch tokens per expert
93 for expert_idx in range(self.num_experts):
94 # Find tokens assigned to this expert
95 expert_mask = (selected_experts == expert_idx) # [B, S, K]
96
97 if expert_mask.any():
98 # Get indices where this expert is selected
99 # This is inefficient but clear
100 for k in range(self.top_k):
101 mask = expert_mask[:, :, k] # [B, S]
102
103 if mask.any():
104 # Get tokens for this expert
105 expert_input = hidden_states[mask] # [num_tokens, d_model]
106
107 # Process through expert
108 expert_output = self.experts[expert_idx](expert_input)
109
110 # Get corresponding weights
111 expert_weights = router_weights[:, :, k][mask] # [num_tokens]
112
113 # Weighted contribution
114 weighted_output = expert_output * expert_weights.unsqueeze(-1)
115
116 # Add to output
117 output[mask] += weighted_output
118
119 return output, router_logits
120
121
122def test_moe_layer():
123 """
124 Test MoE layer.
125 """
126 print("Testing MoE Layer")
127 print("=" * 60)
128
129 batch_size = 4
130 seq_len = 16
131 d_model = 256
132 d_ff = 512
133 num_experts = 4
134 top_k = 2
135
136 # Input
137 x = torch.randn(batch_size, seq_len, d_model)
138
139 # MoE layer
140 moe = SparseMoELayer(d_model, d_ff, num_experts, top_k)
141
142 # Forward
143 output, router_logits = moe(x)
144
145 print(f"Input shape: {x.shape}")
146 print(f"Output shape: {output.shape}")
147 print(f"Router logits shape: {router_logits.shape}")
148
149 # Compare params with dense FFN
150 dense_params = d_model * d_ff + d_ff + d_ff * d_model + d_model
151 moe_params = num_experts * (d_model * d_ff + d_ff + d_ff * d_model + d_model)
152 router_params = d_model * num_experts
153
154 print(f"\nParameter comparison:")
155 print(f" Dense FFN: {dense_params:,}")
156 print(f" MoE ({num_experts} experts): {moe_params + router_params:,}")
157 print(f" Ratio: {(moe_params + router_params) / dense_params:.1f}x more params")
158 print(f" Active per token: {2 * dense_params:,} (top-{top_k})")
159
160
161test_moe_layer()2.4 Load Balancing Loss
Preventing Expert Collapse
1class MoEAuxiliaryLoss(nn.Module):
2 """
3 Auxiliary losses for MoE training.
4
5 Prevents:
6 1. Router collapse (all tokens to few experts)
7 2. Expert imbalance (some experts never trained)
8 """
9
10 def __init__(
11 self,
12 num_experts: int,
13 top_k: int,
14 aux_loss_weight: float = 0.01
15 ):
16 """
17 Initialize auxiliary loss.
18
19 Args:
20 num_experts: Number of experts
21 top_k: Number of selected experts per token
22 aux_loss_weight: Weight for auxiliary loss
23 """
24 super().__init__()
25 self.num_experts = num_experts
26 self.top_k = top_k
27 self.aux_loss_weight = aux_loss_weight
28
29 def load_balancing_loss(
30 self,
31 router_logits: torch.Tensor
32 ) -> torch.Tensor:
33 """
34 Load balancing loss from Switch Transformer.
35
36 Encourages uniform expert utilization.
37
38 Args:
39 router_logits: [batch_size, seq_len, num_experts]
40
41 Returns:
42 Auxiliary loss scalar
43 """
44 # Get routing probabilities
45 router_probs = F.softmax(router_logits, dim=-1) # [B, S, E]
46
47 # Fraction of tokens routed to each expert
48 # f_i = mean probability of routing to expert i
49 expert_usage = router_probs.mean(dim=[0, 1]) # [E]
50
51 # Fraction of router probability allocated to each expert
52 # P_i = mean of max probability among tokens for expert i
53 # Simplified: fraction of probability mass going to each expert
54 expert_load = router_probs.sum(dim=[0, 1]) / router_probs.sum() # [E]
55
56 # Loss: sum of f_i * P_i * num_experts
57 # Minimized when uniform: f_i = P_i = 1/num_experts
58 aux_loss = (expert_usage * expert_load).sum() * self.num_experts
59
60 return aux_loss * self.aux_loss_weight
61
62 def router_z_loss(
63 self,
64 router_logits: torch.Tensor
65 ) -> torch.Tensor:
66 """
67 Router z-loss to prevent router logits from growing too large.
68
69 From ST-MoE paper.
70
71 Args:
72 router_logits: [batch_size, seq_len, num_experts]
73
74 Returns:
75 Z-loss scalar
76 """
77 # Log-sum-exp of router logits
78 log_z = torch.logsumexp(router_logits, dim=-1) # [B, S]
79
80 # Loss is mean of squared log_z
81 z_loss = log_z.pow(2).mean()
82
83 return z_loss * 0.001 # Small weight
84
85 def forward(
86 self,
87 router_logits: torch.Tensor
88 ) -> Tuple[torch.Tensor, dict]:
89 """
90 Compute total auxiliary loss.
91
92 Args:
93 router_logits: Router logits from MoE layers
94
95 Returns:
96 Total aux loss, dict of individual losses
97 """
98 load_loss = self.load_balancing_loss(router_logits)
99 z_loss = self.router_z_loss(router_logits)
100
101 total_loss = load_loss + z_loss
102
103 return total_loss, {
104 'load_balancing_loss': load_loss.item(),
105 'router_z_loss': z_loss.item()
106 }Without Load Balancing:
1Problem: Router learns to always select same experts
2
3Training progression:
4Epoch 1: Expert distribution [30%, 20%, 25%, 25%]
5Epoch 2: Expert distribution [45%, 15%, 25%, 15%]
6Epoch 3: Expert distribution [70%, 10%, 15%, 5%]
7Epoch 4: Expert distribution [95%, 2%, 2%, 1%]
8
9Result: Only 1 expert is trained effectively!
10This is called "expert collapse" or "winner take all"With Load Balancing Loss:
1Loss = CrossEntropy + α * LoadBalancingLoss
2
3LoadBalancingLoss encourages uniform distribution:
4- Penalizes when some experts get too much traffic
5- Penalizes when some experts get too little
6
7Training with aux loss:
8Epoch 1: Expert distribution [30%, 20%, 25%, 25%]
9Epoch 2: Expert distribution [28%, 24%, 24%, 24%]
10Epoch 3: Expert distribution [26%, 25%, 25%, 24%]
11Epoch 4: Expert distribution [25%, 25%, 25%, 25%]
12
13Result: All experts trained equally!2.5 Complete MoE Transformer Block
Putting It All Together
1class MoETransformerBlock(nn.Module):
2 """
3 Transformer block with MoE FFN.
4 """
5
6 def __init__(
7 self,
8 d_model: int,
9 num_heads: int,
10 d_ff: int,
11 num_experts: int = 8,
12 top_k: int = 2,
13 dropout: float = 0.1,
14 use_moe: bool = True
15 ):
16 """
17 Initialize MoE transformer block.
18
19 Args:
20 d_model: Model dimension
21 num_heads: Number of attention heads
22 d_ff: FFN intermediate dimension
23 num_experts: Number of experts
24 top_k: Experts per token
25 dropout: Dropout rate
26 use_moe: Whether to use MoE (vs dense FFN)
27 """
28 super().__init__()
29
30 self.use_moe = use_moe
31
32 # Self-attention
33 self.self_attn = nn.MultiheadAttention(
34 d_model, num_heads, dropout=dropout, batch_first=True
35 )
36 self.norm1 = nn.LayerNorm(d_model)
37 self.dropout1 = nn.Dropout(dropout)
38
39 # FFN (MoE or dense)
40 if use_moe:
41 self.ffn = SparseMoELayer(
42 d_model, d_ff, num_experts, top_k, dropout
43 )
44 else:
45 self.ffn = nn.Sequential(
46 nn.Linear(d_model, d_ff),
47 nn.GELU(),
48 nn.Dropout(dropout),
49 nn.Linear(d_ff, d_model)
50 )
51
52 self.norm2 = nn.LayerNorm(d_model)
53 self.dropout2 = nn.Dropout(dropout)
54
55 def forward(
56 self,
57 x: torch.Tensor,
58 attn_mask: Optional[torch.Tensor] = None
59 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
60 """
61 Forward pass.
62
63 Args:
64 x: Input [batch_size, seq_len, d_model]
65 attn_mask: Optional attention mask
66
67 Returns:
68 output: [batch_size, seq_len, d_model]
69 router_logits: Optional router logits for aux loss
70 """
71 # Self-attention with residual
72 residual = x
73 x = self.norm1(x)
74 x, _ = self.self_attn(x, x, x, attn_mask=attn_mask)
75 x = self.dropout1(x) + residual
76
77 # FFN with residual
78 residual = x
79 x = self.norm2(x)
80
81 if self.use_moe:
82 x, router_logits = self.ffn(x)
83 else:
84 x = self.ffn(x)
85 router_logits = None
86
87 x = self.dropout2(x) + residual
88
89 return x, router_logits
90
91
92class MoETransformer(nn.Module):
93 """
94 Transformer with MoE layers.
95
96 Typically, MoE is applied to every other layer or every 4th layer.
97 """
98
99 def __init__(
100 self,
101 vocab_size: int,
102 d_model: int = 512,
103 num_heads: int = 8,
104 num_layers: int = 12,
105 d_ff: int = 2048,
106 num_experts: int = 8,
107 top_k: int = 2,
108 moe_frequency: int = 2, # Apply MoE every N layers
109 dropout: float = 0.1,
110 max_seq_len: int = 512
111 ):
112 """
113 Initialize MoE Transformer.
114
115 Args:
116 vocab_size: Vocabulary size
117 d_model: Model dimension
118 num_heads: Number of attention heads
119 num_layers: Number of transformer layers
120 d_ff: FFN intermediate dimension
121 num_experts: Number of experts
122 top_k: Experts per token
123 moe_frequency: Apply MoE every N layers
124 dropout: Dropout rate
125 max_seq_len: Maximum sequence length
126 """
127 super().__init__()
128
129 self.d_model = d_model
130 self.num_experts = num_experts
131
132 # Embeddings
133 self.token_embedding = nn.Embedding(vocab_size, d_model)
134 self.position_embedding = nn.Embedding(max_seq_len, d_model)
135 self.dropout = nn.Dropout(dropout)
136
137 # Transformer layers (alternating dense and MoE)
138 self.layers = nn.ModuleList()
139 for i in range(num_layers):
140 use_moe = (i % moe_frequency == moe_frequency - 1) # Every Nth layer
141 self.layers.append(
142 MoETransformerBlock(
143 d_model, num_heads, d_ff, num_experts, top_k, dropout, use_moe
144 )
145 )
146
147 self.norm = nn.LayerNorm(d_model)
148 self.output = nn.Linear(d_model, vocab_size)
149
150 # Auxiliary loss
151 self.aux_loss_fn = MoEAuxiliaryLoss(num_experts, top_k)
152
153 def forward(
154 self,
155 input_ids: torch.Tensor,
156 attention_mask: Optional[torch.Tensor] = None
157 ) -> Tuple[torch.Tensor, torch.Tensor]:
158 """
159 Forward pass.
160
161 Args:
162 input_ids: [batch_size, seq_len]
163 attention_mask: Optional mask
164
165 Returns:
166 logits: [batch_size, seq_len, vocab_size]
167 aux_loss: Auxiliary loss for MoE
168 """
169 batch_size, seq_len = input_ids.shape
170
171 # Embeddings
172 positions = torch.arange(seq_len, device=input_ids.device)
173 x = self.token_embedding(input_ids) + self.position_embedding(positions)
174 x = self.dropout(x)
175
176 # Collect router logits for aux loss
177 all_router_logits = []
178
179 # Transformer layers
180 for layer in self.layers:
181 x, router_logits = layer(x, attention_mask)
182 if router_logits is not None:
183 all_router_logits.append(router_logits)
184
185 x = self.norm(x)
186 logits = self.output(x)
187
188 # Compute auxiliary loss
189 if all_router_logits:
190 # Concatenate and compute loss
191 combined_logits = torch.cat(all_router_logits, dim=0)
192 aux_loss, _ = self.aux_loss_fn(combined_logits)
193 else:
194 aux_loss = torch.tensor(0.0, device=input_ids.device)
195
196 return logits, aux_lossSummary
MoE Key Points:
| Aspect | Dense Model | MoE Model |
|---|---|---|
| Parameters | All active | Sparse (top-K active) |
| FLOPs per token | Proportional to params | ~Fixed regardless of params |
| Memory (inference) | Full model | Full model |
| Memory (training) | Activations + gradients | Same + routing overhead |
| Scaling | Expensive | Efficient |
Design Decisions:
1. Number of experts: 8-128 (sweet spot: 8-16)
2. Top-K: 1-2 (top-2 most common)
3. MoE frequency: Every layer or every 2nd/4th layer
4. Expert size: Same as dense FFN or smaller
5. Aux loss weight: 0.01-0.1
Exercises:
1. Implement Top-1 routing and compare with Top-2.
2. Visualize expert utilization during training.
3. Train a small MoE model on text generation and compare with dense.
4. Implement Expert Choice routing and compare load balance.
5. Measure inference latency of MoE vs dense with same active parameters.
Next Section: In Section 3, we'll explore modern position encoding techniques including Rotary Position Embeddings (RoPE) and ALiBi.