Chapter 6
20 min read
Section 33 of 117

Auxiliary Loss Approaches and Their Cost

Auxiliary-Loss-Free Load Balancing

The previous section left us with a system on the edge of collapse: every MoE layer wants to route tokens to a handful of experts, and the unrouted experts learn nothing, contribute nothing, and waste their share of the cluster. For four years between GShard (2020) and DeepSeek V3 (2024), every MoE in production reached for the same fix — an extra term in the loss function whose job was to spread tokens across experts. This section is the careful, quantitative story of that fix: why it works, why it is universal, and why every team that scaled it eventually noticed the same uncomfortable side effect — the model trained on a joint objective is measurably worse at the task it was actually built to do.

The bargain. Auxiliary balancing loss buys MoE feasibility — without it, expert parallelism degenerates into one busy GPU and seven idle ones. The bill comes back as a quality tax that compounds across 58 MoE layers and 14.8 trillion training tokens. That tax, and the search for a way out of it, is what makes Chapter 6 exist.

The Problem Auxiliary Loss Was Invented to Solve

Section 6.1 showed routing collapse as a dynamical instability: a slightly favored expert receives more tokens, its parameters improve faster, it gets more favored, and a positive feedback loop locks the router into a degenerate distribution. The pure language-modeling loss has no reason to prevent this. Cross-entropy only cares whether the next token is predicted; if expert 17 can predict the next token alone, the loss is happy and the other 255 experts can sit idle for the rest of training.

From the data-flow side, the consequences are worse than just wasted parameters. Recall from §5.5 that every MoE layer dispatches tokens to their chosen experts via an all-to-all collective with a fixed capacity per expert (the receive buffer size). Tokens that overflow a popular expert's buffer get dropped — they skip the MoE block entirely. So a router collapse causes two simultaneous failures: a quality failure (most experts contribute nothing) and a throughput failure (popular experts overflow and drop their excess). The system is not just suboptimal; it actively destroys gradient information on the dropped tokens.

What we need from the fix. A differentiable signal, added to the loss, that grows when routing is imbalanced and shrinks when it is uniform — and is small enough at the optimum that the model can still learn the task. GShard's authors (Lepikhin et al., 2020) proposed exactly this; Switch Transformer (Fedus et al., 2021) reused it almost unchanged; every major MoE between them did the same. Four years of practice, one expression.

Intuition: A Penalty That Punishes Popularity

Picture a school with eight cafeteria lines, and a thousand students with strong opinions about which line they prefer. Left alone, they all pile into the same two lines, lines 1 and 3 overflow, the other six stand empty. The cafeteria manager has a choice. She could force students into specific lines by assignment — but that ignores their preferences and produces grumpy diners. Or she could put a small penalty on the popular lines: a one-second wait per extra student in front of you. Now student preference still drives most of the decision, but at the margin, when two lines are nearly equal, the lightly used line wins.

That penalty is what an auxiliary loss does. The router still picks based on token content. But the loss adds a small, differentiable nudge: the more concentrated your routing distribution, the higher the penalty. At equilibrium, the router learns to break content ties in favor of underutilized experts. The popular lines stay popular; the empty lines start to fill up.

The price the cafeteria manager pays is that a student who really wanted line 1 sometimes ends up in line 4. The price the MoE layer pays is that a token whose content genuinely best matches expert 17 sometimes lands at expert 42. We will quantify that price below — but the intuition to anchor is: auxiliary loss does not steer the router; it perturbs it. And every perturbation is, by definition, a deviation from what the task gradient was asking for.

The Math: GShard and Switch Auxiliary Loss

Let TT be the number of tokens in the current batch and EE the number of experts. The router produces a probability matrix PRT×EP \in \mathbb{R}^{T \times E} where Pt,iP_{t,i} is the softmax probability that token tt is routed to expert ii. From this matrix we extract two summary statistics per expert.

First, the average routing probability assigned to expert ii across the batch: Pˉi=1Tt=1TPt,i\bar P_i = \frac{1}{T} \sum_{t=1}^{T} P_{t,i}. This is differentiable with respect to the router weights — it is just an average of softmax outputs.

