Chapter 15
15 min read
Section 84 of 117

The Critic Bottleneck in PPO

GRPO: Group Relative Policy Optimisation

The Real Problem: PPO Costs Two Models

Proximal Policy Optimization is the algorithm that made RLHF practical. OpenAI used it for InstructGPT and ChatGPT; Anthropic and Meta used variants of it for Claude and LLaMA-Chat; nearly every commercial LLM deployed between 2022 and 2024 owes its alignment to a PPO loop. It works. It is stable. It is well-understood. And it has a problem that only becomes obvious when you try to scale it to a 70B parameter policy: PPO trains two models at the same time.

The first model is the one you actually want — the policy πθ\pi_{\theta}, which generates the responses. The second is the critic VϕV_{\phi}, which predicts how much reward each intermediate state is going to accumulate. PPO needs the critic because it relies on a quantity called the advantage AtA_t — how much better the action at step tt was than the critic expected — to do the variance-reduction trick that makes the gradient learnable. No advantage, no PPO.

In the original Atari and MuJoCo papers, the critic was tiny: a two-layer MLP feeding off a shared CNN trunk. Free, basically. In LLM RLHF the critic is a copy of the language model with a scalar head bolted on. A 70B policy comes with a 70B critic. Mixed precision and AdamW push the per-parameter training cost to ~16 bytes, so the critic alone is over 1 TB of GPU memory before you have stored a single activation. That is the bottleneck this section is named for.

What this section is going to prove. The critic is not a polite optional ingredient of PPO — it is structurally necessary for the algorithm as written. Removing it changes the math. The next section (15.2) shows how GRPO replaces it with a group baseline that achieves the same variance-reduction effect without the model. Before we can earn that result, we have to understand exactly what the critic is doing and exactly what it costs.

Intuition: The Coach Who Must Be as Smart as the Player

Imagine a chess match. The player (the policy) makes moves. After the whole game, an external judge says "you won" or "you lost" — that is the reward signal from the reward model. The player wants to learn from this verdict, but there is a problem: the verdict applies to the entire game, not to any individual move. Was the queen sacrifice on move 12 brilliant or terrible? The scalar "you won" cannot tell us.

Plain policy gradient handles this by being naïve: it credits every move in a winning game with the full reward. This is unbiased but absurdly noisy — most of the moves had nothing to do with the win. Training is technically possible but takes oceans of data and is prone to wild swings.

PPO solves it by hiring a coach — the critic. After every move, the coach whispers: "you are now in a position I'd expect to win 73% of the time." The next move's coach-estimate is 81%. So the policy now has a clean signal: that move was a +8% surprise, much smaller variance than a blunt "you won". Every move now gets its own credit assignment, derived from how much it shifted the coach's opinion.

The catch: the coach has to actually understand the game. A coach who knows nothing about chess gives noisy whispers, and the whole scheme falls apart. For chess this is fine — winning probability has been studied for centuries and a small model can estimate it. For an LLM completing an arbitrary prompt under an arbitrary reward model, "winning probability from this partial response" is a problem on the same complexity scale as generating the response. The coach needs to be roughly as smart as the player.

The single sentence that motivates GRPO. For an LLM, the value function is approximately as hard to learn as the policy itself — so PPO is paying for two language models when you only get to ship one of them.

GRPO's insight (which we will derive properly in 15.2) is that there is another way to get a baseline: sample the same prompt many times, then use the group's mean reward as the baseline for each individual sample. A response that scored above the group mean gets positive advantage; below mean gets negative. No critic, no value function, no second neural network. The variance reduction is comparable, and the memory cost is gone.

The Mathematics of PPO and Its Critic

We need three equations to see the bottleneck. First, the policy gradient with a baseline:

θJ(θ)=Eτπθ[t=0Tθlogπθ(atst)At]\nabla_{\theta} J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ \sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t \mid s_t) \cdot A_t \right]

Here τ\tau is a trajectory (in LLM-land, a full response of TT tokens), ata_t is the action at step tt (a sampled token), and sts_t is the state (the prompt plus all tokens generated so far). The advantage AtA_t is what gets multiplied with the score function. Without a baseline it would just be the cumulative reward from tt onward — a random variable with enormous variance for sparse, terminal rewards.

The baseline-corrected advantage subtracts a state-conditional estimate of expected return, called the value function:

Vπ(st)=Eτπ,s0=st[k=0Ttγkrt+k]V^{\pi}(s_t) = \mathbb{E}_{\tau \sim \pi, s_0 = s_t} \left[ \sum_{k=0}^{T-t} \gamma^k r_{t+k} \right]

