Chapter 15
30 min read
Section 89 of 117

Implementing GRPO from Scratch

GRPO: Group Relative Policy Optimisation

The Real Problem: PPO's Critic Is the Bottleneck

By the end of Section 15.2 we derived GRPO as PPO with the value head replaced by an in-group baseline. In Section 15.3 we pinned the hyperparameters DeepSeek used for R1. We now turn theory into code. The question is narrow and practical: given a rollout of G responses per prompt with a scalar reward each, what is the exact tensor program that updates the policy?

The four-model dance of PPO — policy, value head, reference, reward — is a memory monster at scale. At 70B parameters each model is roughly 140 GB in bf16. The value head alone, plus its gradients, plus its AdamW state, costs the same as a second policy replica. For DeepSeek V3 at 671B the value head was simply not affordable on the cluster they had. GRPO is the answer that fell out of that constraint, and its implementation is what makes the savings real.

The implementation question, sharpened: PPO needs per-token advantages from a learned V(st)V(s_t). GRPO needs per-response advantages from a single batch statistic. The rest of the PPO machinery — the clipped surrogate, the KL penalty, the gradient clip — survives unchanged. The entire delta is roughly fifteen lines of code, but every one of them is the kind of line where a wrong axis or a missing mask quietly destroys a run.

Intuition: Compare Each Response to Its Siblings

Picture a class of four students answering the same hard math problem. Each turns in a different attempt. You grade them and obtain rewards r1,r2,r3,r4r_1, r_2, r_3, r_4. How do you tell each student whether their answer was good?

You could grade against an absolute rubric — this is PPO with a learned value head, where the rubric is a neural net that predicts “what reward would I expect from a typical attempt at this problem?” The trouble is that the rubric itself has to be learned, and at 70B+ the rubric is as expensive as the student.

Or you can grade relative to the other students in the class — this is GRPO. The baseline for student ii is the average reward of the other three. If you outscored the class average you did well; if you fell below it you did poorly; if everyone scored the same, the class learned nothing from this problem and we move on.

Why this works: we do not need an absolute estimate of expected reward to know which response in a group was better. The gradient only cares about the SIGN and the relative magnitude of the advantages within each group. Subtracting any group-constant baseline is theoretically a free variance-reduction trick (it does not bias the gradient); GRPO just uses the most informative group-constant available — the group's own mean.

The GRPO Objective in Equations

Given a prompt qq and a group of GG responses {o1,,oG}\{o_1, \dots, o_G\} sampled from the old policy πθold\pi_{\theta_{\text{old}}}, with scalar rewards r1,,rGr_1, \dots, r_G, the group-relative advantage of response ii is

Ai  =  rirˉσr,rˉ=1Gj=1Grj,σr=1Gj=1G(rjrˉ)2.\quad A_i \;=\; \frac{r_i - \bar r}{\sigma_r}, \qquad \bar r = \frac{1}{G}\sum_{j=1}^G r_j, \qquad \sigma_r = \sqrt{\frac{1}{G}\sum_{j=1}^G (r_j - \bar r)^2}.

Here AiA_i is a SINGLE scalar broadcast to every token of oio_i: for any token at position tt inside response oio_i, the advantage used in the gradient is Ai,t=AiA_{i,t} = A_i. There is no per-token credit assignment beyond what the PPO clip already does.

The per-token probability ratio is the standard off-policy correction:

ρi,t(θ)  =  πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t).\quad \rho_{i,t}(\theta) \;=\; \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,<t})}.

The clipped surrogate is PPO's, applied per token:

Lisurr(θ)  =  1oit=1oimin ⁣(ρi,tAi,  clip(ρi,t,1ε,1+ε)Ai).\quad L_i^{\text{surr}}(\theta) \;=\; \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min\!\big( \rho_{i,t}\,A_i, \; \text{clip}(\rho_{i,t},\,1-\varepsilon,\,1+\varepsilon)\,A_i \big).

The per-token KL penalty against the frozen reference policy uses Schulman's k3k_3 estimator, which is unbiased and sample-wise non-negative:

KL^i,t  =  exp ⁣(logπreflogπθ)    (logπreflogπθ)    1.\quad \widehat{\mathrm{KL}}_{i,t} \;=\; \exp\!\big( \log \pi_{\text{ref}} - \log \pi_\theta \big) \;-\; \big( \log \pi_{\text{ref}} - \log \pi_\theta \big) \;-\; 1.

Putting it together, vanilla GRPO maximises (equivalently, minimises the negative of)

JGRPO(θ)  =  1Gi=1G[Lisurr(θ)    β1oit=1oiKL^i,t].\quad \mathcal{J}_{\text{GRPO}}(\theta) \;=\; \frac{1}{G} \sum_{i=1}^{G} \Big[ L_i^{\text{surr}}(\theta) \;-\; \beta \cdot \tfrac{1}{|o_i|} \sum_{t=1}^{|o_i|} \widehat{\mathrm{KL}}_{i,t} \Big].

Symbols, in order: qq the prompt; GG the group size; oio_i the ii-th sampled response; rir_i its scalar reward; AiA_i its group-relative advantage; ρi,t\rho_{i,t} the per-token importance ratio of the live vs. old policy; ε\varepsilon the PPO clip range (0.2 in DeepSeek R1); πref\pi_{\text{ref}} the frozen reference (SFT) policy; β\beta the per-token KL coefficient (0.04 in DeepSeek R1).

Manual Numerical Walkthrough: One GRPO Update

We trace one GRPO step on a single prompt with G=4G = 4 short responses. Every intermediate number is shown; the toy is set up so that no clipping happens for most tokens and exactly one clip event happens for response 3.

Step 1 — rewards and group statistics

Rewards from the reward function: r=(0.9,0.3,0.1,0.7)r = (0.9,\, 0.3,\, -0.1,\, 0.7). Group mean and group standard deviation:

rˉ=(0.9+0.30.1+0.7)/4=0.45,\bar r = (0.9 + 0.3 - 0.1 + 0.7)/4 = 0.45, σr=(0.452+0.152+0.552+0.252)/4=0.14750.384.\sigma_r = \sqrt{(0.45^2 + 0.15^2 + 0.55^2 + 0.25^2)/4} = \sqrt{0.1475} \approx 0.384.

Step 2 — per-response advantages

responsereward r_ir_i - meanadvantage A_i
o_1+0.90+0.45+1.172
o_2+0.30-0.15-0.391
o_3-0.10-0.55-1.432
o_4+0.70+0.25+0.651

Notice the advantages sum to zero by construction. The most-rewarded response (o1o_1) gets the largest positive advantage; the least-rewarded (o3o_3) gets the largest negative advantage.

Step 3 — per-token ratios for one response

Take response 3, which has 2 tokens. Suppose the log-ratios (live vs old) are logρ3,1=+0.25,  logρ3,2=0.30\log\rho_{3,1} = +0.25,\;\log\rho_{3,2} = -0.30. Exponentiating: ρ3,11.284,  ρ3,20.741\rho_{3,1} \approx 1.284,\;\rho_{3,2} \approx 0.741. Both lie OUTSIDE the trust region [1ε,1+ε]=[0.8,1.2][1-\varepsilon,\,1+\varepsilon] = [0.8,\,1.2].

Step 4 — the asymmetric clip in action

With A3=1.432A_3 = -1.432:

tokenratiosurr1 = r·Aclip(r)·Amin (objective)clipped?
t=11.284-1.8381.2·(-1.432) = -1.718-1.838no (A<0, r>1+ε → full penalty)
t=20.741-1.0610.8·(-1.432) = -1.146-1.146yes (A<0, r<1-ε → bounded)

The mean objective for response 3 is (1.8381.146)/2=1.492(-1.838 - 1.146)/2 = -1.492. The clip did exactly what it should: it allowed the full penalty when the live policy over-assigned probability to a bad token (t=1t=1), but prevented an outsized penalty when the live policy was already moving away from the bad token (t=2t=2).

Step 5 — per-token KL (Schulman k3)

For response 1, token 1, suppose logπθ=0.476,  logπref=0.600\log\pi_\theta = -0.476,\;\log\pi_{\text{ref}} = -0.600. Then log(πref/πθ)=0.124\log(\pi_{\text{ref}}/\pi_\theta) = -0.124 and

KL^=e0.124(0.124)1=0.8834+0.1241=0.0074.\widehat{\mathrm{KL}} = e^{-0.124} - (-0.124) - 1 = 0.8834 + 0.124 - 1 = 0.0074.

Small, positive — the policy has drifted only slightly from the reference. Multiplied by β=0.04\beta = 0.04 the KL penalty per token is 0.0003\approx 0.0003 — a gentle leash, not a tight one.

Step 6 — aggregate into one scalar loss

Compute the per-response mean objective for each of the four responses (with ratios mostly near 1), negate, add βKL\beta \cdot \overline{\mathrm{KL}}, and average over the group. With ratios this close to 1 the loss is small — the policy will only start moving once enough gradient steps have shifted the per-token log-probs meaningfully.

Sanity reading of the numbers: on a fresh rollout every ratio is very close to 1, every KL term is close to 0, and the loss is small. That is healthy. A first-step loss with ρ1>0.1|\rho - 1| > 0.1 typically means the rollout and the trainer are computing log-probs differently — check sampling temperature, dtype, and the prompt/response mask alignment.

Interactive: The Group-Relative Advantage

Drag the reward sliders below to see how a group of four responses gets turned into four advantages. Notice three things in particular:

  1. If you align all four rewards to the same value every advantage collapses to zero. This group contributes nothing to the gradient — GRPO learns from disagreement inside a group, not from absolute reward.
  2. Shifting all four rewards up or down by the same amount leaves the advantages unchanged. Only the spread inside the group matters.
  3. Toggle off “divide by std” to see the Dr.GRPO variant (Section 15.5): the advantages now scale with the spread, which avoids vanilla GRPO's subtle bias toward low-variance groups.
Loading group-advantage visualiser…

Plain Python: GRPO Loss from Scratch

We re-implement the GRPO objective in pure NumPy on the toy rollout from the walkthrough above. No autograd, no transformer — just the loss arithmetic. Swap the hard-coded logits for a real transformer's output and the code is unchanged.

GRPO loss — pure NumPy
🐍grpo_loss_plain.py
18A group of G responses sampled from the OLD policy

📚 GRPO collects a GROUP of G responses to the same prompt by sampling the frozen old policy with a non-zero temperature. Here we use G = 4 toy responses with 2–3 tokens each. In a real run, G = 16 or 64 responses of up to 1024 tokens. actions[i] is the token IDs of response i; old_logp[i] is the log-probability the old policy assigned to each of those tokens at the moment of sampling. None of these tensors carries a gradient — they are precomputed rollout state.

EXAMPLE
G = 4, response lengths = [3, 2, 2, 2], total response tokens = 9
19Reference log-probabilities for the per-token KL

📚 ref_logp is the log-probability the FROZEN reference (SFT) model assigns to the same tokens. GRPO adds a per-token KL penalty against this reference, exactly like RLHF-PPO does. The reference model is loaded once at the start of training and never updated; its log-probs can be precomputed during the rollout phase to avoid re-forwarding it during the optimisation phase.

EXAMPLE
ref_logp[0] = [-0.60, -0.71, -0.82] (one entry per response-1 token)
22Scalar rewards — one number per WHOLE response

📚 Crucial difference from PPO: GRPO has no per-token reward and no value head. The reward function (rule-based for math, a learned RM for general chat, or a judge LLM for open-ended tasks) consumes the ENTIRE response and returns one scalar. The same scalar becomes the advantage of EVERY token in that response — credit assignment at the token level is delegated to PPO's clip mechanism and to the group-normalisation step on the next line.

EXAMPLE
rewards = [0.9, 0.3, -0.1, 0.7] — one scalar per response, not per token
32Group-relative mean and std — GRPO's signature trick

