项目地址: GRPO vs GSPO
GRPO原文: Group Relative Policy Optimization
GSPO原文: Group Sequence Policy Optimization
- 数据集: GSM8K
- 参考模型: qwen2.5-1.5B-Instruct
- 目标模型: qwen2.5-1.5B-Instruct
- 硬件配置: 3 × AutoDL vGPU-32G (GPU0/1用于训练, GPU2用于采样)
- 训练步数: 200 steps (60min)

准确率评估包含答案和格式两部分:
- GSPO算法在50个训练步左右基本稳定并到达峰值, 答案准确率为0.6左右, 格式准确率为0.99左右
- GRPO算法在120个训练步左右基本稳定并到达峰值, 答案准确率为0.6左右, 格式准确率为0.99左右
从结果来看GSPO训练速度明显优于GRPO, 消耗更少的时间达到稳定状态. 从模型特性来解释, GSPO模型训练时方差更小, 在矫正输出分布时有更强的确定性能够快速调整, 宏观上体现为更快得收敛至稳定值. 训练至200步后两种方法训练的结果基本接近, 应该是达到模型极限.
GRPO算法
目标函数:
J
G
R
P
O
(
θ
)
=
E
[
1
G
∑
i
=
1
G
1
∣
y
i
∣
∑
t
=
1
∣
y
i
∣
min
(
w
i
,
t
(
θ
)
A
^
i
,
t
,
clip
(
w
i
,
t
(
θ
)
,
1
−
ϵ
,
1
+
ϵ
)
A
^
i
,
t
)
]
J_{GRPO}(\theta)=E\left[\frac{1}{G} \sum_{i=1}^G \frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \min \left(w_{i, t}(\theta) \hat{A}_{i, t}, \text{clip}\left(w_{i, t}(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_{i, t}\right)\right]
JGRPO(θ)=E
G1i=1∑G∣yi∣1t=1∑∣yi∣min(wi,t(θ)A^i,t,clip(wi,t(θ),1−ϵ,1+ϵ)A^i,t)
w
i
,
t
(
θ
)
=
π
θ
(
y
i
,
t
∣
x
,
y
i
<
t
)
π
θ
o
l
d
(
y
i
,
t
∣
x
,
y
i
<
t
)
w_{i, t}(\theta) = \frac{\pi_\theta(y_{i,t} \mid x, y_{i<t})}{\pi_{\theta_{old}}(y_{i,t} \mid x, y_{i<t})}
wi,t(θ)=πθold(yi,t∣x,yi<t)πθ(yi,t∣x,yi<t)
A
^
i
,
t
=
A
^
i
=
r
(
x
,
y
i
)
−
mean
(
{
r
(
x
,
y
i
)
}
i
=
1
G
)
std
(
{
r
(
x
,
y
i
)
}
i
=
1
G
)
\hat{A}_{i, t}=\hat{A}_i=\frac{r(x, y_i)-\text{mean}(\{r(x, y_i)\}_{i=1}^G)}{\text{std}(\{r(x, y_i)\}_{i=1}^G)}
A^i,t=A^i=std({r(x,yi)}i=1G)r(x,yi)−mean({r(x,yi)}i=1G)
其中 w i , t w_{i,t} wi,t表示token级别的重要性采样, A ^ i \hat{A}_{i} A^i表示序列的组内回报值:
代码实现:
ref_policy_log_probs_ = ref_policy_log_probs[:, prefix_len-1:] # 参考策略概率分布
old_policy_log_probs_ = old_policy_log_probs[:, prefix_len-1:] # 旧策略概率分布
new_policy_log_probs_ = new_policy_log_probs[:, prefix_len-1:] # 新策略概率分布
attention_mask_ = attention_mask[:, prefix_len:]
importance_ratio = torch.exp(new_policy_log_probs_ - old_policy_log_probs_) # 重要性采样
cliped_ratio = torch.clip(importance_ratio, 1 - clip_epsilon, 1 + clip_epsilon) # 相似度裁剪
importance_term = importance_ratio * advantages
clip_term = cliped_ratio * advantages
kl_term = torch.exp(ref_policy_log_probs_ - new_policy_log_probs_) - (ref_policy_log_probs_ - new_policy_log_probs_) - 1 # kl散度
objective_function = torch.min(importance_term, clip_term) - kl_beta * kl_term # 目标函数
per_token_loss = -objective_function # loss函数
loss = ((per_token_loss * attention_mask_).sum(dim=1) / attention_mask_.sum(dim=1)).mean() # batch的均值作为最终loss(只统计有效token的loss)
GSPO算法
目标函数:
J
G
S
P
O
(
θ
)
=
E
[
1
G
∑
i
=
1
G
min
(
s
i
(
θ
)
A
^
i
,
clip
(
s
i
(
θ
)
,
1
−
ϵ
,
1
+
ϵ
)
A
^
i
)
]
J_{GSPO}(\theta)=E\left[\frac{1}{G} \sum_{i=1}^G \min \left(s_i(\theta) \hat{A}_i, \text{clip}\left(s_i(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_i\right)\right]
JGSPO(θ)=E[G1i=1∑Gmin(si(θ)A^i,clip(si(θ),1−ϵ,1+ϵ)A^i)]
s
i
(
θ
)
=
(
π
θ
(
y
i
∣
x
)
π
θ
old
(
y
i
∣
x
)
)
1
∣
y
i
∣
=
exp
(
1
∣
y
i
∣
∑
t
=
1
∣
y
i
∣
log
π
θ
(
y
i
,
t
∣
x
,
y
i
,
<
t
)
π
θ
old
(
y
i
,
t
∣
x
,
y
i
,
<
t
)
)
s_i(\theta)=\left(\frac{\pi_\theta\left(y_i \mid x\right)}{\pi_{\theta_{\text {old }}}\left(y_i \mid x\right)}\right)^{\frac{1}{\left|y_i\right|}}=\exp \left(\frac{1}{\left|y_i\right|} \sum_{t=1}^{\left|y_i\right|} \log \frac{\pi_\theta\left(y_{i, t} \mid x, y_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(y_{i, t} \mid x, y_{i,<t}\right)}\right)
si(θ)=(πθold (yi∣x)πθ(yi∣x))∣yi∣1=exp
∣yi∣1t=1∑∣yi∣logπθold (yi,t∣x,yi,<t)πθ(yi,t∣x,yi,<t)
其中 s i ( θ ) s_i(\theta) si(θ)表示序列重要性采样, 与序列组内回报 A ^ i \hat{A}_i A^i颗粒度是对齐的.
代码实现:
batch_size = ref_policy_log_probs.shape[0]
# 取生成部分的概率分布
ref_policy_log_probs_ = ref_policy_log_probs[:, prefix_len-1:] # token_0裁剪了, 因此需要裁剪的长度为prefix_len-1
old_policy_log_probs_ = old_policy_log_probs[:, prefix_len-1:]
new_policy_log_probs_ = new_policy_log_probs[:, prefix_len-1:]
attention_mask_ = attention_mask[:, prefix_len:] # attention_mask维度中token_0的位置没裁剪, 因此需要裁剪的长度为prefix_len
# 计算有效序列, 遮掩pad_token
valid_seq_len = attention_mask_.sum(dim=1)
new_old_log_probs_ = (new_policy_log_probs_ - old_policy_log_probs_) * attention_mask_
ref_new_log_probs_ = (ref_policy_log_probs_ - new_policy_log_probs_) * attention_mask_
# 序列级别的重要性采样
importance_ratio = torch.exp(new_old_log_probs_.sum(dim=1) / valid_seq_len).view(batch_size, 1) # batch_size * 1
cliped_ratio = torch.clip(importance_ratio, 1 - clip_epsilon, 1 + clip_epsilon) # batch_size * 1
importance_term = importance_ratio * advantages # batch_size * 1
clip_term = cliped_ratio * advantages # batch_size * 1
kl_term = torch.exp(ref_new_log_probs_.sum(dim=1) / valid_seq_len) - (ref_new_log_probs_.sum(dim=1) / valid_seq_len) - 1
kl_term = kl_term.view(batch_size, 1)
objective_function = torch.min(importance_term, clip_term) - kl_beta * kl_term
sequence_loss = -objective_function
# 批次平均损失作为总损失
loss = sequence_loss.mean()
1730

被折叠的 条评论
为什么被折叠?



