Chapter 5
25 min read
Section 30 of 117

Expert Parallelism at Scale

Mixture-of-Experts: DeepSeekMoE

The previous four sections built DeepSeekMoE on a single device: 256 routed experts, 1 shared expert, top-8 gating. The arithmetic is clean on paper. It is impossible on hardware. A single H800 holds 80 GB; a DeepSeek-V3 MoE layer alone weighs hundreds of gigabytes. The expert pool simply does not fit on one GPU. Expert parallelism — sharding the experts across many GPUs and shuttling tokens to wherever their chosen expert lives — is the engineering trick that turns the beautiful single-device math into a 671B-parameter system that actually trains.

The bet of expert parallelism. Stop trying to fit every expert on every GPU. Put each expert on exactly one GPU, and instead mail the tokens to the expert. You pay for two all-to-all collectives per layer; in exchange you get a model 30× bigger than any single accelerator could hold.

The Memory Wall of 671B Parameters

Start with a concrete accounting. DeepSeek-V3 has 256 routed experts, each a 2-layer FFN with hidden dimension dff=2048d_{ff} = 2048 and model dimension d=7168d = 7168. That is roughly 2ddff2.91072 \cdot d \cdot d_{ff} \approx 2.9 \cdot 10^7 parameters per expert, or about 58 MB at BF16. One layer's 256 experts therefore weigh 25658MB15GB256 \cdot 58\,\text{MB} \approx 15\,\text{GB}. Multiply across 58 MoE layers and you reach roughly 880 GB of expert parameters alone — before counting attention, embeddings, the optimizer states (which are usually 4× the parameter footprint), gradients, or activations.

An 80 GB H800 cannot hold a single MoE layer's experts in BF16, let alone the whole model with optimizer state. Pure data parallelism is dead on arrival: replicating an 880 GB expert pool on every GPU is not an arithmetic problem, it is a physics problem. Tensor parallelism (the Megatron-style trick of splitting each matmul across devices) helps for a single dense FFN, but it does not exploit the structural sparsity of MoE — it shards each expert across GPUs even though most experts do nothing for most tokens.

What we actually need. A parallelism strategy that exploits the fact that MoE experts are independent and only a few are active per token. Expert parallelism is exactly that: assign each expert to one device, and route tokens to the device that owns their expert. Memory is sharded perfectly. Compute is balanced exactly when the router is balanced.

Sharding Experts Across the Cluster

Picture a hospital again, but now spread across four cities. Each city has two specialists; eight specialists in total. When a patient arrives at a local clinic, the triage desk decides which specialist they need, and the patient flies to the city where that specialist works. After the appointment they fly home. The clinics are tiny (one router each). The specialists are stationary. The patients travel.

The four cities are GPUs. The eight specialists are experts (two per GPU). The patients are tokens. The flight is the all-to-all collective: in a single coordinated step, every GPU sends its tokens to wherever they need to go and receives the tokens it must process — all at the same time, all over the network. After expert compute, a reverse all-to-all flies the outputs home.

Why "all-to-all" and not point-to-point? Every device generally has tokens for every other device. If we did N² sequential sends we'd be bandwidth-starved. All-to-all is a single collective that lets the NIC schedule all sends and receives concurrently, saturating the network links. NCCL implements it as a ring or hierarchical exchange.

Two pieces of the system stay replicated, not sharded: the router and the shared experts. The router is small (a single d×Ed \times E linear) and must run on every device — every device decides where its own tokens go without asking anyone. The shared experts, similarly small relative to the routed pool, live on every device and run locally for every token. The expensive, sharded thing is the routed pool.

The Math: Permutations and All-to-All

Let WW be the world size of the expert-parallel group, EE the number of routed experts, and assume EE is divisible by WW. Each device r{0,,W1}r \in \{0, \dots, W-1\} owns the contiguous expert block [rE/W,  (r+1)E/W)[r \cdot E/W,\; (r+1) \cdot E/W). Define owner(e)=e/(E/W)\mathrm{owner}(e) = \lfloor e / (E/W) \rfloor.

