GRPO与GSPO算法训练对比

Qwen3-8B

Qwen3-8B

文本生成
Qwen3

Qwen3 是 Qwen 系列中的最新一代大型语言模型,提供了一整套密集型和专家混合(MoE)模型。基于广泛的训练,Qwen3 在推理、指令执行、代理能力和多语言支持方面取得了突破性进展

项目地址: GRPO vs GSPO
GRPO原文: Group Relative Policy Optimization
GSPO原文: Group Sequence Policy Optimization

  • 数据集: GSM8K
  • 参考模型: qwen2.5-1.5B-Instruct
  • 目标模型: qwen2.5-1.5B-Instruct
  • 硬件配置: 3 × AutoDL vGPU-32G (GPU0/1用于训练, GPU2用于采样)
  • 训练步数: 200 steps (60min)

grpo_vs_gspo

准确率评估包含答案和格式两部分:

  • GSPO算法在50个训练步左右基本稳定并到达峰值, 答案准确率为0.6左右, 格式准确率为0.99左右
  • GRPO算法在120个训练步左右基本稳定并到达峰值, 答案准确率为0.6左右, 格式准确率为0.99左右

从结果来看GSPO训练速度明显优于GRPO, 消耗更少的时间达到稳定状态. 从模型特性来解释, GSPO模型训练时方差更小, 在矫正输出分布时有更强的确定性能够快速调整, 宏观上体现为更快得收敛至稳定值. 训练至200步后两种方法训练的结果基本接近, 应该是达到模型极限.

GRPO算法

