Chapter 5
25 min read
Section 27 of 117

The Routing Mechanism

Mixture-of-Experts: DeepSeekMoE

In the last section we treated the router as a single line of math: score every expert, take the top kk, softmax the survivors. That equation is true — and dangerously incomplete. The moment you try to actually train an MoE model with it, the router becomes the most fragile component in the entire system. Tokens collide on the same expert; gradients refuse to flow through a discrete argmax; logits drift to infinity; the wrong expert wins for the wrong reason and the model never recovers.

This section opens the router and looks at the engineering that makes sparse routing trainable in practice. By the end you should be able to read the routing code of GShard, Switch Transformer, or DeepSeek-V3 and recognise every line as the answer to a specific failure mode.

The thesis of this section. A router is not a softmax. A router is four hacks stacked on top of a softmax, each one fixing a problem that nearly killed MoE training before someone solved it.

Four Hard Problems Inside the Router

Look at the basic top-kk rule from section 1: gi(x)=softmax(Wrx)ig_i(x) = \mathrm{softmax}(W_r x)_i restricted to the top kk entries. Read it carefully and four problems jump out, none of them obvious from the math:

ProblemWhat goes wrongThe fix this section covers
Non-differentiable top-kargmax has zero gradient. The router can't learn from which experts it picked.Route gradients through the gate values, not the indices.
Capacity overflowSome experts get hammered, GPU tensors must be rectangular, excess tokens have no slot.Capacity factor + token dropping + residual fallback.
Premature confidenceAn untrained router locks onto a few experts and the others never train (routing collapse).Learnable Gaussian noise on the logits at training time.
Logit driftOver long training runs router logits grow to ±∞, softmax saturates, gradients vanish.Router z-loss — a tiny penalty on the softmax log-denominator.

Each row above is a section below. None of them are theoretical niceties: every production MoE codebase you will read implements every single one, and removing any of them in an ablation will visibly hurt loss in the first 1B training tokens.

How Gradients Cross the Top-k Wall

Backprop needs derivatives. Top-kk is a discrete selector — it returns a set of indices, not a smooth function of xx. Move any router input by ε and the chosen set either does not change (gradient = 0) or jumps to a new set (gradient = undefined). On paper, the router cannot learn.

The trick is to notice that top-kk returns two things: indices and values. The indices are not differentiable, but the values are. So we route the gradient through the values. Concretely, the MoE output is y=iTgi(x)Ei(x)y = \sum_{i \in \mathcal{T}} g_i(x) \, E_i(x), and the gate gi(x)g_i(x) is a smooth function of WrxW_r x. Differentiating with respect to a router weight Wr,ijW_{r,ij}:

yWr,ij=iTgiWr,ijEi(x)\frac{\partial y}{\partial W_{r,ij}} = \sum_{i \in \mathcal{T}} \frac{\partial g_i}{\partial W_{r,ij}} \, E_i(x). The sum runs over the currently chosen experts. If expert ii was not picked, its gate is zero and gi/Wr,ij=0\partial g_i / \partial W_{r,ij} = 0 too — the router gets no signal about that expert this step. The router only learns about experts it actually tried.

The deep consequence. The router's gradient is a bandit-style signal: it can only update the experts it sampled, and it has to figure out the others by exploration. That is exactly why the noise from section 4 below is not optional — without it, every unsampled expert is forever invisible to the router.

Why not just use Gumbel-softmax or REINFORCE?

Both have been tried. The Gumbel-softmax relaxation replaces top-kk with a soft sample that has a smooth gradient everywhere, but it adds compute (you have to evaluate all experts during training to use the soft weights) and breaks the sparsity gain that motivated MoE in the first place. REINFORCE-style policy gradients work but have huge variance at the scale of billion-token training runs. Top-kk + softmax-on-survivors won in practice because it is cheap, sparse, and good enough — the price you pay is needing the other three hacks in this section to make it stable.

Expert Capacity and Dropped Tokens

GPUs do not like ragged tensors. To dispatch tokens to experts in parallel, every expert must receive the same number of tokens — or at least a number known in advance so the dispatch buffer is rectangular. But routing is data-dependent: there is no guarantee that the 1024 tokens in a batch will split evenly across 8 experts. Some experts get 300 tokens, others get 50, and the matmul shapes refuse to cooperate.