Second, the actual fraction of tokens whose top-1 pick was expert ii: fi=1Tt=1T1 ⁣[argmaxjPt,j=i]f_i = \frac{1}{T} \sum_{t=1}^{T} \mathbb{1}\!\left[\arg\max_j P_{t,j} = i\right]. This is not differentiable — argmax has zero gradient almost everywhere — but it serves as a faithful magnitude: if expert ii grabbed half the batch, then fi=0.5f_i = 0.5 regardless of how the softmax was shaped.

The GShard auxiliary loss combines them: Laux=Ei=1EfiPˉi\mathcal{L}_{\text{aux}} = E \cdot \sum_{i=1}^{E} f_i \cdot \bar P_i. The factor EE is the rescaling that pins the minimum at 1.0 — at perfectly uniform routing, fi=Pˉi=1/Ef_i = \bar P_i = 1/E, so Laux=EE(1/E)2=1\mathcal{L}_{\text{aux}} = E \cdot E \cdot (1/E)^2 = 1. Any deviation pushes the value above 1, and the quadratic-like coupling between ff and Pˉ\bar P means concentrated routing is penalized super-linearly.

The full training objective is L=Ltask+αLaux\mathcal{L} = \mathcal{L}_{\text{task}} + \alpha \cdot \mathcal{L}_{\text{aux}}, where α\alpha is a tuned hyperparameter — Switch Transformer used α=102\alpha = 10^{-2}, GShard reported similar values. The gradient that reaches the router weights is the sum of the gradient from the language modeling loss and a small gradient from Laux\mathcal{L}_{\text{aux}} that pushes Pˉi\bar P_i downward for popular experts and upward for unpopular ones.

The Switch Transformer variation

Switch Transformer published essentially the same formula but with a slightly different scaling argument and an emphasis on top-1 routing only. The mathematical content is identical: LauxSwitch=αEifiPˉi\mathcal{L}_{\text{aux}}^{\text{Switch}} = \alpha \cdot E \cdot \sum_{i} f_i \cdot \bar P_i. Where Switch and GShard genuinely differ is in their treatment of the capacity factor (Switch uses 1.0 with token-dropping, GShard allowed higher capacity with re-routing), but as a balancing mechanism the loss is one thing, not two.

Why combine f and P̄ at all? One might ask why we don't just penalize the KL divergence of Pˉ\bar P to uniform — that is differentiable and obviously balance-encouraging. The answer is that Pˉ\bar P alone is a weak signal: the router can keep it close to uniform while still concentrating all its top-1 picks on one expert, by spreading the runner-up probability across the others. Multiplying by ff forces the loss to track actual dispatches, not just soft mass.

Manual Numerical Walkthrough

Let's compute Laux\mathcal{L}_{\text{aux}} for one tiny batch by hand. The numbers are small enough to follow on paper.

Click to expand: GShard aux loss on 8 tokens / 4 experts

Setup. Eight tokens, four experts, top-1 routing. The router has produced the following probability matrix (rows sum to 1, columns are experts E0..E3):

tokenE0E1E2E3
t00.600.100.200.10
t10.550.050.300.10
t20.150.100.650.10
t30.500.100.300.10
t40.200.100.600.10
t50.450.100.350.10
t60.100.100.700.10
t70.100.600.200.10

Step 1: column averages — Pˉi\bar P_i. Add each column and divide by 8.

  • Pˉ0=(0.60+0.55+0.15+0.50+0.20+0.45+0.10+0.10)/8=2.65/80.331\bar P_0 = (0.60+0.55+0.15+0.50+0.20+0.45+0.10+0.10)/8 = 2.65/8 \approx 0.331
  • Pˉ1=(0.10+0.05+0.10+0.10+0.10+0.10+0.10+0.60)/8=1.25/80.156\bar P_1 = (0.10+0.05+0.10+0.10+0.10+0.10+0.10+0.60)/8 = 1.25/8 \approx 0.156
  • Pˉ2=(0.20+0.30+0.65+0.30+0.60+0.35+0.70+0.20)/8=3.30/80.413\bar P_2 = (0.20+0.30+0.65+0.30+0.60+0.35+0.70+0.20)/8 = 3.30/8 \approx 0.413
  • Pˉ3=(0.10+0.10+0.10+0.10+0.10+0.10+0.10+0.10)/8=0.80/8=0.100\bar P_3 = (0.10+0.10+0.10+0.10+0.10+0.10+0.10+0.10)/8 = 0.80/8 = 0.100

