Chapter 11
20 min read
Section 62 of 117

Tensor Parallelism (TP)

Distributed Training: DualPipe and the Parallelism Stack

A single H800 has 80 GB of HBM. A 70 B-parameter dense model in BF16 needs 140 GB just for the weights. Data parallelism cannot help — every replica still has to hold the whole model. The fix is to cut the model itself in half, then in quarters, then in eighths, distributing the matmul across GPUs and stitching the answer back together with a single collective per layer. That is tensor parallelism. The previous section showed how to scale by replicating; this section shows how to scale by splitting.

The Real Problem: One Layer No Longer Fits

Data parallelism (Section 11.2) gets you more throughput by running the same model on many GPUs with different mini-batches, then averaging gradients. It works beautifully — until the model itself stops fitting on a single GPU. For LLaMA-70B, the full forward+backward state in BF16 with FP32 master weights and Adam moments costs roughly (2+2+4+8)70109=1.12TB(2 + 2 + 4 + 8) \cdot 70 \cdot 10^{9} = 1.12 \, \text{TB} — 14× larger than an 80 GB H800. No amount of data parallelism rescues you. The model has to be split.

There are three ways to split it:

  1. Pipeline parallelism (PP) — give different layers to different GPUs. Easy memory math, but introduces pipeline bubbles: when GPU 0 is on micro-batch 1, GPUs 1-7 are idle. We cover this in Section 11.4.
  2. Tensor parallelism (TP) — give a slice of each layer to every GPU. Every GPU is busy on every layer, but every layer pays a collective communication tax. The topic of this section.
  3. Expert parallelism (EP) — give different MoE experts to different GPUs. Specific to mixture-of-experts; Section 11.6.

Tensor parallelism is the most surgical of the three. It does not change what the model computes — it changes where each piece of the computation lives. Forward and backward of a TP-sharded layer are mathematically identical to the unsharded layer, up to floating-point rounding. The cost is a collective communication operation per layer.

TP is not optional for frontier-scale dense models. A 70 B dense Transformer's MLP has weight matrices of roughly ddff819228672235Md \cdot d_{ff} \approx 8192 \cdot 28672 \approx 235 \, \text{M} parameters each. In BF16 that is 470 MB per matrix, two matrices per layer, 80 layers — about 75 GB of MLP weights alone. Add attention QKV, KV cache, gradients, optimiser state, activations and the layer cannot run on one GPU. TP splits each matrix across (typically) 8 GPUs in a node so the layer fits.

Intuition: Split the Matmul, Not the Data