This is the quantity the critic VϕV_{\phi} is trained to predict. With it in hand, we can build the generalized advantage estimator (GAE) of AtA_t:

At=k=0Tt(γλ)kδt+k,δk=rk+γVϕ(sk+1)Vϕ(sk)A_t = \sum_{k=0}^{T-t} (\gamma \lambda)^k \delta_{t+k}, \quad \delta_k = r_k + \gamma V_{\phi}(s_{k+1}) - V_{\phi}(s_k)

The term δk\delta_k is the temporal-difference (TD) error: the difference between the immediate reward plus the critic's estimate of the next state and the critic's estimate of the current state. GAE compounds these errors with discount factor γ\gamma (typically 1 for LLMs) and an interpolation knob λ[0,1]\lambda \in [0, 1] (typically 0.95). At λ=0\lambda = 0 the advantage reduces to a single TD error (low variance, biased by an inaccurate critic); at λ=1\lambda = 1 it collapses to the Monte-Carlo return minus Vϕ(st)V_{\phi}(s_t) (unbiased but high variance).

Finally, the PPO clipped surrogate loss:

LPPO(θ)=Et[min ⁣(ρtAt,clip(ρt,1ϵ,1+ϵ)At)],ρt=πθ(atst)πθold(atst)\mathcal{L}_{\text{PPO}}(\theta) = -\mathbb{E}_{t} \left[ \min\!\left( \rho_t A_t,\, \mathrm{clip}(\rho_t, 1 - \epsilon, 1 + \epsilon) A_t \right) \right], \quad \rho_t = \frac{\pi_{\theta}(a_t \mid s_t)}{\pi_{\theta_{\text{old}}}(a_t \mid s_t)}

And alongside it, the value-function loss that trains the critic:

LV(ϕ)=Et[(Vϕ(st)(At+Vϕold(st)))2]\mathcal{L}_V(\phi) = \mathbb{E}_t \left[ \left( V_{\phi}(s_t) - (A_t + V_{\phi_{\text{old}}}(s_t)) \right)^2 \right]

The full PPO step is a weighted sum: L=LPPO+cvLVchH\mathcal{L} = \mathcal{L}_{\text{PPO}} + c_v \mathcal{L}_V - c_h \mathcal{H} with cv0.5c_v \approx 0.5 and an entropy bonus H\mathcal{H} weighted by ch0.01c_h \approx 0.01. Note where ϕ\phi appears in this system: in AtA_t (which enters the policy loss), in LV\mathcal{L}_V, and implicitly in the bootstrap target via VϕoldV_{\phi_{\text{old}}}. If you delete ϕ\phi, you have to provide an alternative for every one of these — which is exactly what GRPO does.

Why is there no critic in DPO? Direct Preference Optimization (Section 14.4) bypasses both the reward model and the critic by reparameterising the RLHF objective so it can be optimized directly on preference pairs. It is a different escape from the same bottleneck. GRPO keeps the reward model (so it works for sparse verifiable rewards in reasoning) but drops only the critic — a narrower, more surgical removal that retains everything PPO was good at.

Manual Numerical Walkthrough

Let us compute one PPO step end-to-end on a trajectory of 6 tokens, by hand, to see exactly where the critic enters every line of the calculation.

Manual PPO step on a 6-token trajectory

Step 0 — Setup. We have a single response of 6 tokens. The reward model assigns a sparse, terminal reward: r=[0,0,0,0,0,1.2]r = [0, 0, 0, 0, 0, 1.2]. The old policy log-probs (from rollout time) and the current log-probs (after a couple of gradient steps) are: logπold=[1.20,0.80,1.10,0.90,1.30,1.00]\log \pi_{\text{old}} = [-1.20, -0.80, -1.10, -0.90, -1.30, -1.00] and logπθ=[1.05,0.95,1.15,0.70,1.35,0.85]\log \pi_{\theta} = [-1.05, -0.95, -1.15, -0.70, -1.35, -0.85]. The critic predicts V=[0.62,0.71,0.55,0.68,0.74,0.80]V = [0.62, 0.71, 0.55, 0.68, 0.74, 0.80]. Take γ=1\gamma = 1, λ=0.95\lambda = 0.95.

Step 1 — Bootstrap value. For each step we need V(st+1)V(s_{t+1}). Shift V left and append 0 (terminal): Vnext=[0.71,0.55,0.68,0.74,0.80,0.00]V_{\text{next}} = [0.71, 0.55, 0.68, 0.74, 0.80, 0.00].