The solution is brutal and pragmatic: declare an expert capacity up front and drop anything that overflows. The capacity is cap=CNkE\mathrm{cap} = \big\lceil C \cdot \frac{N \cdot k}{E} \big\rceil, where NN is the number of tokens in the batch, kk is top-k, EE is the number of experts, and CC is the capacity factor — typically 1.0 to 2.0.

Read the formula slowly. Nk/EN k / E is the perfectly balanced share: if routing were uniformly random, every expert would receive exactly this many tokens. CC is the headroom we grant to absorb imbalance. C=1.0C = 1.0 means no headroom — any imbalance causes drops. C=2.0C = 2.0 means every expert can absorb twice the fair share — almost no drops, but you have allocated 2× the dispatch memory.

Dropped tokens are not zero. When a token is dropped at the MoE layer, its FFN output is zero for that block, but the residual connection still carries the input through: xout=x+0=xx_{\mathrm{out}} = x + 0 = x. The model sees the token, the gradient still flows through the residual, but this block contributed nothing to its representation. Small drop rates (1-5%) are tolerable. Above 10%, the model starts to noticeably underfit.

The capacity dial in real systems

SystemCapacity factor CNotes
Switch Transformer (Fedus 2021)1.0Aggressive: 'drop or die'. Required strong balance loss.
GShard (Lepikhin 2020)1.25Standard default. Small headroom, modest balance loss.
DeepSeek-V2 / V3no fixed cap; uses device-level capBias-term load balancing makes drops rare; see ch.6.
Mixtral 8×7B (inference)implicit, k=2 onlyInference does not need fixed shapes — drops disappear.

Notice the trend: newer systems push the capacity factor down or remove it entirely, relying instead on better load-balancing mechanisms to keep routing roughly uniform. The capacity factor is a crutch; the next chapter is about throwing the crutch away.

Manual Numerical Walkthrough

Let us trace a tiny example end-to-end. 6 tokens, 3 experts, top-1 routing, capacity factor C=1.0C = 1.0. We will compute logits, pick experts, hit overflow, and watch a token get dropped.

Click to expand: 6 tokens, 3 experts, k = 1, by hand

Setup. Capacity =1.061/3=2\lceil 1.0 \cdot 6 \cdot 1 / 3 \rceil = 2. Each expert may hold at most 2 tokens. We have:

Router logits per token (rows = tokens, cols = experts):

        E1     E2     E3
t0:   [ 2.1,  0.4,  0.7]   → argmax = E1
t1:   [ 1.8,  0.6,  0.2]   → argmax = E1
t2:   [ 2.4,  0.9,  0.5]   → argmax = E1   ← problem
t3:   [ 0.1,  1.9,  0.5]   → argmax = E2
t4:   [ 0.3,  0.4,  2.2]   → argmax = E3
t5:   [ 0.6,  2.0,  0.9]   → argmax = E2

Greedy assignment (batch order).

  • t0 → E1 (E1 count 1/2) ✓
  • t1 → E1 (E1 count 2/2) ✓ — E1 is now full
  • t2 → E1 wanted, but E1 is full → DROPPED
  • t3 → E2 (E2 count 1/2) ✓
  • t4 → E3 (E3 count 1/2) ✓
  • t5 → E2 (E2 count 2/2) ✓

Final counts: E1 = 2, E2 = 2, E3 = 1. Drop rate = 1/6 ≈ 17%. Notice that E3 has a free slot — token t2 could have been routed there with negligible loss, but the deterministic top-1 rule never considered E3 for t2.

What happens to t2 in the forward pass? Its gate is multiplied by the assignment mask (0), so its MoE output is zero. The residual connection still passes the token through, so the next transformer block sees the unmodified hidden state. The model is not broken — it just lost one MoE block of representation work for this token.

What happens during backprop? Because t2 has gmask=0g \cdot \mathrm{mask} = 0, the gradient through that gate is also zero. The router gets no signal from this token. This is one reason drops hurt: they consume training data but produce no gradient for the router.

Visualising Capacity Overflow

