Skip to content

13. PPO 详解

Proximal Policy Optimization (PPO) 是 RLHF 在工业界落地的"标准件"。InstructGPT、ChatGPT、Claude、LLaMA-2、GPT-4 都用 PPO 完成 RL 阶段。本章从策略梯度定理开始一步步推导,覆盖 GAE、重要性采样、PPO-Clip、4 模型架构和工程稳定性技巧。这是整个对齐部分数学密度最高的一章——但只要走完一遍推导,后续所有 RL 类对齐算法(GRPO、RLOO、ReMax 等)都是它的变体。


13.1 策略梯度定理

13.1.1 RL 目标

强化学习的核心目标:找到使期望累积奖励最大的策略:

J(θ)=Eτπθ[R(τ)]=Eτπθ[t=0Tγtrt]

其中 τ=(s0,a0,r0,s1,a1,r1,,sT) 是一条由策略 πθ 与环境交互产生的轨迹,γ(0,1] 是折扣因子。

我们想要:

θ=argmaxθJ(θ)

最自然的想法是梯度上升

θk+1=θk+αθJ(θk)

但这里有一个根本困难:期望对 θ 的依赖——轨迹的分布 πθ 本身就含 θ

13.1.2 推导 ∇J

设轨迹概率:

P(τ|θ)=p(s0)t=0T1πθ(at|st)p(st+1|st,at)

其中 p(s0)p(st+1|st,at)θ 无关。所以:

θlogP(τ|θ)=t=0T1θlogπθ(at|st)

注意:环境动力学梯度被消掉了,这是 RL 的关键好处。

利用 log-derivative trick

θJ(θ)=θP(τ|θ)R(τ)dτ=θP(τ|θ)R(τ)dτ=P(τ|θ)θP(τ|θ)P(τ|θ)R(τ)dτ=P(τ|θ)θlogP(τ|θ)R(τ)dτ=Eτπθ[θlogP(τ|θ)R(τ)]

代入 logP

θJ(θ)=Eτπθ[t=0T1θlogπθ(at|st)R(τ)]

这就是 策略梯度定理 (Policy Gradient Theorem) 的雏形(Williams, 1992 的 REINFORCE 形式)。

13.1.3 时间因果性:reward-to-go

注意到 R(τ)=t=0Trt 包含了 t<t 的奖励,但 at 不可能影响过去的奖励。利用 E[logπθ(at|st)rt]=0(当 t<t,因为 rtat 之前已经发生),可以只保留未来奖励

θJ(θ)=Eτπθ[t=0T1θlogπθ(at|st)t=tTγttrtGt=reward-to-go]

定义状态-动作值函数 Qπθ(s,a)=E[Gt|st=s,at=a],则更紧凑的形式:

θJ(θ)=E[tθlogπθ(at|st)Qπθ(st,at)]

13.1.4 引入基线降低方差

REINFORCE 的方差非常高。可以减去任意只依赖 st基线 (baseline) b(st)

J=E[tlogπθ(at|st)(Qπθ(st,at)b(st))]

为什么减基线不改变期望

对任意 b(st)

Eatπθ(|st)[θlogπθ(at|st)b(st)]=b(st)E[θlogπθ(at|st)]

而:

Eaπθ[logπθ(a|s)]=πθ(a|s)πθ(a|s)πθ(a|s)da=πθ(a|s)da=(1)=0

所以减基线不改变期望,但减小方差。

最优基线

理论上最优基线是 b(s)=Ea[logπ(a|s)2Q(s,a)]/Ea[logπ(a|s)2],但实际中常用 b(s)=Vπθ(s)(状态值函数)作为近似最优。此时 QV=A优势函数 (advantage)

θJ(θ)=Eτπθ[tθlogπθ(at|st)Aπθ(st,at)]

这是策略梯度的实用形式。所有现代算法(A2C/A3C/TRPO/PPO/GRPO)都基于此式。


13.2 重要性采样:从 on-policy 到 off-policy

13.2.1 问题:数据浪费