Imagine you are computing Y=XWY = XW where XX is a small matrix of inputs and WW is enormous — too big to hold on one machine. You have two coworkers. Two ways to share the work:

  1. Cut W into vertical strips (columns). You take the left half of W's columns, your coworker takes the right half. Each of you multiplies the full X by your strip. You get the left half of Y, your coworker gets the right half. No talking required — you each independently computed a piece of the answer. This is column-parallel.
  2. Cut W into horizontal strips (rows). Now X also has to be split: you take the left columns of X (matching W's top rows), your coworker takes the right columns. You each compute something the FULL shape of Y, but each result is a partial — it only contains the contribution from your rows. To get the real Y, you have to add the partials together. This is row-parallel, and the addition step is the all-reduce.

Neither approach alone is enough. Column-parallel leaves you with a sharded output; row-parallel demands a sharded input. The Megatron-LM insight is to compose them: a column-parallel layer feeds its sharded output directly into a row-parallel layer, which all-reduces at the end. The composition produces a replicated output from a replicated input — drop-in replacement for an unsharded MLP — with only ONE collective per pair.

Why the sandwich works. The MLP is Y=(XWup)WdownY = (X W_{\text{up}}) W_{\text{down}}. If we make WupW_{\text{up}} column-parallel and WdownW_{\text{down}} row-parallel, the sharded intermediate flows from one to the other without any collective in between — the activation (GELU/SwiGLU) is element-wise, so it commutes with sharding. Only the final output needs an all-reduce. Attention works the same way: column-parallel QKV → attention compute on each rank's head shard → row-parallel output projection + all-reduce. Two all-reduces per transformer layer, period.

The Math: Column-Parallel and Row-Parallel

Let WRd×mW \in \mathbb{R}^{d \times m} be the weight matrix. Partition the inner dimension mm into TT equal chunks across TT TP ranks. We write W=[W(1)W(2)W(T)]W = [W^{(1)} \, | \, W^{(2)} \, | \, \cdots \, | \, W^{(T)}] where each block W(r)Rd×m/TW^{(r)} \in \mathbb{R}^{d \times m/T} lives on rank rr.

Column-parallel

The forward pass on rank rr is:

Y(r)=XW(r)RB×m/TY^{(r)} = X \, W^{(r)} \in \mathbb{R}^{B \times m/T}

with XX replicated. The output is sharded along the inner dimension, no communication required. The full unsharded output Y=[Y(1)Y(2)]Y = [Y^{(1)} \, | \, Y^{(2)} \, | \, \cdots] exists implicitly across the ranks, but is never materialised.

Row-parallel

Now suppose WRm×dW \in \mathbb{R}^{m \times d} and we partition the row dimension mm: W=[W(1)W(2)]W = [W_{(1)}^\top \, | \, W_{(2)}^\top \, | \, \cdots]^\top. The input must already be sharded along mm — which is exactly what column-parallel produced — so on rank rr we have X(r)RB×m/TX^{(r)} \in \mathbb{R}^{B \times m/T}. The local matmul is:

P(r)=X(r)W(r)RB×dP^{(r)} = X^{(r)} \, W_{(r)} \in \mathbb{R}^{B \times d}

Each P(r)P^{(r)} has the FULL output shape but is a partial — only the contribution from rank rr's slice of mm. The actual output is the sum:

Y=r=1TP(r)=AllReduce(P(r))Y = \sum_{r=1}^{T} P^{(r)} = \text{AllReduce}(P^{(r)})

Communication volume

A ring all-reduce of a tensor of size BdB \cdot d elements moves roughly 2(T1)/TBd2 (T-1)/T \cdot B \cdot d elements through each rank's links. The total per-layer communication volume (attention output + MLP down-proj, both row-parallel) is:

Vlayer=22T1TBdbytes/eltV_{\text{layer}} = 2 \cdot 2 \cdot \frac{T-1}{T} \cdot B \cdot d \cdot \text{bytes/elt}

The factor of (T1)/T(T-1)/T approaches 1 as TP degree grows — so doubling TP from 4 to 8 only increases per-layer comm volume by about 14%. But it halves the local compute, so the ratio V/FLOPsV / \text{FLOPs} roughly doubles. That ratio is the crucial number: when it climbs above ~0.3 (comm time is 30% of compute time), tensor cores start starving and wall-clock throughput collapses.

Manual Numerical Walkthrough

We walk through a 2-rank TP forward pass on a 2×4-wide MLP, by hand. Tiny enough to verify every number; structurally identical to a 70 B-model layer.

Click to expand: TP=2 MLP forward pass by hand

Setup. One token, hidden dimension d=2d = 2, FFN inner dimension dff=4d_{ff} = 4, TP degree T=2T = 2. Ignore the activation for clarity.

Tensors. Input X=[1,2]X = [1, 2] (shape 1×21 \times 2). The full unsharded weights are:

Wup=[10120111]W_{\text{up}} = \begin{bmatrix} 1 & 0 & 1 & 2 \\ 0 & 1 & 1 & -1 \end{bmatrix} (shape 2×42 \times 4), Wdown=[10011111]W_{\text{down}} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \\ -1 & 1 \end{bmatrix} (shape 4×24 \times 2).

Reference output. First XWup=[1,2,3,0]X W_{\text{up}} = [1, 2, 3, 0], then [1,2,3,0]Wdown=[1+0+3+0,0+2+3+0]=[4,5][1, 2, 3, 0] \cdot W_{\text{down}} = [1 + 0 + 3 + 0, \, 0 + 2 + 3 + 0] = [4, 5]. So Yref=[4,5]Y_{\text{ref}} = [4, 5].

TP=2: rank 0. Owns the first two columns of WupW_{\text{up}} and the first two rows of WdownW_{\text{down}}:

Wup(0)=[1001],Wdown,(0)=[1001]W_{\text{up}}^{(0)} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}, \quad W_{\text{down},(0)} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}

