关于DPO 的介绍,我们经常能看的说法是: DPO 发现了一种数学变换,可以直接利用人类偏好数据来优化策略,而不需要显式地训练一个独立的奖励模型。它把奖励模型隐含地表达在了策略优化目标中。DPO 更简单、更稳定,效果常常媲美甚至超过 PPO+RM。只要准备好高质量三元组数据和基础模型,利用llama-factory这个开源工具就可以轻松启动DPO训练!
但是关键信息“把奖励模型隐含地表达在了策略优化目标中”却很少有简单易懂的博文能讲明白。这里我们假设一个业务场景NL2SQL,我们的目标是希望通过微调大模型,让大模型能够理解业务领域中的库表结构和元数据,实现用户通过自然语言询问,模型生成高质量可执行的SQL语句。(NL2SQL这个业务场景挺常见的,虽然也可以通过RAG和Agent的方式实现,尤其是当业务问题比较简单,仅从几张大宽表就能解决的时候。但强化学习微调大模型确实是个提升准确率的有效手段)
1. DPO的训练过程
(1)数据集示例:
{
"prompt": "统计2023年各部门销售额",
"chosen": "SELECT d.dept_name, SUM(s.amount) FROM sales s JOIN departments d ON s.dept_id=d.id WHERE s.year=2023 GROUP BY d.dept_name",
"rejected": "SELECT dept_name, SUM(amount) FROM sales WHERE year=2023" // 缺少JOIN导致结果错误
}
(2)2个关键模型
-
待训练的策略模型 (
π_θ
): 这就是我们想要训练好的最终模型。它接收用户prompt
(自然语言问题),输出SQL语句。 -
参考模型 (
π_ref
): 通常是一个微调过的基础模型(例如,用指令微调或SFT微调过的模型,使其初步理解SQL任务)。它代表了训练前的“基准行为”。在DPO中,它固定不变,用于提供“锚点”,防止模型偏离太远或走捷径。这里需要强调一下,这个模型兼具: 掌握基础SQL语法能力。理解业务领域特定的库表结构、元数据关系(最重要!)。初步建立自然语言问题到SQL结构的映射。
(3)前向计算与反向迭代
假设我们现在只有那一条训练样本。
-
前向传播 (Forward Pass):
-
将
prompt
(“统计2023年各部门销售额”) 同时输入给待训练模型 (π_θ
) 和 参考模型 (π_ref
)。 -
模型的任务不是生成完整的SQL,而是计算给定输出序列(
chosen
SQL 和rejected
SQL)的概率(或更准确地说,对数似然)。 -
计算:
-
π_θ(chosen | prom
-
-