Step 2 — TD errors. δt=rt+Vnext(t)V(t)\delta_t = r_t + V_{\text{next}}(t) - V(t). Token 0: 0+0.710.62=0.090 + 0.71 - 0.62 = 0.09. Token 1: 0+0.550.71=0.160 + 0.55 - 0.71 = -0.16. Token 2: 0+0.680.55=0.130 + 0.68 - 0.55 = 0.13. Token 3: 0+0.740.68=0.060 + 0.74 - 0.68 = 0.06. Token 4: 0+0.800.74=0.060 + 0.80 - 0.74 = 0.06. Token 5 (terminal): 1.2+00.80=0.401.2 + 0 - 0.80 = 0.40. So δ=[0.09,0.16,0.13,0.06,0.06,0.40]\delta = [0.09, -0.16, 0.13, 0.06, 0.06, 0.40].

Step 3 — Run GAE backwards. With γλ=0.95\gamma\lambda = 0.95:

  • A5=0.40A_5 = 0.40
  • A4=0.06+0.950.40=0.440A_4 = 0.06 + 0.95 \cdot 0.40 = 0.440
  • A3=0.06+0.950.440=0.478A_3 = 0.06 + 0.95 \cdot 0.440 = 0.478
  • A2=0.13+0.950.478=0.584A_2 = 0.13 + 0.95 \cdot 0.478 = 0.584
  • A1=0.16+0.950.584=0.395A_1 = -0.16 + 0.95 \cdot 0.584 = 0.395
  • A0=0.09+0.950.395=0.465A_0 = 0.09 + 0.95 \cdot 0.395 = 0.465

So A=[0.465,0.395,0.584,0.478,0.440,0.400]A = [0.465, 0.395, 0.584, 0.478, 0.440, 0.400]. Notice that every token has a positive advantage even though only the last token actually received a reward. That is GAE doing its job — credit propagated backward through the critic's estimates.

Step 4 — Importance ratio. ρt=exp(logπθlogπold)\rho_t = \exp(\log\pi_\theta - \log\pi_{\text{old}}). Differences: [0.15,0.15,0.05,0.20,0.05,0.15][0.15, -0.15, -0.05, 0.20, -0.05, 0.15]. Exponentiate: ρ[1.162,0.861,0.951,1.221,0.951,1.162]\rho \approx [1.162, 0.861, 0.951, 1.221, 0.951, 1.162]. Every ratio sits inside the 0.8–1.2 clip window, so clipping is inactive for this trace — a clean PPO update.

Step 5 — Policy loss. LPPO=16tρtAt\mathcal{L}_{\text{PPO}} = -\tfrac{1}{6} \sum_t \rho_t A_t. Per-token products: [0.540,0.340,0.555,0.584,0.418,0.465][0.540, 0.340, 0.555, 0.584, 0.418, 0.465]. Sum 2.902, mean 0.484, so LPPO0.484\mathcal{L}_{\text{PPO}} \approx -0.484 (a negative loss means the policy is improving in the direction of the advantage — exactly what we want).

Step 6 — Value loss. Returns Rt=At+VtR_t = A_t + V_t: R=[1.085,1.105,1.134,1.158,1.180,1.200]R = [1.085, 1.105, 1.134, 1.158, 1.180, 1.200]. Squared error per step (VtRt)2(V_t - R_t)^2: [0.216,0.156,0.341,0.228,0.194,0.160][0.216, 0.156, 0.341, 0.228, 0.194, 0.160]. Mean: 0.216. So LV0.216\mathcal{L}_V \approx 0.216. The critic is consistently under-predicting the returns at every step — it's late to the party, which is typical critic behaviour and what the value loss is trying to correct.

Step 7 — Total loss. L=LPPO+0.5LV0.484+0.108=0.376\mathcal{L} = \mathcal{L}_{\text{PPO}} + 0.5 \cdot \mathcal{L}_V \approx -0.484 + 0.108 = -0.376. Two networks contributed gradients to this number. Two networks received parameter updates. Two networks need to be evaluated on the next rollout. Multiply by every PPO step in a training run, then multiply by the parameter count of a frontier policy.

Step 8 — What GRPO will replace. Steps 1–3 (the entire GAE machinery) and step 6 (the value loss) all go away. In GRPO we would instead sample GG responses from the same prompt, give each its terminal reward, normalize those rewards across the group, and use the normalized scalar directly as AtA_t for every token of that response. Steps 4 (ratio) and 5 (clipped surrogate) survive unchanged.

Visualizing the Critic Bottleneck

The ledger below stacks the GPU memory PPO consumes during one training step. Trainable weights (policy and critic) cost ~16 bytes per parameter under mixed-precision AdamW: 2 bytes for the fp16 forward copy, 2 for the fp16 gradient, 4 for the fp32 master copy, and 4 + 4 for the AdamW first- and second-moment buffers. Frozen models (reference policy and reward model) cost only the 2 bytes for fp16 forward. Toggle PPO ↔ GRPO to delete the critic and watch its share of the bill — typically the largest share — disappear.

