Transformer——Q154 RLHF中PPO目标函数的梯度推导

该问题归类到Transformer架构问题集——前沿扩展。请参考LLM数学推导——Transformer架构问题集

1. 问题背景

在大语言模型(LLM)领域,随着模型参数量不断增加,模型生成能力显著提升,但也面临一个棘手问题:模型生成内容可能不符合人类价值观、缺乏事实准确性或在对话场景中表现不佳 。例如,直接训练的 LLM 在被询问 “如何制作炸弹” 时,可能会给出详细步骤;在多轮对话中,可能出现前后逻辑矛盾的情况。为解决这些问题,基于人类反馈的强化学习(Reinforcement Learning from Human Feedback,RLHF)应运而生。

RLHF 通过引入人类反馈来指导模型优化,让模型学会生成符合人类期望的内容。近端策略优化(Proximal Policy Optimization,PPO)算法因其高效性和稳定性,成为 RLHF 中常用的优化方法。PPO 目标函数的设计与优化直接影响到 RLHF 的效果,而目标函数的梯度推导则是理解其优化过程的关键,它能帮助我们明确模型参数如何更新才能更好地符合人类反馈,进而提升模型性能。

2. 技术原理或数学理论解析

2.1 强化学习基础概念

在深入探讨 PPO 目标函数之前,先回顾强化学习的基础概念。强化学习中,智能体(Agent)在环境(Environment)中不断执行动作(Action),并根据环境反馈的奖励(Reward)来调整自身策略(Policy),以最大化长期累计奖励。

  • 策略:用 \pi(a|s) 表示,即给定状态 s 时采取动作 a 的概率。
  • 价值函数:分为状态价值函数 V^{\pi}(s) 和动作价值函数 Q^{\pi}(s, a)V^{\pi}(s) 表示从状态 s 开始,遵循策略 \pi 所能获得的期望累计奖励; Q^{\pi}(s, a) 表示在状态 \(s\) 执行动作 a 后,遵循策略 \pi 所能获得的期望累计奖励。
  • 奖励:环境给予智能体的反馈信号,用于指导策略优化。

2.2 PPO 算法概述

PPO 算法是一种基于策略梯度的强化学习算法,旨在优化策略以最大化期望累计奖励。它通过限制新旧策略之间的差异,避免策略更新过于激进,从而提高算法的稳定性和收敛性。

2.3 PPO 目标函数推导

2.3.1 策略梯度公式

策略梯度算法的核心思想是通过计算目标函数关于策略参数 \theta 的梯度,来更新策略参数,使得目标函数值增大。目标函数通常定义为期望累计奖励,即:

J(\theta)=\mathbb{E}_{\tau\sim\pi_{\theta}}[R(\tau)]

其中, \tau=(s_0, a_0, r_0, s_1, a_1, r_1, \cdots) 表示一个轨迹, \pi_{\theta} 是参数为 \theta 的策略, R(\tau) 是轨迹 \tau 的累计奖励。

根据策略梯度定理,目标函数 J(\theta) 关于 \theta 的梯度为:

\nabla_{\theta}J(\theta)=\mathbb{E}_{\tau\sim\pi_{\theta}}\left[\sum_{t = 0}^{T}\nabla_{\theta}\log\pi_{\theta}(a_t|s_t)A^{\pi_{\theta}}(s_t, a_t)\right]

其中, A^{\pi_{\theta}}(s_t, a_t) 是优势函数(Advantage Function),表示在状态 s_t 执行动作 a_t 相对于平均价值的优势程度,可通过 A^{\pi_{\theta}}(s_t, a_t)=Q^{\pi_{\theta}}(s_t, a_t)-V^{\pi_{\theta}}(s_t) 计算 ,也可使用广义优势估计(Generalized Advantage Estimation,GAE)进行近似。

2.3.2 PPO 目标函数设计

PPO 算法通过引入重要性采样(Importance Sampling),使用旧策略 \pi_{\theta_{old}} 收集数据来更新新策略 \pi_{\theta} ,以提高数据利用效率。此时,目标函数可改写为:

J^{IS}(\theta)=\mathbb{E}_{s_t, a_t\sim\pi_{\theta_{old}}}\left[\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}A^{\pi_{\theta_{old}}}(s_t, a_t)\right]