Column-parallel: Y1(0)=XWup(0)=[1,2]Y_1^{(0)} = X W_{\text{up}}^{(0)} = [1, 2]. Row-parallel: P(0)=Y1(0)Wdown,(0)=[11+20,10+21]=[1,2]P^{(0)} = Y_1^{(0)} \, W_{\text{down},(0)} = [1 \cdot 1 + 2 \cdot 0, \, 1 \cdot 0 + 2 \cdot 1] = [1, 2].

TP=2: rank 1. Owns the last two columns of WupW_{\text{up}} and the last two rows of WdownW_{\text{down}}:

Wup(1)=[1211],Wdown,(1)=[1111]W_{\text{up}}^{(1)} = \begin{bmatrix} 1 & 2 \\ 1 & -1 \end{bmatrix}, \quad W_{\text{down},(1)} = \begin{bmatrix} 1 & 1 \\ -1 & 1 \end{bmatrix}

Column-parallel: Y1(1)=XWup(1)=[11+21,12+2(1)]=[3,0]Y_1^{(1)} = X W_{\text{up}}^{(1)} = [1 \cdot 1 + 2 \cdot 1, \, 1 \cdot 2 + 2 \cdot (-1)] = [3, 0]. Row-parallel: P(1)=[3,0]Wdown,(1)=[31+0(1),31+01]=[3,3]P^{(1)} = [3, 0] \, W_{\text{down},(1)} = [3 \cdot 1 + 0 \cdot (-1), \, 3 \cdot 1 + 0 \cdot 1] = [3, 3].

All-reduce. Y=P(0)+P(1)=[1,2]+[3,3]=[4,5]Y = P^{(0)} + P^{(1)} = [1, 2] + [3, 3] = [4, 5] — exactly the reference. Mathematical equivalence verified.

Memory bookkeeping. Rank 0 stores 4 floats of WupW_{\text{up}} and 4 of WdownW_{\text{down}} — 8 total, vs 16 unsharded. Two-rank TP halves the weight footprint, as expected. The all-reduce moves Bd=2B \cdot d = 2 floats between the two ranks — the constant comm tax for using TP.

What scales. Replace d=2d = 2 with 8 192, dff=4d_{ff} = 4 with 28 672, T=2T = 2 with 8, and B=1B = 1 with 8 192 (a typical micro-batch × sequence). The algorithm is line-for-line the same. The weight shard shrinks from ddffd \cdot d_{ff} to ddff/T29Md \cdot d_{ff} / T \approx 29 \, \text{M} elements per rank, the all-reduce moves Bd=67MB \cdot d = 67 \, \text{M} floats per layer per rank, twice per layer (attention + MLP).

Visualizing Sharding and Communication Cost

Slide TP degree, hidden dim, FFN inner dim, batch×seq, and NVLink bandwidth. The diagram redraws the matmul split; the stat cards retally per-GPU memory, FLOPs, and all-reduce volume. Watch the bottom callout — comm/compute ratio is the number that decides whether TP is profitable.

Loading tensor-parallelism lab…
Three settings worth trying. (1) Set TP=8, hidden 4096, d_ff 14336, batch×seq 8192, NVLink 600 GB/s — DeepSeek-V3 / LLaMA-70B intra-node defaults. Comm/compute lands around 0.05-0.10; comfortable. (2) Bump TP to 8 but drop NVLink to 50 GB/s (simulating cross-node InfiniBand). Comm/compute jumps past 1.0 — TP is now actively slowing you down. This is why TP almost never crosses node boundaries. (3) Hold TP=8 but push batch×seq to 32 768. Compute scales linearly, comm scales linearly, ratio stays roughly flat — TP's overhead is shape-invariant in the batch dimension, which is what makes it composable with DP.

Plain Python: TP From Scratch with NumPy

Before we bring in NCCL and CUDA, prove that the algebra works on a single CPU process. Below, we simulate TP=4 by literally slicing the weight matrices into 4 shards, computing each rank's partial forward pass in a loop, and reconstructing the answer with a sum. The output has to match the unsharded reference exactly — if it doesn't, your sharding is wrong long before you spend a minute on distributed debugging.

