trl中的PPO代码解析(炒冷饭版)

不说其他的解释,上来就看代码。建议先对PPO的整体流程有了解。

trl的版本为0.4.0,注:【新版的trl中代码更复杂,如果只是想读懂PPO具体怎么用trl实现的,0.4.0版本即可】
在这里插入图片描述

step1: rollout

ppo_trainer.generate()函数使用policy model生成rollout
在这里插入图片描述

step2:evaluate

使用reward model对step1产生的rollout进行evaluate,获得一个标量的score,这个score并不是rewards,step4计算得到的才是最终的rewards
在这里插入图片描述

step3: logprobs

从old policy model和ref model中获得rollout的logits, values等值,用于后续计算rewards。
在这里插入图片描述

对应的代码部分为:
在这里插入图片描述

step4: rewards

注意:这里产生的变量中,score变成了rewards。
在这里插入图片描述
PPO中,为了防止policy model过度偏离ref model,会在计算rewards过程中额外增加一项KL散度,
r e w a r d s = s c o r e − λ K L ( π θ ( a ∣ s ) ∣ ∣ π θ r e f ( a ∣ s ) ) rewards = score - \lambda KL(\pi_{\theta}(a|s)||\pi_{\theta_{ref}}(a|s)) rewards=scoreλKL(πθ(as)∣∣πθref(as))
对应的代码部分为:
在这里插入图片描述

step5: train_minibatch

注意,这里的logprobs, vpreds, 与old_logprobs, old_values均是policy LM产生的,但是参数不一样。
在这里,产生logprobs, vpreds的policy LM的参数是会按照mini_batch_size进行不断更新的,所以每个mini_batch_size对于的new policy LM的参数是不一样的。而产生old_logprobs, old_values的old policy LM的参数对于每个mini_batch_size是不变的。

可以按照一般的训练神经网络的过程理解:产生old_logprobs, old_values的old policy LM的参数是按照epoch更新的,而产生logprobs, vpreds的new policy LM是按照step更新的。
在这里插入图片描述

对应的代码部分为:
在这里插入图片描述

step6: advantages

根据old_values, rewards,计算优势,在进一步计算出returns
在这里插入图片描述
对应的代码部分为(代码中的values为old_values):
在这里插入图片描述

step7: critic_loss

critic loss通常是通过均方误差(MSE)来计算。对于每一个状态,我们都有一个由critic网络预测的预期回报 v p r e d s vpreds vpreds,以及一个真实的回报 r e t u r n s returns returns,critic_loss是二者的平方差。
对应的代码部分为:
在这里插入图片描述

step8: actor loss

actor loss是基于策略梯度的损失函数,用于优化policy。在ppo中,通常使用一种称为重要性采样(importance sampling)的技术来计算策略梯度。
m a x i m i z e θ    E π θ ′ [ m i n ( r t ( θ ) A π θ o l d ( s , a ) ,   c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A π θ o l d ( s , a )   ) ] maximize_{\theta} \ \ E_{\pi_{\theta^{'}}}[min( r_{t}(\theta) A^{\pi_{\theta_{old}}}(s,a),\ clip(r_{t}(\theta), 1-\epsilon, 1+\epsilon)A^{\pi_{\theta_{old}}}(s,a)\ )] maximizeθ  Eπθ[min(rt(θ)Aπθold(s,a), clip(rt(θ),1ϵ,1+ϵ)Aπθold(s,a) )]
其中, r t ( θ ) = π θ ( a ∣ s ) π θ o l d ( a ∣ s ) r_{t}(\theta) = {\pi_{\theta}(a|s) \over \pi_{\theta_{old}}(a|s)} rt(θ)=πθold(as)πθ(as),这一项是新旧策略的比率, A π θ o l d ( s , a ) A^{\pi_{\theta_{old}}}(s,a) Aπθold(s,a)是优势函数,clip是裁剪函数,将其裁剪到 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon,1+\epsilon] [1ϵ,1+ϵ]之间。这个损失函数的目标是,最大化新策略的期望回报,同时限制新旧策略之间的差异。
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值