然而,直接使用上述公式可能导致新旧策略差异过大,使算法不稳定。为解决这个问题,PPO 引入截断(Clipping)机制,定义了 PPO 目标函数:

J^{PPO}(\theta)=\mathbb{E}_{s_t, a_t\sim\pi_{\theta_{old}}}\left[\min\left(\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}A^{\pi_{\theta_{old}}}(s_t, a_t), \text{clip}\left(\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1 - \epsilon, 1 + \epsilon\right)A^{\pi_{\theta_{old}}}(s_t, a_t)\right)\right]

其中, \epsilon 是截断参数,用于限制新旧策略概率比值的范围。当 \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} 超过 1 + \epsilon 或低于 1 - \epsilon 时,目标函数的值将被截断,避免策略更新过于剧烈。

2.3.3 PPO 目标函数梯度推导

为了更新策略参数 \theta ,需要计算 J^{PPO}(\theta) 关于 \theta 的梯度。根据期望的性质和链式法则:

\nabla_{\theta}J^{PPO}(\theta)=\mathbb{E}_{s_t, a_t\sim\pi_{\theta_{old}}}\left[\nabla_{\theta}\min\left(\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}A^{\pi_{\theta_{old}}}(s_t, a_t), \text{clip}\left(\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1 - \epsilon, 1 + \epsilon\right)A^{\pi_{\theta_{old}}}(s_t, a_t)\right)\right]

r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} ,则:

\nabla_{\theta}J^{PPO}(\theta)=\mathbb{E}_{s_t, a_t\sim\pi_{\theta_{old}}}\left[\begin{cases} \nabla_{\theta}(r_t(\theta)A^{\pi_{\theta_{old}}}(s_t, a_t)), & \text{if } 1 - \epsilon\leq r_t(\theta)\leq 1 + \epsilon \\ \nabla_{\theta}((1 - \epsilon)A^{\pi_{\theta_{old}}}(s_t, a_t)), & \text{if } r_t(\theta)< 1 - \epsilon \\ \nabla_{\theta}((1 + \epsilon)A^{\pi_{\theta_{old}}}(s_t, a_t)), & \text{if } r_t(\theta)> 1 + \epsilon \end{cases}\right]

进一步计算:

1 - \epsilon\leq r_t(\theta)\leq 1 + \epsilon 时,根据乘积求导法则 (uv)^\prime = u^\prime v + uv^\prime ,其中 u = r_t(\theta)v = A^{\pi_{\theta_{old}}}(s_t, a_t)A^{\pi_{\theta_{old}}}(s_t, a_t)\theta 无关,其关于 \theta 的导数为 0 ):

\nabla_{\theta}(r_t(\theta)A^{\pi_{\theta_{old}}}(s_t, a_t))=A^{\pi_{\theta_{old}}}(s_t, a_t)\nabla_{\theta}r_t(\theta)

r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} ,对其求导:

\nabla_{\theta}r_t(\theta)=\frac{\nabla_{\theta}\pi_{\theta}(a_t|s_t)\pi_{\theta_{old}}(a_t|s_t)-\pi_{\theta}(a_t|s_t)\nabla_{\theta}\pi_{\theta_{old}}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)^2}

由于 \pi_{\theta_{old}}(a_t|s_t) 是旧策略,其关于 \theta 的导数为 0 ,所以:

\nabla_{\theta}r_t(\theta)=\frac{\nabla_{\theta}\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}

\nabla_{\theta}(r_t(\theta)A^{\pi_{\theta_{old}}}(s_t, a_t))=\frac{A^{\pi_{\theta_{old}}}(s_t, a_t)\nabla_{\theta}\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}

r_t(\theta)< 1 - \epsilon 时, \nabla_{\theta}((1 - \epsilon)A^{\pi_{\theta_{old}}}(s_t, a_t)) = 0 ,因为 (1 - \epsilon)A^{\pi_{\theta_{old}}}(s_t, a_t)\theta 无关。

r_t(\theta)> 1 + \epsilon 时, \nabla_{\theta}((1 + \epsilon)A^{\pi_{\theta_{old}}}(s_t, a_t)) = 0 ,同理, (1 + \epsilon)A^{\pi_{\theta_{old}}}(s_t, a_t)\theta 无关。