Sanity: iPˉi=0.331+0.156+0.413+0.100=1.000\sum_i \bar P_i = 0.331 + 0.156 + 0.413 + 0.100 = 1.000. Good. Expert 2 is the most popular by soft mass; expert 3 the least.

Step 2: top-1 fractions — fif_i. Take the argmax of each row.

  • Top-1 picks: t0→E0, t1→E0, t2→E2, t3→E0, t4→E2, t5→E0, t6→E2, t7→E1.
  • f0=4/8=0.500f_0 = 4/8 = 0.500 (four tokens picked E0)
  • f1=1/8=0.125f_1 = 1/8 = 0.125
  • f2=3/8=0.375f_2 = 3/8 = 0.375
  • f3=0/8=0.000f_3 = 0/8 = 0.000 (E3 was never anyone's favorite)

Step 3: assemble the loss. Laux=EifiPˉi=4(0.5000.331+0.1250.156+0.3750.413+00.100)\mathcal{L}_{\text{aux}} = E \cdot \sum_i f_i \bar P_i = 4 \cdot (0.500 \cdot 0.331 + 0.125 \cdot 0.156 + 0.375 \cdot 0.413 + 0 \cdot 0.100).

  • 0.5000.331=0.16550.500 \cdot 0.331 = 0.1655
  • 0.1250.156=0.01950.125 \cdot 0.156 = 0.0195
  • 0.3750.413=0.15490.375 \cdot 0.413 = 0.1549
  • 00.100=00 \cdot 0.100 = 0
  • Sum: 0.1655+0.0195+0.1549+0=0.33990.1655 + 0.0195 + 0.1549 + 0 = 0.3399
  • Laux=40.3399=1.3596\mathcal{L}_{\text{aux}} = 4 \cdot 0.3399 = 1.3596

The result, 1.36\approx 1.36, is above the minimum of 1.0, confirming this batch is imbalanced. With α=0.01\alpha = 0.01, this aux loss contributes 0.011.36=0.01360.01 \cdot 1.36 = 0.0136 to the total loss. Tiny in absolute terms, but its gradient is specifically targeting the router weights — a small absolute number applied to a tiny subset of parameters can be a relatively large update on those parameters.

The takeaway. The minimum (1.0) is reached only at perfect balance, and the value grows in two ways at once: through ff (whose imbalance is binary at the top-1 level) and through Pˉ\bar P (whose imbalance is continuous). When both align — popular by soft mass and by argmax — the loss penalizes the router heavily. That double-pressure is the whole point.

Visualizing the Tension

Slide the α\alpha control below. At α=0\alpha = 0 the router freely follows content and the routing distribution is wildly imbalanced — that is the collapse from §6.1. As α\alpha grows, the bars equalize toward the uniform marker, but watch the red curve on the right: the quality penalty rises with every step of balance you buy.

Loading aux-loss tension visualizer…

Two things to notice. First, the relationship is not linear: small α\alpha buys a lot of balance for almost no quality cost (the blue curve drops fast while the red curve stays near zero), but past a certain α\alpha the quality penalty accelerates while the marginal balance improvement flattens out. This concavity is what makes auxiliary loss useful in practice — there is a sweet spot. Second, the two curves never both reach the floor. No setting of α\alpha gives perfect balance and zero quality penalty simultaneously. The whole reason §6.3's bias-term approach exists is to find a mechanism that can deliver both.

The visualizer above measures balance and quality as if the cluster had infinite capacity. Real systems do not. Each expert's receive buffer is sized at compile time as capacity=CTk/E\text{capacity} = C \cdot T \cdot k / E where CC is the capacity factor (typically between 1.0 and 1.5). Tokens whose expert is already full are dropped — they bypass the MoE block entirely. The 16-token capacity simulator below shows the consequence directly.

Loading router capacity visualizer…

Slide capacity down. Notice how, even with the same router decisions, the number of dropped tokens climbs immediately as soon as one expert's slot count is exceeded. Slide capacity up. Drops vanish, but most of the slot grid is wasted — GPU utilization tanks. The balancing loss is what keeps you off both horns of this trade-off: at balanced routing, every expert's bucket is roughly the same size, so a capacity factor of 1.0 is enough to absorb the load without dropping. Without balancing, no realistic capacity factor (short of replicating every expert) can keep up.

The system-level argument for the loss. Even if you believed the quality cost was zero, you would still need a balancing mechanism, because expert parallelism with a bounded capacity factor demands it. The choice is not between "use the loss" and "skip the loss"; it is between "use the auxiliary loss with its quality cost" and "find a non-loss mechanism that delivers the same balance." That second option is what Chapter 6 is building toward.

Plain Python: The GShard Loss From Scratch

Before wiring this into a training loop, let's implement the loss with nothing but NumPy. Every step you read should match the hand-walkthrough above, with the same numbers landing in the same places.

🐍gshard_aux_numpy.py
3Batch shape

T = 12 tokens, E = 4 experts, top-1 routing for clarity (the math generalizes to top-k by multiplying f and P by k). These are the same names used in the GShard paper.

EXECUTION STATE
T = 12
E = 4
k = 1
8Hand-crafted imbalanced logits

Each row is a token's affinity to the 4 experts. We engineered the batch so that experts 0 and 2 win most argmaxes — exactly the collapse scenario §6.1 worried about. This lets the aux loss have something visible to fight.

EXECUTION STATE
logits.shape = (12, 4)
25Soft routing probabilities

Softmax over experts converts logits to a (T, E) probability matrix. P[t, i] is the differentiable signal: changing logits[t, i] continuously moves P[t, i]. This is the matrix the aux loss will reach into to modify gradients.

EXECUTION STATE
P.shape = (12, 4)
31mean_P — the differentiable load proxy

Averaging P across the batch gives, for each expert, the fraction of probability mass it received. Crucially this is differentiable w.r.t. the router logits — argmax is not. This is the term the gradient actually flows through.

EXAMPLE
If the router were perfect, mean_P would equal [0.25, 0.25, 0.25, 0.25]. Anything else means imbalance.
EXECUTION STATE
mean_P.shape = (4,)
34Top-1 hard assignment

Argmax across experts picks the chosen expert per token. This is the discrete event that actually happens at routing time — what gets dispatched to which device.

EXECUTION STATE
top1.shape = (12,)
36fraction_f — actual dispatch counts as fractions

one_hot encodes the dispatch; averaging over tokens yields the fraction of the batch each expert received. f_i is NOT differentiable (argmax has zero gradient almost everywhere), but it provides a faithful scale for the loss — it tells us by how much each expert is over- or under-loaded.

EXECUTION STATE
fraction_f.shape = (4,)
39The GShard auxiliary loss

L_aux = E · Σ_i f_i · P_i. The non-differentiable f provides amplitude; the differentiable P carries gradient. The factor E rescales so the minimum (perfect balance) is exactly 1.0, regardless of E. This is the single most-cited expression in MoE training.

EXAMPLE
At perfect balance: f = P = (1/E, 1/E, ..., 1/E). Then L_aux = E · E · (1/E)² = 1.
42Inspection

For the imbalanced batch above you should see mean_P concentrated near experts 0 and 2, fraction_f even more concentrated, and L_aux > 1 (it grows quadratically with how skewed the routing is).

EXECUTION STATE
L_aux (perfect balance) = 1.000
L_aux (this batch) = ≈ 1.4–1.6
39 lines without explanation
1import numpy as np
2
3# Toy MoE: 12 tokens, 4 experts, top-1 routing.
4T, E, k = 12, 4, 1
5rng = np.random.default_rng(0)
6
7# Router logits per token (T, E). In real code these come from x @ W_router.T.
8# We hand-craft an imbalanced batch: experts 0 and 2 are favored.
9logits = np.array([
10    [2.0, 0.1, 1.5, 0.2],   # picks 0
11    [1.8, 0.0, 1.0, 0.4],   # picks 0
12    [0.3, 0.2, 2.4, 0.1],   # picks 2
13    [2.1, 0.0, 1.0, 0.0],   # picks 0
14    [0.1, 0.3, 2.3, 0.0],   # picks 2
15    [0.2, 0.4, 2.0, 0.1],   # picks 2
16    [2.4, 0.1, 0.5, 0.2],   # picks 0
17    [0.0, 0.3, 2.2, 0.1],   # picks 2
18    [1.9, 0.4, 0.5, 0.6],   # picks 0
19    [0.1, 0.7, 1.8, 0.2],   # picks 2
20    [0.2, 0.2, 0.4, 1.3],   # picks 3
21    [0.3, 1.4, 0.1, 0.2],   # picks 1
22])
23
24# 1. Soft routing probabilities (the "P" tensor in GShard / Switch).
25#    P[t, i] = probability the router would route token t to expert i.
26def softmax(x, axis=-1):
27    z = x - x.max(axis=axis, keepdims=True)
28    e = np.exp(z)
29    return e / e.sum(axis=axis, keepdims=True)
30
31P = softmax(logits, axis=-1)                 # (T, E)
32
33# 2. Mean probability per expert across the batch — the "P_i" in the loss.
34mean_P = P.mean(axis=0)                      # (E,)
35
36# 3. Actual top-1 assignment counts as fractions — the "f_i".
37top1 = P.argmax(axis=-1)                     # (T,)
38one_hot = np.eye(E)[top1]                    # (T, E)
39fraction_f = one_hot.mean(axis=0)            # (E,)
40
41# 4. GShard / Switch auxiliary loss.
42L_aux = E * np.sum(fraction_f * mean_P)
43
44# 5. Inspect each piece.
45print("mean_P     :", np.round(mean_P, 3))     # should NOT be uniform
46print("fraction_f :", np.round(fraction_f, 3)) # discrete, biased the same way
47print("L_aux      :", round(float(L_aux), 4))  # min = 1.0 (perfect balance)

The whole loss is six lines: softmax, mean across the batch dim, argmax + one-hot for ff, then the dot product. Nothing about it is sophisticated. The sophistication is in appreciating which lines have a gradient and which do not — that is the ingredient PyTorch turns into a training signal.

The minimum-is-one trick. Why does the formula multiply by EE? Because ifi=iPˉi=1\sum_i f_i = \sum_i \bar P_i = 1 always, so by Cauchy-Schwarz the sum ifiPˉi\sum_i f_i \bar P_i is minimized at 1/E1/E (when both are uniform) and bounded above by 1 (when both concentrate on the same expert). Multiplying by EE gives a clean, scale-invariant loss in [1,E][1, E], with the lower bound at exactly 1 regardless of EE.

PyTorch: Wiring the Loss Into a Training Step

Production MoE modules return their aux loss alongside the activation output. The trainer is responsible for adding it to the task loss with the right coefficient and calling backward on the sum. Here is the cleanest possible end-to-end version.

🐍moe_with_aux_loss.py
6aux_alpha is now part of the model

Why store α on the module? Because every MoE layer in DeepSeek-V3, Mixtral, and Switch has its own aux loss, and trainers want to anneal or zero out α independently per layer. Making it a module attribute keeps the wiring local.

19Flatten (B, S, D) → (T, D)

MoE works at the token level; sequence position is irrelevant for routing. We collapse the batch and sequence dims so T = B*S. Every per-token quantity below — P, f, P_mean — is computed over this flattened batch.

EXECUTION STATE
x_flat.shape = (T, D)
T = B*S = 64
23Soft router output P

Same softmax as the NumPy version. P is the matrix the aux loss reaches into for its differentiable signal.

EXECUTION STATE
P.shape = (T, E)
27Top-k gating

topw is the (T, k) tensor of the top-k routing weights, normalized so they sum to 1 per token. topi gives the chosen expert ids. This is the dispatch decision — and what determines f below.

EXECUTION STATE
topi.shape = (T, k)
32f wrapped in no_grad

f comes from argmax, which has zero gradient. Wrapping in torch.no_grad makes that explicit and tells autograd not to bother building a graph for this branch. Practically f is just a scalar multiplier on P_mean inside the loss.

EXECUTION STATE
f.shape = (E,)
36P_mean is the gradient highway

This is where the aux loss earns its keep. P_mean[i] is differentiable w.r.t. every router weight, so reducing aux_loss applies a gradient that lowers P[t, i] for the over-loaded expert and raises it for the under-loaded one. That is the load-balancing signal.

37aux_loss formula

E · Σ f_i · P_i. With the scaling factor E, the loss is bounded below by 1.0 at perfect balance. Anything above 1.0 is the 'penalty' the optimizer will try to remove.

53task_loss is the only thing the user cares about

This is the language modeling cross-entropy (or, here, an MSE proxy). Its gradients are what actually teaches the model. Everything else exists to keep the pipeline running.

54The joint objective — and the source of all interference

total = task_loss + α · aux_loss. backward() now produces a gradient that is the sum of two unrelated drivers: 'predict the next token well' AND 'spread tokens evenly across experts'. These are not, in general, aligned. Whenever they disagree, the router's gradient is the wrong direction for one of them.

EXAMPLE
If expert 3 is genuinely the best for math tokens, the task gradient pulls those tokens TOWARD expert 3 and the aux gradient (if expert 3 is now overloaded) pushes them AWAY. The net update is somewhere in between — usually worse for both objectives.
56backward() flows through BOTH paths

autograd does the bookkeeping. Importantly, the router weights receive gradient from two sources every step, and there is no way to inspect 'how much of this update is task-driven vs. balance-driven' without explicit hooks. §6.3 will introduce a mechanism (a bias term) whose gradient is ZERO — its update happens by a non-gradient rule — solving exactly this interference.

60 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class MoEWithAuxLoss(nn.Module):
6    """Top-k MoE layer that returns its own aux load-balancing loss."""
7    def __init__(self, d_model, d_ff, num_experts, k, aux_alpha=0.01):
8        super().__init__()
9        self.E = num_experts
10        self.k = k
11        self.aux_alpha = aux_alpha
12        self.router = nn.Linear(d_model, num_experts, bias=False)
13        self.experts = nn.ModuleList([
14            nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(),
15                          nn.Linear(d_ff, d_model))
16            for _ in range(num_experts)
17        ])
18
19    def forward(self, x):                          # x: (B, S, D)
20        B, S, D = x.shape
21        x_flat = x.reshape(-1, D)                  # (T, D)  with T = B*S
22        T = x_flat.size(0)
23
24        # ---- ROUTER ----
25        logits = self.router(x_flat)               # (T, E)
26        P = F.softmax(logits, dim=-1)              # (T, E)  differentiable
27
28        # Top-k gating (we keep the top-k weights and renormalize)
29        topw, topi = P.topk(self.k, dim=-1)        # (T, k)
30        topw = topw / topw.sum(dim=-1, keepdim=True)
31
32        # ---- AUXILIARY LOAD-BALANCE LOSS (GShard / Switch) ----
33        # f_i = fraction of tokens (in this batch) whose argmax is expert i.
34        with torch.no_grad():
35            one_hot = F.one_hot(P.argmax(dim=-1), self.E).float()
36            f = one_hot.mean(dim=0)                # (E,)  no grad — that is fine
37
38        P_mean = P.mean(dim=0)                     # (E,) differentiable
39        aux_loss = self.E * (f * P_mean).sum()     # scalar, min = 1.0
40
41        # ---- EXPERT COMPUTE (kept simple — not the focus here) ----
42        out = torch.zeros_like(x_flat)
43        for slot in range(self.k):
44            eid_slot = topi[:, slot]               # (T,)
45            w_slot = topw[:, slot].unsqueeze(-1)   # (T, 1)
46            for e in range(self.E):
47                mask = eid_slot == e
48                if mask.any():
49                    out[mask] += w_slot[mask] * self.experts[e](x_flat[mask])
50
51        return out.view(B, S, D), aux_loss
52
53# -------------------- TRAINING STEP --------------------
54model = MoEWithAuxLoss(d_model=64, d_ff=128, num_experts=4, k=1, aux_alpha=0.01)
55opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
56
57x = torch.randn(4, 16, 64)                         # toy batch (B=4, S=16, D=64)
58target = torch.randn_like(x)                       # pretend supervised target
59
60y, aux = model(x)
61task_loss = F.mse_loss(y, target)                  # the "real" loss
62total = task_loss + model.aux_alpha * aux          # the JOINT objective
63
64opt.zero_grad()
65total.backward()                                   # gradients flow through BOTH
66opt.step()
67
68print(f"task_loss={task_loss.item():.4f}  "
69      f"aux={aux.item():.4f}  "
70      f"alpha*aux={(model.aux_alpha*aux).item():.4f}")