The widget below makes the trade-off tangible. Slide the capacity factor down and watch tokens get dropped (rose-red, struck through); slide it up and watch GPU utilisation collapse as unpopular experts sit half-empty. Hit New batch to reroll the router scores and see how the dynamics shift when the popularity skew changes.

Loading capacity visualizer…

Two patterns are worth pausing on. First, at C = 1.0 you almost always drop something — the perfectly balanced assignment is a knife-edge that random batches never hit. Second, dropping is bursty: it concentrates on whichever expert the batch happens to over-favour, which is exactly the signature of an unbalanced router. The whole point of the next chapter (auxiliary-loss-free balancing) is to make the popularity bars all the same height so this widget never has to drop anything.

Noisy Top-k: Why Routers Need Randomness

Imagine you start training with a freshly initialised router. The logits are tiny, near-random. Token xx sees, say, s=[0.31,0.29,0.28,0.30]s = [0.31, 0.29, 0.28, 0.30]. Expert 1 wins by a hair. Now backprop: expert 1 produced something useful, so its gate goes up, its logit gets a small positive push. Next forward pass with the same token, the gap widens slightly. Within a few thousand steps every token is routed to expert 1. The other experts never see traffic, never receive gradient, and the model is dead.

This is routing collapse in its purest form, and it is the reason naive top-kk does not work. The fix from Shazeer et al. (2017) is to add learnable Gaussian noise to the logits before the top-kk selection:

s~i=Wrxei+softplus(Wnxei)εi,εiN(0,1)\tilde{s}_i = W_r x \cdot e_i + \mathrm{softplus}(W_n x \cdot e_i) \cdot \varepsilon_i, \quad \varepsilon_i \sim \mathcal{N}(0, 1)

Read each piece. WrxW_r x is the usual router logit. WnxW_n x is a second linear that produces a per-(token, expert) noise scale; softplus keeps it positive. εi\varepsilon_i is fresh Gaussian noise on every forward pass. The two logits are added: the deterministic signal and the noisy perturbation.

Two experts whose deterministic logits differ by 0.01 might routinely swap rank under noise of std 0.3. Two experts whose deterministic logits differ by 3.0 will almost never swap. The router learns to sharpen WnxW_n x down once it is confident, and leave it high in regions where the right expert is genuinely ambiguous. The result: every expert sees traffic for long enough to train, but the router still ends up deterministic once it knows what it is doing.

The connection to exploration / exploitation. A router is a bandit. Noise = exploration. As training proceeds, learned confidence reduces noise and the router becomes greedier. This is the same trade-off as ε-greedy in reinforcement learning, just baked into the architecture.

Switch Transformer's alternative: jitter

Switch Transformer (Fedus 2021) dropped the noise network and replaced it with input-multiplicative jitter: x=xu,  uUniform(1ϵ,1+ϵ)x' = x \cdot u, \; u \sim \mathrm{Uniform}(1 - \epsilon, 1 + \epsilon). Same effect at lower cost. DeepSeek-V3 dropped both — it relies on its bias-term load balancer (Chapter 6) to do the exploration work, proving that with the right balance signal you can route deterministically without collapsing.

Sigmoid vs Softmax Gates (DeepSeek-V3)

Every routing recipe up to GShard used softmax over experts: gi(x)=esi/jesjg_i(x) = e^{s_i} / \sum_j e^{s_j}. The softmax is the natural choice when you want a probability distribution over experts. But it has a subtle problem at large EE: the gradient through softmax is a normalised quantity. Boosting one expert's logit necessarily suppresses the others. The router has to fight against the normalisation to learn fine-grained preferences.

DeepSeek-V3 broke from tradition and used a sigmoid gate per expert: gi(x)=σ(si)=1/(1+esi)g_i(x) = \sigma(s_i) = 1 / (1 + e^{-s_i}). Each expert is gated independently. The gates do not sum to 1; they are then top-kk-selected and re-normalised (divided by their sum) just before being applied. The math at output time looks similar; the math at gradient time is very different — every expert's logit can move independently, which empirically makes the router faster to converge with 64+ experts.