📚 PPO uses a learned V(s) to predict the expected reward, then defines advantage = reward − V(s). GRPO does NOT learn a V(s). Instead it forms the baseline ON-THE-FLY from the OTHER responses in the same group. mean_r is the group's average reward (the baseline); std_r is the group's standard deviation (the scale). The +1e-8 in std_r is the canonical 'never divide by zero' guard for degenerate groups where every response received the same reward.

EXAMPLE
mean_r = (0.9 + 0.3 - 0.1 + 0.7) / 4 = 0.45; std_r = 0.384
33Per-response advantage — one scalar broadcast over all tokens

📚 A_per_response[i] = (rewards[i] - mean_r) / std_r. This is a vector of length G. The advantage of response 1 (reward 0.9, above the group mean) is positive — every token of o_1 will be encouraged. The advantage of response 3 (reward -0.1, below the group mean) is the most negative — every token of o_3 will be discouraged. If all responses in the group earn the same reward, A_per_response becomes the zero vector and the group contributes nothing to the gradient.

EXAMPLE
A_per_response = [(0.9-0.45)/0.384, (0.3-0.45)/0.384, (-0.1-0.45)/0.384, (0.7-0.45)/0.384] = [+1.172, -0.391, -1.432, +0.651]
41Live-policy logits at the same response tokens

📚 For each response, new_logits[i] has shape (T_i, vocab). In a real run we re-forward the response tokens through the live policy AFTER each gradient step — the logits drift as the policy trains. We hard-code plausible logits here so the walkthrough is reproducible. The vocab here is just 3 actions; in a real LLM the vocab is 32k–256k.

EXAMPLE
new_logits[0].shape = (3, 3) — three tokens, three-action vocab
60Outer loop — one set of GRPO terms per response in the group

📚 GRPO is per-response, then per-token within each response. We iterate over the G responses and compute three things for each: per-token ratios r_t, the clipped surrogate, and the per-token KL k3 estimator. Because each response can have a different length T_i, padding-free per-response loops are clearer than the (B, S) tensorised version. The PyTorch version below recovers the tensorised form with masks.

62Numerically stable log-softmax — the standard log-sum-exp trick

📚 Subtract the per-row max before exponentiating to keep exp() in a safe range, then take log of the normalised sum. Without this, raw logits above ~30 overflow exp() to inf and downstream values become NaN. log_pall[t, v] is now log pi_new(action v | state at token t) for every position of response i.

EXAMPLE
logits row [1.10, -0.20, 0.30] → log-probs row [-0.476, -1.776, -1.276] (sum-exp = 1 by construction)
64Gather log pi_new at the actions that were ACTUALLY taken

📚 log_pall has shape (T_i, vocab). We pick out log_pall[t, actions[i][t]] for every t — exactly one log-prob per token. After this line, new_logp_i has shape (T_i,) and stores the new policy's log-probability of the token the old policy sampled at each position.

EXAMPLE
actions[0] = [0, 2, 1]  →  new_logp_i = [log_pall[0,0], log_pall[1,2], log_pall[2,1]] = [-0.476, -0.176, -0.376]
70Probability ratio — the heart of off-policy correction

📚 r_t = pi_new(a_t | s_t) / pi_old(a_t | s_t) = exp(log pi_new − log pi_old). Compute the difference in LOG space first (numerically safe), then exponentiate the small scalar. r_t = 1 means the live policy is unchanged at this (s, a); r_t > 1 means it now assigns this action MORE probability than the old policy did; r_t < 1 means LESS. GRPO uses this ratio inside the same PPO clipped surrogate.

EXAMPLE
old_logp[0][0] = -0.50, new_logp_i[0] = -0.476  →  log_ratio[0] = 0.024  →  ratio[0] = exp(0.024) = 1.024
75The PPO clipped surrogate — exact same shape as PPO

📚 GRPO inherits PPO's clipped surrogate verbatim. surr1 = r_t · A is the unclipped objective; surr2 = clip(r_t, 1-ε, 1+ε) · A bounds the ratio inside the trust region BEFORE multiplying by A. The crucial difference from PPO is that A is a single scalar broadcast over every token of the response — the SAME number for every t in o_i. That makes A constant inside the inner loop and reduces the per-token computation to just the ratio and the clip.

EXAMPLE
For o_1 (A=+1.172): ratio=[1.024, 1.105, 1.052]  →  surr1=[1.200, 1.295, 1.232]; all ratios inside [0.8, 1.2] so surr2 = surr1; obj = surr1.
78Take the MIN — PPO's pessimistic bound that GRPO inherits

📚 element-wise: obj[t] = min(surr1[t], surr2[t]). For A > 0 the min caps how much the policy can be rewarded for already-probable good actions; for A < 0 the min caps how much the policy can be penalised for already-rare bad actions. Same asymmetric safety brake as PPO — no math here is GRPO-specific. What IS GRPO-specific is that A came from the group mean, not from a value head.

EXAMPLE
For o_3 (A=-1.432) with ratio=[1.284, 0.741]: surr1=[-1.838, -1.061]; surr2=[1.2·-1.432, 0.8·-1.432]=[-1.718, -1.146]; min = [-1.838, -1.146]. Token 0 is unclipped (A<0, r>1+ε → full penalty); token 1 is clipped (A<0, r<1-ε → bounded penalty).
82Schulman k3 unbiased KL estimator — the standard RLHF KL

📚 The naive sample-based KL estimator log_ratio = log_new − log_ref is unbiased but has high variance. Schulman's k3 estimator k3 = exp(−log_ratio) − (−log_ratio) − 1 is also unbiased AND has the property of being non-negative for every sample (the naive one can be negative on individual samples). DeepSeek R1, OpenAI's RLHF code, and trl all use k3. log_r_ref here is log(pi_ref) − log(pi_new), so the formula matches the standard form.