tp_mlp_numpy.py
🐍python
3The MLP block — the target of tensor parallelism

Every transformer layer contains two giant matrix multiplications inside the feed-forward block: an up-projection from hidden dimension d to a wider inner dimension d_ff, and a down-projection back to d. For a 70 B model, d ≈ 8 192 and d_ff ≈ 28 672 — each of these weight matrices is roughly 230 M parameters. The block accounts for ~2/3 of the model's parameters. If we are going to split the model across GPUs, the MLP is where the win lives.

EXECUTION STATE
X =
(B·S, d)
W_up =
(d, d_ff)
W_down =
(d_ff, d)
11Toy dimensions chosen so we can verify by hand

Tiny shapes — B=2, S=4, d=8, d_ff=16, TP=4 — keep every intermediate tensor small enough to print. The algorithm is identical at d=8 192 / d_ff=28 672 / TP=8; only the numbers change. The key relationship is d_ff % TP == 0: the inner dimension must be evenly divisible by the TP degree so each rank gets an equal-sized shard.

EXECUTION STATE
B*S = 8 tokens
d = 8 (hidden)
d_ff = 16 (FFN inner)
TP = 4 ranks
19Reference: the answer every rank must produce

Compute the unsharded MLP on one process. This is what the TP forward pass has to reconstruct exactly. If our sharded algebra ever drifts from Y_ref by more than floating-point noise, TP is broken — and any divergence here is a bug that compounds across 60 layers and millions of training steps.

EXECUTION STATE
Y_ref =
(B*S, d) = (8, 8)
23shard — the slice each rank owns

Each rank gets a contiguous slice of size d_ff // TP = 4 columns of W_up (and the matching 4 rows of W_down). Megatron-LM does exactly this: the column indices [0:4], [4:8], [8:12], [12:16] go to ranks 0, 1, 2, 3. The split is along the INNER dimension d_ff — never along d. That is the whole trick: splitting along the inner dimension means the outer dimensions of inputs and outputs are unchanged, so X stays replicated and Y stays the right shape.

EXECUTION STATE
shard = 4 columns/rows per rank
27Column-parallel up-projection — no communication

Each rank multiplies the full input X by its column slice of W_up. The result Y1_r has the full token batch but only 'shard' columns of d_ff — it is sharded along the inner dimension. Critically, no rank needs data from any other rank to compute its Y1_r. This stage runs at full GEMM throughput with zero collective communication, which is why column-parallel is the cheap half of the sandwich.

EXECUTION STATE
W_up_r =
(d, shard) = (8, 4)
Y1_r =
(B*S, shard) = (8, 4)
31Row-parallel down-projection — produces partial sums

Now each rank multiplies its sharded Y1_r by its row slice of W_down. The output Y_r has the full output shape (B*S, d) — but it is a PARTIAL sum: it only contains the contribution from the rank's slice of d_ff. Mathematically, Y = Σ_r (Y1_r @ W_down_r). The next step has to actually do that sum across ranks, which is the all-reduce.

EXECUTION STATE
W_down_r =
(shard, d) = (4, 8)
Y_r =
(B*S, d) = (8, 8) — partial sum
36All-reduce — the (unavoidable) communication

Sum the per-rank partials. On real hardware this is an all-reduce: every rank ends up holding Y. The volume is (B*S * d) bytes per rank per all-reduce — independent of TP degree, which means doubling TP does not double comm volume. But doubling TP halves compute per rank, so the ratio (comm / compute) doubles. That ratio is what eventually stops TP from scaling beyond about 8.

EXECUTION STATE
Y_tp =
(B*S, d) — final output, every rank has it
41Correctness check

max abs diff should be ~1e-13 (just float64 rounding). If we got the slicing right, sharded TP and the unsharded reference are mathematically identical. This is the property that lets TP work invisibly inside autograd: forward and backward of a TP layer are bit-equivalent (up to rounding) to the unsharded layer.

LOOP TRACE · 2 iterations
max abs diff
value = ≈ 1e-13
np.allclose
value = True
45Memory: 4× smaller weights per rank