Loading PPO memory ledger…

Three things to internalize from the widget. First, at every model size the critic and its optimizer state together account for more than the policy itself. That is because the policy and critic are the same architecture (so their fp16 weights are equal), but the optimizer state for the critic is added on top. Second, the saving is not 25% — it is closer to 50% of the trainable budget and ~40% of the total budget once frozen models and activations are counted. Third, doubling the H100 count to host the critic is just the memory story; the compute story (FLOPs per step) and the comms story (gradient all-reduce volume) also double. The bottleneck is multi-dimensional.

The numbers are approximations, but the gap is real. Production frameworks add 5–15% overhead for kernel buffers, NCCL staging, ZeRO bucketing, and gradient accumulation. Different teams use slightly different critic sizes (some use a 50%-sized critic, some use a shared trunk with two heads). What does not vary across teams is the qualitative observation: PPO is roughly twice as expensive as a method that does not require a value function.

Plain Python: PPO With a Tiny Critic

Here is the full PPO step in NumPy. Six tokens, hand-coded GAE, the clipped surrogate loss, and the value loss. The numbers from this script reproduce the manual walkthrough above to four decimal places.

🐍ppo_with_critic.py
11Sparse, terminal reward — the LLM RLHF setting

An RLHF episode is a full response of T tokens. The reward model scores only the completed response, so r_t = 0 for t < T-1 and r_{T-1} is the scalar reward. This sparsity is the entire reason we need a value function: pure REINFORCE would assign that one scalar reward to every token with no notion of which tokens were responsible, producing a gradient with crippling variance.

EXECUTION STATE
T = 6
rewards = [0, 0, 0, 0, 0, 1.2]
Σ rewards = 1.2 (the reward-model score)
17Old vs new log-probs — the importance ratio

PPO collects a rollout under the policy parameters that existed at the start of the epoch (old_logp), then does several gradient steps on that same data. After a few steps, the current policy assigns different probabilities to the same actions; the ratio between current and old is exp(new − old). PPO clips this ratio to prevent the inner-loop policy from drifting too far from the data-generating distribution — the same idea will reappear in GRPO unchanged.

EXECUTION STATE
old_logp = [-1.20, -0.80, -1.10, -0.90, -1.30, -1.00]
new_logp = [-1.05, -0.95, -1.15, -0.70, -1.35, -0.85]
26V_phi(s_t) — the critic's per-step value prediction

This is the array we are going to delete in GRPO. V[t] is the critic's estimate of the expected return from state s_t under the current policy. In LLM RLHF, computing V[t] for every token of every response in a batch means running a second forward pass through a transformer that is typically the same size as the policy. The size of this array is small; the cost is in producing it.

EXECUTION STATE
V = [0.62, 0.71, 0.55, 0.68, 0.74, 0.80]
V_next = [0.71, 0.55, 0.68, 0.74, 0.80, 0.00]
37deltas — the TD-error per step

δ_t = r_t + γ·V(s_{t+1}) − V(s_t) is the one-step temporal-difference error: how much better (or worse) state s_t turned out to be than the critic predicted. Without V_phi we cannot compute δ at all. In our trace, the early deltas are nearly zero (the critic thinks each successive state is about as valuable as the last); the terminal delta is +0.40 (reward 1.2 minus V_T = 0.80 minus 0 bootstrap).

EXECUTION STATE
deltas = [0.09, -0.16, 0.13, 0.06, 0.06, 0.40]
39Backward GAE recursion — variance reduction in one loop

GAE compounds the TD errors from the end of the trajectory backwards: A_T = δ_T, A_{T-1} = δ_{T-1} + γλ·A_T, and so on. The λ knob trades bias for variance — λ=1 recovers Monte Carlo returns (unbiased, high-variance), λ=0 recovers single-step TD (biased, low-variance). This recursion is the value function's payoff: it gives every action a low-variance credit assignment that the bare reward could never provide.

EXECUTION STATE
advantages = computed in-place via reversed loop
λ = 0.95
44returns = advantages + V — the critic's training target

The value loss trains V_phi to predict GAE returns. This is a self-referential bootstrap: the critic's prediction shapes the advantage, which shapes the policy, which generates the next rollout, which generates the next critic target. When this loop is stable, PPO converges. When it's not — when the critic lags behind the policy or saturates on early-trajectory tokens — PPO diverges in spectacular ways.

