Introduction
- 作者提出 O1-Pruner,通过将长度和精度同时纳入奖励函数,鼓励模型进行高效推理
Method
- Length Disharmony. 对于相同的问题,推理模型中采样得到的更长回复的正确率未必高于更短回复,说明推理步骤中存在大量冗余
- Length-Harmonizing Fine-Tuning (O1-Pruner). 优化目标如下所示,既要保证模型输出长度比 reference model 更短,又要保证精度更高
上述条件优化目标可以转化为如下形式:
其中, λ ≥ 0 \lambda\geq 0 λ≥0, A ( ⋅ ) A(\cdot) A(⋅) 根据回答正确与否返回 0 或 1, L ˉ r e f ( x ) = E y ′ ∼ π r e f ( y ∣ x ) L ( y ′ ) \bar L_{ref}(x)=\mathbb E_{y'\sim\pi_{ref}(y|x)}L(y') Lˉref(x)=Ey′∼πref(y∣x)L(y′) 和 A ˉ r e f ( x ) = E y ′ ∼ π r e f ( y ∣ x ) A ( x , y ′ ) \bar A_{ref}(x)=\mathbb E_{y'\sim\pi_{ref}(y|x)}A(x,y') Aˉref(x)=Ey′∼πref(y∣x)A(x,y′) 可以通过采样做近似:
最终的优化目标为:
为了降低训练开销,作者训练时采用 off-policy training,全部数据都提前采样自 reference model,将上述优化目标直接作为 advantage function R L H ( x , y ) R_{LH}(x,y) RLH(x,y) 使用 PPO-style loss 进行训练
其中,
Experiments
- Training Dataset. MATH (randomly sample 5k samples from 10k math problem of high school level)
- Baselines. (i) Fast-Solving Prompt. (ii) SFT: 对于每个问题,从采样回复中选择最短的两个正确回复组成数据集. (iii) DPO: 最短的两个正确回复作为正样本,最长回复作为负样本
- Evaluation Metric. 作者定义 Accuracy-Efficiency Score (AES)
其中, γ > β > 0 \gamma>\beta>0 γ>β>0 用于惩罚精度下降
- Main Results. 神奇的是,o1-pruner 训练后模型精度不降反增,有可能是因为这两个模型都是 unsaturated model,继续做强化学习本身就能提点,进而掩盖了回复长度减少带来的精度损失。应该在训练程度更高的一些模型上进行实验 (e.g. QwQ-32B)
- Ablation Study. (1) Ablation on Hyper-parameter Sensitivity.
(2) Ablation on Difficulty Levels. 在更难数据上训练有利于增加模型精度,但也会使得模型输出长度更长