Chapter 14
30 min read
Section 82 of 117

PPO: The Standard RLHF Algorithm

Reward Modeling and RLHF

The Real Problem: A Reward Is Not a Loss

By the time you reach this section you already have a reward model (§ 14.3) that takes a (prompt, response) pair and returns a scalar — a single number that says how good the response was. You also have an SFT model (§ 13) that knows how to follow chat formatting. The question of this chapter is brutally simple: how do we use the reward model to actually train the SFT model into something better?

The naive answer feels obvious. We have a scalar score; just treat it as a loss and run backprop. Maximise reward = minimise negative reward, problem solved. This is wrong, and understanding why is the entire reason PPO exists.

Look at what stands between the model's parameters and the reward. The model produces logits. We sample a token from those logits (or take an argmax, or a top-p sample — every decoding strategy is sampling-shaped). We append the sampled token to the context and sample the next one. We repeat for hundreds of tokens until we have a full response. Only then does the reward model score the response. The reward depends on the response, the response depends on which tokens were sampled, and sampling is not differentiable. There is no chain rule from the reward back to the logits, because the argmax/categorical sampling operation breaks the graph.

This is the central obstacle: we have a scalar feedback signal, but the path from parameters to that signal goes through an undifferentiable step. Cross-entropy training side-steps this by always providing the correct token at each position — there is no sampling in the loss. RLHF cannot do that; the very thing we want to learn is which tokens the model should pick. The whole field of policy gradients was built to solve this one problem: how do you take a gradient through a non-differentiable sampling step?

And once we have a working policy gradient, a second problem immediately appears: even modest policy updates can drag the model wildly off the manifold of human-likely text. A single bad gradient step can turn a coherent SFT model into one that emits gibberish that happens to score well on the reward model — the famous failure mode called reward hacking. PPO is the algorithm that takes the policy gradient idea and adds two safety brakes — a clipped surrogate objective and a per-token KL penalty — to make it actually usable at scale.

The three jobs PPO has to do at once

One: push the policy in the direction of higher reward, despite sampling being non-differentiable. Two: never let a single update move the policy too far from where it started, so we can take several gradient steps on the same rollout data without going off-distribution. Three: stay close enough to the SFT model that the outputs remain human-like, even if the reward model could be tricked into rewarding non-human text. Every term in the PPO loss is one of these three jobs.

Intuition: Do More of What Works

Forget gradients for a moment. Imagine you sample 100 responses to a prompt from your model. The reward model scores each one. The intuition behind policy gradient is the simplest training rule in machine learning:

Take the responses that scored well, and increase their probability. Take the responses that scored badly, and decrease their probability.

That's it. That's the algorithm. Every formula in this section is a way to make that idea precise, low-variance, and safe enough to run for thousands of steps. The math will look intimidating — ratios, clips, advantages, GAE — but every symbol is just protecting that simple rule from one specific failure mode.

A useful physical analogy: imagine the policy as the centre of mass of a swarm of probability. The reward model is a landscape with peaks (good responses) and valleys (bad ones). At each step we sample a few points from the swarm, look up the landscape height under each, and gently nudge the swarm toward the higher points. The catch: if we nudge too hard the swarm fragments and slides off the manifold of sensible English text into a region where the reward model is unreliable. PPO is the leash on the swarm.

Mental model in one line: PPO = the policy gradient of REINFORCE + a baseline (so we know what counts as ‘good’) + a clip (so a single step can't move us too far) + a KL penalty to the SFT model (so we stay in the region of human-like text). Everything else is bookkeeping.

REINFORCE: The Naive Policy Gradient

Start with the simplest version that works in theory. The model is a policy πθ(as)\pi_\theta(a \mid s) — a probability distribution over actions aa given a state ss. For an LLM, the action at each timestep is the next token and the state is the context so far. We want to maximise the expected reward of trajectories sampled from πθ\pi_\theta:

J(θ)=Eτπθ[R(τ)]J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ R(\tau) \right]

Here τ=(s0,a0,s1,a1,)\tau = (s_0, a_0, s_1, a_1, \ldots) is a full trajectory and R(τ)R(\tau) is the scalar reward at the end. The problem we just identified — sampling being non-differentiable — looks like it blocks us from computing θJ\nabla_\theta J. The famous log-derivative trick rescues us:

θJ(θ)=Eτ[R(τ)t=0Tθlogπθ(atst)]\nabla_\theta J(\theta) = \mathbb{E}_{\tau} \left[ R(\tau) \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t \mid s_t) \right]

Read this slowly. The sampling step has vanished into the expectation; the only gradient that appears is the gradient of log probability of the chosen action, which is fully differentiable. The reward R(τ)R(\tau) is just a scalar weight that says ‘multiply this trajectory's log-probability gradient by R(τ)R(\tau)’. If RR is large and positive, we strongly increase the log-probability of the actions we took; if negative, we decrease. Exactly the ‘do more of what works’ rule, written as a gradient.

The proof of the log-derivative identity is two lines and worth knowing: θpθ=pθθlogpθ\nabla_\theta p_\theta = p_\theta \cdot \nabla_\theta \log p_\theta is the standard chain rule applied to logpθ\log p_\theta, and substituting that into the gradient of the expectation pulls pθp_\theta into a sample-weighted form. Every modern policy-gradient algorithm — REINFORCE, A2C, TRPO, PPO, GRPO — starts from this same identity.

In practice we never have the expectation — we have a finite batch of sampled trajectories. The Monte Carlo estimator is:

g^=1Ni=1NR(τi)tθlogπθ(atisti)\hat{g} = \frac{1}{N} \sum_{i=1}^{N} R(\tau_i) \sum_{t} \nabla_\theta \log \pi_\theta(a_t^i \mid s_t^i)

This is REINFORCE. It works. It also has a variance problem so bad that direct REINFORCE training of a 7B LLM with a typical reward signal will produce a useless model in under a hundred steps. Every subsequent idea — baselines, advantages, GAE, clipping — is a way to reduce that variance without biasing the estimator.

Baselines, Advantage, and GAE

The first variance reduction is the cleanest. Subtract a baseline b(s)b(s) from the reward inside the sum:

g^=1Nit(R(τi)b(sti))θlogπθ(atisti)\hat{g} = \frac{1}{N} \sum_i \sum_t \left( R(\tau_i) - b(s_t^i) \right) \nabla_\theta \log \pi_\theta(a_t^i \mid s_t^i)