PropertySoftmax gateSigmoid gate
Sum over experts before top-k= 1free
Per-expert gradient couplinghigh (zero-sum)low (independent)
Behaviour at large Ediffuse, slow to sharpensharper signal per expert
Use in DeepSeek-V3 (E = 256)default
Use in Mixtral (E = 8)default
Rule of thumb. Softmax works well up to ~16 experts. Past that, the gradient interference between competing logits starts to slow training and sigmoid begins to look better. DeepSeek-V3's choice of sigmoid is one of the small architectural changes that lets it run with 256 experts where earlier systems plateaued at 64.

Router Z-Loss: Keeping Logits Sane

Train an MoE model for long enough and the router logits do something unpleasant: they drift upward in absolute magnitude. sis_i values that started at 0.3 are now 30. The softmax saturates. e30e^{30} overflows in fp16. Even in fp32 the gradient through softmax becomes vanishingly small in the unselected entries, and the router stops learning.

ST-MoE (Zoph 2022) introduced the z-loss to suppress this drift. The idea: the log-denominator of the softmax, logjesj\log \sum_j e^{s_j}, equals zero only when all logits sum (in the log-sum-exp sense) to one. Penalise its square and you are penalising the router for using unnecessarily large logit magnitudes:

Lz=λz1Nn=1N(logj=1Eesn,j)2\mathcal{L}_z = \lambda_z \cdot \frac{1}{N} \sum_{n=1}^{N} \Big( \log \sum_{j=1}^{E} e^{s_{n,j}} \Big)^2

λz\lambda_z is tiny — typically 10310^{-3}. The loss is added to the main cross-entropy. Two beautiful properties: (1) it is invariant to adding a constant to all logits (a no-op for softmax anyway), so it only penalises magnitude, not direction. (2) Its gradient is small near zero — it does not perturb a healthy router.

Where this lives in code. Almost always computed inside the router's forward pass and returned as a side output, so the training loop can add it to the cross-entropy loss before.backward(). The PyTorch snippet below does exactly this.

Plain Python: Noisy Top-k With Capacity

Here is the entire routing mechanism — noisy top-k, gate softmax, capacity-respecting assignment — in pure NumPy, small enough to read in one sitting:

🐍router_numpy.py
3Reproducible randomness

We fix a seed so every walkthrough lands on the same numbers. In real training the noise must be different on every step or the router learns nothing from it.

4The toy world

16 tokens, model dim 4, four experts, top-1 routing. Tiny enough to print, big enough to overflow at least one expert when the capacity factor drops.

EXECUTION STATE
N = 16
D = 4
E = 4
k = 1
5Capacity factor C

Each expert can hold ceil(C · N · k / E) tokens. With C = 1.25 that is 5 slots per expert in our batch — 25% headroom over the perfectly balanced share of 4. Smaller C = more drops, larger C = wasted compute.

7Input batch

A real model would have shape (batch, seq_len, D). We collapse it to (N, D) — one row per token, exactly what the router sees.

EXECUTION STATE
X.shape = (16, 4)
8Router weights

Single linear layer mapping a token to E logits. Initialised small so the early router is nearly uniform — this matters: a confident untrained router will collapse onto a few experts and never learn.

EXECUTION STATE
W_r.shape =
(4, 4)
9Noise scale network

A separate linear that decides how much noise to add per (token, expert). The router has to learn not only which expert to pick but how confident to be — a noisy logit becomes a fair coin flip, a sharp logit becomes a near-deterministic choice.

EXECUTION STATE
W_n.shape =
(4, 4)
14Compute routing logits

X · W_r.T is the cheap matmul that scores every (token, expert) pair. Total cost O(N · D · E) — vanishing compared to running the experts themselves.

EXECUTION STATE
logits.shape = (16, 4)
16Per-expert noise standard deviation

softplus(X · W_n.T) is always positive — a valid σ. Each token gets its own noise budget; learnable so the router can decide when to be confident and when to explore.

17Add Gaussian noise — the exploration step

ε ~ N(0, σ²). Without this, a token always picks the same expert and other experts never see traffic. With it, two tokens with similar logits can route differently, giving every expert a chance to train.

18Pick top-k indices per token

argsort(-logits) gives descending order; we keep the first k. These are the only experts that will actually be evaluated for each token.