综上, \nabla_{\theta}J^{PPO}(\theta)=\mathbb{E}_{s_t, a_t\sim\pi_{\theta_{old}}}\left[\begin{cases} \frac{A^{\pi_{\theta_{old}}}(s_t, a_t)\nabla_{\theta}\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, & \text{if } 1 - \epsilon\leq r_t(\theta)\leq 1 + \epsilon \\ 0, & \text{otherwise} \end{cases}\right]

2.4 根因分析

PPO 目标函数及其梯度推导的核心目的是在保证策略更新稳定性的前提下,高效优化策略。传统策略梯度算法直接使用新策略采样数据更新策略,数据利用率低,且可能因策略更新幅度过大导致算法发散。PPO 通过重要性采样复用旧策略数据,降低数据收集成本;引入截断机制,限制新旧策略差异,避免策略更新过于激进。而梯度推导则明确了参数更新方向,使得策略能够沿着最大化目标函数的方向逐步优化,从而在 RLHF 中有效利用人类反馈信号,引导模型生成更符合人类期望的内容。

3. 在 LLM 中的使用示例

3.1 对话内容优化

在智能客服场景中,LLM 最初可能给出机械、缺乏情感的回答,如用户询问 “我订单怎么还没到”,模型回答 “请等待”。通过 RLHF 结合 PPO 优化,收集人类标注员对不同回答的评分作为奖励信号,如更详细且友好的回答 “您好,由于近期物流高峰,您的订单可能会稍有延迟,我们会为您持续跟进” 获得高分。PPO 目标函数根据这些奖励信号,通过梯度推导更新模型参数,使模型学会生成更优质的对话内容,提升用户满意度。

3.2 内容真实性改进

LLM 在回答知识类问题时可能出现错误,例如被问到 “地球绕太阳公转一周需要多久”,错误回答 “300 天”。利用 RLHF - PPO,将正确答案作为奖励引导,当模型回答正确时给予高奖励,错误时给予低奖励。PPO 目标函数的梯度推导驱动模型参数调整,让模型逐渐学习到准确的知识,提高内容真实性。

3.3 价值观对齐

在生成文本时,模型可能输出不符合主流价值观的内容,如宣扬暴力。通过 RLHF - PPO,以符合价值观的内容为奖励导向,如生成倡导和平、友善的内容给予高奖励。PPO 通过目标函数梯度推导优化模型,使模型生成内容与人类价值观保持一致。

4. 优缺点分析

4.1 优点

  • 高效性:PPO 通过重要性采样复用旧策略数据,提高数据利用效率,减少训练所需数据量和时间。
  • 稳定性:截断机制有效限制新旧策略差异,避免策略更新过于剧烈,使算法更稳定,收敛性更好。
  • 灵活性:适用于多种强化学习场景,在 RLHF 中能灵活利用不同形式的人类反馈优化 LLM。

4.2 缺点

  • 超参数敏感:PPO 算法的性能对截断参数 \epsilon 等超参数较为敏感,需要仔细调整才能达到最佳效果。
  • 难以处理稀疏奖励:当奖励信号稀疏时,如在复杂任务中,模型可能难以学习到有效的策略,因为难以准确估计优势函数。
  • 计算复杂度较高:在计算目标函数梯度时,涉及期望计算和复杂的概率比值计算,尤其是在大规模 LLM 中,计算复杂度较高,对硬件要求高。

5. 优化策略分析

5.1 超参数优化

采用自动化超参数调整方法,如贝叶斯优化、随机搜索等,通过在验证集上评估模型性能,自动搜索最优的截断参数 \epsilon 及其他超参数,减少人工调参成本和时间。

5.2 奖励工程

针对稀疏奖励问题,设计更合理的奖励函数。例如,在长文本生成任务中,除了最终内容质量评分,还可根据生成过程中的中间状态给予奖励,如句子语法正确性、段落连贯性等,使奖励信号更密集,帮助模型更快学习。

5.3 分布式训练

由于 PPO 在大规模 LLM 上计算复杂度高,采用分布式训练技术,将计算任务分配到多个计算节点上并行计算,加快训练速度,降低对单个硬件设备的性能要求。

6. 代码示例(Python,基于 PyTorch)

import torch

import torch.nn as nn

import torch.optim as optim

from collections import deque

import numpy as np

# 定义策略网络

class PolicyNetwork(nn.Module):

def __init__(self, state_dim, action_dim):

super(PolicyNetwork, self).__init__()

self.fc1 = nn.Linear(state_dim, 128)

self.fc2 = nn.Linear(128, 128)

self.fc_mean = nn.Linear(128, action_dim)

self.fc_std = nn.Linear(128, action_dim)

def forward(self, x):

x = torch.relu(self.fc1(x))

x = torch.relu(self.fc2(x))

mean = self.fc_mean(x)

std = torch.clamp(self.fc_std(x), min=-20, max=2)

std = torch.exp(std)

dist = torch.distributions.Normal(mean, std)

action = dist.rsample()

log_prob = dist.log_prob(action).sum(-1, keepdim=True)

return action, log_prob

# 定义价值网络

class ValueNetwork(nn.Module):

def __init__(self, state_dim):

super(ValueNetwork, self).__init__()

self.fc1 = nn.Linear(state_dim, 128)

self.fc2 = nn.Linear(128, 128)

self.fc_value = nn.Linear(128, 1)

def forward(self, x):

x = torch.relu(self.fc1(x))

x = torch.relu(self.fc2(x))

value = self.fc_value(x)

return value

# PPO算法类

class PPO:

def __init__(self, state_dim, action_dim, lr_actor=3e-4, lr_critic=1e-3, gamma=0.99,

K_epochs=10, eps_clip=0.2):

self.policy = PolicyNetwork(state_dim, action_dim)

self.value = ValueNetwork(state_dim)

self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=lr_actor)