The pattern generalizes to deep models with many MoE layers: each layer returns its own aux loss, and the trainer either sums them or averages them before adding α\alpha. DeepSeek-V2's public code, for instance, computes per-layer aux losses and a separate device-level balance loss (we will see that in §6.4) and stacks them all into the final scalar.

  1. One α\alpha per balancing mechanism, not per layer. All MoE layers share the same α\alpha in standard implementations. Per-layer tuning was tried by several teams — it did not generalize across model sizes and was abandoned.
  2. The detach on f matters for performance, not correctness. Argmax has no gradient regardless, but wrapping it in torch.no_grad prevents autograd from building an unnecessary graph. On a 671B model with 58 MoE layers this saves a measurable amount of host memory.
  3. Aux loss is sometimes annealed. Some implementations start α\alpha high (to bootstrap balance quickly) and decay it as training progresses, on the theory that experts diversify and balance naturally once the router has learned something. DeepSeek-V3 ablations (§6.3, Table 3) found this brittle: the balance regresses as soon as α\alpha shrinks, and you cannot tell from the loss curve whether you are regressing.

The Hidden Cost: Gradient Interference

The argument against auxiliary loss is not philosophical; it is a precise statement about gradient geometry. The router weights WrW_r receive an update every step that is the sum of two contributions:

WrL=WrLtask+αWrLaux\nabla_{W_r} \mathcal{L} = \nabla_{W_r} \mathcal{L}_{\text{task}} + \alpha \cdot \nabla_{W_r} \mathcal{L}_{\text{aux}}.

The first term is the gradient that actually improves the model — it pushes routing decisions toward whatever the data says is the right expert for each token. The second term pushes routing decisions toward uniformity, regardless of data. For a single token, these two vectors live in the same parameter space and can be added — but they generally point in different directions. Decompose the aux gradient into the component parallel to the task gradient and the component orthogonal to it:

Laux=(g^taskLaux)g^task+g\nabla \mathcal{L}_{\text{aux}} = (\hat g_{\text{task}} \cdot \nabla \mathcal{L}_{\text{aux}}) \, \hat g_{\text{task}} + g_\perp.

The parallel component is fine — it just resizes the task gradient up or down. The orthogonal component gg_\perp is the toxic part: it deflects the update vector away from the direction the task would have taken on its own. On batches where the task and balance objectives happen to agree, the deflection is small. On batches where they disagree — when the task gradient pulls hard toward a popular expert because that expert genuinely knows the content best — the deflection is large.

Why this is fundamentally different from a learning-rate adjustment. A smaller learning rate would scale both gradients down equally and the model would just learn more slowly. The aux gradient's orthogonal component does not vanish under any α>0\alpha > 0; it merely shrinks. There is no choice of α\alpha that turns off the interference — only choices that make it small enough to tolerate.