目标函数:
J G R P O ( θ ) = E [ 1 G ∑ i = 1 G 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ min ⁡ ( w i , t ( θ ) A ^ i , t , clip ( w i , t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ i , t ) ] J_{GRPO}(\theta)=E\left[\frac{1}{G} \sum_{i=1}^G \frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \min \left(w_{i, t}(\theta) \hat{A}_{i, t}, \text{clip}\left(w_{i, t}(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_{i, t}\right)\right] JGRPO(θ)=E G1i=1Gyi1t=1yimin(wi,t(θ)A^i,t,clip(wi,t(θ),1ϵ,1+ϵ)A^i,t)
w i , t ( θ ) = π θ ( y i , t ∣ x , y i < t ) π θ o l d ( y i , t ∣ x , y i < t ) w_{i, t}(\theta) = \frac{\pi_\theta(y_{i,t} \mid x, y_{i<t})}{\pi_{\theta_{old}}(y_{i,t} \mid x, y_{i<t})} wi,t(θ)=πθold(yi,tx,yi<t)πθ(yi,tx,yi<t)
A ^ i , t = A ^ i = r ( x , y i ) − mean ( { r ( x , y i ) } i = 1 G ) std ( { r ( x , y i ) } i = 1 G ) \hat{A}_{i, t}=\hat{A}_i=\frac{r(x, y_i)-\text{mean}(\{r(x, y_i)\}_{i=1}^G)}{\text{std}(\{r(x, y_i)\}_{i=1}^G)} A^i,t=A^i=std({r(x,yi)}i=1G)r(x,yi)mean({r(x,yi)}i=1G)

其中 w i , t w_{i,t} wi,t表示token级别的重要性采样, A ^ i \hat{A}_{i} A^i表示序列的组内回报值:

代码实现:

ref_policy_log_probs_ = ref_policy_log_probs[:, prefix_len-1:] # 参考策略概率分布
old_policy_log_probs_ = old_policy_log_probs[:, prefix_len-1:] # 旧策略概率分布
new_policy_log_probs_ = new_policy_log_probs[:, prefix_len-1:] # 新策略概率分布
attention_mask_       = attention_mask[:, prefix_len:]

importance_ratio = torch.exp(new_policy_log_probs_ - old_policy_log_probs_) # 重要性采样
cliped_ratio = torch.clip(importance_ratio, 1 - clip_epsilon, 1 + clip_epsilon) # 相似度裁剪
importance_term = importance_ratio * advantages
clip_term = cliped_ratio * advantages

kl_term = torch.exp(ref_policy_log_probs_ - new_policy_log_probs_) - (ref_policy_log_probs_ - new_policy_log_probs_) - 1 # kl散度

objective_function = torch.min(importance_term, clip_term) - kl_beta * kl_term # 目标函数
per_token_loss = -objective_function # loss函数

loss = ((per_token_loss * attention_mask_).sum(dim=1) / attention_mask_.sum(dim=1)).mean() # batch的均值作为最终loss(只统计有效token的loss)

GSPO算法

目标函数:
J G S P O ( θ ) = E [ 1 G ∑ i = 1 G min ⁡ ( s i ( θ ) A ^ i , clip ( s i ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ i ) ] J_{GSPO}(\theta)=E\left[\frac{1}{G} \sum_{i=1}^G \min \left(s_i(\theta) \hat{A}_i, \text{clip}\left(s_i(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_i\right)\right] JGSPO(θ)=E[G1i=1Gmin(si(θ)A^i,clip(si(θ),1ϵ,1+ϵ)A^i)]
s i ( θ ) = ( π θ ( y i ∣ x ) π θ old  ( y i ∣ x ) ) 1 ∣ y i ∣ = exp ⁡ ( 1 ∣ y i ∣ ∑ t = 1 ∣ y i ∣ log ⁡ π θ ( y i , t ∣ x , y i , < t ) π θ old  ( y i , t ∣ x , y i , < t ) ) s_i(\theta)=\left(\frac{\pi_\theta\left(y_i \mid x\right)}{\pi_{\theta_{\text {old }}}\left(y_i \mid x\right)}\right)^{\frac{1}{\left|y_i\right|}}=\exp \left(\frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_\theta\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right) si(θ)=(πθold (yix)πθ(yix))yi1=exp yi1t=1yilogπθold (yi,tx,yi,<t)πθ(yi,tx,yi,<t)

其中 s i ( θ ) s_i(\theta) si(θ)表示序列重要性采样, 与序列组内回报 A ^ i \hat{A}_i A^i颗粒度是对齐的.

代码实现:

batch_size = ref_policy_log_probs.shape[0]

# 取生成部分的概率分布
ref_policy_log_probs_ = ref_policy_log_probs[:, prefix_len-1:] # token_0裁剪了, 因此需要裁剪的长度为prefix_len-1
old_policy_log_probs_ = old_policy_log_probs[:, prefix_len-1:]
new_policy_log_probs_ = new_policy_log_probs[:, prefix_len-1:]
attention_mask_       = attention_mask[:, prefix_len:]         # attention_mask维度中token_0的位置没裁剪, 因此需要裁剪的长度为prefix_len

# 计算有效序列, 遮掩pad_token
valid_seq_len = attention_mask_.sum(dim=1)
new_old_log_probs_ = (new_policy_log_probs_ - old_policy_log_probs_) * attention_mask_
ref_new_log_probs_ = (ref_policy_log_probs_ - new_policy_log_probs_) * attention_mask_

# 序列级别的重要性采样
importance_ratio = torch.exp(new_old_log_probs_.sum(dim=1) / valid_seq_len).view(batch_size, 1) # batch_size * 1
cliped_ratio = torch.clip(importance_ratio, 1 - clip_epsilon, 1 + clip_epsilon) # batch_size * 1
importance_term = importance_ratio * advantages # batch_size * 1
clip_term = cliped_ratio * advantages # batch_size * 1

kl_term = torch.exp(ref_new_log_probs_.sum(dim=1) / valid_seq_len) - (ref_new_log_probs_.sum(dim=1) / valid_seq_len) - 1
kl_term = kl_term.view(batch_size, 1)

objective_function = torch.min(importance_term, clip_term) - kl_beta * kl_term
sequence_loss = -objective_function

# 批次平均损失作为总损失
loss = sequence_loss.mean()

您可能感兴趣的与本文相关的镜像

Qwen3-8B

Qwen3-8B

文本生成
Qwen3

Qwen3 是 Qwen 系列中的最新一代大型语言模型,提供了一整套密集型和专家混合(MoE)模型。基于广泛的训练,Qwen3 在推理、指令执行、代理能力和多语言支持方面取得了突破性进展

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值