EXECUTION STATE
topk_idx.shape = (16, 1)
19Gather the surviving logits

take_along_axis pulls out, for every token, just the k logits we are going to softmax. This is what makes gradients flow: the loss will read these values and push them up or down, and that is how the router learns.

20Stable softmax — the gate weights

Subtract the row max, exponentiate, normalise. Result: k weights per token that sum to 1. These weights — not the discrete index — are the differentiable signal back into the router.

EXECUTION STATE
gates.shape = (16, 1)
24Compute per-expert capacity

capacity = ceil(C · N · k / E). With C=1.25, N=16, k=1, E=4 this is ceil(5.0) = 5. Each expert may accept up to 5 tokens in this batch.

27Greedy first-come-first-served assignment

Walk tokens in batch order. If their chosen expert has room, accept. If full, mark the slot as not assigned — the token will be dropped at this expert. In real implementations tokens are first sorted by priority (e.g. gate value) so the most confident routings survive.

35Count dropped tokens

A token is fully dropped when none of its k chosen experts accepted it. For top-1 that means: my one chosen expert was full. Dropped tokens contribute zero to the MoE output — only the residual connection carries them forward.

25 lines without explanation
1import numpy as np
2
3rng = np.random.default_rng(0)
4N, D, E, k = 16, 4, 4, 1                     # 16 tokens, dim 4, 4 experts, top-1
5capacity_factor = 1.25
6
7X = rng.standard_normal((N, D))              # token batch
8W_r = rng.standard_normal((E, D)) * 0.3      # router
9W_n = rng.standard_normal((E, D)) * 0.3      # noise scale net (per-expert)
10
11def softplus(x):
12    return np.log1p(np.exp(x))
13
14def noisy_topk_route(X, training=True):
15    logits = X @ W_r.T                       # (N, E)
16    if training:
17        noise_std = softplus(X @ W_n.T)      # (N, E), positive
18        logits = logits + rng.standard_normal(logits.shape) * noise_std
19    topk_idx = np.argsort(-logits, axis=-1)[:, :k]   # (N, k)
20    topk_log = np.take_along_axis(logits, topk_idx, axis=-1)
21    z = topk_log - topk_log.max(axis=-1, keepdims=True)
22    gates = np.exp(z) / np.exp(z).sum(axis=-1, keepdims=True)
23    return topk_idx, gates                   # (N, k), (N, k)
24
25def assign_with_capacity(topk_idx):
26    capacity = int(np.ceil(capacity_factor * N * k / E))
27    counts = np.zeros(E, dtype=int)
28    assigned = np.zeros((N, k), dtype=bool)
29    for n in range(N):
30        for slot in range(k):
31            i = topk_idx[n, slot]
32            if counts[i] < capacity:
33                counts[i] += 1
34                assigned[n, slot] = True
35    return assigned, counts, capacity
36
37topk_idx, gates = noisy_topk_route(X)
38assigned, counts, cap = assign_with_capacity(topk_idx)
39dropped = (~assigned.any(axis=-1)).sum()
40print(f"capacity={cap}, counts={counts.tolist()}, dropped={dropped}")

Running this script prints something like capacity=5, counts=[5, 4, 2, 4], dropped=1. One token wanted expert 1 but the bucket was full; with C=1.25C = 1.25 it gets dropped. Bump capacity_factor to 2.0 and the drop disappears, at the cost of 60% of the dispatch slots being wasted on air.

PyTorch: A Production-Style Router

The PyTorch version below adds the two pieces a serious training run needs: a z-loss returned as a side output, and a self.training-gated noise so that inference is fully deterministic. The interface is intentionally narrow: forward(x_flat) returns everything the dispatch layer and the loss layer need.

🐍router_pytorch.py
5Subclass nn.Module, one class for the whole routing decision

Everything that decides which expert runs lives in one place. The downstream MoE layer just consumes (topk_idx, gates) and stays oblivious to noise, capacity, and z-loss.

10Capacity factor as a hyperparameter

Typical values: 1.0 (Switch Transformer, very aggressive), 1.25 (GShard default), 2.0 (safer for noisier domains). Lower = faster but more drops; higher = wasteful but no information loss.