EXECUTION STATE
returns = advantages + V (per-step value targets)
54The clipped surrogate objective

PPO's signature: compute the importance-weighted advantage two ways — unclipped (ratio · A) and clipped (clip(ratio, 1±ε) · A) — and take the minimum. When A > 0 (good action), the clip caps how much we push the probability up; when A < 0 (bad action), it caps how much we push it down. The min(·) ensures we never benefit from clipping when it would help us — we only ever lose from it, which is what makes the update conservative.

EXECUTION STATE
ε = 0.2
ratio = exp(new_logp - old_logp)
60Two losses, one optimizer — the joint training step

policy_loss updates θ (the policy weights) using PPO; value_loss updates φ (the critic weights) using MSE against the GAE returns. In a real implementation these are usually two separate networks (or two heads on one trunk) with separate optimizer states. The 0.5 coefficient on value_loss is the standard Schulman setting; production code also subtracts an entropy bonus to keep the policy from collapsing.

EXECUTION STATE
policy_loss = -mean(min(ratio·A, clip(ratio)·A))
value_loss = mean((V - returns)²)
65 lines without explanation
1import numpy as np
2
3# A toy "language model" episode.
4# In real PPO each timestep is one decoded token; here we use 6 timesteps
5# of a single trajectory so we can compute every advantage by hand.
6#
7# rewards[t] is the per-token reward the reward model assigns at step t.
8# In an LLM RLHF run, rewards[t] = 0 for every step except the last, and
9# rewards[T-1] is the scalar score from the reward model on the full
10# response. This sparsity is exactly why we need a value function in PPO.
11
12T = 6
13rewards = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 1.2])  # reward at end of sequence
14
15# Old policy log-probs (collected during rollout) and current log-probs
16# (recomputed after we have started updating theta during the PPO inner loop).
17# In the very first inner-loop step these are identical; the ratio drifts as
18# we take gradient steps. Here we fake a small drift to make the clip visible.
19old_logp = np.array([-1.20, -0.80, -1.10, -0.90, -1.30, -1.00])
20new_logp = np.array([-1.05, -0.95, -1.15, -0.70, -1.35, -0.85])
21
22# Critic predictions V_phi(s_t) for each timestep -- this is the *whole reason*
23# we need a separate value-function network. In a real LLM run, V_phi is
24# produced by a separate 7B/70B transformer with a scalar head.
25V = np.array([0.62, 0.71, 0.55, 0.68, 0.74, 0.80])
26V_next = np.append(V[1:], 0.0)  # bootstrap with 0 at terminal state
27
28# --- Generalized Advantage Estimation (GAE), Schulman et al. 2016 ----------
29# A_t = delta_t + (gamma * lambda) * delta_{t+1} + (gamma * lambda)^2 * ...
30# delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
31#
32# GAE interpolates between high-bias TD(0) (lambda=0) and high-variance
33# Monte Carlo (lambda=1). The critic V_phi appears twice: once to compute
34# the bootstrap target r + gamma * V(s'), and once as the baseline V(s)
35# we subtract to get the advantage. Without V_phi there is no GAE.
36
37gamma = 1.0     # LLM RLHF typically uses gamma = 1 (full credit)
38lam   = 0.95    # GAE lambda; standard choice
39
40deltas = rewards + gamma * V_next - V
41advantages = np.zeros(T)
42gae = 0.0
43for t in reversed(range(T)):
44    gae = deltas[t] + gamma * lam * gae
45    advantages[t] = gae
46
47# The PPO value target is what the critic is *trained* to predict next step.
48returns = advantages + V
49
50# --- The PPO clipped objective ---------------------------------------------
51# r_t(theta) = exp(new_logp_t - old_logp_t)            # importance ratio
52# L_t        = min( r_t * A_t,  clip(r_t, 1-eps, 1+eps) * A_t )
53# The critic gets a separate MSE loss against the returns:
54# L_value    = (V_phi(s_t) - returns_t)^2
55# Total loss = -mean(L_t) + c_v * mean(L_value)  -  c_h * entropy
56
57epsilon = 0.2
58ratio   = np.exp(new_logp - old_logp)
59unclipped = ratio * advantages
60clipped   = np.clip(ratio, 1 - epsilon, 1 + epsilon) * advantages
61ppo_per_t = np.minimum(unclipped, clipped)
62
63policy_loss = -ppo_per_t.mean()
64value_loss  =  ((V - returns) ** 2).mean()
65total_loss  =  policy_loss + 0.5 * value_loss
66
67print(f"deltas      = {np.round(deltas, 3)}")
68print(f"advantages  = {np.round(advantages, 3)}")
69print(f"returns     = {np.round(returns, 3)}")
70print(f"ratio       = {np.round(ratio, 3)}")
71print(f"policy_loss = {policy_loss:.4f}")
72print(f"value_loss  = {value_loss:.4f}")
73print(f"total_loss  = {total_loss:.4f}")