EXAMPLE
log_new_i[0] = -0.476, log_ref_i[0] = -0.60  →  log_r_ref = -0.124  →  k3 = exp(-0.124) - (-0.124) - 1 = 0.0078
87Per-response mean of objective and KL — token-uniform weighting

📚 Vanilla GRPO averages the surrogate and the KL over the tokens of EACH response, then averages those response-level means over the group. The .mean() inside the loop is the per-response token average. We negate obj (PPO objective is a thing to MAXIMISE; loss is the negative) and add the KL term with coefficient BETA_KL. BETA_KL = 0.04 is the DeepSeek R1 default. Stacking these per-response losses gives one number per response that we will aggregate next.

EXAMPLE
For o_1: obj.mean() = (1.200+1.295+1.232)/3 = 1.242; kl_tok.mean() ≈ 0.0035  →  per_response_loss[0] = -1.242 + 0.04 * 0.0035 = -1.242
94Outer aggregate — average over the group of G responses

📚 The final scalar loss is the uniform mean over the per-response losses. Every response counts the same regardless of length. This is the vanilla GRPO objective from the DeepSeekMath paper. Dr.GRPO (Section 15.5) keeps the inner token-mean but changes how short and long responses get weighted at this outer step — short responses dominate vanilla GRPO because their per-token surrogate is averaged over fewer tokens.

EXAMPLE
per_response_loss = [-1.242, +0.402, +1.492, -0.651]  →  loss = mean ≈ 0.000 (advantages cancel by construction, ratios are ~1 → tiny first-step loss is expected)
94 lines without explanation
1"""
2GRPO objective — pure NumPy, no autograd, no value head.
3
4For one PROMPT we collected a GROUP of G responses with the OLD policy.
5Each response has been scored by a REWARD function (rule-based, judge,
6or learned RM) into a single scalar. That scalar advantage will be
7broadcast over every token of that response.
8
9The 'policy' here is a single-token categorical over a 3-action vocab
10so every number is readable. Swap in a transformer logits head and
11the loss code is byte-for-byte what GRPO trainers use in production.
12"""
13
14import numpy as np
15
16# ---------------------------------------------------------------------------
17# 1. One prompt, G = 4 responses sampled from the OLD policy.
18#    Each response is a token sequence; rewards is one scalar per response.
19# ---------------------------------------------------------------------------
20
21#               o_1            o_2          o_3       o_4
22actions = [[0, 2, 1],       [1, 0],      [2, 2],   [0, 1]]
23old_logp = [
24    [-0.50, -0.61, -0.92],  [-1.20, -0.41],  [-0.81, -0.69], [-0.51, -1.10],
25]
26ref_logp = [
27    [-0.60, -0.71, -0.82],  [-1.10, -0.51],  [-0.91, -0.79], [-0.41, -1.00],
28]
29rewards = np.array([0.9, 0.3, -0.1, 0.7])   # one scalar per response
30G = len(rewards)
31
32# ---------------------------------------------------------------------------
33# 2. Group-relative advantage — GRPO's signature move.
34#    Replace V(s) with the in-group mean and std of rewards.
35#    A_i is the SAME number for every token of response i.
36# ---------------------------------------------------------------------------
37
38mean_r = rewards.mean()
39std_r  = rewards.std() + 1e-8
40A_per_response = (rewards - mean_r) / std_r          # shape (G,)
41
42# ---------------------------------------------------------------------------
43# 3. The NEW (live) policy's logits at the SAME states. In a real run we
44#    re-forward each response's tokens through the live model after every
45#    gradient step. Here we hard-code plausible "slightly drifted" logits.
46# ---------------------------------------------------------------------------
47
48new_logits = [
49    np.array([[ 1.10, -0.20,  0.30],     # o_1, token 0
50              [-0.40,  0.80,  1.60],     # o_1, token 1
51              [ 0.10,  1.40,  0.20]]),   # o_1, token 2
52    np.array([[-0.20,  1.10,  0.40],     # o_2, token 0
53              [ 0.90, -0.30,  0.10]]),   # o_2, token 1
54    np.array([[ 0.10,  0.20,  1.30],     # o_3, token 0
55              [-0.10,  0.40,  1.50]]),   # o_3, token 1
56    np.array([[ 1.20, -0.40,  0.10],     # o_4, token 0
57              [-0.30,  1.50,  0.20]]),   # o_4, token 1
58]
59
60# ---------------------------------------------------------------------------
61# 4. Per-token GRPO terms for each response.
62# ---------------------------------------------------------------------------
63
64EPSILON   = 0.2          # PPO-style clip range
65BETA_KL   = 0.04         # per-token KL coefficient (DeepSeek R1 default)
66
67per_response_loss = []
68per_response_kl   = []
69
70for i in range(G):
71    # log pi_new at each token via stable log-softmax
72    z         = new_logits[i] - new_logits[i].max(axis=-1, keepdims=True)
73    log_pall  = z - np.log(np.exp(z).sum(axis=-1, keepdims=True))
74    new_logp_i = log_pall[np.arange(len(actions[i])), actions[i]]   # (T_i,)
75
76    old_logp_i = np.array(old_logp[i])
77    ref_logp_i = np.array(ref_logp[i])
78
79    # Probability ratio in log space, then exponentiate.
80    log_ratio = new_logp_i - old_logp_i
81    ratio     = np.exp(log_ratio)
82
83    # PPO clipped surrogate — A is the SAME scalar across all tokens of o_i.
84    A         = A_per_response[i]
85    surr1     = ratio * A
86    surr2     = np.clip(ratio, 1.0 - EPSILON, 1.0 + EPSILON) * A
87    obj       = np.minimum(surr1, surr2)
88
89    # Schulman k3 unbiased KL estimator vs the REFERENCE policy.
90    # k3 = exp(log_ref - log_new) - (log_ref - log_new) - 1   >= 0
91    log_r_ref = ref_logp_i - new_logp_i
92    kl_tok    = np.exp(log_r_ref) - log_r_ref - 1.0
93
94    # Per-RESPONSE mean: equal weight to every token of o_i.
95    per_response_loss.append(-obj.mean() + BETA_KL * kl_tok.mean())
96    per_response_kl.append(kl_tok.mean())
97
98# ---------------------------------------------------------------------------
99# 5. Final scalar loss = uniform mean over the GROUP of responses.
100#    (Vanilla GRPO — every response gets the same weight regardless of
101#    length. Dr.GRPO removes the per-response length normalisation; we
102#    show that variant in the engineering-reality section below.)
103# ---------------------------------------------------------------------------
104
105loss = np.mean(per_response_loss)
106
107print(f"advantages   = {A_per_response.round(3)}")
108print(f"per-resp KL  = {[round(k, 4) for k in per_response_kl]}")
109print(f"loss         = {loss:.4f}")