11Z-loss coefficient

Tiny (~1e-3). Just enough to keep router logits from drifting to ±∞ during long training runs. We will see the math two sections down.

12Gate weights — no bias

A single Linear from d_model to num_experts. Bias is omitted on purpose: a constant offset on every expert is equivalent to a no-op after softmax.

13Separate noise network

Learned per-(token, expert) noise scale. Switch Transformer drops this entirely; GShard and DeepSeek keep something equivalent. The cost is one extra small Linear; the benefit is much more stable early-training dynamics.

17Flat token stream

We assume (batch, seq) has already been flattened to (N, D). This is the standard pattern: routing is per-token, so the batch and time axes do not matter to the router.

EXECUTION STATE
x_flat.shape = (N, D)
18Logits per (token, expert)

Result has shape (N, E). Every row is the affinity of one token to every expert.

EXECUTION STATE
logits.shape = (N, E)
20self.training gates the noise

Critical detail: noise is only added during training. At inference the router is fully deterministic so we never route the same prompt to a different expert across calls.

21Softplus → positive σ

Softplus(x) = log(1 + e^x) is smooth, positive, and approaches the identity for large x. It guarantees a valid standard deviation without the hard zero of ReLU.

22Sample Gaussian noise

randn_like generates ε with the same shape as logits. Multiplying by noise_std gives per-element Gaussian noise. The router sees a perturbed view of every logit on every forward pass.

25Z-loss formula

logsumexp(logits) is the log of the softmax denominator. Squaring and averaging penalises any growth in that denominator — equivalently, it penalises large absolute logits without changing the relative ranking.

27Top-k selection — discrete but cheap

torch.topk runs O(N · E) and returns both values and indices. The values flow gradients; the indices are detached integer choices.

EXECUTION STATE
topk_logits.shape = (N, k)
topk_idx.shape = (N, k)
28Softmax over the survivors

Same pattern as section 1 — gates sum to 1 across the k chosen experts. This is where gradient information enters the router: dLoss/dgate flows back through the softmax into the router weights.

31Compute capacity

ceil with a +1 fudge — matches the original GShard implementation, prevents off-by-one on rounding. With N=1024, k=2, E=8, C=1.25 you get capacity ≈ 321 tokens per expert.

35Greedy assignment loop

Iterate over tokens in batch order; accept if the expert has room. In production, this loop is replaced by a parallel cumulative-sum + mask trick that runs in one GPU kernel — but the math is identical.

41Zero out dropped gates

Crucial. A dropped slot must contribute zero to the MoE output AND zero gradient to its expert. Multiplying the gate by the assigned mask achieves both at once: the value drops to zero and so does ∂loss/∂gate for that slot.

42Return everything downstream needs

(topk_idx, gates) drive the expert dispatch. counts is for load-balance auxiliary losses (next chapter). z_loss is added to the main training loss directly.

26 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class Router(nn.Module):
6    """Noisy top-k router with capacity and z-loss, à la GShard / Switch / DeepSeek."""
7
8    def __init__(self, d_model: int, num_experts: int, k: int,
9                 capacity_factor: float = 1.25, z_loss_coef: float = 1e-3):
10        super().__init__()
11        self.E, self.k = num_experts, k
12        self.capacity_factor = capacity_factor
13        self.z_loss_coef = z_loss_coef
14        self.w_gate  = nn.Linear(d_model, num_experts, bias=False)
15        self.w_noise = nn.Linear(d_model, num_experts, bias=False)
16
17    def forward(self, x_flat: torch.Tensor):
18        N = x_flat.size(0)
19        logits = self.w_gate(x_flat)                               # (N, E)
20
21        if self.training:
22            noise_std = F.softplus(self.w_noise(x_flat))           # (N, E)
23            logits = logits + torch.randn_like(logits) * noise_std
24
25        # z-loss penalises large logits — keeps softmax numerically stable.
26        z_loss = self.z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
27
28        topk_logits, topk_idx = logits.topk(self.k, dim=-1)        # (N, k)
29        gates = F.softmax(topk_logits, dim=-1)                     # (N, k)
30
31        # Capacity assignment.
32        capacity = int((self.capacity_factor * N * self.k) / self.E) + 1
33        assigned = torch.zeros_like(topk_idx, dtype=torch.bool)
34        counts   = torch.zeros(self.E, dtype=torch.long, device=x_flat.device)
35        for n in range(N):
36            for slot in range(self.k):
37                e = topk_idx[n, slot].item()
38                if counts[e] < capacity:
39                    counts[e] += 1
40                    assigned[n, slot] = True
41
42        gates = gates * assigned                                   # zero out dropped slots
43        return topk_idx, gates, counts, z_loss

