Improve Mathematical Reasoning in LanguageModels by Automated Process Supervision
arxiv: https://arxiv.org/abs/2406.06592
问题背景
COT和SC-COT对于模型推理能力的提升仍然有效,已有研究提出用一个验证器去帮助LLM提升推理能力。采用ORM结果验证器岁可以对最终结果生成一个信号,但是不能去奖励或者惩罚中间步骤。采用PRM可以对中间步骤在更细粒度的视角下,对中间步骤进行奖励或者惩罚。受到AlphaGo Zero的启发,本文提出了一个分而治之的蒙特卡洛树搜索算法OmegaPRM,来有效的收集高质量过程监督数据。
本文方法
这篇论文提出了一种名为OmegaPRM的新型分治风格蒙特卡罗树搜索(MCTS)算法,通过引入二分搜索算法来高效识别COT中的第一个错误快速定位错误位置,用于自动收集高质量的过程监督数据。
(1)蒙特卡洛过程标注方法
已有的方式是构建了一个『完成者』策略,接受一个问题q和一个包含前t步骤 x 1 : t x_{1:t} x1:t的前缀解决方案,并输出后续步骤的完成度。
在图里的(a)中,对于解决方案的任何步骤,可以使用更完备的策略从该步骤随机抽样k个rollout。然后,将这些rollout后得到的最终答案与正确答案进行比较,评估出一个前缀步骤的『正确性等级』,公式如下:
这种方式计算了从步骤t开始后所有rollout里得到正确答案的比例情况,这种方式需要从头到尾执行每个步骤的rollout,需要大量的调用,开销过大。
为了优化标注效率,本文提出了一种基于二分搜索的蒙特卡洛的方法。当解决方案中出现了第一个错误步骤时,这种数据就足以用来训练PRM。基于这种想法,本文的目标是有采用有效的方式定位第一个错误。主要是通过重复划分解决方案和执行rollout来实现这一点。
- 二分搜索定位错误大致流程:
假设我们的目标数据是真负例,首先从中点步骤m将其拆分,然后对前半步骤1:m执行rollout,当 c m > 0 c_m>0 cm>0时,表示前半步骤中至少有一个步骤是正确的可以得到正确的答案,错误步骤在后半部分。当 c m = 0 c_m=0 cm=0时,说明前半部分中很有可能有错误步骤,因为从该步骤往后进行rollout后的结果中没有一个是正确的。
以此方式不断迭代定位错误,直至到达停止条件(该解过程足够短,视为单个步骤)。从而将时间复杂度从O(kM),缩短至O(klogM)。