目录
前言
本篇为上中下三篇章的【下篇】,接续自【中篇】。主要针对整个DeepSpeed-Chat框架中最为复杂的第三阶段进行详解,其中涉及到部分InstructGPT所述相关原理的实践,基于其代码将更易于理解诸多原理中一笔带过的话题,如“用于经验采集的策略网络到底是SFT还是Actor”“Critic的迭代又是如何实现的”等等。
尽管这是讨论DeepSpeed-Chat技术为主的文章,但还是不得不借用ColossalChat(另一个实现了RLHF Pipeline的开源项目,项目地址)绘制的流程示意图(下图),因为它所描绘的第三阶段训练流程,非常详细且与DeepSpeed-Chat的实现过程基本一致,而DeepSpeed-Chat本身给出的示意图实在太过简略(见【中篇】头图)。
相信结合这张示意图来学习phase3效果更佳。
3 phase-3: RLHF Finetuning
3.1 训练数据样例
3.1 基本数据
数据格式名称 | 说明 | 样例 |
---|---|---|
prompt | 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3。 | "Human: Please tell me about Microsoft in a few sentence? Assistant: " |
3.2 经验数据
数据格式名称 | 说明 | 样例 |
---|---|---|
prompt | 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3。 | "Human: Please tell me about Microsoft in a few sentence? Assistant: "(举文本例子是为了便于理解,实际上此处为input_ids) |
seq | actor基于prompt输入生成的完整对话序列。 | "Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a world-renowned company."举文本例子是为了便于理解,实际上此处为input_ids) |
logprobs | actor基于seq输出的logits/策略对数。 | shape: 本应为(seq_bs, max_seq_len, vocab_size),经过gather处理后仅取实际label token的log_logit值,为(seq_bs, max_seq_len, 1)。 |
ref_logprobs | reference/SFT基于seq输出的logits/策略对数。 | shape: 本应为(seq_bs, max_seq_len, vocab_size),经过gather处理后仅取实际label token的log_logit值,为(seq_bs, max_seq_len, 1)。 |
value | critic基于seq输出的对序列每个位置的价值评估。 | shape: (seq_bs, max_seq_len) |
reward | reward/RM基于seq输出的对整个对话的(环境)奖励。 | shape: (seq_bs,) |
attention_mask | 用于滤掉非有效元素。 | shape: (seq_bs, max_seq_len) |
各个框架对于经验数据的定义不完全相同,例如ColossalChat定义的经验数据还比此处多了项“adv”和“reward”(此reward非彼reward,ColossalChat的reward指的是“经过KL散度修正后的KL_Reward”),但本质上都是同理的,只是框定的范围不同,因为adv(优势函数Adventage)和KL_Reward完全可以由已有项logprobs、ref_logprobs、reward、value计算得到。
从代码效率的角度来考量,ColossalChat的经验数据定义相对更严谨些,因为优势以及KL惩罚奖励完全可以由基本经验数据计算得到,在生成经验的阶段一步到位计算即可;而DeepSpeed-Chat中将其安排在训练阶段来计算,每次PPO迭代才计算,优势和KL惩罚奖励是基于基本经验数据计算得到的,而基本经验数据在生成经验阶段已经确定了,所以即使是在不同的PPO迭代中,优势和KL惩罚奖励也是不变的,因此DeepSpeed-Chat对adv以及KL惩罚奖励进行了重复计算,这个环节的计算顺序后续(编辑日期2023.05.19)相关团队应该会做出调整。
3.2 训练过程
在此简单讲述UML时序图的元素含义:
- 箭头表示信息传递:实线表示调用,虚线表示返回;
- alt表示假设分支,其后方“[]”中的内容表示“条件”;
- loop表示循环;
- 淡蓝色区域即为高亮部分。