REINFORCE 是严格 on-policy 的:用 πθ 采样的数据估计 J(θ)。一旦更新一步,旧数据就废弃。在 LLM 场景下:

  • 每次 rollout 要生成上千 token,开销极大;
  • 数据只用一次太浪费;
  • 我们希望"采一次样,更新很多次"。

解决思路:重要性采样 (Importance Sampling)

13.2.2 重要性比率

考虑要估计 Exp[f(x)],但只能从 q 采样。引入比率 w(x)=p(x)/q(x)

Exp[f(x)]=p(x)f(x)dx=q(x)p(x)q(x)f(x)dx=Exq[p(x)q(x)f(x)]

应用到策略梯度,把"目标策略 πθ"对应 p,"行为策略 πθold"对应 q

θJ(θ)=Eaπθold[πθ(a|s)πθold(a|s)θlogπθ(a|s)A^]

利用 logπθ=πθ/πθ

πθπθoldlogπθ=πθπθoldπθπθ=πθπθold=θ(πθπθold)

所以可以把目标写成:

JIS(θ)=Eaπθold[πθ(a|s)πθold(a|s)A^]

重要性比率

rt(θ)=πθ(at|st)πθold(at|st)

显然 rt(θold)=1

13.2.3 重要性采样的危险

理论上 IS 无偏,但实际上方差可能爆炸。当 πθ 远离 πθold 时:

  • 比率 rt 可能极大(如 100)或极小(如 0.001);
  • 极少数高比率样本主导梯度,方差飙升;
  • 等价于"训练不稳定"。

为此需要约束 πθ 不能离 πθold 太远——这就是 TRPO/PPO 的核心思想。


13.3 TRPO 与 PPO

13.3.1 TRPO:硬 KL 约束

Trust Region Policy Optimization (Schulman et al., 2015) 在 θold 周围设置一个"信任域":

maxθE[rt(θ)A^t]s.t.E[KL(πθold(|s)πθ(|s))]δ

求解需要二阶方法(计算 Fisher 信息矩阵 + 共轭梯度),实现复杂。

13.3.2 PPO-Penalty:软 KL 惩罚

把 KL 当作惩罚项加进目标:

JKL-PEN(θ)=E[rt(θ)A^t]βE[KL(πθoldπθ)]

并用 adaptive β(每 epoch 后根据实际 KL 调整)。简单但调参敏感。

13.3.3 PPO-Clip:裁剪比率(主流)

Schulman et al. (2017) 提出的 PPO-Clip 直接裁剪比率[1ε,1+ε]

JCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ε,1+ε)A^t)]

典型 ε=0.2

直觉:为什么这个 min 能限制更新

考察单个 (st,at),记 r=rt(θ)A=A^t

情况 1:A>0(这个动作好,应该提升其概率)

区域rAclip(r)Amin
r<1ε(1ε)ArA(不限制下降)
1εr1+εrArArA
r>1+ε(1+ε)A(1+ε)A限制上升

情况 2:A<0(这个动作差,应该降低其概率)

区域rAclip(r)Amin
r<1ε大(rA 大,因 A<0r 小)(1ε)A(1ε)A限制下降
1εr1+εrArArA
r>1+εrA(1+ε)ArA(不限制上升)

结论

  • A>0 时,限制 r 不能涨过 1+ε
  • A<0 时,限制 r 不能跌过 1ε
  • 反方向不限:好动作的 r 可以无上限地跌(梯度向上推),坏动作的 r 可以无上限地涨(梯度向下压);这是"悲观下界 (pessimistic bound)"——只在偏离扩大时才生效。

图示

A > 0 时 J^CLIP 关于 r 的图:
J

│         ___________  ← (1+ε)A 之后被截
│        /
│       /
│      /
│_____/_______________ r
     1-ε  1   1+ε

A < 0 时 J^CLIP 关于 r 的图:
J
│________________
│                 \
│                  \
│                   \  ← 1-ε 之前被截
│___________________\___ r
                   1-ε  1   1+ε

13.3.4 PPO-Clip 完整目标

实际 PPO 还包括值函数损失和(可选)熵奖励:

LPPO(θ,ψ)=JCLIP(θ)+cvEt[(Vψ(st)R^t)2]ceEt[H[πθ(|st)]]

其中:

  • R^t=A^t+Vψ(st) 是 GAE 派生的目标回报;
  • cv (典型 0.5 或 0.1):value loss 权重;
  • ce (典型 0.01):熵正则权重,鼓励探索;
  • H[π] 是策略熵。

13.4 Generalized Advantage Estimation (GAE)

13.4.1 优势估计的两个极端

回到 §13.1.4:我们需要估计 A^t。两种简单做法:

蒙特卡洛 (MC)

A^tMC=GtV(st)=l=0γlrt+lV(st)

无偏,但方差极高(依赖整条轨迹)。

1-step TD

A^t(1)=rt+γV(st+1)V(st)=δt

方差小,但 V 不准时偏差大。

更一般的 n-step TD:

A^t(n)=l=0n1γlrt+l+γnV(st+n)V(st)

n 越大越接近 MC(方差大),n 越小越接近 1-step(偏差大)。这是经典的偏差-方差权衡

13.4.2 GAE:指数加权平均

Schulman et al. (2016) 的 GAE 用指数加权融合所有 n-step:

A^tGAE(γ,λ)=l=0(γλ)lδt+l

其中 δt=rt+γV(st+1)V(st) 是 1-step TD 残差。

推导:为什么指数加权能融合所有 n-step

定义 A^t(n)=l=0n1γlδt+l(递归形式),则:

A^tGAE=(1λ)n=1λn1A^t(n)

代入 A^(n)=l=0n1γlδt+l

A^tGAE=(1λ)n=1λn1l=0n1γlδt+l

调换求和顺序(让 l 在外):

=(1λ)l=0γlδt+ln=l+1λn1=(1λ)l=0γlδt+lλl1λ=l=0(γλ)lδt+l

得证。

边界情况

  • λ=0A^t=δt(1-step TD,高偏低方);
  • λ=1A^t=l=0γlδt+l,可以化简为 lγlrt+lV(st)(蒙特卡洛减基线,无偏高方);
  • λ(0,1):在两者之间平滑插值。

13.4.3 LLM 中的 GAE

LLM 场景下回答有限长度(典型 ≤ 4K token),常取:

  • γ=1(不折现,关心整段回答的总质量);
  • λ=0.95(接近 1,偏 MC,因为 V 训练慢且稀疏 reward)。

实现:递推计算

GAE 可以反向递推 O(T) 计算:

A^t=δt+γλA^t+1
python
def compute_gae(rewards, values, gamma=1.0, lam=0.95):
    """
    rewards: [T] (per-token rewards)
    values:  [T+1] (V(s_t) for t=0..T,最后一项为 0 或 bootstrapping value)
    返回 advantages [T],returns [T]
    """
    T = len(rewards)
    advantages = torch.zeros_like(rewards)
    last_gae = 0.0
    for t in reversed(range(T)):
        delta = rewards[t] + gamma * values[t+1] - values[t]
        advantages[t] = last_gae = delta + gamma * lam * last_gae
    returns = advantages + values[:-1]
    return advantages, returns

13.5 LLM-PPO:4 模型架构

13.5.1 角色与显存

把上述 RL 框架搬到 LLM 上,需要同时持有 4 个大模型

模型符号角色是否更新备注
Actor / Policyπθ生成回答(rollout)+ 训练主梯度
Referenceπref计算 KL 惩罚❌(冻结 SFT)仅前向
Rewardrϕ给完整回答打分❌(冻结)仅前向
Critic / ValueVψ估计 V(st),给 GAE 用第二梯度

GPU 显存预算(以 7B actor 为例,FP16/BF16):

模型参数梯度优化器状态 (Adam)总显存
Actor (7B, BF16, AdamW)14 GB14 GB56 GB (FP32 m,v + master)~84 GB
Critic (7B)14 GB14 GB56 GB~84 GB
Reference (7B, frozen)14 GB--14 GB
Reward (7B, frozen)14 GB--14 GB
激活值 (rollout 时)---20-40 GB
合计~220-240 GB