Two implementation details to lock into memory. First, GAE is computed by a backward recursion over time, not a forward one. This is because the discount factor compounds away from the terminal state — only after we know At+1A_{t+1} can we compute AtA_t. Forgetting this and writing a forward loop is one of the most common bugs in homegrown PPO implementations. Second, the value loss is computed against Rt=At+VtR_t = A_t + V_t, not against the raw cumulative reward. The critic is being trained to predict a quantity that itself depends on the critic's current predictions. It is a bootstrap, and bootstraps can diverge when the policy is improving rapidly — which is precisely when you most want stable training.

PyTorch: The Real PPO Loss

Now the production shape. Two transformers — one with a vocab-sized output head (the policy), one with a scalar head (the critic). Identical layer counts, identical hidden sizes, identical attention. The only architectural difference is the final linear layer.

🐍ppo_transformer_critic.py
22Policy and critic are identical transformers up to one head

The TransformerBlock is shared; Policy has a vocab-sized output head, Critic has a scalar head. Strip the heads and they are the same network. In a real LLM run, this means: same number of layers, same hidden size, same attention heads, same KV cache requirements, same activation memory. Every microbatch must flow forward through both, and the critic must also backprop. Doubling the trainable params is doubling everything that depends on trainable params — optimizer states, gradients, ZeRO bucket sizes, NCCL all-reduce volume.

EXECUTION STATE
Policy params = ~25M (4 layers × 256 hidden)
Critic params = ~25M (same architecture)
35The scalar head — one line that doubles your bill

nn.Linear(D, 1) is the only structural difference between policy and critic. Replacing it with nn.Linear(D, V_VOCAB) would make them indistinguishable. This is the literal cost of advantage estimation: you pay for an entire second network so that this one final layer can produce a per-token scalar that GAE uses as a baseline. GRPO's claim is that this scalar is unnecessary if you have multiple rollouts from the same prompt — the group mean does the same variance-reduction job for free.

EXECUTION STATE
head = nn.Linear(256, 1)
55Two forward passes per training step

In a vanilla policy-gradient method you forward the policy and backward the policy. In PPO you forward policy AND critic, then backward both. For a 70B policy with a 70B critic this means: 2× FLOPs per step, 2× activation memory (because both have to keep activations for backward), 2× pipeline/DP communication volume. At the scale of frontier RLHF this is the dominant per-step cost — and it's a cost you eat for *every gradient step of every epoch of every prompt*.

EXECUTION STATE
logits = shape (B, T, V) — policy output
values = shape (B, T) — critic output
65GAE in PyTorch: identical recursion, batched over B

The reversed-time loop is the same as the NumPy version — we just vectorize over the batch dimension. The critical detail is `torch.no_grad()`: GAE is computed with the critic in evaluation mode (its parameters are not part of the policy loss graph). The critic gets its gradient only from the MSE on line 80. Forgetting `no_grad` here is a classic bug: it routes the policy loss through the critic, scrambling both losses.

EXECUTION STATE
adv = shape (B, T) — per-token advantage
returns = shape (B, T) — critic targets
75Advantage normalisation — a band-aid the critic forces

Production PPO almost always normalises advantages per batch (mean 0, std 1) before the clipped loss. The reason is that the critic's predictions are typically miscalibrated at the start of training, producing advantages that span many orders of magnitude across a batch. This normalisation is a workaround for the critic being wrong. GRPO does an analogous step — divide by group standard deviation — but the quantity being normalised already has a meaningful baseline (group mean), so the normalisation is cleaner.

EXECUTION STATE
adv (post-normalise) = mean ≈ 0, std ≈ 1
86Joint loss — but with separate gradients in practice

We add the losses for brevity, but in real PPO RLHF the policy and critic are updated by separate optimizers, often at different learning rates, sometimes on different microbatches. The reason: the critic typically wants 4–10× the learning rate of the policy because it has to track a moving target (the changing policy generates a changing return distribution). Coupling them through one optimizer makes that asymmetry hard to express cleanly — yet another piece of engineering complexity that disappears in GRPO.