self.optimizer_value = optim.Adam(self.value.parameters(), lr=lr_critic)

self.gamma = gamma

self.K_epochs = K_epochs

self.eps_clip = eps_clip

self.buffer_states = []

self.buffer_actions = []

self.buffer_log_probs = []

self.buffer_rewards = []

self.buffer_dones = []

def select_action(self, state):

state = torch.FloatTensor(state).unsqueeze(0)

action, log_prob = self.policy(state)

return action.detach().numpy()[0], log_prob.detach()

def store_transition(self, state, action, log_prob, reward, done):

self.buffer_states.append(state)

self.buffer_actions.append(action)

self.buffer_log_probs.append(log_prob)

self.buffer_rewards.append(reward)

self.buffer_dones.append(done)

def update(self):

gamma = self.gamma

K_epochs = self.K_epochs

eps_clip = self.eps_clip

buffer_states = torch.FloatTensor(np.array(self.buffer_states))

buffer_actions = torch.FloatTensor(np.array(self.buffer_actions))

buffer_log_probs = torch.FloatTensor(np.array(self.buffer_log_probs))

buffer_rewards = np.array(self.buffer_rewards)

buffer_dones = np.array(self.buffer_dones)

returns = []

discounted_return = 0

for r, d in zip(reversed(buffer_rewards), reversed(buffer_dones)):

if d:

discounted_return = 0

discounted_return = r + (gamma * discounted_return)

returns.insert(0, discounted_return)

returns = torch.FloatTensor(returns)

old_states = buffer_states

old_actions = buffer_actions

old_log_probs = buffer_log_probs

for _ in range(K_epochs):

values = self.value(old_states)

advantage = returns - values.detach()

new_action, new_log_probs = self.policy(old_states)

log_ratio = new_log_probs - old_log_probs

ratio = torch.exp(log_ratio)

surr1 = ratio * advantage

surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage

loss = -torch.min(surr1, surr2).mean()

self.optimizer_policy.zero_grad()

loss.backward()

self.optimizer_policy.step()

value_loss = nn.MSELoss()(values, returns)

self.optimizer_value.zero_grad()

value_loss.backward()

self.optimizer_value.step()

self.buffer_states = []

self.buffer_actions = []

self.buffer_log_probs = []

self.buffer_rewards = []

self.buffer_dones = []

# 示例环境交互

class ExampleEnv:

def __init__(self):

self.state_dim = 4

self.action_dim = 2

self.state = np.random.rand(self.state_dim)

def step(self, action):

reward = np.random.rand()