需要 ZeRO-3 + offload 才能在多卡 H100 上跑。

13.5.2 训练流程

                                     ┌──────────────┐
                                     │  Prompt set  │
                                     └──────┬───────┘


                            ┌──────────────────────────────┐
                            │  Stage 1: Rollout (生成阶段)  │
                            │                                │
                            │   π_θ  ─── generate ───►  y    │
                            │                                │
                            │   π_ref ─── forward ───►  log p_ref(y|x)
                            │                                │
                            │   r_φ   ─── forward ───►  r(x,y)
                            │                                │
                            │   V_ψ   ─── forward ───►  V(s_t)
                            │                                │
                            │   compute per-token rewards    │
                            │   compute GAE → Â_t, R̂_t       │
                            └──────────────┬─────────────────┘


                            ┌──────────────────────────────┐
                            │  Stage 2: Optimize (更新阶段) │
                            │                                │
                            │   for ppo_epoch in K:          │
                            │     for minibatch B:           │
                            │       L = L_clip + cv·L_v - ce·H│
                            │       θ ← θ - α·∇L            │
                            │       ψ ← ψ - α·∇L            │
                            └──────────────┬─────────────────┘


                                  (loop back to next iter)

13.5.3 Token-level reward 写法

把 RM 的稀疏标量奖励 rϕ(x,y) 与 token-level KL 惩罚结合:

rt={βlogπθ(yt|st)πref(yt|st),t<|y|rϕ(x,y)βlogπθ(y|y||s|y|)πref(y|y||s|y|),t=|y|

每个 token 都有 KL 惩罚,只有最后一个 token 拿到 RM 奖励。

13.5.4 完整伪代码

python
def llm_ppo_iteration(prompts, π_θ, V_ψ, π_ref, r_φ, β, ε, γ, λ):
    """LLM-PPO 一次迭代"""

    # ---------- Stage 1: Rollout ----------
    π_θ_old = copy_params(π_θ)   # 保存 old policy 用于比率
    rollouts = []

    for x in prompts:
        # 用旧策略生成回答
        y, logp_old = π_θ_old.generate(x, return_logprob=True)

        # 三个冻结模型前向
        with torch.no_grad():
            logp_ref = π_ref.forward(x, y)        # [|y|]
            r_score = r_φ(x, y)                    # scalar
            v_t = V_ψ(x, y)                        # [|y|]

        # token 级 reward
        kl_per_token = β * (logp_old - logp_ref)  # 实现细节:用 logπ_θ vs logπ_ref
        rewards = -kl_per_token.clone()
        rewards[-1] += r_score                     # 最后 token 加 RM 分

        # GAE
        v_extended = torch.cat([v_t, torch.zeros(1)])
        Â, R̂ = compute_gae(rewards, v_extended, γ, λ)

        rollouts.append({
            "x": x, "y": y, "logp_old": logp_old,
            "Â": Â, "R̂": R̂
        })

    # advantage 归一化(whitening)
    all_Â = torch.cat([r["Â"] for r in rollouts])
    Â_mean, Â_std = all_Â.mean(), all_Â.std()
    for r in rollouts:
        r["Â"] = (r["Â"] - Â_mean) / (Â_std + 1e-8)

    # ---------- Stage 2: Optimize ----------
    for epoch in range(K):
        for batch in make_minibatches(rollouts, batch_size):
            logp_new = π_θ.forward(batch["x"], batch["y"])
            v_new = V_ψ(batch["x"], batch["y"])

            ratio = torch.exp(logp_new - batch["logp_old"])

            # PPO-Clip
            surr1 = ratio * batch["Â"]
            surr2 = torch.clamp(ratio, 1-ε, 1+ε) * batch["Â"]
            L_clip = -torch.min(surr1, surr2).mean()

            # Value loss (with optional clipping)
            v_clipped = batch["v_old"] + (v_new - batch["v_old"]).clamp(-ε, ε)
            L_v1 = (v_new - batch["R̂"]) ** 2
            L_v2 = (v_clipped - batch["R̂"]) ** 2
            L_v = 0.5 * torch.max(L_v1, L_v2).mean()

            # Entropy bonus
            entropy = -(logp_new * logp_new.exp()).sum(-1).mean()  # 简化

            L = L_clip + 0.1 * L_v - 0.01 * entropy

            optimizer.zero_grad()
            L.backward()
            torch.nn.utils.clip_grad_norm_(params, 1.0)
            optimizer.step()

13.6 训练稳定性技巧

LLM-PPO 是出了名的"调参炼金"——这里整理工业实现中最有效的稳定性技巧。

13.6.1 Reward 处理

Reward whitening(批内归一化):

r=rmean(r)std(r)+ϵ

降低 RM 分布漂移带来的影响。

Reward clipping:限制极端值

r=clip(r,c,+c),c5

防止 outlier 拉爆梯度。InstructGPT 与 TRL 默认开启。

13.6.2 Advantage 处理

Advantage whitening(强烈推荐):

A^=A^mean(A^)std(A^)+ϵ

让 advantage 尺度与 clip 阈值 ε 解耦。

13.6.3 Value clipping

类似 policy clipping,对 value 也加裁剪:

Vclip=Vold+clip(VψVold,εv,+εv)Lvalue=12E[max((VψR^)2,(VclipR^)2)]

避免 critic 大幅震荡。OpenAI baselines 与 TRL 默认开启。

13.6.4 Reference model 周期更新

DeepSeek-V3 等大模型采用:每 N 步(典型 400)把 πrefπθ,等价于 trust region 周期重锚。好处:

  • πθ 已经远离原 SFT 时,KL(π_θ||π_ref_old) 变得过大、过紧;
  • 重锚后可以继续优化;
  • 类似 target network 在 DQN 中的作用。

13.6.5 Adaptive KL controller

InstructGPT 使用 PI 控制器锁定目标 KL:

python
class AdaptiveKLController:
    def __init__(self, init_β=0.2, target_kl=6.0, K_p=0.1):
        self= init_β
        self.target = target_kl
        self.K_p = K_p

    def update(self, current_kl, n_steps=1):
        proportional = (current_kl - self.target) / self.target
        proportional = max(-0.2, min(0.2, proportional))
        self*= 1 + self.K_p * proportional * n_steps

13.6.6 Mini-batch + 多 epoch

PPO 的核心优势是数据复用:

  • rollout batch:一次生成 256-1024 prompts × 各 N 个回答;
  • ppo_epochs:在同一批 rollout 上做 2-4 次更新;
  • mini_batch:每个 epoch 内分成更小的 batch(典型 32-64)多次更新。

经验:

  • ppo_epochs > 4 容易 overfit 当前 rollout;
  • mini_batch 太小则方差大;
  • 总 update steps = ppo_epochs × (rollout_size / mini_batch)。

13.6.7 Gradient clipping

python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

防止偶发大梯度毁掉模型。RLHF 必备。

13.6.8 Mixed precision 与 ZeRO

4 模型显存压力极大,需要:

  • BF16/FP16 训练;
  • ZeRO-3(参数、梯度、优化器状态分片);
  • Gradient checkpointing(用计算换显存);
  • CPU offload(把 reference / reward 模型放 CPU 或别的 GPU);
  • vLLM / SGLang for rollout:用专门推理引擎做生成阶段。

13.7 工程参考:TRL 的 PPOTrainer

Hugging Face TRL 是最流行的开源 RLHF 实现。PPOTrainer 关键默认值:

超参默认说明
learning_rate1.41e-5actor 学习率
mini_batch_size1per-GPU
batch_size256全局 rollout
ppo_epochs4每批 rollout 的更新次数
cliprange0.2ε
cliprange_value0.2value clip
gamma1.0折扣
lam0.95GAE
vf_coef0.1value loss 权重
init_kl_coef0.2β 初值
target_kl6.0adaptive controller 目标
whiten_rewardsTruereward 归一化

调用流程:

python
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

config = PPOConfig(
    model_name="meta-llama/Llama-2-7b-hf",
    learning_rate=1e-5,
    batch_size=128,
    ppo_epochs=4,
)

# Actor + Critic 共享 backbone(用 ValueHead 给 actor 加一个 value head)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
    "path/to/sft_model")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    "path/to/sft_model")     # 参考模型