EXECUTION STATE
loss = policy_loss + 0.5 · value_loss
88 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5# The real PPO setup. We use a 4-layer transformer "policy" and an identical
6# "critic" -- in production LLM RLHF the critic is a full copy of the LM with
7# a scalar value head, which is the entire reason this section exists.
8
9D, H, V_VOCAB, N_LAYERS = 256, 4, 32000, 4
10
11class TransformerBlock(nn.Module):
12    def __init__(self):
13        super().__init__()
14        self.attn = nn.MultiheadAttention(D, H, batch_first=True)
15        self.mlp  = nn.Sequential(nn.Linear(D, 4*D), nn.GELU(), nn.Linear(4*D, D))
16        self.n1, self.n2 = nn.LayerNorm(D), nn.LayerNorm(D)
17    def forward(self, x):
18        x = x + self.attn(self.n1(x), self.n1(x), self.n1(x), need_weights=False)[0]
19        return x + self.mlp(self.n2(x))
20
21class Policy(nn.Module):
22    """The actor: outputs a distribution over the next token."""
23    def __init__(self):
24        super().__init__()
25        self.embed  = nn.Embedding(V_VOCAB, D)
26        self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
27        self.head   = nn.Linear(D, V_VOCAB, bias=False)
28    def forward(self, tokens):                          # (B, T)
29        h = self.embed(tokens)                          # (B, T, D)
30        for blk in self.blocks: h = blk(h)
31        return self.head(h)                             # (B, T, V) logits
32
33class Critic(nn.Module):
34    """The value function: scalar per token. SAME ARCHITECTURE as Policy."""
35    def __init__(self):
36        super().__init__()
37        self.embed  = nn.Embedding(V_VOCAB, D)
38        self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
39        self.head   = nn.Linear(D, 1)                   # <-- only difference
40    def forward(self, tokens):                          # (B, T)
41        h = self.embed(tokens)
42        for blk in self.blocks: h = blk(h)
43        return self.head(h).squeeze(-1)                 # (B, T) values
44
45policy, critic = Policy(), Critic()
46p_n = sum(p.numel() for p in policy.parameters())
47c_n = sum(p.numel() for p in critic.parameters())
48print(f"policy params: {p_n:>12,}")
49print(f"critic params: {c_n:>12,}")
50print(f"critic / policy = {c_n / p_n:.3f}")
51
52# --- One PPO step on a tiny batch ------------------------------------------
53B, T = 2, 16
54tokens     = torch.randint(0, V_VOCAB, (B, T))
55old_logp   = torch.randn(B, T)                          # collected during rollout
56rewards    = torch.zeros(B, T)
57rewards[:, -1] = torch.tensor([0.8, -0.3])              # terminal reward
58gamma, lam, eps = 1.0, 0.95, 0.2
59
60# 1) Forward both networks. THIS IS THE EXPENSIVE STEP IN A REAL RUN.
61logits = policy(tokens)                                 # (B, T, V)
62values = critic(tokens)                                 # (B, T)
63
64# 2) New log-probs of the actions actually taken.
65logp_new = F.log_softmax(logits, dim=-1).gather(
66    -1, tokens.unsqueeze(-1)
67).squeeze(-1)
68
69# 3) GAE -- vectorised backwards over time.
70with torch.no_grad():
71    v_next = torch.cat([values[:, 1:], torch.zeros(B, 1)], dim=1)
72    deltas = rewards + gamma * v_next - values
73    adv = torch.zeros_like(deltas)
74    running = torch.zeros(B)
75    for t in reversed(range(T)):
76        running = deltas[:, t] + gamma * lam * running
77        adv[:, t] = running
78    returns = adv + values
79    adv = (adv - adv.mean()) / (adv.std() + 1e-8)       # variance reduction
80
81# 4) PPO clipped surrogate loss.
82ratio = (logp_new - old_logp).exp()
83loss_unclipped = ratio * adv
84loss_clipped   = ratio.clamp(1 - eps, 1 + eps) * adv
85policy_loss    = -torch.min(loss_unclipped, loss_clipped).mean()
86
87# 5) Critic MSE loss against GAE returns.
88value_loss     = F.mse_loss(values, returns.detach())
89
90# 6) JOINT loss -- but in production these go to two separate optimizers.
91loss = policy_loss + 0.5 * value_loss
92loss.backward()
93print(f"policy_loss = {policy_loss.item():.4f}")
94print(f"value_loss  = {value_loss.item():.4f}")

Pay attention to the two forward passes through transformer stacks of the same size. In the toy version (4 layers, 256 hidden, batch of 2 × 16 tokens) the wall-clock difference is invisible. At the scale of a 70B policy, those two lines are the dominant per-step cost of RLHF — and they are unavoidable in PPO.

Sanity-check yourself. Print the parameter counts of policy and critic. They should be within 0.1% of each other (the head shape differs by a factor of Vvocab/1V_{\text{vocab}} / 1, but the head is a tiny fraction of total params). This near-equality is the bottleneck, expressed in a single number.