PyTorch: A Production-Shaped GRPO Step

Now the full thing: tensorised across the (B, G, S) layout that every production GRPO trainer uses, with the group-relative normalisation as a single mean/std along the G axis, gradient clipping, and the diagnostics that go straight onto a training dashboard. Drop in any HuggingFace causal LM for policy and a frozen copy of the SFT checkpoint for ref_policy and this trains a real reasoning model.

GRPO step — clipped surrogate + per-token KL
🐍grpo_step_torch.py
14The three hyperparameters that define a GRPO run

📚 EPSILON = 0.2 is the PPO clip range that GRPO inherits unchanged. BETA_KL = 0.04 is the per-token KL coefficient against the frozen reference (SFT) policy — DeepSeek R1 used exactly this value. MAX_GRAD = 1.0 is the universal grad-norm clip default; lift it and the run will eventually NaN on a long sequence. Notice what is MISSING from this list versus PPO: no VF_COEF (no value head to weight), no ENT_COEF (GRPO does not use an entropy bonus by default — the in-group reward diversity provides exploration pressure).

18The GRPO step signature — minibatch, two models, one optimizer

📚 Only TWO models live in memory: 'policy' (trainable, being updated) and 'ref_policy' (frozen SFT model, used only to compute ref_logp during rollout). PPO needed FOUR (policy + value + reference + reward model). Eliminating the value head is the entire reason GRPO exists — at 70B+ that saves ~140 GB of weights, ~140 GB of value-head gradients, and ~280 GB of AdamW optimiser state, for a total of ~560 GB of GPU memory per replica. That is the practical reason DeepSeek picked GRPO for R1.

32Three-dimensional rollout layout — (B, G, S) is the GRPO shape

📚 PPO operates on (B, S). GRPO inserts a G axis: B prompts × G responses per prompt × S tokens per response. This layout makes the group-relative computation a single mean/std along dim=1. action_mask is 1 on the response tokens and 0 everywhere else (prompt, padding). rewards has only two dims (B, G) — there is one scalar per response, not per token. Getting this shape right is the most common GRPO bug: the (B, G) vs (B*G,) collapse breaks group normalisation.

EXAMPLE
B = 8 prompts, G = 16 responses, S = 1024 tokens  →  input_ids.shape = (8, 16, 1024); rewards.shape = (8, 16)
41Group-relative mean — across the G responses of EACH prompt

📚 dim=1 is the G axis. mean(dim=1, keepdim=True) gives a (B, 1) tensor: one mean per prompt, not one mean across the entire batch. This is the critical distinction from PPO's advantage normalisation, which is a SINGLE mean/std across the whole minibatch. GRPO's mean is LOCAL to the prompt; it is the prompt's own baseline. A single all-reduce is enough to make this distributed-safe, but ONLY along dim=1 within each rank — the per-prompt mean must not be cross-prompt averaged.

EXAMPLE
rewards.shape = (8, 16); grp_mean.shape = (8, 1)  →  prompt 0's baseline is the mean of its own 16 responses
42Group-relative std — the per-prompt scale

📚 .std(dim=1, keepdim=True).clamp(min=1e-8) handles the degenerate case where all G rewards on a prompt are identical (the reward function returned the same number for every response, e.g. all wrong on a math problem). Without the clamp, std would be zero and the advantage would be inf/NaN. The clamp lets such a group contribute exactly zero gradient — the advantages all collapse to ~0 — which is the desired behaviour.

EXAMPLE
For a math prompt where 16 of 16 responses are wrong (reward=0): mean=0, std≤1e-8, every advantage ≈ 0, the group contributes no gradient (as it should — there is nothing to learn from)
44Broadcast the response-level advantage over all tokens

📚 adv has shape (B, G). We .unsqueeze(-1) to (B, G, 1), then .expand to (B, G, S). expand is a view-only operation in PyTorch — no new memory is allocated, the last dimension just reads the same value S times. After this, adv[b, g, t] is the SAME scalar for every t — exactly what GRPO requires. (If you replace expand with repeat you allocate B·G·S floats; in production runs this matters.)

EXAMPLE
adv before: shape (8, 16); after unsqueeze + expand: shape (8, 16, 1024); adv[0, 3, :] is one repeated scalar = (rewards[0,3] - grp_mean[0]) / grp_std[0]
49Forward the live policy on the FLATTENED (B*G) batch

📚 The transformer forward expects a 2-D (batch, seq) input. We flatten the (B, G) groups into B*G rows for the forward pass, then reshape the logits back to (B, G, S-1, V). use_cache=False is mandatory during training — a KV-cache breaks gradient flow. The logits slice [:, :-1, :] is the standard 'predict t+1 from <=t' alignment. For a 70B policy with V=128k, S=1024, B=8, G=16 the logits tensor is ~16 GB in bf16; many implementations rematerialise the logits via gradient checkpointing to halve this.

EXAMPLE
input_ids.view(8*16, 1024) = (128, 1024)  →  out.logits = (128, 1024, 128256)  →  reshape to (8, 16, 1023, 128256)
58log_softmax → gather: per-token log pi_new at the sampled actions

📚 Two-step standard. log_softmax in fp32 (PyTorch will automatically up-cast for stability) gives the full distribution at every position. gather with the response-token IDs picks out one log-prob per (b, g, t). After this line new_logp has shape (B, G, S-1) and stores the LIVE policy's log-probability of every response token. This same tensor is reused in the ratio, the surrogate, and the KL — one forward pass, three uses.

