OpenAI o1复现:自动构造prm训练数据-OmegaPRM

openai o1复现中,有个比较关键的问题,怎么样自动化构造prm模型的训练数据?本文主要从代码层面,来解析OmegaPRM原理。

论文

Improve Mathematical Reasoning in Language…[1]

原理

Markov决策过程

OmegaPRM

State:对应Markov决策过程中的状态,rollout:对应Markov决策过程中的动作;

  • step1:初始化root节点state;每个state包含n个扩展rollouts,q+pa作为prompt,进行n次llm生成采样;基于bootstrap采样方法估计Monte Carlo模拟正确答案的概率mc;

  • step2:从所有节点中,基于UCB1(Explore&&Exploit方法)选取最优的“state和rollout”,添加到PRM训练集;Exploit:alpha ** (1 - mc) * beta ** (len® / L),其中:mc表示蒙特卡洛模拟正确答案概率、len®表示LLM生成的长度;Explore:c_puct * sqrt(N_sum) / (1 + s.v),其中:N_sum表示所有节点的访问次数,s.v表示当前节点的访问次数,c_puct控制MCTS树的探索程度;

  • step3:评估最优“state和rollout”,二分rollout的结果,将左半部分纳入到新的state中,并计算新的mc;mc=1,表示state完全包含正确答案,忽略;mc=0,表示state完全没有生成正确答案可能性,添加到叶子节点;mc>0,表示state作为继续探索的节点;

  • step4:重复step2、step3,直至“探索到足够的样本、无法继续探索”退出;

  • step5:将叶子节点全部添加到PRM训练集;

PRM模型训练效果

论文的base模型

基于OmegaPRM方法合成数据,在MATH数据集,相比base model51%的准确率,OmegaPRM准确率提高到69.4%;

其他PRM方法

OmegaPRM:gemini提到的方法;

AlphaMath:qwen提到的方法;

Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations[2]

AlphaMath Almost Zero: Process Supervision without Process[3]

源码来源

https://github.com/openreasoner/openr[4]

源码解析

数据结构

class State:       def __init__(self, q, pa, a):           self.q = q #问题           self.pa = pa #当前step的prompt           self.a = a #答案           self.mc = None #基于当前节点,生成正确答案的概率           self.v = 0 #被访问次数           self.rollouts = [] #扩展的子节点           self.rollout_was_visited = [] #扩展的子节点是否被访问

主流程

 `# Load the JSON data       data = load_json_file(json_file_path)              # Process each problem and its final answer       for i, item in enumerate(data):           problem = item.get('problem', 'No problem found')           final_answer = item.get('final_answer', 'No answer found')                      # Print to console           print(f"Problem {i + 1}: {problem}")           print(f"Final Answer: {final_answer}")                      # Log each problem and answer           logging.info(f"Processed Problem {i + 1}: {problem}")           logging.info(f"Final Answer: {final_answer}")                      # Call getrollout and handle the result           states = []           root = State(problem, "", final_answer)           max_roll_num = 20           rollouts, corrs = getrollouts(root, max_roll_num)           mcst = cal_mc_bs(root)           root.mc = mcst                         # 生成root节点           states.append(root)              if sum(corrs) > 0 and sum(corrs) < max_roll_num:                print("Process annotation ...\n")               filename = str(i+1) +'_states_list.json'               # 生成PRM训练数据               process_annotation(problem, final_answer, states, filename)`

蒙特卡洛采样

#针对节点s进行n次采样,基于LLM生成n个rollouts,并给出每个rollout是否包含正确答案;   def getrollouts(s, n = 5):     corrs = []     q = s.q     pa = s.pa     for i in range(n):       re = complete_answer(q, pa)       s.add_rollout(re)       #check the answer       a = s.a       if check_answer(a, re):         corrs.append(1)       else:         corrs.append(0)     return s.rollouts, corrs          #蒙特卡洛采样,并给出包含正确答案的概率     def cal_mc_bs(s, bs = 5):       n = len(s.rollouts)       subn = max(1,random.randint(n//2, n))       mc = 0       for i in range(bs):       corr = 0           sub = random.sample(s.rollouts, subn)           for r in sub:               if check_answer(s.a, r):                   corr += 1           mc += corr * 1.0 / len(sub)       return mc / bs           #针对问题problem,使用problem+partial_answer作为prompt,进行LLM生成     complete_answer(problem, partial_answer, checkpoint)     #LLM生成的response是否包含正确答案groundtruth_answer     check_answer(groundtruth_answer, response)