reward_model = ...           # 自己加载 RM

ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)

for batch in dataloader:
    # Stage 1: rollout
    queries = batch["input_ids"]
    responses = ppo_trainer.generate(queries, **gen_kwargs)

    # Stage 2: 计算奖励
    rewards = reward_model(queries + responses)

    # Stage 3: PPO 更新
    stats = ppo_trainer.step(queries, responses, rewards)

    ppo_trainer.log_stats(stats, batch, rewards)

更复杂的多机训练用 OpenRLHF / veRL:通过 Ray 调度,把 rollout 用 vLLM、训练用 PyTorch FSDP,实现高吞吐。


13.8 PPO 的常见问题与诊断

现象可能原因解决
训练初期 loss 爆炸RM 输出尺度太大、未归一化reward whitening、clipping
KL 飙升不可控β 太小或 lr 太大adaptive KL、降 lr
Value loss 不下降critic 学习率太低 / 初始化差warm-up critic、提 vf_coef
输出退化(重复短语)reward hacking早停、加 KL、RM ensemble
训练显存 OOM4 模型 + 长序列ZeRO-3、offload、shorter rollout
Rollout 太慢单卡 generate 慢vLLM 接入、多卡推理
学不动ε 太小 / advantage 标准化问题ε=0.3、检查 Â 分布

经验法则:先看 KL,再看 value loss,再看 reward。KL 是 RLHF 的"温度计"。


13.9 PPO 的局限与替代

PPO 在 LLM 上虽然有效,但代价巨大:

问题程度替代方案
4 模型显存严重DPO(无 RM、无 critic)、GRPO(无 critic)
调参复杂(10+ 超参)严重DPO(仅 β)、SimPO
Rollout 慢(生成阶段)显著DPO(无需 rollout)
训练不稳定显著DPO/IPO 的 closed-form
Reward hacking严重DPO + iterative、DAPO 等

后续章节会逐一介绍这些替代方案。但需要强调:PPO 仍是当前最强对齐能力的代表,OpenAI、Anthropic 至今主用。在数据足够、调参得当时,PPO + 大 RM 仍优于 DPO。


本章小结

  • 策略梯度定理J=E[logπA],是所有 RL 算法的起点;
  • 重要性采样:让 PPO 能用旧策略采样的数据多次更新,但需要约束策略变化;
  • PPO-Clip:通过裁剪比率到 [1ε,1+ε] 实现"软信任域",简单高效;
  • GAE:用指数加权融合所有 n-step TD,平衡偏差-方差,实践 λ=0.95
  • 4 模型架构:actor / critic / reference / reward 同时驻留,显存压力是主要挑战;
  • 稳定性技巧:reward/advantage whitening、value clipping、adaptive KL、reference refresh、grad clip 缺一不可;
  • TRL/OpenRLHF/veRL 提供工业级实现,但调参依然是艺术。

思考题

  1. 为什么 PPO 用比率 rt=πθ/πθold 而不是直接优化 logπθA^(即 vanilla policy gradient)?θ=θoldθJISθJPG 等价吗?请通过链式法则验证。

  2. GAE 推导:完成 λ=1A^tGAE 化简为蒙特卡洛减基线的过程:

l=0γlδt+l=l=0γlrt+lV(st)

提示:注意 δt=rt+γV(st+1)V(st),连续两项的 V 项会形成望远镜求和。

  1. 工程题:某团队报告 PPO 训练中 KL 在前 100 步缓慢上升,第 150 步突然飙升 10×,loss 也炸了。请列举 3-5 种可能原因和对应的诊断/解决方法。如果你只能加一个监控量,你会选哪个?

基于 MIT 协议发布