Each rank stores 32 floats for W_up (vs 128 unsharded) and 32 floats for W_down (vs 128). That is the headline win — TP=4 cuts weight memory per GPU by 4× for the MLP block. At 70 B-model scale, the per-GPU MLP weights drop from ~450 MB to ~110 MB at TP=4, freeing roughly 1.3 GB of HBM for activations and KV cache. The all-reduce bytes (Y_r per rank, 64 bytes here) is the price tag.

EXECUTION STATE
W_up per rank = 32 B (4× smaller)
W_down per rank = 32 B (4× smaller)
all-reduce volume = 64 B per rank per layer
39 lines without explanation
1import numpy as np
2
3# A transformer MLP block:  Y = (X @ W_up) @ W_down       (ignore activation for clarity)
4#   X     : (batch*seq, d)
5#   W_up  : (d, d_ff)
6#   W_down: (d_ff, d)
7# Tensor parallelism splits W_up by COLUMNS and W_down by ROWS across TP ranks.
8# After the row-parallel matmul each rank holds a PARTIAL sum of Y; an all-reduce
9# sums them. We simulate the whole thing on one CPU process — TP=4 just means
10# we cut the weight matrices into 4 shards and check the algebra reconstructs Y.
11
12np.random.seed(0)
13B, S, d, d_ff = 2, 4, 8, 16          # tiny: batch 2, seq 4, hidden 8, FFN 16
14TP = 4                                 # shard across 4 virtual GPUs
15
16# Inputs and full weights (the unsharded reference).
17X      = np.random.randn(B * S, d)
18W_up   = np.random.randn(d, d_ff)
19W_down = np.random.randn(d_ff, d)
20
21# Reference: the answer every rank must agree on after all-reduce.
22Y_ref = (X @ W_up) @ W_down            # shape (B*S, d)
23
24# --- TP forward pass, rank by rank ---
25shard = d_ff // TP                     # each rank owns 'shard' columns of W_up
26                                       # and the matching 'shard' rows of W_down
27partials = []
28for rank in range(TP):
29    # Column-parallel up-projection: rank owns columns [rank*shard : (rank+1)*shard]
30    W_up_r   = W_up[:, rank*shard : (rank+1)*shard]   # (d, shard)
31    Y1_r     = X @ W_up_r                              # (B*S, shard)  -- no comm yet
32
33    # Row-parallel down-projection: rank owns rows [rank*shard : (rank+1)*shard]
34    W_down_r = W_down[rank*shard : (rank+1)*shard, :]  # (shard, d)
35    Y_r      = Y1_r @ W_down_r                          # (B*S, d) partial sum
36    partials.append(Y_r)
37
38# All-reduce: sum the per-rank partial outputs. Now every rank holds Y.
39Y_tp = sum(partials)
40
41# Did we recover the reference?
42print("max abs diff:", np.abs(Y_tp - Y_ref).max())
43print("identical to ~1e-12:", np.allclose(Y_tp, Y_ref))
44
45# Memory accounting per rank:
46print(f"\nW_up per rank:   {W_up_r.nbytes:>5} bytes  (full {W_up.nbytes} bytes)")
47print(f"W_down per rank: {W_down_r.nbytes:>5} bytes  (full {W_down.nbytes} bytes)")
48print(f"all-reduce bytes per call: {Y_r.nbytes} (one Y per rank, summed in-place)")

PyTorch: Megatron-Style TP with all_reduce

Now on real silicon. Same algorithm, two custom nn.Module\texttt{nn.Module} classes — ColumnParallelLinear\texttt{ColumnParallelLinear} and RowParallelLinear\texttt{RowParallelLinear} — wired together with one torch.distributed.all_reduce\texttt{torch.distributed.all\_reduce} per layer. This is a stripped-down copy of what Megatron-LM, NVIDIA TransformerEngine, and DeepSeek's in-house kernels do at much greater scale.

tp_mlp_pytorch.py
🐍python
1Process group setup — NCCL on intra-node NVLink

torchrun spawns one process per GPU. dist.init_process_group with backend='nccl' wires those processes into a collective communication group running on NVLink (intra-node) or NVSwitch (DGX-class) or InfiniBand (cross-node). TP traffic is dominated by all-reduce, and NCCL's ring all-reduce on NVLink is the fastest collective on the box — 600+ GB/s per GPU bidirectional. This bandwidth is the whole reason TP works at all.