基于mcts方法自动构造prm训练数据

#基于MCTS方法生成PRM训练数据   def process_annotation(q, a, states, filename = 'states_list.json'):      print("++++++")      it = 0      leaf_states = []      while True:          s, rollout, maxqu = select(states)          if s is not None and s.pa!='':              new_data = {                  "q": q,           # Ensure q is serializable                  "states": s.pa, # Ensure states is serializable                  "mcs": s.mc        # Ensure mcs is serializable              }              # Call the function to append the new data              append_to_json_file(filename, new_data)              it += 1              if it > 100:                  break          # all state-rolls pairs were exhausted          if s is None:              break          print()          print("[sel]")          print(s)          print("  roll=",rollout," || qu=", maxqu)                    s.add_visit()          div_roll_sts,leaf_sts = error_locate(s, rollout)          if len(div_roll_sts)==0:              continue                    states.extend([s for s in div_roll_sts if s!=None and s.pa != ''])          leaf_states.extend(leaf_sts)      #      ## add leaf states to data      for s in leaf_states:          new_data = {              "q": q,           # Ensure q is serializable              "states": s.pa, # Ensure states is serializable               "mcs": s.mc        # Ensure mcs is serializable          }          # Call the function to append the new data          append_to_json_file(filename, new_data)      print("++++++")

基于UCB1方法,选择最优的节点,纳入到训练集

#选择当前最优的节点   #exploitation:使用“更大的mc、更短的llm生成”节点;   #exploration:探索“未充分访问的、更大的树探索程度”节点;   def select(states):       best_st = None       best_roll_idx = -1       best_qu = -1       for s in states:           # mcs = cal_mc(s) if s.mc is None else s.mc           mcs = cal_mc_bs(s) if s.mc is None else s.mc           if mcs == 0 or mcs==1.0:               continue           for i,r in enumerate(s.rollouts):               if s.rollout_was_visited[i]:                   continue               q = Q(r, mcs)               u = U(s,states)               qu = q + u               if qu > best_qu:                   best_st = s                   best_roll_idx = i                   best_qu = qu                     #       if best_roll_idx != -1:           best_st.rollout_was_visited[best_roll_idx] = True       return best_st,best_st.rollouts[best_roll_idx],best_qu      #exploitation:倾向于选择已知表现好的状态和rollout;   #alpha ** (1 - mc) * beta ** (len(r) / L)   #1. 鼓励使用更大mc(生成包含正确答案可能性更大);   #2. 更短rollout(更短的生成,更可能推理出正确答案)的节点,   def Q(r, mc, alpha  = 0.5, beta = 0.9, L = 500):       part1 = alpha ** (1 - mc)       part2 = beta ** (len(r) / L)       Q_value = part1 * part2       return Q_value      #exploration:鼓励尝试未充分探索的选项,使用UCB1算法(Upper Confidence Bound 1);   #c_puct * sqrt(N_sum) / (1 + s.v)   #1. s.v:当前状态访问次数,鼓励探索访问次数较少的节点;   #2. N_sum:所有状态的访问次数总和,表示搜索过程的广度和深度,即鼓励更大的搜索树探索程度;   #3. c_puct:控制探索程度的常数;   def U(s, states, c_puct = 0.125):       N_sum = 0       for item in states:           N_sum += item.v       numerator = math.sqrt(N_sum)       denominator = 1 + s.v       U_value = c_puct * (numerator / denominator)       return U_value      def qu(i, r, mc, ncs):       q = Q(r, mc)       u = U(i, ncs)       return q+u

评估最优节点,是否继续探索?无法探索(完全错误)作为叶子节点,纳入到训练集

#评估最优“state和rollout”,二分rollout的结果,将左半部分纳入到新的state中,并计算新的mc;   def error_locate(s, rollout):       current_span = rollout       prev = ""       divide_roll_pos_st = []       leaf_st = []       while True:           word_count = len(current_span.split())           if word_count < 2:               break           np1, np2 = split_sentence_middle(current_span)           print("----")           print(" BS[l]=", np1)           print(" BS[r]=", np2)           #二分LLM生成结果rollout,新的prompt:已有生成结果+左半部分           st = State(s.q, prev + np1, s.a)           rollouts, corrs = getrollouts(st)           # mcst = cal_mc(st)           mcst = cal_mc_bs(st)           st.mc = mcst           # case 1: always correct (we are not interested in this kind of state)           if mcst == 1:            # leaf_st.append(st)               break           # case 2: right span(继续扩展节点)           elif mcst > 0:               current_span = np2               prev = prev + np1               divide_roll_pos_st.append(st)           # case 3: left span(这里LLM生成完全没有可能包含正确答案,因此节点扩展terminated)           elif mcst == 0:               current_span = np1               leaf_st.append(st)                  #       print("----")       return divide_roll_pos_st,leaf_st