The orthogonal component compounds in two ways. First, across MoE layers: a 58-layer DeepSeek-V3 has 58 separate routers, each receiving its own deflected gradient. Second, across training steps: 14.8 trillion tokens means hundreds of thousands of optimizer steps, each adding a small amount of unwanted noise to the router. The effect is statistical and gradual — there is no single step where the model breaks. There is only a steady drift, measurable only by ablation, of the trained model away from where the pure task gradient would have taken it.

What Goes Wrong at Massive Scale

DeepSeek-V2 and V3 papers publish ablations that quantify this drift. Holding model size, data, and compute fixed, training the same MoE with vs. without an auxiliary balancing loss produces measurable differences on downstream benchmarks. The numbers move with model scale, but the pattern is consistent.

Model & setupValidation PPLMMLUReasoning bench
MoE 16B, no balancing (collapses)failsfailsfails
MoE 16B, aux loss α=0.01baselinebaselinebaseline
MoE 16B, aux loss α=0.1 (over-balanced)+1.2%−0.8 pts−1.4 pts
MoE 16B, bias-term (§6.3) no aux−0.4%+0.6 pts+1.1 pts

Numbers above are illustrative of the directional pattern reported in DeepSeek-V2 and V3 ablations — exact values vary by setup. The important columns are not the magnitudes; it is the sign. The bias-term variant (the subject of §6.3) consistently beats the well-tuned aux-loss variant. It is not a tie, and it is not a regression noise: in the V3 ablations the gap is reported as statistically robust across multiple training seeds.