61log-space ratio first, exponentiate second — never the other way

📚 log_ratio = log pi_new − log pi_old is bounded and numerically safe; small differences stay small. Exponentiating the difference gives the actual ratio centred around 1. Computing ratio = exp(log_new) / exp(log_old) directly can underflow both numerator and denominator to zero on long sequences (the log-probs themselves can be -100 or worse) and you lose precision exactly where ratio matters most — near 1.0.

EXAMPLE
Typical first-step values: log_ratio in [-0.05, +0.05]  →  ratio in [0.951, 1.051]
64The clipped surrogate — torch.clamp + torch.min, element-wise

📚 surr1 = ratio * adv (unclipped); surr2 = clamp(ratio, 1-ε, 1+ε) * adv (clipped). torch.min picks the lower (more pessimistic) of the two element-wise. This is the same PPO clipped surrogate from §14.6, applied per token, with the advantage broadcast from the response level. The asymmetric safety brake still holds: for A > 0 we cap the upside (cannot reward a runaway ratio); for A < 0 we cap the downside (cannot reward a collapsing ratio).

EXAMPLE
Worked example: A = +1.172, ratio = 1.105 (in range)  →  surr1 = 1.295, surr2 = 1.295, obj = 1.295.  A = -1.432, ratio = 1.284 (above +ε)  →  surr1 = -1.838, surr2 = -1.718, obj = min = -1.838 (full penalty, unclipped)
70Schulman k3 KL estimator — log(ref/new) form

📚 log_r_ref = log pi_ref − log pi_new = log(pi_ref / pi_new). The k3 estimator k3 = exp(log_r_ref) − log_r_ref − 1 is unbiased and always non-negative — desirable because the naive log-ratio estimator can be negative on individual samples (only unbiased in expectation). DeepSeek R1, OpenAI's RLHF code, trl, and OpenRLHF all use this exact form. The KL is per-token and is summed/masked together with the policy loss in the next block.

EXAMPLE
If log_ref = -0.60, log_new = -0.476 (model slightly more confident than reference): log_r_ref = -0.124  →  k3 = exp(-0.124) - (-0.124) - 1 = 0.883 + 0.124 - 1 = 0.0073
75Token-count normalisation per response — the vanilla GRPO mean

📚 Sum the objective and the KL over the response's tokens (with the action_mask multiplied in to zero out prompt/pad positions), then divide by that response's true token count. This is per-RESPONSE mean weighted equally over tokens. Critically, the divisor is the SUM of the mask, not S — getting that count right is the single biggest pitfall, because a wrong divisor silently changes the effective KL coefficient by a factor of (S / token_count) and the run looks fine until step ~500.

EXAMPLE
For a response of true length 87 inside an S=1024 padded slot: tok_count = 87 (not 1024); obj_per_resp = (obj * am).sum() / 87
80Group + batch aggregation — uniform mean over (B, G)

📚 .mean() over the (B, G) tensor weights every response equally. Vanilla GRPO does this exact uniform mean — every prompt and every response inside each group contribute the same. The Dr.GRPO variant (Section 15.5) keeps this same outer mean but moves the per-response length normalisation away from a 1/T_i divisor to a 1/S_max divisor, removing a subtle bias that pushes vanilla GRPO toward longer responses. We discuss that one-line change in the engineering-reality block below.

85Joint loss — policy + per-token KL, NO value loss, NO entropy bonus

📚 Compare to PPO §14.6 where the loss had four terms (policy + value + KL + entropy). GRPO has TWO. No value loss because there is no value head. No entropy bonus because the in-group reward variation already supplies exploration signal — every response in the group is a different rollout. Removing these two terms is what makes GRPO simpler to tune than PPO; the KL coefficient and the clip range are essentially the only knobs that matter.

EXAMPLE
Typical first-step values: policy_loss ≈ -0.05, kl_loss ≈ 0.001  →  loss ≈ -0.05 + 0.04*0.001 = -0.04996
89One gradient step + global grad-norm clip

📚 loss.backward() populates parameter gradients (only on the policy — ref_policy was forwarded with no_grad during rollout and never appears in this graph). clip_grad_norm_ caps the global L2 norm at MAX_GRAD = 1.0 — without this, a single noisy minibatch can throw the policy into a divergent region from which it never recovers. optimizer.zero_grad(set_to_none=True) is preferred over zero_grad() because it frees the gradient tensors rather than zeroing them in place, saving one tensor's worth of memory per parameter.

95Diagnostics — what every GRPO dashboard plots per step

📚 approx_kl is the simple log-ratio estimate of KL between new and old policies (the off-policy distance) — distinct from kl_loss above which is KL to the REFERENCE. clipfrac is the fraction of tokens whose ratio fell outside [1-ε, 1+ε]; the standard healthy range is 0.05–0.25. adv_std is the standard deviation of the broadcast advantage tensor across response tokens — useful for noticing when groups are all the same (adv_std collapses to ~0 → the model has saturated this prompt distribution). Every production GRPO run logs these three plus the joint loss every step.