如何学习大模型 AI ?

由于新岗位的生产效率,要优于被取代岗位的生产效率,所以实际上整个社会的生产效率是提升的。

但是具体到个人,只能说是:

“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。

这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。

我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。

我意识到有很多经验和知识值得分享给大家,也可以通过我们的能力和经验解答大家在人工智能学习中的很多困惑,所以在工作繁忙的情况下还是坚持各种整理和分享。但苦于知识传播途径有限,很多互联网行业朋友无法获得正确的资料得到学习提升,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。

在这里插入图片描述

第一阶段(10天):初阶应用

该阶段让大家对大模型 AI有一个最前沿的认识,对大模型 AI 的理解超过 95% 的人,可以在相关讨论时发表高级、不跟风、又接地气的见解,别人只会和 AI 聊天,而你能调教 AI,并能用代码将大模型和业务衔接。

  • 大模型 AI 能干什么?
  • 大模型是怎样获得「智能」的?
  • 用好 AI 的核心心法
  • 大模型应用业务架构
  • 大模型应用技术架构
  • 代码示例:向 GPT-3.5 灌入新知识
  • 提示工程的意义和核心思想
  • Prompt 典型构成
  • 指令调优方法论
  • 思维链和思维树
  • Prompt 攻击和防范

第二阶段(30天):高阶应用

该阶段我们正式进入大模型 AI 进阶实战学习,学会构造私有知识库,扩展 AI 的能力。快速开发一个完整的基于 agent 对话机器人。掌握功能最强的大模型开发框架,抓住最新的技术进展,适合 Python 和 JavaScript 程序员。

  • 为什么要做 RAG
  • 搭建一个简单的 ChatPDF
  • 检索的基础概念
  • 什么是向量表示(Embeddings)
  • 向量数据库与向量检索
  • 基于向量检索的 RAG
  • 搭建 RAG 系统的扩展知识
  • 混合检索与 RAG-Fusion 简介
  • 向量模型本地部署

第三阶段(30天):模型训练

恭喜你,如果学到这里,你基本可以找到一份大模型 AI相关的工作,自己也能训练 GPT 了!通过微调,训练自己的垂直大模型,能独立训练开源多模态大模型,掌握更多技术方案。

到此为止,大概2个月的时间。你已经成为了一名“AI小子”。那么你还想往下探索吗?

  • 为什么要做 RAG
  • 什么是模型
  • 什么是模型训练
  • 求解器 & 损失函数简介
  • 小实验2:手写一个简单的神经网络并训练它
  • 什么是训练/预训练/微调/轻量化微调
  • Transformer结构简介
  • 轻量化微调
  • 实验数据集的构建

第四阶段(20天):商业闭环

对全球大模型从性能、吞吐量、成本等方面有一定的认知,可以在云端和本地等多种环境下部署大模型,找到适合自己的项目/创业方向,做一名被 AI 武装的产品经理。

  • 硬件选型
  • 带你了解全球大模型
  • 使用国产大模型服务
  • 搭建 OpenAI 代理
  • 热身:基于阿里云 PAI 部署 Stable Diffusion
  • 在本地计算机运行大模型
  • 大模型的私有化部署
  • 基于 vLLM 部署大模型
  • 案例:如何优雅地在阿里云私有部署开源大模型
  • 部署一套开源 LLM 项目
  • 内容安全
  • 互联网信息服务算法备案

学习是一个过程,只要学习就会有挑战。天道酬勤,你越努力,就会成为越优秀的自己。

如果你能在15天内完成所有的任务,那你堪称天才。然而,如果你能完成 60-70% 的内容,你就已经开始具备成为一名大模型 AI 的正确特征了。

这份完整版的大模型 AI 学习资料已经上传优快云,朋友们如果需要可以微信扫描下方优快云官方认证二维码免费领取【保证100%免费

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值