InstructGPT高效实践——【DeepSpeed-Chat】源码详解(3/3):RLHF Finetuning

本文详细介绍了DeepSpeed-Chat框架在RLHF(ReinforcementLearningfromHumanFeedback)的第三阶段,即微调阶段的工作原理。重点讨论了数据处理、模型初始化、经验数据获取、PPO训练数据管理以及训练过程。其中,数据处理包括prompt数据和无监督数据的加载,以及DataCollatorRLHF的作用。模型初始化涉及actor、ref/SFT、critic和reward模型的DeepSpeed封装。经验数据获取涉及seq的生成、奖励和价值估计的计算。PPO训练过程涵盖优势和回报的计算,以及actor和critic的损失计算。最后,文章还提到了无监督训练的处理和一些相关参数的设置影响。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

  本篇为上中下三篇章的【下篇】,接续自【中篇】。主要针对整个DeepSpeed-Chat框架中最为复杂的第三阶段进行详解,其中涉及到部分InstructGPT所述相关原理的实践,基于其代码将更易于理解诸多原理中一笔带过的话题,如“用于经验采集的策略网络到底是SFT还是Actor”“Critic的迭代又是如何实现的”等等。
  尽管这是讨论DeepSpeed-Chat技术为主的文章,但还是不得不借用ColossalChat(另一个实现了RLHF Pipeline的开源项目,项目地址)绘制的流程示意图(下图),因为它所描绘的第三阶段训练流程,非常详细且与DeepSpeed-Chat的实现过程基本一致,而DeepSpeed-Chat本身给出的示意图实在太过简略(见【中篇】头图)。
ColossalChat Stage3
  相信结合这张示意图来学习phase3效果更佳。

3 phase-3: RLHF Finetuning

3.1 训练数据样例

3.1 基本数据

数据格式名称 说明 样例
prompt 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3 "Human: Please tell me about Microsoft in a few sentence? Assistant: "

3.2 经验数据

数据格式名称 说明 样例
prompt 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3。 "Human: Please tell me about Microsoft in a few sentence? Assistant: "(举文本例子是为了便于理解,实际上此处为input_ids)
seq actor基于prompt输入生成的完整对话序列。 "Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a world-renowned company."举文本例子是为了便于理解,实际上此处为input_ids)
logprobs actor基于seq输出的logits/策略对数。 shape: 本应为(seq_bs, max_seq_len, vocab_size),经过gather处理后仅取实际label token的log_logit值,为(seq_bs, max_seq_len, 1)。
ref_logprobs reference/SFT基于seq输出的logits/策略对数。 shape: 本应为(seq_bs, max_seq_len, vocab_size),经过gather处理后仅取实际label token的log_logit值,为(seq_bs, max_seq_len, 1)。
value critic基于seq输出的对序列每个位置的价值评估。 shape: (seq_bs, max_seq_len)
reward reward/RM基于seq输出的对整个对话的(环境)奖励。 shape: (seq_bs,)
attention_mask 用于滤掉非有效元素。 shape: (seq_bs, max_seq_len)

  各个框架对于经验数据的定义不完全相同,例如ColossalChat定义的经验数据还比此处多了项“adv”和“reward”(此reward非彼reward,ColossalChat的reward指的是“经过KL散度修正后的KL_Reward”),但本质上都是同理的,只是框定的范围不同,因为adv(优势函数Adventage)和KL_Reward完全可以由已有项logprobs、ref_logprobs、reward、value计算得到。

  从代码效率的角度来考量,ColossalChat的经验数据定义相对更严谨些,因为优势以及KL惩罚奖励完全可以由基本经验数据计算得到,在生成经验的阶段一步到位计算即可;而DeepSpeed-Chat中将其安排在训练阶段来计算,每次PPO迭代才计算,优势和KL惩罚奖励是基于基本经验数据计算得到的,而基本经验数据在生成经验阶段已经确定了,所以即使是在不同的PPO迭代中,优势和KL惩罚奖励也是不变的,因此DeepSpeed-Chat对adv以及KL惩罚奖励进行了重复计算,这个环节的计算顺序后续(编辑日期2023.05.19)相关团队应该会做出调整。

3.2 训练过程

在此简单讲述UML时序图的元素含义:
- 箭头表示信息传递:实线表示调用,虚线表示返回;
- alt表示假设分支,其后方“[]”中的内容表示“条件”;
- loop表示循环;
- 淡蓝色区域即为高亮部分。
main3.py utils.py data_utils.py rlhf_engine.py ppo_trainer.py load_hf_tokenizer() 1 tokenizer 2 create_dataset() 3 create_prompt_dataset() 4 prompt_train_dataset 5 get_unsupervised_data() 6 unsupervised_train_dataset 7 alt [unsupervised_training_enabled] DataCollatorRLHF() 8 data_collator 9 train_dataloader 10 DeepSpeedRLHFEngine() 11 rlhf_engine 12 ppo_trainer() 13 trainer 14 MiniDataset() 15 exp_mini_dataset, unsup_mini_dataset 16 unsup_mini_dataset.add() 17 unsup_dataset 18 trainer.generate_experience() 19 out 20 exp_mini_dataset.add() 21 exp_dataset 22 trainer.train_rlhf() 23 actor_loss, critic_loss 24 trainer.train_unsupervised() 25 unsup_loss 26 moving_average() 27 alt [enable_ema] alt [unsupervised_training_enabled] loop [ppo_step]
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值