95 lines without explanation
1"""
2PyTorch GRPO step — the inner loop of every modern verifier-based RL
3trainer (DeepSeek R1, Olmo 3, Qwen-Math-RL, AceMath-RL).
4
5A rollout is a list of PROMPTS, each rolled out to G responses with
6the OLD policy (frozen snapshot). For each token of each response we have:
7   - old_log_prob   : log pi_old(a_t | s_t)            (no gradient)
8   - ref_log_prob   : log pi_ref(a_t | s_t)            (no gradient)
9For each response we have:
10   - reward         : scalar from the reward function
11GRPO turns rewards into advantages by group-mean / group-std within each
12prompt's group, then runs PPO's clipped surrogate token-wise.
13"""
14
15import torch
16import torch.nn.functional as F
17
18EPSILON   = 0.2          # PPO clip range  (DeepSeek R1: 0.2 lo, 0.28 hi)
19BETA_KL   = 0.04         # per-token KL to reference policy
20MAX_GRAD  = 1.0          # global grad-norm clip
21
22def grpo_step(batch, policy, ref_policy, optimizer):
23    """
24    batch contains tensors laid out as (B, G, S) where:
25        B = number of prompts in the minibatch
26        G = number of sampled responses per prompt
27        S = max response length (right-padded with zeros)
28
29    Fields:
30        input_ids       : (B, G, S)   prompt + response tokens
31        attn_mask       : (B, G, S)
32        action_mask     : (B, G, S)   1 on RESPONSE tokens, 0 on prompt/pad
33        old_logp        : (B, G, S)   log pi_old at each token
34        ref_logp        : (B, G, S)   log pi_ref at each token
35        rewards         : (B, G)      one scalar per response
36    """
37    input_ids   = batch["input_ids"]
38    attn_mask   = batch["attn_mask"]
39    action_mask = batch["action_mask"]
40    old_logp    = batch["old_logp"]
41    ref_logp    = batch["ref_logp"]
42    rewards     = batch["rewards"]
43
44    B, G, S = input_ids.shape
45
46    # ---- 1. Group-relative advantage ------------------------------------
47    # Mean and std are computed PER PROMPT, across the G responses.
48    grp_mean = rewards.mean(dim=1, keepdim=True)                # (B, 1)
49    grp_std  = rewards.std (dim=1, keepdim=True).clamp(min=1e-8)
50    adv      = (rewards - grp_mean) / grp_std                   # (B, G)
51    adv      = adv.unsqueeze(-1).expand(B, G, S)                # broadcast
52
53    # ---- 2. Forward the LIVE policy on the response tokens --------------
54    # Flatten the (B, G) groups into a single batch for the forward pass.
55    ids_flat = input_ids.view(B * G, S)
56    am_flat  = attn_mask.view(B * G, S)
57    out      = policy(input_ids=ids_flat, attention_mask=am_flat,
58                      use_cache=False)
59    logits   = out.logits[:, :-1, :].view(B, G, S - 1, -1)      # predict t+1
60    targets  = input_ids[:, :, 1:]                              # (B, G, S-1)
61    am       = action_mask[:, :, 1:].float()
62    adv      = adv[:, :, 1:]
63
64    log_probs_all = F.log_softmax(logits, dim=-1)
65    new_logp = log_probs_all.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
66
67    # ---- 3. Clipped surrogate -------------------------------------------
68    log_ratio = new_logp - old_logp[:, :, 1:]
69    ratio     = log_ratio.exp()
70
71    surr1 = ratio * adv
72    surr2 = torch.clamp(ratio, 1 - EPSILON, 1 + EPSILON) * adv
73    obj   = torch.min(surr1, surr2)
74
75    # ---- 4. Per-token KL to the reference policy (k3) -------------------
76    log_r_ref = ref_logp[:, :, 1:] - new_logp                   # log(ref/new)
77    kl_tok    = log_r_ref.exp() - log_r_ref - 1.0               # >= 0
78
79    # ---- 5. Per-response means, then GROUP mean -------------------------
80    # Sum within each response, divide by the response's token count.
81    tok_count = am.sum(dim=-1).clamp(min=1.0)                   # (B, G)
82    obj_per_resp = (obj * am).sum(dim=-1) / tok_count           # (B, G)
83    kl_per_resp  = (kl_tok * am).sum(dim=-1) / tok_count        # (B, G)
84
85    # Vanilla GRPO: uniform mean across the G responses, then across B.
86    policy_loss = -(obj_per_resp).mean()
87    kl_loss     =  (kl_per_resp).mean()
88
89    loss = policy_loss + BETA_KL * kl_loss
90
91    # ---- 6. One gradient step -------------------------------------------
92    loss.backward()
93    gn = torch.nn.utils.clip_grad_norm_(policy.parameters(), MAX_GRAD)
94    optimizer.step()
95    optimizer.zero_grad(set_to_none=True)
96
97    # ---- 7. Diagnostics --------------------------------------------------
98    with torch.no_grad():
99        approx_kl = ((old_logp[:, :, 1:] - new_logp) * am).sum() / am.sum()
100        clipfrac  = (((ratio - 1.0).abs() > EPSILON).float() * am).sum() / am.sum()
101        adv_std   = adv[am.bool()].std()
102
103    return {
104        "loss":       loss.detach(),
105        "pg_loss":    policy_loss.detach(),
106        "kl_ref":     kl_loss.detach(),
107        "approx_kl":  approx_kl.detach(),
108        "clipfrac":   clipfrac.detach(),
109        "adv_std":    adv_std.detach(),
110        "grad_norm":  gn.detach(),
111    }

At Massive Scale: What Changes for a 70B+ Run

Every line of code above scales to a 70B-parameter GRPO run, but three things change in spirit.

Memory: GRPO's biggest gift to production training

At 70B parameters in bf16, the policy and the reference together are ~280 GB of weights. Adding a value head (PPO) doubles weights to ~420 GB, plus another ~140 GB of value-head gradients, plus ~280 GB of AdamW optimiser state for the value head — an extra ~560 GB relative to GRPO. On an 8×H100 node with 80 GB per GPU, that extra 560 GB is the difference between a run that fits with FSDP sharding and a run that requires a second node just for the value head. This is why DeepSeek picked GRPO for R1 and why every post-R1 reasoning recipe (Qwen-Math-RL, AceMath-RL, Olmo 3) followed suit.

The (B, G) axis lives across ranks

At scale, the natural FSDP sharding is to split BGB \cdot G across data-parallel ranks (every rank holds the same parameters and a different slice of prompts) and shard the parameters across model-parallel ranks. The group-relative mean/std becomes a careful primitive: each rank computes partial sum and partial sum-of-squares along the G axis for its own prompts; an all-reduce is NOT needed for the mean and std as long as every prompt's entire group is on the same rank. The most common GRPO bug at scale is splitting a single prompt's group across ranks — each rank then sees a different baseline, the effective advantage scale varies across ranks, and the run becomes noisy in a way that takes a week to diagnose. Every production GRPO library has unit tests that assert all G responses for a prompt live on the same rank.