done = np.random.choice([True, False], p=[0.1, 0.9])

self.state = np.random.rand(self.state_dim)

return self.state, reward, done

def reset(self):

self.state = np.random.rand(self.state_dim)

return self.state

if __name__ == "__main__":

env = ExampleEnv()

state_dim = env.state_dim

action_dim = env.action_dim

ppo = PPO(state_dim, action_dim)

num_episodes = 100

for episode in range(num_episodes):

state = env.reset()

episode_reward = 0

while True:

action, log_prob = ppo.select_action(state)

next_state, reward, done = env.step(action)

ppo.store_transition(state, action, log_prob, reward, done)

state = next_state

episode_reward += reward

if done:

break

ppo.update()

print(f"Episode {episode + 1}, Reward: {episode_reward}")

7. 代码解读

  1. 网络定义部分
    • ​​​​​​​策略网络(PolicyNetwork:接受状态维度state_dim和动作维度action_dim作为输入,通过两个全连接层fc1和fc2对状态进行特征提取,接着分别由fc_mean和fc_std输出动作分布的均值和标准差。在forward方法中,对网络输出的标准差进行裁剪和指数变换,确保其为正数,然后基于均值和标准差构建正态分布dist,从中重新采样得到动作action,并计算该动作的对数概率log_prob返回。这部分代码实现了根据当前状态生成动作及其概率的功能,是 PPO 算法中策略生成的核心。
    • 价值网络(ValueNetwork:仅接受状态维度state_dim,通过两个全连接层处理后,由fc_value输出状态价值估计value。其作用是评估给定状态的价值,为计算优势函数提供基础。
  2. PPO 算法类(PPO
    • 初始化方法(__init__:实例化策略网络policy和价值网络value,并分别为它们创建优化器optimizer_policy和optimizer_value。同时,设置折扣因子gamma、更新轮数K_epochs、截断参数eps_clip等超参数,以及用于存储轨迹数据的缓冲区。
    • 选择动作方法(select_action:将输入状态转换为张量并增加维度,传入策略网络得到动作和对数概率,再将动作转换为 numpy 数组返回,方便与环境进行交互。
    • 存储过渡方法(store_transition:将每次环境交互的状态、动作、对数概率、奖励和是否结束的信息存储到相应的缓冲区中,为后续的策略更新提供数据。
    • 更新方法(update:首先将缓冲区中的数据转换为张量形式,计算累计回报returns。接着进入K_epochs轮更新,通过价值网络计算状态价值values,进而得到优势函数advantage。然后根据旧状态生成新动作及其对数概率,计算概率比值ratio。依据 PPO 目标函数,计算策略损失loss并更新策略网络参数;计算价值损失value_loss并更新价值网络参数。最后清空缓冲区,准备下一轮数据收集。
  3. 示例环境与主程序
    • 示例环境(ExampleEnv:定义了一个简单的模拟环境,包含状态维度state_dim、动作维度action_dim,以及step方法用于执行动作并返回下一个状态、奖励和是否结束,reset方法用于重置环境状态。
    • 主程序:实例化示例环境和 PPO 算法对象,进行num_episodes轮训练。每轮训练中,先重置环境,然后在环境中不断执行动作、存储过渡信息,直到该轮结束,最后调用update方法更新 PPO 算法的网络参数,并打印每轮的奖励。

8. 总结

RLHF 中的 PPO 目标函数及其梯度推导是优化大语言模型,使其生成内容贴合人类期望的关键。通过引入重要性采样和截断机制,PPO 算法在保证策略更新稳定性的同时,提升了数据利用效率。梯度推导明确了策略参数的更新方向,让模型能依据人类反馈信号进行有效优化 。在 LLM 的实际应用里,PPO 在对话优化、内容真实性改进、价值观对齐等任务中成绩斐然。不过,它也存在超参数敏感、处理稀疏奖励困难、计算复杂等不足。针对这些问题,超参数优化、奖励工程、分布式训练等策略可有效提升其性能。通过代码示例与解读,我们更直观地看到 PPO 算法的运行逻辑,未来,随着研究深入,PPO 有望在 RLHF 中发挥更大价值,推动 LLM 迈向新高度。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

墨顿

唵嘛呢叭咪吽

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值