On step tt, device rr holds a local token slice of shape (Tr,d)(T_r, d). After routing, it has top-kk expert ids per token, which we flatten into a list of TrkT_r \cdot k dispatch decisions. Group these by destination device, producing per-destination send counts crrc_{r \to r'}. The dispatch all-to-all is the permutation:

recvr=r=0W1sendrr,sendrr=crr\text{recv}_{r'} = \bigsqcup_{r=0}^{W-1} \text{send}_{r \to r'}, \quad |\text{send}_{r \to r'}| = c_{r \to r'}

Each device rr' then runs its local experts on the received rows — no communication during compute. The combine all-to-all is the inverse permutation, mailing each output back to its source device. After applying the gates and summing across the kk slots, the per-token output is restored.

Communication cost

Each device sends about TrkdT_r \cdot k \cdot d elements during dispatch and the same again during combine. The total volume per layer per device is therefore roughly 2Trkdb2 \cdot T_r \cdot k \cdot d \cdot b bytes (where bb is bytes per element — 2 for BF16). For DeepSeek-V3 with Tr=4096,k=8,d=7168,b=2T_r = 4096, k = 8, d = 7168, b = 2, that's about 0.94 GB per device per layer per step, and we have 58 such layers. The all-to-all traffic is the single largest bandwidth consumer in MoE training.

The bandwidth tax. Communication scales linearly with kdk \cdot d, not with the number of experts. So increasing NrN_r (granularity) is essentially free for the network — only the routed compute changes. This is one reason DeepSeek-V3 can push NrN_r to 256: more experts ≠ more bytes on the wire.

The load-balance precondition

The math above assumes crrc_{r \to r'} is roughly uniform. If one device's experts are popular and the rest are dead, that device's inbox overflows while the others sit idle — compute and communication both stall. Expert parallelism therefore requires the load-balancing mechanisms we will derive in chapter 6 (the auxiliary-loss-free bias terms). Without balance, the whole sharding strategy collapses.

Manual Numerical Walkthrough

Let us trace a tiny dispatch by hand. Four devices, eight experts (two per device), top-1 routing, three tokens per device.

Click to expand: one all-to-all by hand

Setup. W=4W = 4 devices, E=8E = 8 experts, two experts per device, so  owner(0..1)=0\;\mathrm{owner}(0..1)=0, owner(2..3)=1\mathrm{owner}(2..3)=1, owner(4..5)=2\mathrm{owner}(4..5)=2, owner(6..7)=3\mathrm{owner}(6..7)=3. Each device has 3 local tokens and routes top-1 — call the routed expert ids:

  • Device 0's tokens pick experts [5,2,0][5, 2, 0].
  • Device 1's tokens pick experts [7,3,1][7, 3, 1].
  • Device 2's tokens pick experts [4,6,2][4, 6, 2].
  • Device 3's tokens pick experts [0,5,7][0, 5, 7].

Step 1: expert id → destination device. Apply owner\mathrm{owner}:

  • Device 0 dests: [2,1,0][2, 1, 0].
  • Device 1 dests: [3,1,0][3, 1, 0].
  • Device 2 dests: [2,3,1][2, 3, 1].
  • Device 3 dests: [0,2,3][0, 2, 3].

Step 2: build the send-count matrix C[r,r]=C[r, r'] = rows rr sends to rr':

src ↓ / dst →0123
device 01110
device 11101
device 20111
device 31011

Row sums are all 3 (every device sends out all 3 of its tokens). Column sums are also all 3 (every device receives exactly 3 tokens). That column-sum equality is exactly the balance precondition — if column 0 had been 5 and column 2 had been 1, device 0 would have to do 5× the work of device 2 this step.

Step 3: the all-to-all itself. Every device simultaneously sends row C[r,:]C[r, :] chunks out and receives column C[:,r]C[:, r] chunks in. So after the collective, device 0's inbox is:

  • 1 token from device 0 (its own local token that picked expert 0)
  • 1 token from device 1 (the token that picked expert 1)
  • 0 tokens from device 2
  • 1 token from device 3 (the token that picked expert 0)

Total: 3 tokens, two destined for expert 0 and one for expert 1 — the two experts device 0 happens to own.

Step 4: local expert compute on device 0. Run expert 0 once on its packed (2, dd) mini-batch and expert 1 once on its (1, dd) mini-batch. No cross-device anything.

Step 5: combine all-to-all. The transpose of CC tells device 0 to send 1 row back to device 0, 1 to device 1, 0 to device 2, 1 to device 3 — exactly the reverse permutation. Each output lands back in the originating device's row index, gates are applied locally, and one MoE layer is done.

The takeaway. Two collectives, no cross-device compute, all data motion is a pure permutation of rows. Scale this from 4 devices to 256 EP ranks, 12 tokens to 4096 × 8 dispatch rows, and the picture is identical — just bigger arrays.

Visualizing the Dispatch

Step through the five stages below. Watch how the colored tokens (color = chosen expert id) leave their home GPU during Dispatch, land next to the matching expert on the destination GPU during Compute, and travel back during Combine. The side panel tracks how many tokens actually had to cross the network — the rest picked a local expert by luck of the router.

Loading expert-parallelism visualizer…

Three observations to anchor. First, every GPU sends and receives in the same collective — that is what "all-to-all" means and why it is bandwidth-efficient. Second, the traffic counter exposes the load-balance pressure: in a uniformly balanced batch, every device sends and receives the same number of rows. Third, the compute step happens entirely locally — once tokens land, the experts run normally, identical to the single-device MoE from earlier sections.

Plain Python: A Four-Device Simulator

Before touching torch.distributed, here is the entire mechanism in NumPy. We simulate four devices as four Python lists, route every device's tokens locally, build the send buffers, and assemble each device's inbox by hand. The collective is the dictionary rearrangement on lines 36–43 — that is, literally, what an all-to-all does.

🐍expert_parallel_numpy.py
3Cluster shape

World size W=4 devices, E=8 experts, top-k=1 so each token has exactly one destination. experts_per_device=2 means devices [0,1,2,3] host experts [[0,1],[2,3],[4,5],[6,7]] respectively.

EXECUTION STATE
W = 4
E = 8
experts_per_device = 2
6Expert-to-device map

owner(e) is the device id that physically holds expert e. With a contiguous block layout, owner(5)=2. Every device must agree on this map — it is part of the model parallel plan.

10Local batch per device

Each device has a 3-token slice of the global 12-token batch. In real DDP, this slice arrives from the data loader; the global batch is never materialized on one device.

EXECUTION STATE
local_x[0].shape = (3, 4)
14Router is replicated

The router is tiny ((E, d)) so every device gets a full copy. Routing decisions are made locally — no communication needed to decide where tokens go.

EXECUTION STATE
W_router.shape = (8, 4)
17Experts are SHARDED, not replicated

This is the entire point of expert parallelism. Each device only allocates the parameters of its own experts. The dict comprehension here builds device 0's experts only — the same code on device 2 would build experts 4 and 5.

24Route each token locally

x @ W_router.T computes (T, E) affinity logits. argmax along the expert axis is the top-1 pick. No cross-device communication — every device routes its own tokens.

EXECUTION STATE
logits.shape = (3, 8)
local_expert_ids[0].shape = (3,)
30Translate expert id → destination device

local_dst[s][i] = the device that owns the expert token i (on device s) was routed to. This is the address every token has to be mailed to in the all-to-all.

EXAMPLE
If local_expert_ids[0] = [5, 2, 0], local_dst[0] = [owner(5), owner(2), owner(0)] = [2, 1, 0].
36Build the (src → dst) send buffers

Group local rows by their destination device. send[(s, t)] is the list of rows device s must hand to device t. This is the data the all-to-all collective will move.

41The all-to-all = inbox per device

After the collective, device t's inbox is the concatenation of every other device's send[(s, t)]. No reduction happens — all-to-all is a pure permutation of rows.

47Run only the local experts

Device 0 walks its inbox, looks up the expert weight from its local dict, and applies the FFN. It never touches the other devices' expert weights — those don't live in this process.

49Tag each output with (src device, src row)

We need to remember where each token came from so the combine step can mail the result back to the originating device and the originating row index. Production code packs this metadata into a small int32 tensor sent alongside the activations.

45 lines without explanation
1import numpy as np
2
3# 4 simulated devices, 8 experts (2 per device), top-k = 1.
4W, E, k = 4, 8, 1
5experts_per_device = E // W
6def owner(expert_id):
7    return expert_id // experts_per_device
8
9# Each device holds 3 tokens of dimension d=4 (its local batch slice).
10d = 4
11rng = np.random.default_rng(0)
12local_x = [rng.standard_normal((3, d)) for _ in range(W)]
13
14# Router weights are replicated on every device.
15W_router = rng.standard_normal((E, d)) * 0.1
16
17# Expert FFNs are SHARDED: each device only stores its own.
18local_experts = {
19    e: (rng.standard_normal((d, d)) * 0.1)
20    for e in range(E) if owner(e) == 0  # device 0's experts
21}
22# (In real code each device would only build its own dict.)
23
24def route(x):
25    """Pick the top-1 expert id for each row of x."""
26    logits = x @ W_router.T            # (T, E)
27    return logits.argmax(axis=1)       # (T,)
28
29# ---- 1. ROUTE (local, no comms) ----
30local_expert_ids = [route(x) for x in local_x]
31local_dst = [
32    np.array([owner(e) for e in ids]) for ids in local_expert_ids
33]
34
35# ---- 2. DISPATCH = all-to-all of tokens ----
36# Build, for each (src, dst) pair, the rows src wants to send to dst.
37send = {(s, t): [] for s in range(W) for t in range(W)}
38for s in range(W):
39    for row, dst in enumerate(local_dst[s]):
40        send[(s, dst)].append((local_x[s][row], local_expert_ids[s][row], row))
41
42# Each device's inbox = concat of what every other device sent IT.
43inbox = [
44    [pkt for s in range(W) for pkt in send[(s, t)]] for t in range(W)
45]
46
47# ---- 3. EXPERT COMPUTE (local on each device) ----
48# We simulate device 0 only — the others run the same code with their experts.
49outputs_dev0 = []
50for x_row, eid, src_row in inbox[0]:
51    W_e = local_experts[eid]           # device 0 owns this expert
52    y = np.maximum(0, W_e @ x_row)
53    outputs_dev0.append((y, src_row, 0))  # 0 is "I am dev 0"
54
55print("inbox[0] size:", len(inbox[0]),
56      "of", sum(len(b) for b in inbox))

Notice the only line that could not exist in real code: the block that builds device 0's expert dict by filtering. In real expert parallelism, that filter happens at process creation — each rank only ever allocates its own experts. There is no global dict, no other device's weights waiting to be filtered out.

Sanity check. Set W=1W = 1. Every token's destination is device 0; the send dict has one entry; the all-to-all is a no-op; we recover the single-device MoE from section 5.1. Set W=EW = E (one expert per device): every all-to-all is fully cross-device and there are no local picks — the bandwidth tax is at its maximum.

PyTorch: all_to_all_single in Anger

Production code uses torch.distributed.all_to_all_single with split sizes. The module below is a complete, runnable expert-parallel MoE layer — short enough to read end-to-end, faithful enough to the real DeepSeek implementation that the structure transfers directly.

🐍expert_parallel_pytorch.py
6An expert-parallel module is rank-aware

Unlike a normal nn.Module, ExpertParallelMoE knows about the distributed world. ep_group is the process group across which experts are sharded — typically a subset of all ranks (the EP group is one axis of a 3D parallelism mesh).

14Only local experts are constructed

local_E = num_experts / W. On rank r, the experts list holds the FFNs for expert ids [r·local_E, …, (r+1)·local_E − 1]. The other (W−1)·local_E experts simply do not exist in this process — that is how memory is saved.

EXECUTION STATE
self.local_E = num_experts / W
22Router is replicated, so it runs locally

self.router is a normal nn.Linear; every rank holds a copy. DDP's gradient all-reduce keeps the replicas in sync. Routing decisions need no communication.

25x is already a token-parallel slice

By the time MoE runs, the global batch has been data-parallel split across ranks. x.shape = (T, D) where T = global_batch * seq_len / dp_world_size.

EXECUTION STATE
x.shape = (T, D)
27Top-k routing, same as section 5.2

Standard top-k logits and gates. No EP-specific magic yet — gating is identical to single-device MoE.

EXECUTION STATE
topi.shape = (T, k)
gates.shape = (T, k)
31Explode (token, k) into one row per dispatch

Each token will travel to k experts, so we flatten the (T, k) plan into T·k rows. flat_dst[i] = the rank that will receive row i.

EXECUTION STATE
flat_x.shape = (T*k, D)
flat_dst.shape = (T*k,)
37Sort by destination so the all-to-all uses counts

all_to_all_single moves contiguous blocks per destination. Sorting rows by destination rank lets us describe the dispatch with one int per rank (the block size), instead of a per-row permutation.

41send_counts[r] = rows this rank sends to rank r

torch.bincount produces a length-W vector. The sum equals T·k. send_counts is the input_split_sizes for the all-to-all.

EXECUTION STATE
send_counts.shape = (W,)
45Exchange counts (tiny but mandatory)

Before we can allocate recv_x, every rank needs to know its incoming row count. This is one small all-to-all of W ints. After it, recv_counts[r] = rows rank r will receive from rank r' (one entry per source).

49Allocate the inbox

The receiver buffer must be exactly the right size — all_to_all_single does not allocate. Its first dim is the total rows incoming, which we just learned via the counts exchange.

53The big activation all-to-all

This is the bandwidth-heavy collective. Every rank sends send_counts[r] rows of shape (D,) to rank r, simultaneously receiving recv_counts[s] rows of shape (D,) from each source s. NCCL implements this as a single ring or hierarchical exchange under the hood.

60Also ship the expert ids

The receiver needs to know which expert to use for each incoming row. We piggyback a second small all-to-all of int64 ids. In tuned implementations this is fused with the activation transfer.

67Local expert id space

recv_eid is in the global [0, E) space. Subtracting rank · local_E maps it into [0, local_E) so we can index self.experts directly. This is a no-op in the data — it's just a coordinate shift.

69One packed matmul per local expert

Mask the rows that picked expert e, run the FFN once on that packed mini-batch, scatter back. Same gather/scatter pattern as the single-device MoE — but now the rows came from anywhere in the cluster.

76The reverse all-to-all = combine

Roles swap: what was input_split is now output_split. Outputs travel back to the rank that originally sent the input row. After this, every rank has its own T·k output rows back in send-order.

84Undo the sort and apply the gates

argsort of the sort permutation is its inverse — restores the original (T·k) ordering. Multiplying by flat_g applies the softmax weight; the final reshape + sum collapses k contributions per token into the final output row.

74 lines without explanation
1import torch
2import torch.distributed as dist
3import torch.nn as nn
4import torch.nn.functional as F
5
6class ExpertParallelMoE(nn.Module):
7    """One expert-parallel MoE layer. World size W must divide num_experts."""
8    def __init__(self, d_model, d_ff, num_experts, k, ep_group):
9        super().__init__()
10        self.k = k
11        self.E = num_experts
12        self.group = ep_group
13        self.W = dist.get_world_size(ep_group)
14        self.rank = dist.get_rank(ep_group)
15        assert num_experts % self.W == 0
16        self.local_E = num_experts // self.W
17
18        # Only LOCAL experts are instantiated on this rank.
19        self.experts = nn.ModuleList([
20            nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(),
21                          nn.Linear(d_ff, d_model))
22            for _ in range(self.local_E)
23        ])
24        self.router = nn.Linear(d_model, num_experts, bias=False)
25
26    def forward(self, x):  # x: (T, D) — already a flat token slice
27        T, D = x.shape
28        logits = self.router(x)                          # (T, E)
29        topw, topi = logits.topk(self.k, dim=-1)         # (T, k)
30        gates = F.softmax(topw, dim=-1)                  # (T, k)
31
32        # Flatten (token, slot) → one row per (token, chosen expert).
33        flat_x   = x.unsqueeze(1).expand(T, self.k, D).reshape(-1, D)
34        flat_eid = topi.reshape(-1)                      # (T*k,)
35        flat_g   = gates.reshape(-1, 1)                  # (T*k, 1)
36        flat_dst = flat_eid // self.local_E              # owner rank
37
38        # --- Sort rows by destination rank so all-to-all can use counts ---
39        order = flat_dst.argsort()
40        send_x   = flat_x[order]
41        send_eid = flat_eid[order]
42        send_g   = flat_g[order]
43        send_counts = torch.bincount(flat_dst, minlength=self.W)
44
45        # --- Exchange counts so every rank knows how many rows it will receive ---
46        recv_counts = torch.empty_like(send_counts)
47        dist.all_to_all_single(recv_counts, send_counts, group=self.group)
48
49        # --- The two big all-to-alls: activations + their expert ids ---
50        recv_x = torch.empty(
51            (recv_counts.sum().item(), D),
52            dtype=x.dtype, device=x.device,
53        )
54        dist.all_to_all_single(
55            recv_x, send_x,
56            output_split_sizes=recv_counts.tolist(),
57            input_split_sizes=send_counts.tolist(),
58            group=self.group,
59        )
60        recv_eid = torch.empty(recv_counts.sum().item(), dtype=torch.long,
61                               device=x.device)
62        dist.all_to_all_single(
63            recv_eid, send_eid,
64            output_split_sizes=recv_counts.tolist(),
65            input_split_sizes=send_counts.tolist(),
66            group=self.group,
67        )
68
69        # --- LOCAL expert compute: one packed matmul per local expert ---
70        local_out = torch.empty_like(recv_x)
71        local_eid = recv_eid - self.rank * self.local_E  # 0..local_E-1
72        for e in range(self.local_E):
73            mask = local_eid == e
74            if mask.any():
75                local_out[mask] = self.experts[e](recv_x[mask])
76
77        # --- Combine = inverse all-to-all to send results home ---
78        return_x = torch.empty_like(send_x)
79        dist.all_to_all_single(
80            return_x, local_out,
81            output_split_sizes=send_counts.tolist(),
82            input_split_sizes=recv_counts.tolist(),
83            group=self.group,
84        )
85
86        # Unsort + apply gates + scatter-add back to per-token rows.
87        unsort = order.argsort()
88        out_flat = return_x[unsort] * flat_g
89        out = out_flat.view(T, self.k, D).sum(dim=1)
90        return out

Three subtleties that bite first-time EP implementers:

  1. The counts exchange is mandatory. all_to_all_single requires the receiver buffer to be pre-sized. You cannot know your inbox size without first exchanging the per-source counts — that is why two collectives appear before the activations move. Skipping this step is the most common cause of segfaults in hand-written EP code.
  2. Sorting by destination is the cheap optimization. all_to_all_single moves contiguous blocks. If you do not sort, you fall back to all_to_all (with explicit per-rank tensors) which is dramatically slower because it triggers many small NCCL sends instead of one large one. The argsort + argsort-of-argsort idiom is worth memorizing.
  3. Autograd flows through the collective. NCCL's PyTorch bindings have proper backward implementations for all_to_all_single: the gradient of a dispatch is the corresponding combine, and vice versa. You write the forward; autograd produces a correctly-distributed backward at no additional code cost. This is why hand-rolling collectives in raw dist.send / dist.recv is almost never worth it.
Production grouped GEMM. The per-local-expert for e in range(self.local_E) loop is the readable version. At DeepSeek scale, this is replaced by a single grouped-GEMM call (NVIDIA's CUTLASS or Triton kernels) that runs all local experts in one launch, packing rows by expert offset. Same semantics, an order of magnitude less kernel-launch overhead.

What Changes at Massive Scale

At cluster scale, expert parallelism does not stand alone — it is one axis of a 3D or 4D parallelism mesh. DeepSeek-V3's reported training configuration is illustrative:

Parallelism axisGroup sizeWhat it shards
Data parallelism (DP)64 replicasthe batch — each replica sees a different microbatch slice
Pipeline parallelism (PP)16 stagesmodel layers — each stage owns a contiguous block
Expert parallelism (EP)32 ranksthe routed experts within one MoE layer
Tensor parallelism (TP)1 (none in V3)would shard within a single matmul; skipped to keep all-to-all local

The cluster mesh is therefore 64×16×32=32,76864 \times 16 \times 32 = 32{,}768 ranks. The MoE layer's 256 routed experts are sharded across the 32-rank EP group: 8 experts per GPU. Every step, every EP group does its own pair of all-to-alls; the 64 data-parallel replicas do them independently in parallel.

Node-limited routing

Cross-node bandwidth (NVLink + InfiniBand combined) is typically 4–8× slower than intra-node NVLink. DeepSeek-V3 introduces node-limited routing: each token is restricted to experts living on at most M=4M = 4 nodes out of the 8 nodes in its EP group. The router's top-kk is masked so the chosen experts span no more than MM nodes. This halves cross-node traffic with almost no quality loss — a beautifully practical compromise.

Capacity factor and token dropping

Even with auxiliary load balancing, the routed traffic is never perfectly uniform. To bound the receiver buffer at compile time, MoE implementations set a capacity factor CC: each expert reserves space for at most C(average tokens per expert)C \cdot (\text{average tokens per expert}) rows per step. Tokens beyond the cap are dropped for that expert — they pass through the shared experts and the residual stream unmodified by the MoE block. DeepSeek typically uses C1.0C \approx 1.0 in training and bumps it to C1.5C \approx 1.5 for inference where latency matters more than throughput.

Communication-compute overlap

The bandwidth budget can be partly hidden. While device rr is running its local experts on the inbox, it can begin sending output rows back to their source devices for the combine — once any single expert finishes, its rows can start traveling. Real implementations pipeline this overlap explicitly. The shared experts play a second role here: they are running locally during the dispatch all-to-all, hiding 10–30% of the network latency behind useful compute.

QuantityDeepSeek-V3 valueWhy it matters
Routed experts per MoE layer256deep specialization pool
EP group size32 GPUsexperts/GPU = 8
Top-k8every token visits 8 experts per layer
Node-limited M4 nodeshalves cross-node all-to-all
Capacity factor (train)≈ 1.0bounds receive buffer
All-to-all bytes / GPU / layer / step≈ 0.94 GBthe bandwidth bill
MoE layers in V358total bill ≈ 55 GB/GPU/step

Engineering Reality: Capacity, Overlap, and Topology

Expert parallelism is the easiest piece of MoE to implement at toy scale and the easiest to get wrong at production scale. Four failure modes are worth carrying in your head:

  1. Imbalanced routing → straggler GPU. If one device owns the most popular experts, the entire EP group waits for it on every step. Symptoms: low SM utilization on most ranks, high on one; all-to-all latency climbing over training. Fix: aggressive load-balancing (chapter 6) and capacity factor > 1.
  2. EP group placed across the wrong topology. If your 32-rank EP group spans 8 nodes instead of 4, every all-to-all is cross-node. Modern clusters give you NVLink within a node and InfiniBand across nodes — those are 10× apart in bandwidth. Place EP groups along NVLink-dense axes, not pipeline-dense ones.
  3. Token-drop cascade in mixed precision. With C=1.0C = 1.0 and BF16 accumulation, a tiny drift in router logits can flip a token from "just under the cap" to "dropped". Dropped tokens get zero gradient through the MoE for that step. Stack hundreds of such events per batch and you have a silent quality regression that no loss curve will catch — only eval metrics will. Periodically log per-expert drop rate.
  4. All-to-all on the wrong stream. If the dispatch collective shares a CUDA stream with the local expert compute, they serialize even though they could overlap. Production EP code puts collectives on dedicated streams and uses CUDA events to synchronize only where strictly required. A 30% throughput swing is on the table.
Where we go from here. The next section (Implementing DeepSeekMoE) glues everything together: shared + routed experts, top-kk gating, expert parallelism, capacity factor, all into a single PyTorch module you could drop into a transformer block today. Chapter 6 then attacks the elephant we kept dodging here: how to make the router actually balance the load without adding a quality-degrading auxiliary loss.

The one sentence to carry forward: expert parallelism turns MoE's structural sparsity into a memory savings — and pays for it with two all-to-all collectives per layer per step. Everything else in massive MoE engineering — load balancing, capacity factors, node-limited routing, communication-compute overlap — exists to keep that all-to-all cheap.

Loading comments...