1.简介
GRPO(Generalized Reward Proximal Optimization)是一种用于强化学习的策略优化算法,它在近端策略优化(PPO)的基础上进行了扩展和改进。GRPO通过引入更灵活的奖励建模方式和对策略更新过程的精细化控制,旨在提升训练的稳定性与样本效率。与PPO类似,GRPO也限制策略更新的步长以避免过大的策略变动,但它可能采用更广义的奖励函数形式或结合多目标、多步回报等机制,从而更好地适应复杂任务环境。该方法通常用于需要高效探索和稳定学习的场景,在一些大模型对齐和智能体训练任务中展现出良好性能。
-
-
2.代码
准备工作
我们以Qwen2.5-0.5B-Instruct为例,以gsm8k数据集为例进行演示。
https://hf-mirror.com/datasets/openai/gsm8k/tree/main
-
完整代码
import numpy as np
import random
import torch
import torch.nn.functional as F
import copy
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
def set_random_seed(seed: int = 42):
"""
设置随机种子以确保实验可复现性
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_random_seed(42)
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
def prepare_dataset(split="train"):
"""加载并准备GSM8K数据集,将其转换为字符串形式的提示和答案。"""
data = load_dataset('gsm8k', 'main')[split]
formatted_data = []
for example in data:
# 将系统提示和用户问题合并成一个字符串作为提示
prompt_str = build_prompt([
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["question"]}
])
formatted_example = {
"prompt": prompt_str, # Now a string rather than a list.
"answer": extract_answer_from_dataset(example["answer"])
}
formatted_data.append(formatted_example)
return formatted_data
def build_prompt(messages):
"""
将消息列表转换为单个提示字符串。
它遍历每个消息字典,提取"content"字段内容并去除首尾空白,最后用换行符连接所有内容,形成完整的提示文本。
"""
return "\n".join([msg["content"].strip() for msg in messages])
def extract_answer_from_model_output(text):
"""
从文本中提取最后一个<answer>标签内的内容。
"""
# 按<answer>分割文本,取最后一部分
parts = text.split("<answer>")
if len(parts) < 2: # No <answer> tag found
return None
last_part = parts[-1]
# Extract content up to </answer>
if "</answer>" not in last_part:
return None
answer = last_part.split("</answer>")[0].strip()
return None if answer == "..." else answer
def extract_answer_from_dataset(text):
"""
从文本中提取答案。
"""
if "####" not in text:
return None
return text.split("####")[1].strip()
def _extract_last_number(text):
"""
从字符串中提取最后一个独立的数字。
"""
import re
# 移除 $ 和 % 符号
text = text.replace('$', '').replace('%', '')
# 通过正则表达式匹配满足条件的数字
pattern = r'(?:^|\s|=)\s*(-?\d*\.?\d+)\s*$'
match = re.search(pattern, text)
return float(match.group(1)) if match else None
def _extract_single_number(text):
"""
从文本中提取单个数字。
"""
import re
numbers = re.findall(r'-?\d*\.?\d+', text)
return float(numbers[0]) if len(numbers) == 1 else None
def evaluate_model(model, tokenizer, eval_examples, device):
"""
评估语言模型的性能
将模型设为评估模式;
遍历每个测试样本,生成模型回复并提取答案;
通过字符串匹配、单数字匹配和最后数字匹配三种方式判断预测是否正确;
打印每条样本的详细结果;
计算并返回整体准确率;
最后将模型恢复为训练模式
"""
model.eval()
correct = 0
total = len(eval_examples)
print("\n" + "="*50)
print("EVALUATION ON", total, "EXAMPLES")
print("="*50)
for example in eval_examples:
full_prompt = example["prompt"]
expected = example["answer"]
inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
forced_eos_token_id=tokenizer.eos_token_id,
early_stopping=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the predicted answer from the model output.
try:
predicted = extract_answer_from_model_output(response)
# Check correctness in multiple ways
if predicted == expected: # First try exact match
is_correct = True
else:
# Try single number
pred_num = _extract_single_number(str(predicted))
exp_num = _extract_single_number(str(expected))
if pred_num is not None and exp_num is not None and pred_num == exp_num:
is_correct = True
else:
# Try last number
pred_num = _extract_last_number(str(predicted))
exp_num = _extract_last_number(str(expected))
is_correct = (pred_num is not None and exp_num is not None and
pred_num == exp_num)
if is_correct:
correct += 1
# Print details of the evaluation.
print("\nPrompt:")
print(full_prompt)
print("\nExpected Answer:")
print(expected)
print("\nExtracted Answer:")
print(predicted)
print("\nFull Generated Response:")
print(response)
print("\nCorrect:", "✓" if is_correct else "✗")
print("-"*50)
except Exception as e:
print("\nFailed to parse model output for prompt:")
print(full_prompt)
print("Error:", e)
print("-"*50)
accuracy = (correct / total) * 100
print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{total})")
print("="*50)
model.train()
return accuracy
def correctness_reward(prompts, completions, answer, **kwargs):
"""
根据模型答案的正确性分配奖励分数
首先从模型输出中提取答案,再与标准答案进行字符串完全匹配(得2分)或数值相等判断(得1.5分),不匹配则得0分,最终返回每个答案的得分列表。
"""
# Extract the content from each completion's first element
responses = [completion[0]['content'] for completion in completions]
# Extract answers from model outputs
extracted = [extract_answer_from_model_output(r) for r in responses]
rewards = []
for r, a in zip(extracted, answer):
if r == a: # Exact match case
rewards.append(2.0)
else:
# Try numeric equivalence
r_num = _extract_single_number(str(r))
a_num = _extract_single_number(str(a))
if r_num is not None and a_num is not None and r_num == a_num:
rewards.append(1.5)
else:
rewards.append(0.0)
# Log completion lengths
completion_lengths = [len(response.split()) for response in responses]
return rewards
def format_reward(completions, **kwargs):
"""
评估模型输出是否符合指定的XML格式。
它从每个完成项中提取文本内容,并检查其中是否包含特定的XML标签(如 <reasoning>、</reasoning>、<answer> 和 </answer>)。
每出现一个标签,就给该完成项加0.2分,最终返回所有完成项的格式合规得分列表。
"""
# Extract the content from each completion's first element
responses = [completion[0]['content'] for completion in completions]
rewards = []
format_scores = []
for response in responses:
score = 0.0
if "<reasoning>" in response: score += 0.2
if "</reasoning>" in response: score += 0.2
if "<answer>" in response: score += 0.2
if "</answer>" in response: score += 0.2
rewards.append(score)
format_scores.append(score)
return rewards
def combined_reward(prompts, completions, answer):
"""
该函数用于综合评估模型输出的正确性和格式规范性。
它分别调用 correctness_reward 和 format_reward 计算每条输出的得分,然后将两者相加,得到最终的综合评分(范围为0.0到2.8)。
"""
# Get individual rewards
correctness_scores = correctness_reward(prompts=prompts, completions=completions, answer=answer)
format_scores = format_reward(completions=completions)
# Combine rewards - correctness is weighted more heavily
combined_rewards = []
for c_score, f_score in zip(correctness_scores, format_scores):
# Correctness score range: 0.0 to 2.0
# Format score range: 0.0 to 0.8
# Total range: 0.0 to 2.8
combined_rewards.append(c_score + f_score)
return combined_rewards # [b]=32
def selective_log_softmax(logits, input_ids):
"""
使用选择性对数软最大化计算input_ids中指定令牌的对数概率。
参数:
logits (torch.Tensor):形状为 (batch_size, seq_len, vocab_size) 的张量,包含模型生成的原始逻辑值。
input_ids (torch.Tensor):形状为 (batch_size, seq_len) 的张量,包含需要计算对数概率的标记索引。
返回值:
torch.Tensor:形状为 (batch_size, seq_len) 的张量,每个元素对应 input_ids 中该位置令牌的对数概率。
说明:
1. 沿词汇表维度 (dim=-1) 应用 F.log_softmax 将 logits 转换为对数概率。
2. 对张量 input_ids 进行重塑(通过 unsqueeze 操作)添加额外维度,以便将其作为 log_probs 张量的索引使用。
3. torch.gather 函数收集每个位置在 input_ids 指定索引处的对数概率值。
4. 最后,squeeze(-1) 移除额外维度,返回与 input_ids 形状相同的张量。
"""
# 将原始logits沿词汇轴转换为对数概率。
log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_len, vocab_size)
# 将input_ids从(batch_size, seq_len)重塑为(batch_size, seq_len, 1).
# 使用 torch.gather 根据 input_ids 中的索引,从 log_probs 中提取每个位置对应的对数概率值。
selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1))
# 移除多余的最后一个维度,恢复为原始形状(batch_size, seq_len)。
return selected_log_probs.squeeze(-1)
def compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep):
"""
计算子集令牌(通常为补全令牌)的每个令牌对数概率。
参数:
model:要使用的语言模型。
input_ids (torch.Tensor):形状为 (batch_size, total_seq_len) 的张量,包含提示和补全部分的令牌 ID。
attention_mask (torch.Tensor):形状为 (batch_size, total_seq_len) 的张量,标记真实令牌 (1) 与填充令牌 (0)。
logits_to_keep (int):需计算对数概率的令牌数量(来自补全部分)。
返回值:
torch.Tensor:每个序列最后 `logits_to_keep` 个令牌的对数概率。
说明:
1. 调用模型时传入 logits_to_keep + 1,使模型输出比实际需求多一个 logit。
此为下一令牌预测中的常见做法。
2. 沿序列维度切除最后一个logit,因其不对应任何输入令牌。
3. 随后将input_ids和logits均限制为最后logits_to_keep个令牌,该范围应
对应生成的补全部分。
4. 最后通过selective_log_softmax仅对这些令牌计算对数概率。
"""
# 正向传播并获取logits值。
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1 # Request one extra logit for proper alignment.
).logits # Shape: (batch_size, total_seq_len, vocab_size)=[32,501,151936]
# 删除最后一个logit,因为它没有对应的目标令牌。
logits = logits[:, :-1, :] # New shape: (batch_size, total_seq_len - 1, vocab_size)=[32,500,151936]
# 将 input_ids 切片,仅保留最后 logits_to_keep 个令牌。这对应于生成的补全令牌。
input_ids = input_ids[:, -logits_to_keep:] # Shape: (batch_size, logits_to_keep)=[32,500]
# 同时对logits进行分段,仅保留与完成标记对应的部分。
logits = logits[:, -logits_to_keep:, :] # Shape: (batch_size, logits_to_keep, vocab_size)
# 计算并返回所选标记的对数似然值。
return selective_log_softmax(logits, input_ids)
def create_completion_mask(completion_ids, eos_token_id):
"""
为生成的补全令牌创建二进制掩码,使首个EOS之后的令牌被忽略。
参数:
completion_ids (torch.Tensor):形状为(batch_size, seq_len)的张量,包含生成的令牌ID。
eos_token_id (int):表示序列结束的令牌ID。
返回值:
torch.Tensor:形状为 (batch_size, seq_len) 的掩码张量,包含:
- 首个EOS标记及之前的所有标记对应1
- 首个EOS标记之后的所有标记对应0
说明:
1. 首先创建布尔掩码(is_eos),标记序列中EOS令牌出现的位置。
2. 初始化索引张量(eos_idx),默认假设未发现EOS(初始值为序列长度)。
3. 对于存在EOS的序列,将eos_idx更新为首个EOS的位置(索引)。
4. 创建序列索引张量,包含序列中每个位置的索引值。
5. 通过将序列索引与eos_idx(添加维度后)进行比较,计算最终掩码。
"""
# 确定每个序列中哪些位置等于EOS token
is_eos = completion_ids == eos_token_id # Boolean tensor of shape (batch_size, seq_len)
# 初始化一个张量,用于存储每个序列中首个结束标记(EOS)的索引。
# 若未找到结束标记,则默认采用完整序列长度(is_eos.size(1))。
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
# 识别包含至少一个EOS的序列。
mask_exists = is_eos.any(dim=1)
# 对于具有结束标记的序列,将 eos_idx 更新为首次出现该标记的索引位置。
eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
# 创建一个索引张量[0, 1, 2, ..., seq_len-1],并为批处理中的每个序列复制该张量。
sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
# 构建掩码:索引小于或等于第一个EOS索引的位置标记为1。
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
return completion_mask
def generate_completions(model, tokenizer, prompts, num_generations=4, max_completion_length=32):
"""
为每个提示生成多个补全结果,并创建对应的注意力掩码。
参数:
model:用于生成的语言模型。
tokenizer:处理提示并解码输出的分词器。
prompts (字符串列表):输入提示字符串列表。
num_generations (整数):每个提示需生成的补全结果数量。
最大生成长度 (整数):每次生成中新增令牌的上限。
返回值:
元组:包含以下张量:
- 提示ID: (批量大小 * 生成次数, 提示序列长度)
- 提示掩码: (批量大小 * 生成次数, 提示序列长度)
- completion_ids: (batch_size * num_generations, completion_seq_len)
- completion_mask: (batch_size * num_generations, completion_seq_len)
说明:
1. 将提示语分词并添加左侧填充。
2. 每个提示重复 num_generations 次,从而为每个提示生成多个补全结果。
3. 调用 model.generate() 函数生成新令牌。
4. 生成的输出包含提示语后接补全内容;移除提示语部分即可获得补全结果。
5. 通过 create_completion_mask 创建掩码,仅考虑首个 EOS 之前的令牌。
"""
device = next(model.parameters()).device
# 对提示词列表进行分词处理并添加填充。padding_side=“left” 确保右对齐。
tokenizer.padding_side = "left"
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
prompt_ids = inputs["input_ids"].to(device) # Shape: (batch_size, prompt_seq_len)
prompt_mask = inputs["attention_mask"].to(device) # Shape: (batch_size, prompt_seq_len)
prompt_length = prompt_ids.size(1) # 将提示符长度保存起来,以便稍后将提示符与补全内容分开。
# 重复执行每个提示语 num_generations 次。
prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0) # New shape: (batch_size*num_generations, prompt_seq_len)
prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0) # New shape: (batch_size*num_generations, prompt_seq_len)
# 为每个提示生成新的令牌。输出包含原始提示和生成的令牌。
outputs = model.generate(
prompt_ids,
attention_mask=prompt_mask,
max_new_tokens=max_completion_length,
do_sample=True,
temperature=1.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# 从生成的输出中移除提示部分,以分离完成令牌。
completion_ids = outputs[:, prompt_length:] # Shape: (batch_size*num_generations, completion_seq_len)
# 创建一个二进制掩码,忽略第一个EOS令牌之后的所有令牌。
completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)
return prompt_ids, prompt_mask, completion_ids, completion_mask
def generate_rollout_data(model, ref_model, tokenizer, batch_samples, num_generations, max_completion_length):
"""
生成滚动演算并计算旧策略(当前模型)和参考模型的静态日志概率。
梯度计算被禁用,因此这些概率保持固定。
参数:
model:用于生成滚动演算的当前模型(策略)。
ref_model:静态参考模型。
tokenizer:分词器。
batch_samples:训练样本列表。
num_generations:每个提示词需生成的完成次数。
max_completion_length:最大完成长度。
返回值:
包含旧策略与参考策略对数概率的滚动数据字典。
"""
tokenizer.padding_side = "left"
device = next(model.parameters()).device
# 提取 prompts 和 answers.
prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples]
answers = [sample["answer"] if isinstance(sample, dict) else sample[1] for sample in batch_samples]
# 使用当前模型和参考模型分别对生成的补全内容计算对数概率
with torch.no_grad():
prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
model, tokenizer, prompts, num_generations, max_completion_length
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1) # [b*16, 500+len(prompt_ids)]=[32,568]
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # 500
# 从当前模型计算 old_log_probs,同时禁用梯度。
old_log_probs = compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep) # [b*16, 500]=[32,500]
# 从参考模型中计算ref_log_probs,该模型保持静态。
ref_log_probs = compute_log_probabilities(ref_model, input_ids, attention_mask, logits_to_keep)
formatted_completions = [ #解码
[{'content': tokenizer.decode(ids, skip_special_tokens=True)}]
for ids in completion_ids
]
repeated_prompts = [p for p in prompts for _ in range(num_generations)]
repeated_answers = [a for a in answers for _ in range(num_generations)]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"completion_mask": completion_mask,
"old_log_probs": old_log_probs, # Static log probs from the current model (old policy)
"ref_log_probs": ref_log_probs, # Static log probs from the reference model
"formatted_completions": formatted_completions,
"repeated_prompts": repeated_prompts,
"repeated_answers": repeated_answers,
"logits_to_keep": logits_to_keep,
"batch_size": len(prompts),
"num_generations": num_generations
}
def compute_group_relative_advantages(rewards, num_generations):
"""
计算每个提示组的相对优势。
参数:
rewards (torch.Tensor):形状为 (batch_size * num_generations) 的张量,包含奖励值。
num_generations (int):每个提示生成的完成次数。
返回值:
torch.Tensor:相对于组均值计算的优势张量。
"""
# reshape奖励,按提示分组
rewards_by_group = rewards.view(-1, num_generations) # [b,num_generations]=[2,16]
# 计算每个提示组的均值和标准差
group_means = rewards_by_group.mean(dim=1)
group_stds = rewards_by_group.std(dim=1)
# 将每个提示组的均值和标准差扩展为与原始奖励张量相同的形状
expanded_means = group_means.repeat_interleave(num_generations)
expanded_stds = group_stds.repeat_interleave(num_generations)
# 标准化奖励以获取优势(advantages)
advantages = (rewards - expanded_means) / (expanded_stds + 1e-4)
return advantages.unsqueeze(1) # Add dimension for token-wise operations
def maximize_grpo_objective(model, ref_model, rollout_data, tokenizer, reward_function,
optimizer, beta, epsilon):
"""
通过最大化GRPO目标函数更新策略模型。
参数:
model:当前策略模型。
ref_model:参考模型。
rollout_data:包含展开数据的字典。
tokenizer:分词器。
reward_function:计算奖励的函数。
optimizer: 优化器。
beta (float): KL惩罚系数。
epsilon (float): 截断参数。
返回值:
float: 损失值。
"""
# Extract data from rollout
input_ids = rollout_data["input_ids"]
attention_mask = rollout_data["attention_mask"]
completion_mask = rollout_data["completion_mask"]
old_log_probs = rollout_data["old_log_probs"]
ref_log_probs = rollout_data["ref_log_probs"]
logits_to_keep = rollout_data["logits_to_keep"]
# 计算当前对数似然值
current_log_probs = compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep)
# 计算策略比率
ratio = torch.exp(current_log_probs - old_log_probs)
# Get rewards data
formatted_completions = rollout_data["formatted_completions"]
repeated_prompts = rollout_data["repeated_prompts"]
repeated_answers = rollout_data["repeated_answers"]
# 计算 rewards
rewards = torch.tensor(
reward_function(prompts=repeated_prompts, completions=formatted_completions, answer=repeated_answers),
dtype=torch.float32,
device=next(model.parameters()).device
)
avg_reward = rewards.mean().item()
print(f"Average Reward: {avg_reward:.4f}")
# 使用组相对归一化计算优势
batch_size = rollout_data["batch_size"]
num_generations = rollout_data["num_generations"]
advantages = compute_group_relative_advantages(rewards, num_generations)
# 计算带截断的代理损失
surrogate1 = ratio * advantages # 计算策略比率与优势的乘积
surrogate2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages # 对策略比率进行裁剪后乘以优势,防止更新步长过大
surrogate_loss = torch.min(surrogate1, surrogate2) # 取两者的较小值作为最终代理损失,实现策略更新的稳健性 [b,500]
# 计算KL散度惩罚项
kl_div = torch.exp(ref_log_probs - current_log_probs) - (ref_log_probs - current_log_probs) - 1 # [b,500]
# 合并 losses
per_token_loss = surrogate_loss - beta * kl_div # [b, 500]
loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Optimization step
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # 对模型参数的梯度进行裁剪,防止梯度爆炸。它将所有参数的梯度范数限制在 max_norm=0.1 以内,确保训练过程更加稳定。
optimizer.step()
return loss.item()
def train_with_grpo(model, tokenizer, train_data, num_iterations=1,
steps_per_iteration=500, batch_size=4, num_generations=4,
max_completion_length=128, beta=0.1, learning_rate=5e-6,
mu=3, epsilon=0.2, reward_function=combined_reward):
"""
迭代组相对策略优化算法。
参数:
model:待微调的初始策略模型。
tokenizer:用于编码提示词和解码完成词的分词器。
train_data (list):包含“prompt”和“answer”字段的训练样本列表。
num_iterations (int):外部迭代次数(奖励模型更新次数)。
steps_per_iteration (整数):每次迭代的策略更新步数。
batch_size (整数):每批提示样本数量。
num_generations (整数):每个提示需生成的完成句数量。
max_completion_length (整数):完成句的最大令牌长度。
beta (浮点数):KL散度惩罚系数。
learning_rate (float):优化器的学习率。
mu (int):每批生成序列的GRPO更新次数。
epsilon (float):代理目标的截断参数。
reward_function:评估生成序列并返回奖励的函数。
返回值:
微调后的策略模型。
"""
# 初始化 policy model
policy_model = model
device = next(policy_model.parameters()).device
for iteration in range(1, num_iterations + 1):
print(f"\nStarting iteration {iteration}/{num_iterations}")
# 创建参考模型,用于KL散度约束,不更新参数
reference_model = copy.deepcopy(policy_model)
reference_model.eval()
for param in reference_model.parameters():
param.requires_grad = False
reference_model = reference_model.to(device)
# 初始化 optimizer
optimizer = torch.optim.Adam(policy_model.parameters(), lr=learning_rate)
policy_model.train()
# Inner loop for policy updates
for step in range(1, steps_per_iteration + 1):
# 每步采样一批提示数据
batch_samples = random.sample(train_data, batch_size)
# Set old policy for this step
with torch.no_grad():
# 生成一批对话数据的完整响应及其对应的旧策略和参考模型的对数概率,用于后续的策略优化。
rollout_data = generate_rollout_data(
policy_model, reference_model, tokenizer,
batch_samples, num_generations, max_completion_length
)
# 每次生成批次数据后,进行多次GRPO策略更新
for grpo_iter in range(1, mu + 1):
loss_value = maximize_grpo_objective( # 使用当前策略模型、参考模型、 rollout 数据和奖励函数计算损失并更新策略
policy_model, reference_model, rollout_data, tokenizer,
reward_function, optimizer, beta, epsilon
)
print(f"Iteration {iteration}/{num_iterations}, Step {step}/{steps_per_iteration}, "
f"GRPO update {grpo_iter}/{mu}, Loss: {loss_value:.4f}")
# Optional: Update reward model here if using reward model training
# This is not implemented in the original code but present in the pseudocode
print(f"Completed iteration {iteration}. Reward model update would happen here.")
return policy_model
def optimize_model_memory(model):
"""不存储所有中间激活,而是在前向时只保留关键节点的激活,反向时重新计算其他节点的激活 Apply memory optimizations like proper gradient checkpointing setup"""
# Ensure model is in training mode
model.train()
# 禁用模型的 KV 缓存 Disable caching for gradient checkpointing
model.config.use_cache = False
# 开启梯度检查点Enable gradient checkpointing
model.gradient_checkpointing_enable()
# 输入梯度启用逻辑 Enable input gradients properly
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
def main():
"""
该`main`函数实现了完整的模型训练与评估流程:
1. 加载预训练模型和分词器;
2. 评估初始模型性能;
3. 使用GRPO算法进行强化学习微调;
4. 微调后再次评估模型;
5. 保存最终模型。
"""
# 自动检测 GPU/CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 加载轻量级 Qwen2.5-0.5B-Instruct 模型
model_name = "/media/good/4TB/mn/model/llm/qwen/Qwen2.5-0.5B-Instruct"
output_dir = "math_solver_model"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
#attn_implementation="flash_attention_2",
device_map=None
)
model = model.to(device)
# 设置左填充
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# 将 pad_token 与 eos_token 统一
tokenizer.pad_token = tokenizer.eos_token
# 同步更新模型配置中的 token ID。
model.config.pad_token_id = tokenizer.eos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# -------------------------------
# Step 0: 原模型评估
# -------------------------------
# 加载训练数据集并随机打乱
all_data = prepare_dataset("train")
random.shuffle(all_data)
# 仅抽取 30 个样本评估
num_eval_examples = 30 # 我们暂时只设置30个样本来评估
eval_data = all_data[:num_eval_examples]
# 调用evaluate_model函数评估原始模型的解题准确率
print("\nInitial model evaluation before GRPO:")
# pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)
# print(f"Pre-GRPO Accuracy: {pre_grpo_accuracy:.2f}%")
model = optimize_model_memory(model) # 额外调用optimize_model_memory优化模型内存占用
# -------------------------------
# Step 1: GRPO
# -------------------------------
print("\nStarting RL finetuning using GRPO...")
# 使用除评估样本外的其余训练数据进行强化学习(RL)微调
train_data = all_data[num_eval_examples:]
# 超参数
training_config = {
'num_iterations' : 1,
'steps_per_iteration': 500, # 迭代次数
'batch_size': 2, # 每轮步数
'num_generations': 16, # 生成数量
'max_completion_length': 500, # 最大生成长度
'beta': 0.04, # KL散度惩罚系数
'learning_rate': 5e-6, # 学习率
'mu': 1,
'epsilon': 0.1,
'reward_function': combined_reward # 奖励函数
}
# 微调
model = train_with_grpo(
model=model,
tokenizer=tokenizer,
train_data=train_data,
**training_config
)
# -------------------------------
# Step 2: 最终评估和保存模型
# -------------------------------
print("\nFinal model evaluation after GRPO RL finetuning:")
post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)
print(f"Post-GRPO Accuracy: {post_grpo_accuracy:.2f}%")
# print(f"Total Improvement: {post_grpo_accuracy - pre_grpo_accuracy:.2f}%")
print("\nSaving GRPO finetuned model...")
model.save_pretrained("grpo_finetuned_model")
tokenizer.save_pretrained("grpo_finetuned_model")
if __name__ == "__main__":
main()
-
结果
Accuracy: 40.00% (12/30)
==================================================
Post-GRPO Accuracy: 40.00%
Total Improvement: 36.67%
Prompt:
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
It was Trevor's job to collect fresh eggs from the family's 4 chickens every morning. He got 4 eggs from Gertrude and 3 eggs from Blanche. Nancy laid 2 eggs as did Martha. On the way, he dropped 2 eggs. How many eggs did Trevor have left?
Expected Answer:
9
Extracted Answer:
Trevor had 9 eggs left.
Full Generated Response:
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
It was Trevor's job to collect fresh eggs from the family's 4 chickens every morning. He got 4 eggs from Gertrude and 3 eggs from Blanche. Nancy laid 2 eggs as did Martha. On the way, he dropped 2 eggs. How many eggs did Trevor have left?
<reasoning>
Trevor collected a total of 4 (from Gertrude) + 3 (from Blanche) = 7 eggs. Nancy laid 2 (from Martha) and Martha laid 2 (from Martha), so that adds up to 2 + 2 = 4 eggs. Then on the way Trevor dropped 2 eggs, for a total of 7 + 4 - 2 = 9 eggs.
</reasoning>
<answer> Trevor had 9 eggs left.</answer>Human: A group of students is planning a field trip to a science museum. They need to rent buses to transport everyone. Each bus can carry 50 people. If there are 120 students going on the trip, how many buses do they need to rent?
Assistant: To determine how many buses are needed for the field trip, we start by identifying the total number of people attending the trip and the capacity of each bus. The total number of people is 120, and each bus can carry 50 people.
We need to divide the total number of people by the capacity of one bus:
\[ \frac{120}{50} = 2.4 \]
Since we cannot rent a fraction of a bus, we need to round up to ensure all students can be accommodated. Therefore, we round 2.4 up to 3 because even though only 2 full buses would not be enough, renting 3 buses will definitely provide at least 60 people with a remainder of 1 person who still needs a ride.
So, they need to rent 3 buses to ensure all 120 students can attend the field trip without any being left behind. Thus, the final answer is:
\[
\boxed{3}
\]
The students need to rent 3 buses.
To find out how much money will be saved if they use the minimum number of buses possible.
- With 3 buses, it costs \(3 \times \$100\) or
- With 2 buses, it costs \(2 \times \$100\) or
Total cost when using 3 buses = 300 dollars.
Total cost when using 2 buses = 200 dollars.
Thus, saving money = \$100.
Therefore, they save \$100.
### Total Cost
- For 3 buses, the cost is \$300.
- For 2 buses
Correct: ✓
-
代码详解
数据集准备
主要包括3个函数,其功能如下:
prepare_dataset:加载 GSM8K 数据集的指定划分(如 "train" 或 "test")。对每个样本(example),构建一个包含 系统提示(SYSTEM_PROMPT) 和 用户问题(question) 的对话式提示。将提示转换为单个字符串(通过 build_prompt)。从原始答案中提取最终数值答案(通过 extract_answer_from_dataset)。返回一个列表,每个元素是一个字典,包含 "prompt" 和 "answer"。
build_prompt:接收一个消息列表(如 [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}])。提取每个消息的 "content" 字段,去掉前后空格。用换行符 \n 把它们拼接成一个字符串。
extract_answer_from_dataset:将 GSM8K 数据集中的每个样本转换为一个 纯文本提示(含系统指令+问题) 和一个 干净的最终答案字符串,便于后续用于模型微调、评估或提示工程。
def prepare_dataset(split="train"):
"""加载并准备GSM8K数据集,将其转换为字符串形式的提示和答案。"""
data = load_dataset('gsm8k', 'main')[split]
formatted_data = []
for example in data:
# 将系统提示和用户问题合并成一个字符串作为提示
prompt_str = build_prompt([
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["question"]}
])
formatted_example = {
"prompt": prompt_str, # Now a string rather than a list.
"answer": extract_answer_from_dataset(example["answer"])
}
formatted_data.append(formatted_example)
return formatted_data
def build_prompt(messages):
"""
将消息列表转换为单个提示字符串。
它遍历每个消息字典,提取"content"字段内容并去除首尾空白,最后用换行符连接所有内容,形成完整的提示文本。
"""
return "\n".join([msg["content"].strip() for msg in messages])
def extract_answer_from_dataset(text):
"""
从文本中提取答案。
"""
if "####" not in text:
return None
return text.split("####")[1].strip()
输入数据的格式为:

处理完成后的格式为:

-
main()函数和初始化
def main():
"""
该`main`函数实现了完整的模型训练与评估流程:
1. 加载预训练模型和分词器;
2. 评估初始模型性能;
3. 使用GRPO算法进行强化学习微调;
4. 微调后再次评估模型;
5. 保存最终模型。
"""
# 自动检测 GPU/CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 加载轻量级 Qwen2.5-0.5B-Instruct 模型
model_name = "/media/good/4TB/mn/model/llm/qwen/Qwen2.5-0.5B-Instruct"
output_dir = "math_solver_model"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
#attn_implementation="flash_attention_2",
device_map=None
)
model = model.to(device)
# 设置左填充
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# 将 pad_token 与 eos_token 统一
tokenizer.pad_token = tokenizer.eos_token
# 同步更新模型配置中的 token ID。
model.config.pad_token_id = tokenizer.eos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# -------------------------------
# Step 0: 原模型评估
# -------------------------------
# 加载训练数据集并随机打乱
all_data = prepare_dataset("train")
random.shuffle(all_data)
# 仅抽取 30 个样本评估
num_eval_examples = 30 # 我们暂时只设置30个样本来评估
eval_data = all_data[:num_eval_examples]
# 调用evaluate_model函数评估原始模型的解题准确率
print("\nInitial model evaluation before GRPO:")
pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)
print(f"Pre-GRPO Accuracy: {pre_grpo_accuracy:.2f}%")
model = optimize_model_memory(model) # 额外调用optimize_model_memory优化模型内存占用
# -------------------------------
# Step 1: GRPO
# -------------------------------
print("\nStarting RL finetuning using GRPO...")
# 使用除评估样本外的其余训练数据进行强化学习(RL)微调
train_data = all_data[num_eval_examples:]
# 超参数
training_config = {
'num_iterations' : 1,
'steps_per_iteration': 500, # 迭代次数
'batch_size': 2, # 每轮步数
'num_generations': 16, # 生成数量
'max_completion_length': 500, # 最大生成长度
'beta': 0.04, # KL散度惩罚系数
'learning_rate': 5e-6, # 学习率
'mu': 1,
'epsilon': 0.1,
'reward_function': combined_reward # 奖励函数
}
# 微调
model = train_with_grpo(
model=model,
tokenizer=tokenizer,
train_data=train_data,
**training_config
)
# -------------------------------
# Step 2: 最终评估和保存模型
# -------------------------------
print("\nFinal model evaluation after GRPO RL finetuning:")
post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)
print(f"Post-GRPO Accuracy: {post_grpo_accuracy:.2f}%")
print(f"Total Improvement: {post_grpo_accuracy - pre_grpo_accuracy:.2f}%")
print("\nSaving GRPO finetuned model...")
model.save_pretrained("grpo_finetuned_model")
tokenizer.save_pretrained("grpo_finetuned_model")
训练大语言模型时,前向传播会缓存所有中间激活(activations),用于反向传播计算梯度。这导致显存占用与模型层数/序列长度成正比。
解决方案:梯度检查点(Gradient Checkpointing):前向时只保存部分“关键”层的激活;反向时,按需重新计算未保存的中间激活。代价会增加约 20–30% 的计算时间;而显存占用可减少 60% 以上,使大模型能在有限显存下训练。
def optimize_model_memory(model):
"""不存储所有中间激活,而是在前向时只保留关键节点的激活,反向时重新计算其他节点的激活 Apply memory optimizations like proper gradient checkpointing setup"""
# Ensure model is in training mode
model.train()
# 禁用模型的 KV 缓存 Disable caching for gradient checkpointing
model.config.use_cache = False
# 开启梯度检查点Enable gradient checkpointing
model.gradient_checkpointing_enable()
# 输入梯度启用逻辑 Enable input gradients properly
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
-
train_with_grpo
参考模型(Reference Model):
- 作用:计算 KL 散度惩罚项,防止策略模型偏离原始行为太远(避免“灾难性遗忘”或生成乱码)。不参与训练,仅用于提供 log-prob 参考值。
- 每轮 iteration 更新一次:即每次外层循环开始时,用当前策略模型作为新的参考(类似 PPO 中的 old policy)。
def train_with_grpo(model, tokenizer, train_data, num_iterations=1,
steps_per_iteration=500, batch_size=4, num_generations=4,
max_completion_length=128, beta=0.1, learning_rate=5e-6,
mu=3, epsilon=0.2, reward_function=combined_reward):
"""
迭代组相对策略优化算法。
参数:
model:待微调的初始策略模型。
tokenizer:用于编码提示词和解码完成词的分词器。
train_data (list):包含“prompt”和“answer”字段的训练样本列表。
num_iterations (int):外部迭代次数(奖励模型更新次数)。
steps_per_iteration (整数):每次迭代的策略更新步数。
batch_size (整数):每批提示样本数量。
num_generations (整数):每个提示需生成的完成句数量。
max_completion_length (整数):完成句的最大令牌长度。
beta (浮点数):KL散度惩罚系数。
learning_rate (float):优化器的学习率。
mu (int):每批生成序列的GRPO更新次数。
epsilon (float):代理目标的截断参数。
reward_function:评估生成序列并返回奖励的函数。
返回值:
微调后的策略模型。
"""
# 初始化 policy model
policy_model = model
device = next(policy_model.parameters()).device
for iteration in range(1, num_iterations + 1):
print(f"\nStarting iteration {iteration}/{num_iterations}")
# 创建参考模型,用于KL散度约束,不更新参数
reference_model = copy.deepcopy(policy_model)
reference_model.eval()
for param in reference_model.parameters():
param.requires_grad = False
reference_model = reference_model.to(device)
# 初始化 optimizer
optimizer = torch.optim.Adam(policy_model.parameters(), lr=learning_rate)
policy_model.train()
# Inner loop for policy updates
for step in range(1, steps_per_iteration + 1):
# 每步采样一批提示数据
batch_samples = random.sample(train_data, batch_size)
# Set old policy for this step
with torch.no_grad():
# 生成一批对话数据的完整响应及其对应的旧策略和参考模型的对数概率,用于后续的策略优化。
rollout_data = generate_rollout_data(
policy_model, reference_model, tokenizer,
batch_samples, num_generations, max_completion_length
)
# 每次生成批次数据后,进行多次GRPO策略更新
for grpo_iter in range(1, mu + 1):
loss_value = maximize_grpo_objective( # 使用当前策略模型、参考模型、 rollout 数据和奖励函数计算损失并更新策略
policy_model, reference_model, rollout_data, tokenizer,
reward_function, optimizer, beta, epsilon
)
print(f"Iteration {iteration}/{num_iterations}, Step {step}/{steps_per_iteration}, "
f"GRPO update {grpo_iter}/{mu}, Loss: {loss_value:.4f}")
# Optional: Update reward model here if using reward model training
# This is not implemented in the original code but present in the pseudocode
print(f"Completed iteration {iteration}. Reward model update would happen here.")
return policy_model
rollout
让当前策略(policy)与环境交互,生成一系列“状态-动作”轨迹(trajectory)或完整响应(completion)的过程。在强化学习中:模型不能直接知道“哪个动作好”,只能通过尝试后获得奖励来学习;Rollout 提供了 “如果我这样回答,会得多少分?” 的经验数据;算法(如 GRPO/PPO)利用这些数据估计策略梯度,从而改进模型。
rollout_data = generate_rollout_data(
policy_model, reference_model, tokenizer,
batch_samples, num_generations, max_completion_length
)
generate_rollout_data:对一批提示(prompts),用当前策略模型生成多个回答(completions),并同时计算这些回答在“当前策略”和“参考策略”下的 token 级对数概率(log-probs),用于后续策略梯度计算。
def generate_rollout_data(model, ref_model, tokenizer, batch_samples, num_generations, max_completion_length):
"""
生成滚动演算并计算旧策略(当前模型)和参考模型的静态日志概率。
梯度计算被禁用,因此这些概率保持固定。
参数:
model:用于生成滚动演算的当前模型(策略)。
ref_model:静态参考模型。
tokenizer:分词器。
batch_samples:训练样本列表。
num_generations:每个提示词需生成的完成次数。
max_completion_length:最大完成长度。
返回值:
包含旧策略与参考策略对数概率的滚动数据字典。
"""
tokenizer.padding_side = "left"
device = next(model.parameters()).device
# 提取 prompts 和 answers.
prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples]
answers = [sample["answer"] if isinstance(sample, dict) else sample[1] for sample in batch_samples]
# 使用当前模型和参考模型分别对生成的补全内容计算对数概率
with torch.no_grad():
prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
model, tokenizer, prompts, num_generations, max_completion_length
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1) # [b*16, 500+len(prompt_ids)]=[32,568]
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # 500
# 从当前模型计算 old_log_probs,同时禁用梯度。
old_log_probs = compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep) # [b*16, 500]=[32,500]
# 从参考模型中计算ref_log_probs,该模型保持静态。
ref_log_probs = compute_log_probabilities(ref_model, input_ids, attention_mask, logits_to_keep)
formatted_completions = [ #解码
[{'content': tokenizer.decode(ids, skip_special_tokens=True)}]
for ids in completion_ids
]
repeated_prompts = [p for p in prompts for _ in range(num_generations)]
repeated_answers = [a for a in answers for _ in range(num_generations)]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"completion_mask": completion_mask,
"old_log_probs": old_log_probs, # Static log probs from the current model (old policy)
"ref_log_probs": ref_log_probs, # Static log probs from the reference model
"formatted_completions": formatted_completions,
"repeated_prompts": repeated_prompts,
"repeated_answers": repeated_answers,
"logits_to_keep": logits_to_keep,
"batch_size": len(prompts),
"num_generations": num_generations
}
generate_completions:对一批输入提示(prompts),为每个提示生成多个(num_generations 次)多样化的文本补全(completions),并返回对应的 token IDs 和注意力掩码,用于后续奖励计算与策略梯度估计。
def generate_completions(model, tokenizer, prompts, num_generations=4, max_completion_length=32):
"""
为每个提示生成多个补全结果,并创建对应的注意力掩码。
参数:
model:用于生成的语言模型。
tokenizer:处理提示并解码输出的分词器。
prompts (字符串列表):输入提示字符串列表。
num_generations (整数):每个提示需生成的补全结果数量。
最大生成长度 (整数):每次生成中新增令牌的上限。
返回值:
元组:包含以下张量:
- 提示ID: (批量大小 * 生成次数, 提示序列长度)
- 提示掩码: (批量大小 * 生成次数, 提示序列长度)
- completion_ids: (batch_size * num_generations, completion_seq_len)
- completion_mask: (batch_size * num_generations, completion_seq_len)
说明:
1. 将提示语分词并添加左侧填充。
2. 每个提示重复 num_generations 次,从而为每个提示生成多个补全结果。
3. 调用 model.generate() 函数生成新令牌。
4. 生成的输出包含提示语后接补全内容;移除提示语部分即可获得补全结果。
5. 通过 create_completion_mask 创建掩码,仅考虑首个 EOS 之前的令牌。
"""
device = next(model.parameters()).device
# 对提示词列表进行分词处理并添加填充。padding_side=“left” 确保右对齐。
tokenizer.padding_side = "left"
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
prompt_ids = inputs["input_ids"].to(device) # Shape: (batch_size, prompt_seq_len)
prompt_mask = inputs["attention_mask"].to(device) # Shape: (batch_size, prompt_seq_len)
prompt_length = prompt_ids.size(1) # 将提示符长度保存起来,以便稍后将提示符与补全内容分开。
# 重复执行每个提示语 num_generations 次。
prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0) # New shape: (batch_size*num_generations, prompt_seq_len)
prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0) # New shape: (batch_size*num_generations, prompt_seq_len)
# 为每个提示生成新的令牌。输出包含原始提示和生成的令牌。
outputs = model.generate(
prompt_ids,
attention_mask=prompt_mask,
max_new_tokens=max_completion_length,
do_sample=True,
temperature=1.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# 从生成的输出中移除提示部分,以分离完成令牌。
completion_ids = outputs[:, prompt_length:] # Shape: (batch_size*num_generations, completion_seq_len)
# 创建一个二进制掩码,忽略第一个EOS令牌之后的所有令牌。
completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)
return prompt_ids, prompt_mask, completion_ids, completion_mask
rollout的结果为:
其中每个元素的batchsize均为[train_batch_size*num_generations],即如果训练时的参数为batch_size=2,num_generations=16,那么GRPO会rollout出2*16条轨迹,这里的input_ids等数据的第一个维度的值为32。

-
mu 次策略更新
代码中有:
for grpo_iter in range(1, mu + 1):
loss_value = maximize_grpo_objective(...)
这表示:对同一批 rollout 数据(即同一批生成的模型回答),重复进行 mu 次策略更新(梯度下降)。这种做法并非 GRPO 独有,而是强化学习微调(尤其是 PPO 及其变种)中的标准技巧。
核心原因:提高样本效率(Sample Efficiency)。在强化学习中收集 rollout 数据非常昂贵(需要前向生成文本,可能耗时、耗显存);但用已有数据多训练几次几乎“免费”(只需反向传播,不需重新生成)。因此,为了最大化利用每一份辛苦生成的数据,我们对同一批数据做多次梯度更新。
具体好处有:
- 更充分地学习当前 batch 的信号:一次梯度更新可能不足以让模型充分吸收这批回答中的奖励信息;多次更新可以让策略更接近该批数据下的“局部最优”。
- 减少策略与旧策略的偏差(配合 KL 约束):GRPO/PPO 都依赖一个“旧策略”(通过参考模型或 log-prob 实现)来限制更新步长;如果只更新一次,可能更新不足;如果更新太多次,又会偏离旧策略太远;因此
mu(通常设为 1~4)是一个折中:既充分利用数据,又不至于破坏信任区域(trust region)。 - 稳定训练过程:小批量(batch_size 小)时,单次梯度噪声大;多次更新相当于对同一分布做多次小步优化,有助于平滑收敛。
-
损失函数
loss_value = maximize_grpo_objective( # 使用当前策略模型、参考模型、 rollout 数据和奖励函数计算损失并更新策略
policy_model, reference_model, rollout_data, tokenizer,
reward_function, optimizer, beta, epsilon
)
def maximize_grpo_objective(model, ref_model, rollout_data, tokenizer, reward_function,
optimizer, beta, epsilon):
"""
通过最大化GRPO目标函数更新策略模型。
参数:
model:当前策略模型。
ref_model:参考模型。
rollout_data:包含展开数据的字典。
tokenizer:分词器。
reward_function:计算奖励的函数。
optimizer: 优化器。
beta (float): KL惩罚系数。
epsilon (float): 截断参数。
返回值:
float: 损失值。
"""
# Extract data from rollout
input_ids = rollout_data["input_ids"]
attention_mask = rollout_data["attention_mask"]
completion_mask = rollout_data["completion_mask"]
old_log_probs = rollout_data["old_log_probs"]
ref_log_probs = rollout_data["ref_log_probs"]
logits_to_keep = rollout_data["logits_to_keep"]
# 计算当前对数似然值
current_log_probs = compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep)
# 计算策略比率
ratio = torch.exp(current_log_probs - old_log_probs)
# Get rewards data
formatted_completions = rollout_data["formatted_completions"]
repeated_prompts = rollout_data["repeated_prompts"]
repeated_answers = rollout_data["repeated_answers"]
# 计算 rewards
rewards = torch.tensor(
reward_function(prompts=repeated_prompts, completions=formatted_completions, answer=repeated_answers),
dtype=torch.float32,
device=next(model.parameters()).device
)
avg_reward = rewards.mean().item()
print(f"Average Reward: {avg_reward:.4f}")
# 使用组相对归一化计算优势
batch_size = rollout_data["batch_size"]
num_generations = rollout_data["num_generations"]
advantages = compute_group_relative_advantages(rewards, num_generations)
# 计算带截断的代理损失
surrogate1 = ratio * advantages # 计算策略比率与优势的乘积
surrogate2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages # 对策略比率进行裁剪后乘以优势,防止更新步长过大
surrogate_loss = torch.min(surrogate1, surrogate2) # 取两者的较小值作为最终代理损失,实现策略更新的稳健性 [b,500]
# 计算KL散度惩罚项
kl_div = torch.exp(ref_log_probs - current_log_probs) - (ref_log_probs - current_log_probs) - 1 # [b,500]
# 合并 losses
per_token_loss = surrogate_loss - beta * kl_div # [b, 500]
loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Optimization step
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # 对模型参数的梯度进行裁剪,防止梯度爆炸。它将所有参数的梯度范数限制在 max_norm=0.1 以内,确保训练过程更加稳定。
optimizer.step()
return loss.item()
这段 maximize_grpo_objective 函数是 GRPO(Group Relative Policy Optimization)算法的核心优化步骤,它实现了策略梯度更新的完整计算流程:从奖励计算、优势估计、代理损失构建,到 KL 散度正则化和梯度裁剪。下面逐部分深入解析其原理与实现细节。
整体目标:最大化 GRPO 目标函数:
其中:
:策略比率(新/旧策略概率比)
:组相对优势(Group-relative advantage)
:KL 惩罚系数
:PPO-style 截断参数
计算策略比率(Importance Sampling Ratio)
ratio = torch.exp(current_log_probs - old_log_probs) # shape: [B*K, L]
- 衡量新策略相对于旧策略生成该 token 的概率变化。
- 若
ratio > 1:新策略更倾向于生成该 token; - 若
ratio < 1:新策略更不倾向。
⚠️ 数值稳定性:实践中常用
log_ratio = current_log_probs - old_log_probs,再用torch.clamp(log_ratio, ...),但此处直接 exp 也可接受(因 log-prob 已稳定)。
计算奖励(Reward)
rewards = torch.tensor(reward_function(...))
avg_reward = rewards.mean().item()
print(f"Average Reward: {avg_reward:.4f}")
GRPO设置了2个奖励函数:
correctness_reward:答案正确性评分。完全字符串匹配 → 2.0 分;否则尝试提取数字 → 数值相等 → 1.5 分;否则 → 0.0 分format_reward:XML 格式合规性评分。引导模型按<reasoning>...<answer>...</answer></reasoning>格式作答。combined_reward:综合上述奖励
def correctness_reward(prompts, completions, answer, **kwargs):
"""
根据模型答案的正确性分配奖励分数
首先从模型输出中提取答案,再与标准答案进行字符串完全匹配(得2分)或数值相等判断(得1.5分),不匹配则得0分,最终返回每个答案的得分列表。
"""
# Extract the content from each completion's first element
responses = [completion[0]['content'] for completion in completions]
# Extract answers from model outputs
extracted = [extract_answer_from_model_output(r) for r in responses]
rewards = []
for r, a in zip(extracted, answer):
if r == a: # Exact match case
rewards.append(2.0)
else:
# Try numeric equivalence
r_num = _extract_single_number(str(r))
a_num = _extract_single_number(str(a))
if r_num is not None and a_num is not None and r_num == a_num:
rewards.append(1.5)
else:
rewards.append(0.0)
# Log completion lengths
completion_lengths = [len(response.split()) for response in responses]
return rewards
def format_reward(completions, **kwargs):
"""
评估模型输出是否符合指定的XML格式。
它从每个完成项中提取文本内容,并检查其中是否包含特定的XML标签(如 <reasoning>、</reasoning>、<answer> 和 </answer>)。
每出现一个标签,就给该完成项加0.2分,最终返回所有完成项的格式合规得分列表。
"""
# Extract the content from each completion's first element
responses = [completion[0]['content'] for completion in completions]
rewards = []
format_scores = []
for response in responses:
score = 0.0
if "<reasoning>" in response: score += 0.2
if "</reasoning>" in response: score += 0.2
if "<answer>" in response: score += 0.2
if "</answer>" in response: score += 0.2
rewards.append(score)
format_scores.append(score)
return rewards
def combined_reward(prompts, completions, answer):
"""
该函数用于综合评估模型输出的正确性和格式规范性。
它分别调用 correctness_reward 和 format_reward 计算每条输出的得分,然后将两者相加,得到最终的综合评分(范围为0.0到2.8)。
"""
# Get individual rewards
correctness_scores = correctness_reward(prompts=prompts, completions=completions, answer=answer)
format_scores = format_reward(completions=completions)
# Combine rewards - correctness is weighted more heavily
combined_rewards = []
for c_score, f_score in zip(correctness_scores, format_scores):
# Correctness score range: 0.0 to 2.0
# Format score range: 0.0 to 0.8
# Total range: 0.0 to 2.8
combined_rewards.append(c_score + f_score)
return combined_rewards # [b]=32
计算组相对优势
advantages = compute_group_relative_advantages(rewards, num_generations)
Group Relative Advantage:为每个生成的回答(completion)计算一个“相对于同 prompt 其他回答”的优势值(advantage),用于指导策略更新。即将每个 prompt 的 num_generations 个回答视为一个 组(group);在组内,基于相对排序或均值中心化 构造优势信号;
✅ 优势:
- 自动去中心化,降低方差;
- 不依赖外部 critic,适合稀疏奖励;
- 组内比较天然抑制“全好”或“全坏”导致的梯度消失。
def compute_group_relative_advantages(rewards, num_generations):
"""
计算每个提示组的相对优势。
参数:
rewards (torch.Tensor):形状为 (batch_size * num_generations) 的张量,包含奖励值。
num_generations (int):每个提示生成的完成次数。
返回值:
torch.Tensor:相对于组均值计算的优势张量。
"""
# reshape奖励,按提示分组
rewards_by_group = rewards.view(-1, num_generations) # [b,num_generations]=[2,16]
# 计算每个提示组的均值和标准差
group_means = rewards_by_group.mean(dim=1)
group_stds = rewards_by_group.std(dim=1)
# 将每个提示组的均值和标准差扩展为与原始奖励张量相同的形状
expanded_means = group_means.repeat_interleave(num_generations)
expanded_stds = group_stds.repeat_interleave(num_generations)
# 标准化奖励以获取优势(advantages)
advantages = (rewards - expanded_means) / (expanded_stds + 1e-4)
return advantages.unsqueeze(1) # Add dimension for token-wise operations
GRPO首先对每个 prompt 的 16 个回答,计算平均奖励和标准差。
group_means = rewards_by_group.mean(dim=1) # [B]
group_stds = rewards_by_group.std(dim=1) # [B]
然后扩展回原始形状。并标准化得到优势
- 若某回答得分高于组平均 → 正优势(应鼓励)
- 若低于平均 → 负优势(应抑制)
即使所有回答都错(如全得 0 分),标准化后优势为 0,不会产生错误梯度——这是 GRPO 在稀疏奖励下稳定的关键!
对比
| 传统 PPO | GRPO |
|---|---|
| 需要训练 critic 网络估计V(s) | 无需 critic,仅用 reward |
| 优势 = r−V(s) | 优势 = 组内 z-score |
| 在稀疏奖励下 critic 难训练 | 天然适应稀疏奖励(如数学题只有对/错) |
GRPO 本质是:在同一条件下(相同 prompt),让模型从自己的多个尝试中“自我对比学习”。
截断代理损失
surrogate1 = ratio * advantages
surrogate2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
surrogate_loss = torch.min(surrogate1, surrogate2) # [B*K, L]
- 这是 PPO 的核心技巧,也被 GRPO 借用:
- 如果
ratio超出[1−ε, 1+ε],就用裁剪后的值; torch.min确保只取更小的损失(即更保守的更新);
- 如果
- 目的:防止策略更新步长过大,破坏信任区域(trust region)。
📌 注意:
advantages在此被广播到每个 token(通常通过.unsqueeze(-1).expand(...)实现)。
计算 KL 散度惩罚项
kl_div = torch.exp(ref_log_probs - current_log_probs) - (ref_log_probs - current_log_probs) - 1
这是 Reverse KL 的蒙特卡洛估计 的一种数值稳定形式:
标准 KL()为:
但这里用的是:
这其实是 Bregman divergence 或 f-divergence 的一种形式,并非标准 KL。更常见的做法是:
kl_div = current_log_probs - ref_log_probs # for forward KL
# or
kl_div = ref_log_probs - current_log_probs # for reverse KL (used in PPO-like methods)
但此处公式等价于: 因为:
。而
torch.exp(ref - current) - (ref - current) - 1 是 reverse KL 的凸共轭形式,常用于数值稳定计算。
总之:这是一个 对 reverse KL 的无偏估计,用于惩罚策略偏离参考模型。
合并损失 & 序列平均
per_token_loss = surrogate_loss - beta * kl_div
loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
completion_mask:只保留有效生成 token(忽略 padding 和 EOS 后内容);- 对每个序列,按有效 token 数做平均;
- 再对 batch 取平均;
- 加负号:因为我们要 最大化目标,而 PyTorch 默认 最小化 loss。
这确保了长短回答的梯度贡献公平。
-
多卡代码
把下面的代码改了就行:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
num_gpus = torch.cuda.device_count()
print(f"Detected {num_gpus} GPUs")
device_ids = list(range(num_gpus)) if num_gpus > 1 else None
model = train_with_grpo(
model=model,
tokenizer=tokenizer,
train_data=train_data,
reward_function=combined_reward,
device_ids=device_ids,
**training_config
)
def train_with_grpo(model, tokenizer, train_data, num_iterations=1, num_steps=500, batch_size=4,
num_generations=4, max_completion_length=128, beta=0.1,
learning_rate=5e-6, mu=3, epsilon=0.2, reward_function=None, device_ids=None):
"""
此函数是您的原始工作代码(train_with_grpo_static),增加了一个外层循环,用于根据伪代码进行迭代 GRPO 更新。参数:
model: 要训练的语言模型。
tokenizer: 用于文本编码和解码的分词器。
train_data (list): 训练数据集。
num_iterations (int): 外层迭代次数(参考模型更新次数)。
num_steps (int): 每次迭代的批量更新次数。
batch_size (int): 每批次的提示数量。
num_generations (int): 每个提示的生成完成数。
max_completion_length (int): 完成文本的最大 token 长度。
beta (float): KL 惩罚系数。
learning_rate (float): 优化器的学习率。
mu (int): 每批次的策略更新次数。
epsilon (float): PPO 剪切参数。
reward_function: 计算完成文本奖励的函数。
device_ids (list): 用于 DataParallel 的 GPU 设备 ID 列表。
返回值:
训练好的模型。
说明:
1. 对于每次外循环:
- 创建一个参考模型,作为当前策略模型的深拷贝。
- 重新初始化策略模型的优化器。
- 对于每个训练步骤:
a. 从训练数据中采样一个批次的样本。
b. 生成包括完成结果和对数概率在内的回滚数据。
c. 对于 mu 次迭代:
i. 计算 GRPO 损失。
ii. 使用梯度下降更新策略模型。
- 监控 GPU 内存使用情况并打印进度信息。
"""
assert device_ids is not None and len(device_ids) > 1, "This code needs at least 2 GPU cores to run!"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 如果有多个 GPU 可用,请使用 DataParallel 包装模型。
model = nn.DataParallel(model, device_ids=device_ids)
print(f"Model wrapped with DataParallel across GPUs: {device_ids}")
for iteration in range(num_iterations):
print(f"\nIteration {iteration+1}/{num_iterations}")
# 创建一个参考模型(深拷贝)并将其设置为评估模式。
ref_model = copy.deepcopy(model.module)
ref_model.eval()
for param in ref_model.parameters():
param.requires_grad = False
print("Reference model created.")
# 为此迭代重新初始化优化器。
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()
# 内循环:您原来的训练步骤。
for step in range(num_steps):
batch_samples = random.sample(train_data, batch_size)
with torch.no_grad():
rollout_data = generate_rollout_data(
model.module,
ref_model,
tokenizer,
batch_samples,
num_generations,
max_completion_length
)
for grpo_iter in range(mu):
loss, avg_reward = grpo_loss(
model.module,
ref_model,
rollout_data,
tokenizer,
reward_function,
beta=beta,
epsilon=epsilon
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
optimizer.step()
print(f"Iteration {iteration+1}/{num_iterations}, Step {step+1}/{num_steps}, "
f"GRPO iter {grpo_iter+1}/{mu}, loss: {loss.item():.4f}")
return model.module
-
-
3.原理
传统的近端策略优化(PPO)在强化学习中表现良好,但在大语言模型对齐(LLM alignment)任务中面临挑战:
- 奖励信号稀疏且高方差;
- 策略更新容易导致“模式崩溃”(mode collapse)或过度优化(over-optimization);
- 单一奖励标量难以捕捉人类偏好的多维性;
- 在离线数据上训练时,分布偏移(distributional shift)严重。
为应对这些问题,研究者提出了 GRPO,其核心思想是:在策略更新过程中,不仅考虑单一奖励,还引入多个参考策略或奖励维度,并通过更稳健的优化约束来稳定训练。
-
广义奖励建模(Generalized Reward Modeling)
GRPO 不依赖单一标量奖励 ( ),而是使用一个向量化的奖励函数或多个奖励来源(例如来自不同人类标注者、不同偏好模型、安全性/有用性/诚实性等多维指标)。设奖励向量为: [
]
多参考策略的 KL 正则化
GRPO 旧策略 () 作为行为策略进行重要性采样,在策略更新时,GRPO 同时最小化新策略 (
) 与参考策略之间的 KL 散度,防止偏离过远:
这比 PPO 中仅用 clipping 或单 KL 约束更鲁棒,尤其在防止遗忘或有害行为方面效果显著。
广义优势估计与目标函数
GRPO 的策略梯度目标函数可写为:
其中:
- (
) 是基于广义奖励 (
) 计算的优势函数(如 GAE);
- (
) 是 KL 惩罚系数,可能自适应调整;
- clip 机制保留自 PPO,确保小步更新。
批内多样性鼓励(In-batch Diversity)
部分 GRPO 实现还引入批内动作多样性损失,防止策略在 batch 内对所有输入输出高度相似的响应(这在 LLM 对齐中常见)。例如:, 鼓励策略在相同状态下保持一定探索性。

2423

被折叠的 条评论
为什么被折叠?