Sampling-vs-training imbalance is even worse than PPO

GRPO requires GG× more rollout tokens per prompt than PPO did. With G=16G = 16 and responses of length 1024 tokens that is 16 384 tokens per prompt to sample before a single gradient step can happen. At 500 tok/s/GPU with vLLM, even a 1024-prompt minibatch takes minutes to roll out. Production GRPO infrastructure runs the rollout cluster (many GPUs, prioritised for inference throughput with vLLM and continuous batching) entirely separately from the training cluster (FSDP, fewer GPUs, prioritised for memory bandwidth) and ships generations over the network. Weights sync after each optimisation phase.

Reward model bottleneck disappears for verifier rewards

DeepSeek R1 uses RULE-BASED rewards for math and code (the verifier is the unit test or the symbolic check), which is essentially free to compute. For tasks that still need a learned RM, the same rollout/score/train split as PPO applies, and the RM becomes the new cost bottleneck.

Engineering Reality: The Knobs That Break Runs

After enough GRPO runs every team converges on the same set of failure modes and the same set of high-leverage fixes.

  • All-equal-reward groups dominate the batch. Symptom: loss is flat, gradient norm collapses, advantage std drops near zero. Cause: for hard math/code prompts every response in a group is wrong (reward = 0) so every advantage is zero. The group contributes nothing. Fix: oversample these prompts with higher temperature, or skip them entirely from the gradient and log them separately for later curriculum work. DeepSeek R1 explicitly filters such groups from the loss.
  • Group split across ranks. Symptom: loss noise is much higher than a single-GPU debug run. Cause: a careless DataLoader put 8 of a prompt's 16 responses on rank 0 and 8 on rank 1. The baseline computed on rank 0 differs from rank 1 and the advantages are inconsistent. Fix: a custom sampler that guarantees prompt_id % num_ranks == rank_id for the entire group. This is a one-time setup that prevents weeks of pain.
  • Length bias toward longer responses. Symptom: mean response length climbs steadily over training even though rewards do not require length. Cause: vanilla GRPO's per-response 1/oi1/|o_i| normalisation gives long responses a smaller per-token loss; the policy can game this by lengthening responses to dilute the gradient. Fix: switch to Dr.GRPO's 1/Smax1/S_{\max} normalisation (Section 15.5), or apply an explicit length penalty inside the reward function.
  • KL coefficient too small. Symptom: reward climbs for 200 steps then the policy degenerates (repeats a single phrase, drops English, emits gibberish that scores high on a mis-calibrated RM). Cause: insufficient KL pressure against the reference SFT model. Fix: raise β\beta by 2× and restart from a pre-collapse checkpoint. Many teams use adaptive KL (raise β\beta when realised KL exceeds a target, lower it when below).
  • Rollout temperature mismatch. Symptom: ratio distribution is centred far from 1.0 on step 0 even with the freshly-cloned policy. Cause: the rollout sampled with temperature=1.0 but the trainer computes logπnew\log\pi_{\text{new}} from softmax(logits)\text{softmax}(\text{logits}). The two distributions are different and the ratio is meaningless. Fix: apply the SAME temperature inside the trainer when computing per-token log-probs, or sample with temperature=1.0 throughout.
  • Action mask off by one. Symptom: training looks fine, reward inches up, but eval shows no change from the SFT model. Cause: the mask includes prompt tokens (which carry no gradient by construction since their log-probs are not part of the decision), or excludes the first response token. Fix: print one example's mask alongside its decoded tokens before launch and eyeball that mask=1 lines up with the response tokens only. The off-by-one between input_ids[:, :, :-1] and input_ids[:, :, 1:] is the most common variant.
  • Old log-probs not refreshed between PPO epochs. Symptom: ratio drifts far from 1 after the second optimisation epoch on the same rollout, clipfrac spikes. Cause: GRPO typically uses 1 epoch per rollout for exactly this reason, but if you set PPO-style 4 epochs you must NOT re-use the old_logp from epoch 1 on epoch 4 — the live policy has moved. Fix: either stick with 1 epoch (the DeepSeek R1 choice) or refresh old_logp at the start of every epoch.
  • Gradient norm explodes after a few hundred steps. Symptom: g|g| climbs from 0.5 to 50 in 100 steps. Cause: either KL coefficient too small (policy escaping to extreme outputs) or a reward model bug that gives unboundedly large rewards on some inputs. Fix: raise β\beta or clip the reward distribution at the source.
  • Reference accidentally on the same fork as policy. Symptom: KL goes to zero, policy never moves. Cause: the “reference” was initialised as a Python reference to the same tensors as the live policy. Fix: assert id(ref_policy) != id(policy) at startup; load the reference from a separate checkpoint.

The Mental Model That Unifies This Section

GRPO is PPO with one substitution and two deletions. Substitution: replace At=RtV(st)A_t = R_t - V(s_t) with the group-normalised reward Ai=(rirˉ)/σrA_i = (r_i - \bar r)/\sigma_r, broadcast over every token of response ii. Deletion 1: remove the value head, its gradient, and its optimiser state (saving ~3 model-replicas worth of memory at scale). Deletion 2: remove the entropy bonus (the per-group reward diversity replaces it as an exploration signal). Everything else — the clipped surrogate, the per-token KL to the reference, the gradient clip, the diagnostics — is the PPO code unchanged. Get the (B, G, S) shape right, keep every group on one rank, and a 70B reasoning model can be RL-trained for a thousand steps on a single 8×H100 node.

With this section closed, Chapter 15 is complete. The next chapter opens DeepSeek R1-Zero — the pure-RL experiment that ran this exact GRPO loop with rule-based math rewards and discovered emergent chain-of-thought reasoning along the way.

Loading comments...