GROP详解
GRPO算法是在PPO算法的基础上进化而来的,在搞清楚GRPO算法前,需要先了解PPO算法是如何在LLM的Post Training中应用的。
本文主要参考DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models。
论文链接:https://arxiv.org/abs/2402.03300
PPO概述
PPO的理论基础已在https://blog.youkuaiyun.com/ooblack/article/details/144198573给出,这里只简述PPO在LLM中的应用。
其中最常用用法如下,目标是最大化JPPO(θ)\mathcal{J}_{PPO}(\theta)JPPO(θ):
JPPO(θ)=E[q∼P(Q),o∼πθold(O∣q)]1∣o∣∑t=1∣o∣min(πθ(ot∣q,o<t)πθold(ot∣q,o<t),clip(πθ(ot∣q,o<t)πθold(ot∣q,o<t),1−ϵ,1+ϵ))At
\mathcal{J}_{PPO}(\theta) = \mathbb{E}[q \sim P(Q), o \sim \pi_{\theta_{old}}(O|q)] \frac{1}{|o|}\sum_{t=1}^{|o|} \min \left( \frac{\pi_{\theta}(o_t|q, o_{<t})}{\pi_{\theta_{old}}(o_t|q, o_{<t})}, clip \left( \frac{\pi_{\theta}(o_t|q, o_{<t})}{\pi_{\theta_{old}}(o_t|q, o_{<t})} , 1-\epsilon, 1+\epsilon \right) \right) A_t
JPPO(θ)=E[q∼P(Q),o∼πθold(O∣q)]∣o∣1t=1∑∣o∣min(πθold(ot∣q,o<t)πθ(ot∣q,o<t),clip(πθold(ot∣q,o<t)πθ(ot∣q,o<t),1−ϵ,1+ϵ))At
其中πθ{\pi}_\thetaπθ和πθold\pi_{\theta_{old}}πθold分别为当前和旧策略模型,qqq、ooo分别为问题数据集和旧策略πθold\pi_{\theta_{old}}πθold中采样的问题和输出,ϵ\epsilonϵ为PPO中引入的用于稳定训练的剪裁相关超参数。AtA_tAt是优势,它是通过应用广义优势估计GAE计算的。
使用PPO来更新模型参数,πθold\pi_{\theta_{old}}πθold指的是未更新参数前的的模型,πθ{\pi}_\thetaπθ指的是每一步更新后的模型。E[q∼P(Q),o∼πθold(O∣q)]\mathbb{E}[q \sim P(Q), o \sim \pi_{\theta_{old}}(O|q)]E[q∼P(Q),o∼πθold(O∣q)]可以理解为采样过程。后面那部分是非常常规的PPO算法。而对于优势函数AtA_tAt的计算,有如下计算公式。
AtGAE(γ,λ)=∑l=0∞(γλ)l(rt+l+γV(st+l+1)−V(st+l))
A_t^{\text{GAE}(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \left( r_{t+l} + \gamma V(s_{t+l+1}) - V(s_{t+l}) \right)
AtGAE(γ,λ)=l=0∑∞(γλ)l(rt+l+γV(st+l+1)−V(st+l))
γ\gammaγ 是折扣因子,λ\lambdaλ是GAEGAEGAE的超参数,rtr_trt是时间步 ttt的奖励,V(st)V(s_t)V(st)是状态 sts_tst的价值函数估计,而这个价值模型的体量一般与要训练的策略模型也就是LLM相当。
而rtr_trt的计算公式如下:
rt=rϕ(q,o≤t)−βlogπθ(ot∣q,o<t)πref(ot∣q,o<t)
r_t = r_\phi (q, o_{\leq t}) - \beta \log \frac{\pi_\theta (o_t | q, o_{< t})}{\pi_{ref}(o_t | q, o_{< t})}
rt=rϕ(q,o≤t)−βlogπref(ot∣q,o<t)πθ(ot∣q,o<t)
rϕ(q,o≤t)r_\phi (q, o_{\leq t})rϕ(q,o≤t)是专门训练的奖励模型给出的,而βlogπθ(ot∣q,o<t)πref(ot∣q,o<t)\beta \log \frac{\pi_\theta (o_t | q, o_{< t})}{\pi_{ref}(o_t | q, o_{< t})}βlogπref(ot∣q,o<t)πθ(ot∣q,o<t)则是对每一次奖励都计算一次KL散度进行约束。也就是说每个token生成都要计算一次KL散度。
综上,我们可以看出PPO算法在计算过程中有两个模型要训练,也就是之前提到的Actor网络和Critic网络,也就是LLM与V网络。然后在每个token生成的时候都需要计算KL散度比较浪费资源。GRPO针对这两个问题进行了优化。
GRPO过程
VVV函数的作用就是在计算优势函数AtA_tAt时是为了降低方差而被当做baseline,但是LLM的奖励模型的性质就决定了它只会为每个回答ooo的最后一个token分配奖励rrr而其余的token的奖励都是0,就是因为这个性质,我们很难在每个token处训练出准确的价值函数。
基于这个思想,GRPO决定不用VVV函数,而是在旧策略$ \pi_{\theta_{old}}$中采样多个输出,将输出的奖励平均值作为baseline来降低方差。
优化目标:
JGRPO(θ)=E[q∼P(Q),{oi}i=1G∼πθold(O∣q)]1G∑i=1G1∣oi∣∑t=1∣oi∣{min[πθ(oi,t∣q,oi,<t)πθ,id(oi,t∣q,oi,<t)A^i,t,clip(πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t),1−ϵ,1+ϵ)A^i,t]−βDKL[πθ∥πref]}
\begin{align*}
\mathcal{J}_{GRPO}(\theta) &= \mathbb{E}[q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{old}}(O|q)] \\
&\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left\{ \min \left[ \frac{\pi_{\theta}(o_{i,t}|q, o_{i,<t})}{\pi_{\theta, id}(o_{i,t}|q, o_{i,<t})} \hat{A}_{i,t}, \text{clip} \left( \frac{\pi_{\theta}(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right] - \beta \mathbb{D}_{KL} \left[ \pi_{\theta} \| \pi_{ref} \right] \right\}
\end{align*}
JGRPO(θ)=E[q∼P(Q),{oi}i=1G∼πθold(O∣q)]G1i=1∑G∣oi∣1t=1∑∣oi∣{min[πθ,id(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)A^i,t,clip(πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t),1−ϵ,1+ϵ)A^i,t]−βDKL[πθ∥πref]}
其中,ϵ\epsilonϵ 和 β\betaβ 是超参数,A^i,t\hat{A}_{i,t}A^i,t 是基于每个组内的相对回报计算的优势。GRPO 利用组相对方式计算优势,这与奖励模型的性质很吻合,因为奖励模型通常基于同一问题的输出比较的数据集进行训练。而KL散度也不再是添加到奖励函数里面了,而是直接添加在损失函数上,降低了优势函数A^i,t\hat{A}_{i,t}A^i,t计算的复杂性。
其中
DKL[πθ∣∣πref]=πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−logπref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−1
\mathbb{D}_{KL}[\pi_{\theta}||\pi_{ref}]=\frac{\pi_{ref}(o_{i,t}|q,o_{i,<t})}{\pi_{\theta}(o_{i,t}|q,o_{i,<t})}-\log\frac{\pi_{ref}(o_{i,t}|q,o_{i,<t})}{\pi_{\theta}(o_{i,t}|q,o_{i,<t})}-1
DKL[πθ∣∣πref]=πθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−logπθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−1
与传统的KL散度计算方法不同,这里采用了无偏估计,使计算出来的惩罚项每一次都是正的。
其中A^i,t\hat{A}_{i,t}A^i,t计算方法如下:
A^i,t=r~i−ri−mean(r)std(r)
\hat{A}_{i,t} = \tilde{r}_i - \frac{r_i - \text{mean}(r)}{\text{std}(r)}
A^i,t=r~i−std(r)ri−mean(r)
相对于PPO来说精简了非常多,为了保证训练的稳定性,加了一个标准化。
GRPO的完整算法流程如下:
其中可以看到πref\pi_{ref}πref是最初的模型,在不会有变化,每次模型的变化都不能与πref\pi_{ref}πref差别过大,保证输出的质量。
总结
回顾开头给出的图片,黄色的是需要训练过程中更新参数的模型,蓝色的是不需要更新参数的模型,GRPO算法相对于PPO算法少训练了一个价值模型,而且大大简化了优势函数的计算,节约了计算资源。