Subtracting b(s)b(s) does not bias the estimator — the proof rests on Ea[θlogπθ(as)]=0\mathbb{E}_a[\nabla_\theta \log \pi_\theta(a \mid s)] = 0, a clean consequence of the policy summing to one — but a well-chosen baseline dramatically reduces variance. The optimal baseline is the expected reward under the current policy, which is the value function Vπ(s)=E[Rs0=s]V^\pi(s) = \mathbb{E}[R \mid s_0 = s]. We approximate VV with a small value head — a second neural network (often a single linear layer on top of the shared transformer backbone) trained alongside the policy with mean-squared error against the actual returns. The result of subtracting the value baseline is called the advantage:

A(st,at)=R(τ)V(st)A(s_t, a_t) = R(\tau) - V(s_t)

Read it as: “how much better than expected this action turned out to be at this state”. A positive advantage says ‘better than average; increase the probability of this action’. A negative advantage says ‘worse than average; decrease it’. The naming is precise: advantage measures the marginal value of taking ata_t versus the policy's default.

GAE: Generalised Advantage Estimation

For long sequences a single end-of-episode reward gives every timestep the same advantage, which is high-variance and ignores the fact that different tokens contributed differently to the outcome. Generalised Advantage Estimation (GAE) (Schulman 2015) interpolates between two extremes:

  1. Pure Monte Carlo — use the full return At=ltγltrlV(st)A_t = \sum_{l \geq t} \gamma^{l-t} r_l - V(s_t): low bias, high variance.
  2. One-step TD — use At=rt+γV(st+1)V(st)A_t = r_t + \gamma V(s_{t+1}) - V(s_t): high bias (depends on the value head being accurate), low variance.

GAE blends them via an exponential moving average with parameter λ\lambda:

A^tGAE=l=0(γλ)lδt+l,δt=rt+γV(st+1)V(st)\hat{A}_t^{\text{GAE}} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}, \quad \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)

With λ=0\lambda = 0 we recover one-step TD; with λ=1\lambda = 1 we recover Monte Carlo; intermediate values trade bias against variance. The PPO paper and every open-source RLHF library default to γ=1.0\gamma = 1.0 and λ=0.95\lambda = 0.95 for language tasks — they treat the per-token reward as zero except at the EOS token, where it equals the reward model's score. We will use those defaults without further comment for the rest of the chapter.

Why Big Policy Steps Are Catastrophic

Suppose we have a low-variance advantage estimate. Why not just take a big gradient step in its direction? Two reasons, and PPO addresses both.

Reason one: the rollout becomes stale. The advantages we computed were under the OLD policy. If a single gradient step changes the policy a lot, the next minibatch we sample from the same rollout is being graded with advantages that no longer match the current policy's behaviour. The objective starts rewarding actions that the new policy already takes more often, creating a positive feedback loop that diverges. This is the classical off-policy / importance-sampling problem: if you want to take K gradient steps on data sampled by the OLD policy you need to correct each step by the probability ratio rt(θ)=πθ(atst)/πθold(atst)r_t(\theta) = \pi_\theta(a_t \mid s_t) / \pi_{\theta_{\text{old}}}(a_t \mid s_t).

Reason two: the reward model is only locally trustworthy. The reward model was trained on responses from the SFT distribution. Push the policy 50 nats of KL away from that distribution and the reward model has no idea what to do — it will return whatever score its lookup table happens to spit out for that out-of-distribution text, and the policy will gleefully optimise for that nonsense score. This is the textbook RLHF failure: after a few thousand steps, the model emits responses like “Sure! I'd be happy to help with that great question!!!” on every prompt, because the reward model gave a slight upvote to enthusiasm and that turned into a runaway exploitation.

Both problems demand the same medicine: keep each update small in policy space, not just in parameter space. TRPO (the predecessor of PPO) does this with a hard constraint DKL(πθoldπθ)δD_{\text{KL}}(\pi_{\theta_{\text{old}}} \,\|\, \pi_\theta) \leq \delta and a conjugate-gradient solver to find the largest safe step. PPO's contribution was the observation that the same objective can be achieved with a much simpler trick: bound the importance ratio rtr_t directly, then take ordinary SGD steps and let the bound do the work.

The PPO Clipped Surrogate Objective

Define the per-step probability ratio:

rt(θ)=πθ(atst)πθold(atst)r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_{\text{old}}}(a_t \mid s_t)}

The importance-weighted policy gradient says we should maximise rtAtr_t \cdot A_t in expectation — bigger ratio on good actions, smaller on bad. PPO's clipped surrogate is one line:

LCLIP(θ)=Et[min(rtAt,  clip(rt,1ε,1+ε)At)]L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min\left( r_t A_t, \; \text{clip}(r_t, 1-\varepsilon, 1+\varepsilon) \, A_t \right) \right]

Read it piece by piece. The first argument of the min is the unclipped importance-weighted advantage. The second is the same thing with rtr_t first restricted to the interval [1ε,1+ε][1-\varepsilon, 1+\varepsilon] (typically ε=0.2\varepsilon = 0.2). The outer min takes the smaller of the two — the pessimistic one. The clip is the trust region; the min is the asymmetry that prevents reward hacking.

The asymmetry is the cleverest part. Consider a positive advantage:

  1. If rt<1+εr_t < 1 + \varepsilon, both terms equal rtAtr_t A_t. Min is rtAtr_t A_t. Normal gradient: increase πθ(at)\pi_\theta(a_t).
  2. If rt>1+εr_t > 1 + \varepsilon (we have already increased the probability beyond the clip), unclipped is rtAtr_t A_t (large positive), clipped is (1+ε)At(1+\varepsilon) A_t (capped). Min is the capped value, which is constant in θ\theta. Gradient is zero. The policy stops being pushed further on this action even though the advantage is positive.

Now consider a negative advantage:

  1. If rt>1εr_t > 1 - \varepsilon, both terms equal rtAtr_t A_t. Normal gradient: decrease πθ(at)\pi_\theta(a_t).
  2. If rt<1εr_t < 1 - \varepsilon, unclipped is rtAtr_t A_t (small in magnitude, negative — the ‘safe’ objective), clipped is (1ε)At(1-\varepsilon) A_t (larger in magnitude, more negative). Min picks the unclipped (smaller-magnitude negative). So we get gradient that continues pushing the probability down. The clip does NOT save you from a bad action once you have already cut its probability — you keep cutting.

This is the asymmetry: the clip only stops gradient when the policy is moving in the “rewarded” direction past the trust region. It never stops the gradient that's pulling away from a bad action. That is precisely the safety property we wanted: bound the speed at which the policy can chase reward, but let it freely run away from penalties.