Where the regression lives

Two regions of model behavior absorb most of the quality penalty. First, rare-domain tokens: code, math, multilingual text — anything where one or two experts have plausibly specialized. The aux loss pulls those tokens toward generalist experts they would otherwise have skipped, diluting the specialization that made MoE worth doing in the first place. Second, the long tail of the loss curve: training-loss ablations show that early in training the aux loss is almost free (because the router barely knows anything anyway), but its cost grows over the course of training as specialization that could have emerged is dampened by the balancing pressure.

Engineering Reality: Why "Just Tune Alpha" Fails

The natural reply is: pick a better α\alpha. If α=0.01\alpha = 0.01 is too aggressive, try α=0.001\alpha = 0.001. If that lets the router collapse again, try α=0.005\alpha = 0.005. Several teams spent serious resources on exactly this hyperparameter sweep. Four reasons it fails to close the gap:

  1. The optimum depends on the data distribution, and the data distribution changes during training. Early data tends to be general; later curriculum stages include more specialized domains. The α\alpha that prevents collapse on generalist data may suppress specialization on math data. There is no single α\alpha that is optimal throughout. (DeepSeek tried annealing schedules; quality variance across seeds dwarfed the schedule's benefit.)
  2. The optimum depends on model size. A 7B MoE tolerates a different α\alpha than a 671B MoE because the router's relative parameter share differs. You cannot ablate α\alpha on a 7B model and extrapolate; you have to retune at every target scale, which is impossibly expensive on the budgets that justify MoE in the first place.
  3. Diagnosing the regression is hard. The aux loss does its damage subtly. Training cross-entropy looks fine. Eval benchmarks move by 0.5–2 points — well within run-to-run noise on a single seed. You only see the effect by running multiple seeds and comparing distributions. Most teams ship without doing that carefully and never know what they left on the table.
  4. The orthogonality argument has no α\alpha sweet spot. As shown above, gg_\perp does not vanish at any positive α\alpha. The best you can do by tuning is to scale it down — and below a certain magnitude, the balance signal becomes too weak to prevent collapse. The window where both objectives are mostly satisfied does exist, but it is narrow and shifts with everything that changes about the run.
The exit. The only way out of this trade-off is to balance the router using a mechanism that does not pass through the gradient. If the load-balancing signal can be applied directly to routing decisions (via a non-differentiable bias) while leaving the gradient to be 100% task-driven, the orthogonality argument disappears entirely — there is no aux gradient to interfere with the task gradient because there is no aux gradient at all. That is precisely what DeepSeek's bias-term load balancing does, and it is the entire content of the next section.

For now, the picture to carry forward is this: auxiliary balancing loss is the standard, the obvious choice, and quietly the single largest avoidable cost in modern MoE training. Every line of math we wrote is correct; every implementation works; every model trained this way ships. They ship slightly worse than they could. Closing that gap — at the level of a few benchmark points and a few hundred basis points of pretraining loss — is what justifies DeepSeek-V3's bias-term approach in §6.3, and what makes the decade-old GShard formula not the last word on MoE balancing.

Loading comments...