At Massive Scale: The 70B PPO Memory Bill

Let us put real numbers on the page. A 70B parameter policy trained with PPO RLHF, mixed-precision AdamW, ZeRO-3 sharded across a multi-node cluster of H100s:

ComponentBytes / paramTotal (70B)
Policy fp16 weights2140 GB
Critic fp16 weights2140 GB
Reference policy fp16 (frozen)2140 GB
Reward model fp16 (frozen)2140 GB
AdamW state (policy): fp32 master + (m,v) + grad fp1614980 GB
AdamW state (critic): fp32 master + (m,v) + grad fp1614980 GB
Activations (batch 8, seq 4096, recompute on)~~120 GB
NCCL buffers + framework overhead~~80 GB
Total~~2.7 TB

2.7 TB of GPU memory. An 8×H100 node provides 640 GB. ZeRO-3 shards the optimizer state, gradients, and parameters across the node, but the activation memory and frozen models are replicated per data-parallel replica. In practice a 70B PPO RLHF run uses 16–32 H100s as the smallest viable configuration, and most of those GPUs are sitting on bytes of critic state that exist for the sole purpose of producing a scalar baseline.

Now look at the same ledger for GRPO. The critic row disappears (140 GB saved). The critic's AdamW state disappears (980 GB saved). The total drops to ~1.6 TB — a 40% reduction, and it is the cleanest 40% you will ever cut from a training bill because it requires changing exactly one part of the algorithm and leaves the rest of the RLHF pipeline (reference policy, KL anchor, reward model, sampling code, sequence-packing utilities) completely untouched.

  1. Memory. ~40% lower per-replica footprint, allowing larger batches or smaller GPU pools.
  2. Compute. ~50% fewer FLOPs per training step (no critic forward, no critic backward).
  3. Communication. ~50% less gradient all-reduce volume, which is often the latency bottleneck on multi-node training.
  4. Engineering. One fewer model to checkpoint, monitor, debug, and tune. Eliminates the value-learning-rate hyperparameter, which was a perennial source of instability.

Engineering Reality: Three Failure Modes

The memory bill alone is enough to motivate GRPO, but the critic causes three more practical headaches that anyone who has tried to run PPO RLHF at scale will recognize.

1. The critic lags the policy

The critic is trained to predict the value function of the current policy. But the policy is changing every step. If the policy improves faster than the critic can keep up, the critic's predictions become systematically wrong, producing biased advantages that point the policy in the wrong direction. The classic symptom: training loss looks fine for several hundred steps, then the reward score collapses suddenly. The fix is to give the critic a higher learning rate than the policy, but the right ratio is dataset-dependent and is one of those hyperparameters you can only find by burning compute. GRPO does not have this knob because it does not have a critic.

2. The critic saturates on early-trajectory tokens

For LLM RLHF, the first few tokens of every response are nearly deterministic given the prompt (think: "Sure, I can help with that..."). The critic learns to predict their value with extreme confidence and a tiny gradient. The last few tokens, where the reward actually gets assigned, have much higher entropy — but the critic has spent its capacity on the easy beginnings. The result is poorly calibrated advantages near the end of the sequence, which is exactly where they matter most. Practitioners work around this with various attention-mask tricks and per-token loss weights. GRPO sidesteps it entirely by using a single sequence-level advantage.

3. Reward hacking through the critic

Because the critic is part of the gradient computation, the policy can — and in long enough training runs, will — find ways to manipulate the critic's predictions to inflate advantages without actually improving the reward. This is Goodhart's Law applied to value learning: the moment your proxy (the critic) becomes part of the optimization loop, it stops being a clean measurement. Detecting this is hard because the proxy reward keeps going up; the true reward (held-out evaluations) is what reveals the deception. GRPO is not immune to reward hacking, but it removes one entire surface — the critic — from the set of things the policy can learn to exploit.

The DeepSeek R1 result in one paragraph. DeepSeek-R1-Zero showed that you can train a frontier-class reasoning model with only GRPO on verifiable rewards — no SFT cold-start, no reward model, no critic. The model learns to produce long chains of thought, self-correct, and verify its own work, all through a training loop that fits in substantially less memory than the equivalent PPO setup. The next four sections of this chapter derive GRPO in detail, catalogue its hyperparameters, design rewards for it, and implement it from scratch. The critic bottleneck is the door we walked through to get here.

In Section 15.2 we will derive GRPO formally — show how the group-relative advantage falls out of the same policy-gradient objective, prove it is a valid baseline, and trace where every piece of PPO is replaced. The math, you will see, is shorter than what we just did.

Loading comments...