Loading PPO clip visualiser…

Drag the sliders. With A=+1A = +1 and ε=0.2\varepsilon = 0.2, watch the orange line: it rises with the ratio up to 1.2, then goes flat for the rest of the domain — that flat region is the no-gradient zone. Flip the advantage to negative and the flat region jumps to the LEFT side of the band, while the right side of the curve becomes the active gradient zone. Shrink ε\varepsilon to 0.05 and the trust region narrows to almost nothing — the policy becomes almost frozen. Push it to 0.5 and the clip barely does anything. The choice ε=0.2\varepsilon = 0.2 is the value the community converged on after extensive ablations; it is the de facto standard for every open-source RLHF library.

What the clip does NOT do

The clip is not a constraint on parameters and not a constraint on gradient magnitude. It is a constraint on the importance ratio, which is a function of the probabilities. Two networks with very different parameter values can have identical ratios; the clip cares about behaviour, not weights. This is also why the clip cannot replace gradient clipping (max_norm=1.0 is still mandatory) — they protect against different failure modes.

Manual Numerical Walkthrough: One PPO Update

Take a rollout of four timesteps from a three-action policy. The old policy assigned the following log-probabilities to the actions it actually sampled, and GAE gave us the following advantages:

taction a_told log π(a_t|s_t)advantage A_t (raw)
00−1.20+0.80
12−0.51+1.50
21−1.61−0.40
30−0.92−1.20

Step 1 — normalise the advantages. Mean is (0.80+1.500.401.20)/4=0.175(0.80 + 1.50 - 0.40 - 1.20) / 4 = 0.175. Standard deviation (population, used in PyTorch by default) is ((0.800.175)2+(1.500.175)2+(0.400.175)2+(1.200.175)2)/41.025\sqrt{((0.80-0.175)^2 + (1.50-0.175)^2 + (-0.40-0.175)^2 + (-1.20-0.175)^2)/4} \approx 1.025. Normalised: A~=[+0.610,+1.293,0.561,1.341]\tilde A = [+0.610, +1.293, -0.561, -1.341].