Slot this Router into the MoEFFN from section 1 by replacing the inline router lines with topk_idx, gates, counts, z_loss = self.router(x_flat), and you have the skeleton of a real production MoE layer. The training loop adds z_loss to the cross-entropy and stores counts for the load-balance auxiliary loss in the next chapter.

The one missing piece. The greedy assignment loop is O(N · k) Python — fine for a teaching toy, ruinous on a GPU. Real implementations replace it with a parallel exclusive-cumulative-sum trick that runs in a single CUDA kernel. The math is identical, the cost is invisible. See Megatron-LM's moe directory for a reference implementation.

What Changes at Massive Scale

At toy scale the router is the cheap part. At trillion-parameter scale the router's outputs determine the cost of the most expensive operation in distributed training: the all-to-all token shuffle. Every token chosen for an expert that lives on a different GPU has to be sent over NVLink or InfiniBand. The router's decisions, made one batch at a time, directly set the network volume.

Routing decisionDirect effect on costEngineering knob
How balanced are the expert counts?Sets the size of the slowest expert's bucket → straggler latency.Auxiliary loss or bias-term balancing (ch. 6).
What capacity factor C?Sets all-to-all buffer size: cost = C · N · k · D per device.Tune per cluster; smaller for fast networks, larger for slow.
Where do the chosen experts live?Local experts = free; remote experts = a network hop.Topology-aware expert placement (ch. 5, section 5).
Are routing decisions stable across steps?Stable routing improves cache locality and reduces kernel re-tuning.Lower noise late in training; z-loss to suppress drift.

DeepSeek-V3 is the cleanest illustration of how much of the router's design is dictated by scale. With 256 experts distributed across hundreds of GPUs, the cost difference between a well-balanced router and a poorly balanced one is enormous — not in FLOPs (those are fixed at k/Ek / E) but in network seconds, which dominate the step time. That is why DeepSeek invested in bias-term load balancing instead of relying on auxiliary losses: the balance has to be near-perfect, and the gradient interference of a balance loss was hurting quality.

The Engineering Reality of Routing

Three patterns recur in every MoE codebase that actually runs at scale, and they are worth knowing before you try to implement one:

  1. The router is the first thing to monitor. Track per- expert token counts every step; track router logit magnitude every 100 steps; track drop rate as a single scalar in your training dashboard. The first sign of trouble in an MoE run is almost always visible in routing statistics 1000+ steps before the loss reflects it.
  2. Routing is fp32 even in mixed-precision training. The router's softmax is one of the few places where fp16 overflow has historically been a real bug; everyone — DeepSeek, GShard, Switch — promotes the router computation to fp32 even when the experts run in bf16 or fp8. The extra cost is rounding error in the total step time.
  3. Inference uses a different router than training. At serve time you disable noise, disable z-loss (irrelevant), disable capacity dropping (no fixed shape needed), and may switch from top-k to a tighter top-1 to reduce latency. The router class should take a training: bool flag and behave differently in each mode.
The bridge to the next chapter. Everything in this section assumes the router can be made balanced — by capacity caps, noise, sigmoid gates, z-loss. None of those forces the router to route evenly. They just make uneven routing less catastrophic. The next chapter (Chapter 6) attacks the imbalance directly: first with auxiliary losses (and why they hurt quality), then with DeepSeek's elegant bias-term solution that removes the trade-off entirely.

The one sentence to carry forward: a router is a softmax wrapped in four corrections — gradient routing through values, capacity with drop, noise for exploration, z-loss for stability — and removing any one of them breaks MoE training. Every line of a production router you ever read is one of those four corrections in code.

Loading comments...