EXECUTION STATE
rank = 0 .. world_size-1
world_size = TP degree (4 here)
12ColumnParallelLinear — owns a column slice of the output

The key line is the Parameter shape: torch.empty(d_in, self.shard) — d_in stays full, the OUTPUT dimension is sharded. Rank 0 owns columns 0..shard-1 of the full (d_in, d_out) weight; rank 1 owns the next shard; etc. Allocating only the shard is what makes weights fit on each GPU. In Megatron-LM this class is named exactly the same; the abstraction is universal.

EXECUTION STATE
self.shard = d_out / world_size
self.weight.shape =
(d_in, shard)
23Column-parallel forward — pure matmul, zero comm

x is replicated identically on every rank (broadcast-replicated, or arrived replicated from the previous all-reduce). Multiply x by the rank's column shard of W and you get a sharded output of shape (B*S, shard). No NCCL call. The activation that follows runs on the sharded slice, which means the activation memory is ALSO sharded — a non-trivial extra saving on top of the weight savings.

EXECUTION STATE
x =
(B*S, d_in) replicated
x @ weight =
(B*S, shard) sharded
27RowParallelLinear — owns a row slice of the input

The mirror image: the INPUT dimension d_in is sharded. Each rank holds rows [rank*shard, (rank+1)*shard) of the full (d_in, d_out) weight. The input x to this layer must be sharded along d_in (which is exactly what ColumnParallelLinear's output gives us). Multiplying a sharded input by a row-sharded weight produces a PARTIAL sum of the full output shape — the partials live on different ranks and must be summed.

EXECUTION STATE
self.weight.shape =
(shard, d_out) row-sharded
39all_reduce — the one collective per TP layer

dist.all_reduce(partial, op=SUM) is a ring all-reduce. In one call, every rank's 'partial' tensor is combined element-wise with the same tensor from every other rank, and the summed result lands on every rank. The wire volume per rank is roughly 2·(world_size-1)/world_size · |tensor| bytes — slightly less than 2× the tensor size for large world_size. This is the all-reduce that bounds TP scaling: it runs every layer, twice (once in attention output, once in MLP down-proj).

EXECUTION STATE
partial (before) =
(B*S, d_out) PARTIAL sum on this rank
partial (after) =
(B*S, d_out) FULL Y, identical on every rank
47TPMlp — the sandwich

Column-parallel up → GELU on sharded slice → row-parallel down + all-reduce. This is the canonical Megatron-LM MLP. The sandwich pattern is what makes the two all-reduces add up to one round-trip per MLP block: column-parallel output flows straight into row-parallel input, no comm in between. Building it any other way (e.g. all-gather in the middle) doubles the communication and halves throughput.

57Driver — TP=4 on a single node

world_size=4 on one DGX-class node with NVLink. The MLP block holds a 4096×14336 + 14336×4096 weight tensor in float, ~470 MB unsharded; per-rank it is ~117 MB. Multiply by 32 transformer layers and TP=4 saves ~11 GB of weight memory per GPU vs no sharding — exactly the headroom a 70 B-class model needs.

EXECUTION STATE
y.shape =
(8, 4096) — replicated after final all-reduce
per-rank up weight = ≈ 14.7 M elts (vs 58.7 M unsharded)
61 lines without explanation
1import os
2import torch
3import torch.distributed as dist
4import torch.nn as nn
5
6# torchrun --nproc-per-node 4 tp_mlp.py  -- single node, 4 GPUs, TP=4
7def init_dist():
8    dist.init_process_group(backend="nccl")
9    rank       = dist.get_rank()
10    world_size = dist.get_world_size()
11    torch.cuda.set_device(rank)
12    return rank, world_size
13
14class ColumnParallelLinear(nn.Module):
15    """Y = X @ W where W is split along its OUTPUT dimension across ranks."""
16    def __init__(self, d_in, d_out, world_size, rank):
17        super().__init__()
18        assert d_out % world_size == 0
19        self.shard = d_out // world_size
20        # Each rank holds only its shard of the full (d_in, d_out) weight.
21        self.weight = nn.Parameter(
22            torch.empty(d_in, self.shard, device="cuda", dtype=torch.bfloat16)
23        )
24        nn.init.normal_(self.weight, std=0.02)
25
26    def forward(self, x):
27        # X is replicated on every rank. No comm needed in forward.
28        return x @ self.weight        # (B*S, d_in) @ (d_in, shard) -> (B*S, shard)
29
30class RowParallelLinear(nn.Module):
31    """Y = X @ W where W is split along its INPUT dimension; output is all-reduced."""
32    def __init__(self, d_in, d_out, world_size, rank):
33        super().__init__()
34        assert d_in % world_size == 0
35        self.shard = d_in // world_size
36        self.weight = nn.Parameter(
37            torch.empty(self.shard, d_out, device="cuda", dtype=torch.bfloat16)
38        )
39        nn.init.normal_(self.weight, std=0.02)
40
41    def forward(self, x):
42        # x is sharded on input dim: shape (B*S, shard). Local matmul produces a
43        # partial sum of full-shape output; all-reduce sums partials across ranks.
44        partial = x @ self.weight                          # (B*S, d_out) PARTIAL
45        dist.all_reduce(partial, op=dist.ReduceOp.SUM)     # in-place ring all-reduce
46        return partial                                     # (B*S, d_out) full
47
48class TPMlp(nn.Module):
49    def __init__(self, d, d_ff, world_size, rank):
50        super().__init__()
51        self.up   = ColumnParallelLinear(d,    d_ff, world_size, rank)
52        self.down = RowParallelLinear   (d_ff, d,    world_size, rank)
53
54    def forward(self, x):
55        h = self.up(x)                  # (B*S, d_ff / world_size)  -- sharded
56        h = torch.nn.functional.gelu(h) # activation runs on sharded slice
57        return self.down(h)             # (B*S, d)  -- replicated after all-reduce
58
59# --- driver ---
60rank, world_size = init_dist()
61mlp  = TPMlp(d=4096, d_ff=14336, world_size=world_size, rank=rank)
62x    = torch.randn(8, 4096, device="cuda", dtype=torch.bfloat16)  # B*S=8 tokens
63y    = mlp(x)                                                      # (8, 4096)
64
65if rank == 0:
66    print(f"TP={world_size}  y.shape={tuple(y.shape)}  y.dtype={y.dtype}")
67    print(f"per-rank up   weight: {mlp.up.weight.numel():>10,} elts")
68    print(f"per-rank down weight: {mlp.down.weight.numel():>10,} elts")

At Massive Scale: Why TP Almost Always Stops at 8

The recurring number in frontier-training papers — Megatron-Turing 530B, LLaMA-70B, GPT-3, DeepSeek-V3 — is TP = 8. Not 16, not 4. Three physical constraints conspire to make 8 the sweet spot.

  1. NVLink lives inside one node. A DGX-class node has 8 GPUs connected by NVLink/NVSwitch at 600+ GB/s. Cross-node links (InfiniBand or RoCE) are 50-200 GB/s — 3-12× slower. TP's all-reduce runs every layer, twice. At BF16, a 70B model at TP=8 moves ~270 MB per layer per all-reduce; on NVLink that is ~0.45 ms, on InfiniBand it would be ~5 ms — and you have 80 layers. TP that crosses node boundaries is dead on arrival.
  2. Comm volume per layer is fixed in the batch dimension; compute is not. The per-layer all-reduce is O(Bd)O(B \cdot d); the per-layer FLOPs are O(Bddff)O(B \cdot d \cdot d_{ff}). The ratio is 1/dff1 / d_{ff}. As you increase TP from 1 to 8, each rank's compute drops by 8× but comm stays the same; ratio rises 8×. Past TP=8 on a single node, the ratio enters the danger zone where tensor cores idle on comm.
  3. TP=8 happens to match the head count. Attention heads (e.g. 64 for LLaMA-70B, 128 for DeepSeek-V3) are evenly divisible by 8, so column-parallel QKV gives each rank an integer number of heads. Going to TP=16 would require splitting individual heads — possible, but messier and rarely worth it.

The 3D stack: DP × TP × PP

TP=8 alone gets you a 70B model running on 8 GPUs; the real run uses hundreds or thousands. The standard composition is:

total GPUs=DP×TP×PP\text{total GPUs} = \text{DP} \times \text{TP} \times \text{PP}

For a 2048-GPU DeepSeek-V3 run on H800s, a typical configuration is TP=1\text{TP} = 1 (DeepSeek famously avoids TP — see Section 11.7), PP=16\text{PP} = 16, EP=64\text{EP} = 64, DP=2\text{DP} = 2. For a dense LLaMA-style 70B, TP=8\text{TP} = 8, PP=4\text{PP} = 4, DP=64\text{DP} = 64 on 2048 GPUs is canonical. TP carves up the layer, PP carves up the layer stack, DP scales throughput. Each axis has a different cost structure; the engineering is in finding the configuration that minimises wall-clock for the target FLOP budget.

ParallelismWhat it splitsComm patternBandwidth need
Data (DP)mini-batchall-reduce on gradients, once per steplow (per step, not per layer)
Tensor (TP)weights along inner dimall-reduce on activations, twice per layervery high (NVLink only)
Pipeline (PP)layer stackpoint-to-point activation send/recv at stage boundariesmoderate
Expert (EP)MoE expertsall-to-all on token routing, twice per MoE layerhigh (cross-node ok)

Engineering Reality: Costs, Failure Modes, and the 3D Stack

TP is mature. Megatron-LM has shipped it since 2019, every frontier lab uses some flavour of it, and the abstractions (ColumnParallel / RowParallel) are stable. That does not make it free.

  1. Hardware lock-in to fast intra-node fabric. TP assumes NVLink-class bandwidth between every pair of TP ranks. On consumer GPUs (no NVLink) or PCIe-only nodes, TP throughput falls off a cliff — you are better off with PP or with model offloading (ZeRO). DeepSeek-V3's decision to avoid TP entirely (Section 11.7) is partly because their H800s have NVLink throttled to ~400 GB/s vs H100's 900 GB/s — the comm/compute ratio tipped just far enough to prefer ZeRO-style memory sharding.
  2. Subtle bugs in the gradient direction. The forward pass of a TP layer is intuitive; the backward pass requires careful insertion of all-reduce in places you would not expect. The gradient of a column-parallel layer needs an all-reduce on the input gradient before being passed to the previous layer; forget this and your gradients become biased and training slowly diverges over thousands of steps. Megatron-LM's f / g\texttt{f / g} trick (identity in forward, all-reduce in backward; all-reduce in forward, identity in backward) automates this, but only inside the framework — custom TP kernels routinely ship with this bug.
  3. RNG and dropout coordination. Activations are replicated outside TP-sharded ops, sharded inside. Dropout needs the SAME mask on every rank in the replicated regions, but a DIFFERENT mask per shard in the sharded regions. Get this wrong and you double-dropout some neurons and never drop others — another silent failure mode that only shows up as a slowly worse loss curve.
  4. Activation memory is sharded inside the block but replicated at the boundary. The sharded intermediate between column-parallel and row-parallel is O(BSdff/T)O(B \cdot S \cdot d_{ff} / T) — TP=8 gives you 8× less activation memory in the FFN inner region. But the block's INPUT and OUTPUT are replicated, costing O(BSd)O(B \cdot S \cdot d) on every rank. For activation checkpointing math, the right number to count is the maximum, not the sum.
  5. Composing with FP8 needs care. The amax scaling used in FP8 (Section 10.3) is computed per-tensor — but with TP the "tensor" is sharded across ranks. A naive per-shard amax produces different scales on different ranks for what is logically one tensor; the matmul then silently loses precision. The fix is an all-reduce on amax before the FP8 cast — one extra (tiny) collective per TP-FP8 layer.
The deepest lesson. Data parallelism trades GPUs for throughput. Tensor parallelism trades intra-node bandwidth for the ability to run models that no single GPU can hold. Pipeline parallelism trades wall-clock time (bubbles) for the ability to run models that no single NODE can hold. None of them are free; all of them are necessary. The art of scaling training is composing the three so that no single bottleneck — HBM, NVLink, or InfiniBand — becomes the binding constraint. The next section unpacks pipeline parallelism; section 11.5 shows how DeepSeek's DualPipe algorithm bends the pipeline-bubble curve, and 11.7 explains why their final answer was to abandon TP entirely.
Loading comments...