Step 2 — compute the new log-probs. After one gradient step, suppose the new policy assigns log-probs [0.48,1.28,1.42,0.40][-0.48, -1.28, -1.42, -0.40] to those same actions. (We'll derive these properly in the plain-Python section below.)

Step 3 — compute ratios. rt=exp(logπnewlogπold)r_t = \exp(\log \pi_{\text{new}} - \log \pi_{\text{old}}):

tlog newlog oldlog ratioratio r_t
0−0.48−1.20+0.722.054
1−1.28−0.51−0.770.463
2−1.42−1.61+0.191.209
3−0.40−0.92+0.521.682

Look at r0=2.054r_0 = 2.054 and r1=0.463r_1 = 0.463 — already these are well outside the trust region [0.8,1.2][0.8, 1.2]. The clip is about to fire on most of this minibatch. That is the point: we will see, in concrete numbers, exactly which timesteps the clip protects against and which it does not.

Step 4 — compute the two surrogates.

tÃ_tr_tsurr1 = r_t · Ãclip(r_t)surr2min
0+0.6102.054+1.2531.200+0.732+0.732
1+1.2930.463+0.5990.800+1.034+0.599
2−0.5611.209−0.6781.200−0.673−0.678
3−1.3411.682−2.2561.200−1.609−2.256

Read it row by row. Row 0: positive advantage, ratio above the trust region — the clip fires; surr2 (0.732) is smaller than surr1 (1.253), so the min picks surr2 and the gradient through r0r_0 is zero. The clip just stopped us from over-rewarding an action we've already pushed up too much. Row 1: positive advantage, ratio below the trust region — clip does NOT fire (clip(0.463)=0.8 makes surr2 larger, but min picks the smaller surr1=0.599). Gradient flows, and it pushes the probability of this good action up. Row 2: negative advantage, ratio just above 1 — clip fires; min picks surr1=−0.678 (the smaller-magnitude negative) because surr2=−0.673 is even smaller in magnitude. Gradient still pushes the probability down, but slightly less strongly than the raw advantage would suggest. Row 3: negative advantage, ratio well above 1 — surr1=−2.256, surr2=−1.609. Min picks the more negative surr1; gradient flows fully, ignoring the clip. This is the asymmetry working as advertised: we let the gradient hammer down on a bad action even when the policy has drifted, because pulling away from a penalty is always safe.

Step 5 — the loss. Mean of the min column: (0.732+0.5990.6782.256)/4=0.401(0.732 + 0.599 - 0.678 - 2.256) / 4 = -0.401. Negate (we maximise the objective, so the loss is the negative): policy loss = +0.401+0.401. That positive value is what gets backpropped. Without the clip, the loss would have been (1.253+0.5990.6782.256)/4=+0.270-(1.253 + 0.599 - 0.678 - 2.256) / 4 = +0.270; with the clip it is higher, which means the gradient magnitude is also different — and the model is being told to make a different, safer update.

The headline number to take from this walkthrough: the clip fired on 3 out of 4 tokens. That is way too high for a production run — a healthy PPO step has clipfrac in [0.10, 0.30]. The high clipfrac here is because we deliberately took a too-large gradient step between ‘old’ and ‘new’ for the walkthrough to be readable. In practice you tune the learning rate and the number of PPO epochs per rollout to keep clipfrac in that safe band.

The Reference Model and KL Penalty

Everything above is pure PPO. The RLHF twist — the modification that turns ‘PPO for games’ into ‘PPO for LLMs’ — is the addition of a third model: the reference policy.

The reference policy πref\pi_{\text{ref}} is a frozen copy of the SFT model — the model from § 13. It never receives gradient. At every step of PPO we compute the per-token KL divergence from the live policy to the reference and add it as a penalty:

LRLHF(θ)=LCLIP(θ)+βKLEt[DKL(πθ(st)πref(st))]L^{\text{RLHF}}(\theta) = -L^{\text{CLIP}}(\theta) + \beta_{\text{KL}} \cdot \mathbb{E}_t \left[ D_{\text{KL}}\big(\pi_\theta(\cdot \mid s_t) \,\|\, \pi_{\text{ref}}(\cdot \mid s_t)\big) \right]

Why a separate KL term when PPO already has a clip? Because the clip prevents large per-step drift, but cannot prevent the policy from accumulating small drifts over thousands of steps into a model that is far from πref\pi_{\text{ref}}. The clip is a per-step trust region; the KL penalty is the long-term anchor.

The KL is computed token-wise. In code we use the k3 unbiased estimator D^KL=rlogr1\hat{D}_{\text{KL}} = r - \log r - 1 where r=πθ/πrefr = \pi_\theta / \pi_{\text{ref}} (the ratio against the reference, not the old policy). The k3 estimator is always positive, has lower variance than the naivelogr\log r estimator, and reduces to true KL in expectation.

The coefficient βKL\beta_{\text{KL}} is the most-tuned hyperparameter in RLHF. Too small and the policy drifts into reward-hacked nonsense in a few thousand steps. Too large and the policy never moves from the SFT model and you waste the entire RLHF run. Modern open-source recipes use βKL0.02\beta_{\text{KL}} \approx 0.02 to 0.10.1, sometimes adapted on the fly to keep the realised KL below a target value (e.g. 6 nats per response).

KL coefficient βBehaviour
0Pure PPO. Policy will eventually reward-hack.
0.001 - 0.01Slow drift allowed. Use only when reward model is very robust.
0.02 - 0.05Standard. Llama-3, Qwen-2, DeepSeek defaults sit here.
0.1 - 0.5Conservative. Policy barely moves; use when reward model is suspect.
>1.0The policy is frozen. Indistinguishable from skipping RLHF.

The Four-Model Dance of RLHF

At this point you have all the ingredients. A single RLHF training step needs four models in memory at once, and tracking what each does is the difference between writing PPO and watching it crash:

ModelTrainable?Role
Policy π_θYes — being trainedThe model we are improving. Same architecture as the SFT model; initialised from it. Forward pass produces logits and log-probs.
Value head V_ϕYes — being trainedSmall head (often one linear layer) on top of the policy's hidden states. Predicts V(s_t) for the advantage baseline. Trained with MSE against returns.
Reference π_refNo — frozen SFT modelUsed to compute the KL penalty. Same architecture as the policy, but its weights never change. Forward only.
Reward model R_ψNo — frozen, separately trainedTakes a full (prompt, response) pair and outputs a scalar. Called once per rollout, not per gradient step. Same architecture as policy + scalar head.

For a 70B RLHF run, holding four 70B models in memory simultaneously is the dominant cost — not the gradients, not the optimizer state. Practical implementations share weights when possible: the value head usually piggybacks on the policy backbone (one extra linear head); the reference and the policy can share parameters with a clever swap (LoRA tricks); the reward model can be smaller (a 7B reward model rates a 70B policy fine).

The orchestration follows a fixed loop, called the PPO outer loop:

  1. Rollout phase. Sample N prompts. For each prompt, generate a response with the OLD policy (a snapshot of π_θ before this rollout starts). Cache: input_ids, old_logp, ref_logp, old_values.
  2. Reward phase. Run the reward model once per (prompt, response). Combine with the running KL signal to form a per-token reward stream r_t.
  3. Advantage phase. Run GAE on the r_t and old_values to produce advantages A_t and returns R_t.
  4. Optimisation phase. For K epochs (typically K=4), shuffle the rollout into minibatches and run the PPO step on each minibatch. Each minibatch step is the function we just analysed.
  5. Snapshot. Set π_old ← π_θ. Back to step 1.

Sampling throughput in step 1 (the rollout) dominates the wall-clock time, often 70–90% of an RLHF step. This is why production RLHF runs use accelerated inference engines (vLLM, SGLang) for rollout and dedicated training engines (FSDP) for optimisation — splitting them across separate GPU pools is the modern best practice.

Plain Python: PPO Loss from Scratch

We re-implement the PPO clipped surrogate in pure NumPy on a toy rollout of four timesteps. No autograd, no transformer — just the loss arithmetic an RLHF trainer runs every minibatch. Swap the hard-coded logits for a real transformer's output and the code is unchanged.

PPO clipped surrogate — pure NumPy
🐍ppo_loss_plain.py
16The rollout — what the old policy did and how well it worked

Three parallel arrays of length N = 4 are the entire input PPO needs. actions are the discrete action indices that the OLD (frozen) policy sampled at four states. old_logp are the log-probabilities the OLD policy assigned to those actions — frozen scalars, no gradient ever flows through them. advantages come from the GAE step (§ Advantage above) and tell us, for each (state, action) pair, how much better or worse that action was than the value baseline expected. Positive means 'do this more often'; negative means 'do this less'.

EXAMPLE
actions = [0, 2, 1, 0]; old_logp = [-1.20, -0.51, -1.61, -0.92]; advantages = [+0.80, +1.50, -0.40, -1.20]
23Advantage normalisation — the single most important PPO trick

📚 Standardise advantages within the minibatch so the loss scale is invariant to how big the rewards happen to be on this batch. Without this, a single high-reward sequence dominates the gradient and the run becomes a noise generator. The +1e-8 in the denominator is the standard 'don't divide by zero' safeguard for batches where every advantage happens to be equal. Every open-source PPO implementation does this — trl, OpenAI Spinning Up, cleanrl. Skipping it is the most common reason a custom PPO loop doesn't train.

EXAMPLE
advantages = [+0.80, +1.50, -0.40, -1.20] → adv = [+0.32, +1.03, -0.46, -1.20] (after centering and scaling)
31Re-evaluate the rollout with the NEW (live) policy

📚 PPO is an OFF-POLICY-ISH algorithm: it collects data with the OLD policy, then takes several gradient steps with the NEW policy on the same data. To do that we have to ask 'what would the CURRENT policy have done at those exact same states?'. new_logits has shape (N, vocab) where each row is the unnormalised score the new policy assigns to each action at that state. In a real LLM-RLHF run, this means re-running the rollout's tokens through the live model.

EXAMPLE
new_logits.shape = (4, 3); each row is one state's per-action scores
40Numerically stable log-softmax

📚 Subtract the per-row max before exponentiating, then take log of the normalised sum. Standard log-sum-exp trick — without it, large logits like 30 would overflow exp() and you get inf/NaN. log_probs_all[i, v] is now log pi_new(action v | state i) for every state and every action.

EXAMPLE
logits row [1.10, -0.20, 0.30] → log_probs row [-0.476, -1.776, -1.276] (the three options sum-exp to 1 by construction)
44Gather log pi_new at the actions that were ACTUALLY taken

📚 Fancy indexing: log_probs_all has shape (N, vocab). We pick out log_probs_all[t, actions[t]] for every t — exactly one log-prob per timestep. This is the new policy's log-probability of the OLD policy's action. The shape collapses from (N, vocab) to (N,).

EXAMPLE
actions[0]=0, log_probs_all[0,0]=-0.476  →  new_logp[0]=-0.476
51The probability ratio — the heart of off-policy correction

📚 r_t = pi_new(a|s) / pi_old(a|s) = exp(log pi_new − log pi_old). We compute the difference in log space FIRST (numerically safe) and exponentiate the small scalar. r_t = 1 means the new policy is unchanged at this (s, a); r_t > 1 means it now assigns this action MORE probability than before; r_t < 1 means LESS. PPO will use this ratio to decide whether to trust the surrogate objective or clip it.

EXAMPLE
old_logp[0] = -1.20, new_logp[0] = -0.476  →  ratio[0] = exp(-0.476 - (-1.20)) = exp(0.724) = 2.063
58The two surrogate objectives — unclipped and clipped

📚 surr1 = r_t · A_t is the naive (unclipped) policy-gradient objective from REINFORCE / TRPO. surr2 = clip(r_t, 1-ε, 1+ε) · A_t bounds the ratio inside [0.8, 1.2] BEFORE multiplying by A. That second clipped version cannot drift further from 1 than ε in either direction — it is the 'safe' objective.

EXAMPLE
ratio[0]=2.063, adv[0]=0.32  →  surr1[0]=0.660, clip(2.063, 0.8, 1.2)=1.2  →  surr2[0]=0.384
63Take the MIN — the pessimistic bound that gives PPO its name

📚 The minimum of (surr1, surr2) is the PPO objective. The min is what makes the clip ASYMMETRIC: for a positive advantage, the model cannot reap more than (1+ε)·A reward from a single (s, a) — going further has zero gradient. For a negative advantage, the model cannot reap more than (1-ε)·A — the SAME logic flipped. The min ensures the clip only ever pulls the objective DOWN, never UP. That asymmetry is the safety brake: you never reward a huge policy drift, but you always penalise one.

EXAMPLE
surr1[0]=0.660, surr2[0]=0.384  →  ppo[0]=min(0.660, 0.384)=0.384 (clipped — the gradient through ratio[0] is now zero)
67Negate — gradient descent minimises, we want to maximise

PyTorch / NumPy optimisers all minimise. The PPO OBJECTIVE we want to MAXIMISE; the PPO LOSS is the negative. Mean over timesteps (not sum) keeps the loss scale independent of minibatch size. Forgetting the negative is a textbook bug — the policy moves away from advantage instead of toward it. The first sign is reward going DOWN monotonically.

EXAMPLE
ppo per step = [+0.384, +1.236, +0.184, +1.440] (signs not actual)  →  policy_loss = -mean = -0.811
68 lines without explanation
1"""
2PPO clipped surrogate objective — pure NumPy, no autograd.
3
4We compute the policy loss for one minibatch of ROLLOUT data.
5A 'rollout' is a list of (state, action, old_log_prob, advantage) tuples
6collected by the OLD policy. The new policy is what we are training.
7
8The 'policy' is a single-token categorical over a 3-action vocab so we
9can read every number. Swap in a transformer logits head and the loss
10code is byte-for-byte what PPO trainers use in production.
11"""
12
13import numpy as np
14
15# ---------------------------------------------------------------------------
16# 1. A rollout: 4 timesteps collected from the OLD policy.
17#    actions    : which action was sampled at each step
18#    old_logp   : log pi_old(a|s) at each step  (frozen, no gradient)
19#    advantages : how much better than the value baseline each action was
20# ---------------------------------------------------------------------------
21
22actions     = np.array([0, 2, 1, 0])
23old_logp    = np.array([-1.20, -0.51, -1.61, -0.92])     # log pi_old
24advantages  = np.array([+0.80, +1.50, -0.40, -1.20])     # from GAE
25N           = len(actions)
26
27# Standard PPO trick: normalise the advantages so the loss scale is stable
28# across batches. Mean 0, std 1, within this minibatch only.
29adv = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
30
31# ---------------------------------------------------------------------------
32# 2. The NEW policy's logits at the SAME states (gradient flows here).
33#    A real PPO step re-forwards the rollout states through the live model
34#    to recompute logits AFTER each gradient update. Here we hard-code
35#    plausible "slightly drifted" logits to make the walkthrough readable.
36# ---------------------------------------------------------------------------
37
38new_logits = np.array([
39    [ 1.10, -0.20,  0.30],     # state 0
40    [-0.40,  0.80,  1.60],     # state 1
41    [ 0.10,  1.40,  0.20],     # state 2
42    [ 1.30, -0.50,  0.10],     # state 3
43])
44
45# Numerically stable log-softmax — log pi_new(a|s) for every action.
46z = new_logits - new_logits.max(axis=-1, keepdims=True)
47log_probs_all = z - np.log(np.exp(z).sum(axis=-1, keepdims=True))
48
49# Gather the log-prob of the action the OLD policy actually took.
50new_logp = log_probs_all[np.arange(N), actions]
51
52# ---------------------------------------------------------------------------
53# 3. Probability ratio r_t = pi_new(a|s) / pi_old(a|s)
54#    Compute in log space first to avoid overflow, then exponentiate.
55# ---------------------------------------------------------------------------
56
57ratio = np.exp(new_logp - old_logp)
58
59# ---------------------------------------------------------------------------
60# 4. The PPO clipped surrogate.
61# ---------------------------------------------------------------------------
62
63EPSILON = 0.2
64
65surr1 = ratio * adv                                          # unclipped
66surr2 = np.clip(ratio, 1.0 - EPSILON, 1.0 + EPSILON) * adv   # clipped
67
68# Take the MIN — this is the PPO pessimistic bound.
69# Note: for A > 0 we cap upside; for A < 0 we cap downside.
70ppo_objective_per_step = np.minimum(surr1, surr2)
71
72# The PPO LOSS is the negative of the OBJECTIVE (we MAXIMISE the objective).
73policy_loss = -ppo_objective_per_step.mean()
74
75print(f"ratios       = {ratio.round(3)}")
76print(f"normalised A = {adv.round(3)}")
77print(f"policy_loss  = {policy_loss:.4f}")

PyTorch: A Production-Shaped PPO Step

Now the full thing — policy loss, clipped value loss, entropy bonus, per-token KL to a reference model, gradient clipping, and diagnostics. This is the inner loop that lives inside trl, OpenRLHF, and DeepSpeed-Chat. Drop in any HF causal LM for ‘policy’ and any scalar-head model for ‘value_head’ and this trains a real RLHF run.

PPO step — policy clip + value clip + entropy + KL
🐍ppo_step_torch.py
16The four hyperparameters that define a PPO run

📚 EPSILON=0.2 is the clip range — the trust region radius. VF_COEF=0.5 weights the value loss into the joint loss; the convention since the original PPO paper. ENT_COEF=0.01 is the entropy bonus; higher values prevent the policy from collapsing to a near-deterministic distribution too early. KL_COEF=0.05 is the per-token KL penalty against the reference (SFT) policy — the RLHF-specific addition that prevents the model from drifting away from human-likely text. MAX_GRAD=1.0 is the universal clip-norm default; lift this and the run will eventually NaN.

23The PPO step signature — minibatch, three models, one optimizer

Three models live in memory: 'policy' (trainable, being updated), 'value_head' (trainable, predicts V(s)), and 'ref_policy' (frozen SFT model, used only to compute ref_logp). At full scale these are three FSDP-wrapped 70B models; at toy scale they can be a single nn.Linear. The function returns a dict of scalars for logging — every metric below is the one a real PPO dashboard plots per step.

34Unpacking the rollout minibatch

📚 Every tensor has shape (B, S) where S is the prompt+response length. Crucially, old_logp, ref_logp, advantages, returns, and old_values are ALL precomputed during the rollout phase and frozen here — no gradient flows through them. action_mask is the per-token mask that distinguishes prompt tokens (no loss) from response tokens (where the policy is being graded). Getting that mask wrong is the second most common PPO bug after forgetting advantage normalisation.

EXAMPLE
input_ids.shape = (16, 1024); action_mask.sum() / total tokens ≈ 0.5 (half prompt, half response)
45Forward the live policy on the rollout — gradient flows from here only

📚 use_cache=False is mandatory during training (a KV-cache breaks gradient flow). The (B, S, V) logits are bf16 — for a 70B model with V=128k and S=1024 this is 4 GB per minibatch. We slice [:, :-1, :] because the model at position t predicts the token at position t+1: there is no successor to predict for the last position.

EXAMPLE
logits.shape = (16, 1023, 128256); targets.shape = (16, 1023)
50Per-token log-probability of the action that was actually taken

📚 log_softmax → gather is the standard two-step. gather along the last dim with targets gives us, for each (batch, time) cell, the log-prob of the token that was sampled at that position. Squeezing the last dim collapses the shape to (B, S-1). After this line, new_logp[b, t] = log pi_new(a_t | s_t) for every token. The same tensor will appear in the ratio, the entropy, and the KL — all three reuse this one forward pass.

56Advantage normalisation over the VALID positions only

📚 We compute the mean and std of advantages on response tokens only (where action_mask=1). Including prompt tokens (which are zero by collator construction) would deflate the std and inflate the normalised advantages. The .clamp(min=1e-8) on the std is the standard 'never divide by zero' safeguard. Many subtle PPO bugs come from normalising over the wrong set of positions.

EXAMPLE
valid count = 8192 out of 16384; a_mean ≈ 0.0, a_std ≈ 1.3 → normalised adv has mean 0, std 1
64log-space ratio first, then exponentiate

📚 log_ratio = log pi_new − log pi_old is bounded and numerically safe. Exponentiating gives the actual ratio. NEVER compute ratio = pi_new / pi_old directly — both numerator and denominator can underflow to zero on long sequences, and you would lose precision exactly where ratio matters most (near 1).

EXAMPLE
Typical first step: log_ratio ≈ 0 ± 0.05  →  ratio ≈ 1.0 ± 0.05
67The clipped surrogate — torch.clamp + torch.min on every token

📚 Element-wise: surr1 = ratio * adv (unclipped); surr2 = clamp(ratio, 1-ε, 1+ε) * adv (clipped). torch.min picks the LOWER (more pessimistic) of the two element-wise. The negative is the loss. masked_mean averages only over response tokens. Doing the negation and the mean on the masked tensor (not the raw one) is critical — averaging zeros in from prompt positions makes the policy loss vanish into noise.

EXAMPLE
policy_loss typical first step: -0.05 to -0.20; sign should be negative because adv is roughly mean-zero and ratio*adv averages slightly positive on selected actions
72Clipped value loss — mirrors the policy clip for stability

📚 The value head predicts V(s). Its target is the actual return R_t. The natural loss is MSE: (V(s) - R)^2. PPO additionally clips the VALUE update: v_clipped is the new value, constrained to within ±ε of the old value. We then take the MAX of the two squared errors. That max is what makes the value clip ASYMMETRIC: it never lets a value-function update reduce the loss by making the update LARGER than ε. Without this, the value head can shoot ahead of the policy and corrupt the next batch's advantages.

EXAMPLE
value_loss is typically 0.1 to 1.0 at the start of training and decays to 0.01 by the end
79Entropy bonus — keep the policy from collapsing

📚 Entropy = -sum(p log p) over the vocab at each token. Higher = more uniform = more exploration. We SUBTRACT it from the loss (so the optimiser increases it). ENT_COEF is small (1e-2) so this is a gentle nudge — a too-large bonus prevents the policy from ever committing to a single best response and reward stagnates. For LLM-RLHF the entropy bonus is often set to 0 because the per-token vocab already has enormous entropy; for game RL the bonus matters more.

EXAMPLE
entropy = average over (B, S-1) of -sum_v p_v log p_v; typical LLM entropy ≈ 2.0–6.0 nats per token
85k3 KL estimator — the RLHF-specific term

📚 The reference model is the FROZEN SFT model. We compute the KL divergence from the live policy to it at every token: D_KL(pi_new || pi_ref). The 'k3' estimator (Schulman 2020) is kl ≈ exp(log_r) - log_r - 1; it is unbiased, always non-negative, and lower variance than the naive log_r estimator. This penalty is the second safety brake of RLHF (the first is the PPO clip): without it the policy drifts into rewarded-but-non-human text — the textbook reward-hacking failure mode where the model emits a string of '!!!' because the reward model happens to like punctuation.

EXAMPLE
Typical KL per token at convergence: 5–20 nats. KL > 50 means runaway drift; restart with bigger KL_COEF.
92The total loss — four signed terms

📚 + policy_loss (already negative-of-objective) + VF_COEF * value_loss (positive — we minimise MSE) − ENT_COEF * entropy_bonus (negative — we MAXIMISE entropy) + KL_COEF * kl_penalty (positive — we MINIMISE drift). The signs are easy to get wrong; an end-to-end debug version always prints each component separately so a sign error shows up immediately as one of the four going the wrong way.

102Single combined backward + clip + step

📚 loss.backward() populates .grad on every parameter of BOTH the policy and the value head. clip_grad_norm_ rescales them together so the joint L2 norm is at most MAX_GRAD=1.0. Why together? Because the value gradient and the policy gradient share the same backbone in some architectures (when the value head is bolted onto the policy's last hidden state). Clipping them separately would change their relative scales and the joint loss balance.

111Diagnostics — the three numbers that tell you if PPO is healthy

📚 approx_kl is the per-step empirical KL between old and new policy (the policy DRIFT this minibatch). clipfrac is the fraction of tokens where |ratio - 1| > ε (i.e. where the clip kicked in). A healthy PPO run holds approx_kl in [0.01, 0.05] and clipfrac in [0.10, 0.30]. If approx_kl > 0.1 for several steps, lower LR or take fewer PPO epochs per rollout. If clipfrac > 0.5, the policy is changing too fast — same fix. These two numbers, more than the loss, are what an RLHF engineer watches.

EXAMPLE
approx_kl=0.018, clipfrac=0.21 → healthy step
127masked_mean — the function that appears six times in the body

📚 Sum the elements of x where mask=1, divide by the count of mask=1 positions. The .clamp(min=1.0) on the denominator prevents division by zero on a degenerate minibatch with no response tokens. This single helper enforces the rule 'every average is taken over the action positions only' — change it once and every loss term inherits the fix.

110 lines without explanation
1"""
2PyTorch PPO step for RLHF — the inner loop of every modern RLHF trainer.
3
4A 'rollout' is a list of prompts, each rolled out to a full response with
5the OLD policy (frozen snapshot). For each token we have:
6   - old_log_prob   : log pi_old(a_t | s_t)            (no gradient)
7   - ref_log_prob   : log pi_ref(a_t | s_t)            (no gradient)
8   - advantage      : A_t from GAE                     (no gradient)
9   - return         : R_t = A_t + V_old(s_t)           (value target)
10
11We then do K epochs of mini-batched updates on this same rollout, each
12forwarding the rollout tokens through the LIVE policy and value head.
13"""
14
15import torch
16import torch.nn.functional as F
17
18EPSILON  = 0.2          # PPO clip
19VF_COEF  = 0.5          # weight on value loss
20ENT_COEF = 0.01         # weight on entropy bonus
21KL_COEF  = 0.05         # weight on per-token KL to reference policy
22MAX_GRAD = 1.0          # global grad-norm clip
23
24def ppo_step(batch, policy, value_head, ref_policy, optimizer):
25    """
26    batch is one minibatch from the rollout, already on device:
27      input_ids    : (B, S)        prompt + response tokens
28      attn_mask    : (B, S)
29      action_mask  : (B, S)        1 on RESPONSE tokens, 0 on prompt/pad
30      old_logp     : (B, S)        log pi_old at each token
31      ref_logp     : (B, S)        log pi_ref at each token (frozen SFT model)
32      advantages   : (B, S)        A_t from GAE
33      returns      : (B, S)        value targets
34      old_values   : (B, S)        V_old(s_t) for clipped value loss
35    """
36    input_ids   = batch["input_ids"]
37    attn_mask   = batch["attn_mask"]
38    action_mask = batch["action_mask"]
39    old_logp    = batch["old_logp"]
40    ref_logp    = batch["ref_logp"]
41    advantages  = batch["advantages"]
42    returns     = batch["returns"]
43    old_values  = batch["old_values"]
44
45    # --- Forward the LIVE policy -----------------------------------------
46    out = policy(input_ids=input_ids, attention_mask=attn_mask, use_cache=False)
47    logits = out.logits[:, :-1, :]              # predict t+1 from <= t
48    targets = input_ids[:, 1:]                  # the action that was taken
49    am      = action_mask[:, 1:].float()
50
51    log_probs_all = F.log_softmax(logits, dim=-1)
52    new_logp = log_probs_all.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
53
54    # --- Advantage normalisation -----------------------------------------
55    masked_adv = advantages[:, 1:]
56    valid = am.bool()
57    a_mean = masked_adv[valid].mean()
58    a_std  = masked_adv[valid].std().clamp(min=1e-8)
59    adv = (masked_adv - a_mean) / a_std
60
61    # --- The PPO clipped surrogate ---------------------------------------
62    log_ratio = new_logp - old_logp[:, 1:]
63    ratio     = log_ratio.exp()
64
65    surr1 = ratio * adv
66    surr2 = torch.clamp(ratio, 1 - EPSILON, 1 + EPSILON) * adv
67    policy_loss = -masked_mean(torch.min(surr1, surr2), am)
68
69    # --- Value loss (Bellman target, clipped) ----------------------------
70    values = value_head(input_ids, attn_mask)[:, 1:]
71    v_clipped = old_values[:, 1:] + (values - old_values[:, 1:]).clamp(
72        -EPSILON, +EPSILON
73    )
74    vf_loss_unc = (values    - returns[:, 1:]).pow(2)
75    vf_loss_clp = (v_clipped - returns[:, 1:]).pow(2)
76    value_loss = 0.5 * masked_mean(torch.max(vf_loss_unc, vf_loss_clp), am)
77
78    # --- Entropy bonus (encourages exploration) --------------------------
79    probs = log_probs_all.exp()
80    entropy = -(probs * log_probs_all).sum(-1)
81    entropy_bonus = masked_mean(entropy, am)
82
83    # --- Per-token KL to reference policy --------------------------------
84    # k3 estimator: ratio_ref - log(ratio_ref) - 1, unbiased and positive.
85    log_r = new_logp - ref_logp[:, 1:]
86    kl_per_token = (log_r.exp() - 1) - log_r        # always >= 0
87    kl_penalty = masked_mean(kl_per_token, am)
88
89    # --- Total loss ------------------------------------------------------
90    loss = (
91        policy_loss
92        + VF_COEF  * value_loss
93        - ENT_COEF * entropy_bonus
94        + KL_COEF  * kl_penalty
95    )
96
97    # --- One gradient step -----------------------------------------------
98    loss.backward()
99    gn = torch.nn.utils.clip_grad_norm_(
100        list(policy.parameters()) + list(value_head.parameters()),
101        MAX_GRAD,
102    )
103    optimizer.step()
104    optimizer.zero_grad(set_to_none=True)
105
106    # --- Diagnostics (log every step) ------------------------------------
107    with torch.no_grad():
108        approx_kl   = masked_mean(old_logp[:, 1:] - new_logp, am)
109        clipfrac    = masked_mean(((ratio - 1.0).abs() > EPSILON).float(), am)
110
111    return {
112        "loss":       loss.detach(),
113        "pg_loss":    policy_loss.detach(),
114        "vf_loss":    value_loss.detach(),
115        "entropy":    entropy_bonus.detach(),
116        "kl_ref":     kl_penalty.detach(),
117        "approx_kl":  approx_kl.detach(),
118        "clipfrac":   clipfrac.detach(),
119        "grad_norm":  gn.detach(),
120    }
121
122
123def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
124    """Average x only where mask == 1. mask is float (1.0/0.0)."""
125    return (x * mask).sum() / mask.sum().clamp(min=1.0)

At Massive Scale: What Changes for a 70B+ Run

Every line of code above scales to a 70B-parameter RLHF run, but three things change in spirit.

Memory: the four-model problem becomes the dominant cost

At 70B, each model is ~140 GB in bf16. Four of them is 560 GB just in weights, plus 280 GB for policy gradients, plus ~1.1 TB for AdamW optimiser state (fp32 m and v), for a total of ~2 TB before a single activation. Compare to pretraining a 70B model from scratch, where the four-model factor is replaced by a 1× factor — RLHF is, on a per-step basis, much more memory-hungry than pretraining. This is why production RLHF systems share the policy and value-head backbones, fold the reference into a LoRA adapter, or offload the reference and reward model to CPU/disk between rollouts.

The sampling-vs-training imbalance

A 70B policy emits maybe 30 tokens/second on a single H100 with vanilla HuggingFace generation; with vLLM that rises to 200–500 tokens/second per GPU. A typical PPO step trains on ~256 prompts × ~512 response tokens = 131k tokens. At 500 tok/s/GPU, even an 8-GPU rollout takes ~33 seconds; the optimisation phase that follows is ~5 seconds. So 85% of wall-clock is spent rolling out. This is why production RLHF infrastructure separates the rollout cluster (vLLM, many GPUs, prioritised for inference throughput) from the training cluster (FSDP, fewer GPUs, prioritised for memory bandwidth). The rollout cluster ships generations to the training cluster over the network; weights are synced after each optimisation phase.

Distributed advantage normalisation

Advantage normalisation is now an all-reduce. Each FSDP rank computes a partial sum and partial sum-of-squares of advantages on its shard; an all-reduce produces the global mean and std; every rank then normalises locally with those globals. Forgetting to all-reduce and normalising per-rank is a textbook bug — each rank sees a different scale, the effective KL coefficient varies across ranks, and the run becomes noisy in a way that takes weeks to diagnose. Every production RLHF library has unit tests for this one all-reduce.

Reward-model bottleneck

For a 70B policy with a 7B reward model, scoring is cheap. For a 405B policy with a 70B reward model, scoring is comparable in cost to a single optimisation epoch. Modern recipes batch-score responses asynchronously (the reward model runs on its own GPU pool, with a queue between it and the rollout cluster), and some cutting-edge work replaces the reward model with a verifier (rule-based, see § 14.4) or a generative judge (see § 14.5) — the DeepSeek-R1 lineage avoids a learned reward model entirely for exactly this reason.

Engineering Reality: The Knobs That Break Runs

After enough PPO runs every team converges on the same set of failure modes and the same set of high-leverage fixes.

  • The KL coefficient is too small. Symptom: reward climbs for 200 steps, then the policy collapses to a degenerate output (a fixed phrase, all-caps, emoji spam) that the reward model happens to score high. Cause: the per-token KL penalty isn't enough to keep the policy near the SFT manifold. Fix: raise β_KL by 2×, restart from a checkpoint before the collapse. Many teams use adaptive KL — increase β when realised KL exceeds a target, decrease when it's below.
  • Clipfrac is too high (> 0.5). Symptom: training loss looks noisy and reward grows slowly. Cause: each gradient step moves the policy too far in ratio space — most tokens hit the clip, gradient becomes erratic. Fix: lower LR by 2×, or reduce PPO epochs per rollout from 4 to 2. The clipfrac and approx_kl diagnostics are the right thing to watch, not the loss itself.
  • Clipfrac is too low (< 0.05). Symptom: reward barely moves. Cause: the policy isn't changing enough — the LR is too low or the dataset is too narrow. Fix: raise LR carefully (2×) or widen the prompt distribution.
  • Forgot to normalise advantages. Symptom: training loss has huge variance across batches and learning stalls. Cause: per-batch reward scale varies wildly. Fix: the one-line normalisation in the code above. This is the single most common ‘PPO doesn't train’ bug.
  • Action mask wrong. Symptom: reward and KL both look normal but downstream evals are unchanged from the SFT model. Cause: the mask covers prompt tokens instead of response tokens, or has an off-by-one against labels. Fix: print one example's mask alongside its input_ids before launch and eyeball that mask=1 lines up with response tokens.
  • Value head exploded. Symptom: value_loss climbing instead of falling, then NaN. Cause: value head LR not decoupled from policy LR, or value clip not applied. Fix: the clipped value loss in the code above; some teams use a smaller LR (typically 5×) for the value head than for the policy.
  • Reference model accidentally on the same fork as policy. Symptom: KL goes to zero, policy never moves. Cause: the ‘reference’ tensor is actually the live policy, e.g. because both were initialised from the same variable without a deep copy. Fix: assert id(ref_policy) != id(policy) at startup; recompute ref_logp from a separately-loaded checkpoint.
  • Rollout uses temperature=1.0 but training uses softmax(logits) without temperature. Symptom: ratio distribution is centred far from 1.0 even on step 0. Cause: the rollout sampled from softmax(logits / T) but the trainer computes new_logp from softmax(logits). They are different distributions and the ratio is meaningless. Fix: apply the SAME temperature inside the trainer when computing log-probs of the action.
  • Gradient norm explodes after a few hundred steps. Symptom: |g| climbs from 0.5 to 50 in 100 steps. Cause: either KL coefficient too small (policy escaping to extreme outputs) or entropy too low (policy collapsing to a near-delta distribution and getting huge per-token gradient magnitudes). Fix: raise β_KL or raise the entropy bonus.
The mental model that unifies this section: PPO is REINFORCE with three modifications. One: subtract a value baseline so positive/negative judgments are calibrated. Two: clip the importance ratio so a single optimisation step cannot move the policy outside a trust region in behaviour space. Three: add a KL penalty against a frozen reference so the cumulative drift over thousands of steps is bounded. Every line of code in the PyTorch step above is one of those three jobs. Get the masking right and the diagnostics right, and a 70B chat model can be RLHF'd in a few hundred PPO outer steps.
Loading comments...