verl数据预处理:训练数据集构建与清洗
概述
verl(Volcano Engine Reinforcement Learning)是一个专为大语言模型(LLM)设计的强化学习框架,数据预处理是其训练流程中的关键环节。本文将深入探讨verl的数据预处理机制,涵盖数据集构建、格式转换、质量清洗等核心内容,帮助开发者高效准备训练数据。
数据预处理架构
verl的数据预处理遵循统一的架构模式,主要包含以下核心组件:
核心数据字段解析
1. data_source字段
标识数据集的来源,用于在RewardModel中索引对应的奖励函数。
data_source = "openai/gsm8k" # 数据集标识
2. prompt字段
遵循HuggingFace chat_template格式,支持多轮对话结构:
prompt = [
{
"role": "user",
"content": "数学问题内容 Let's think step by step and output the final answer after '####'."
}
]
3. ability字段
定义任务类别,用于区分不同的能力维度:
| 能力类型 | 描述 | 适用数据集 |
|---|---|---|
| math | 数学推理能力 | GSM8K, MATH |
| alignment | 对齐能力 | Full_hh_rlhf |
| commonsense | 常识推理 | Hellaswag |
4. reward_model字段
包含奖励模型相关配置,目前主要使用ground_truth字段:
reward_model = {
"style": "rule", # 或 "model"
"ground_truth": "42" # 提取的正确答案
}
5. extra_info字段
记录额外的元数据信息:
extra_info = {
"split": "train", # 数据分割
"index": 123, # 数据索引
"question": "原始问题", # 原始问题文本
"answer": "完整答案" # 完整答案文本
}
数据预处理实战
GSM8K数据集处理
GSM8K是一个经典的数学推理数据集,verl提供了完整的预处理脚本:
import re
import os
import datasets
from verl.utils.hdfs_io import copy, makedirs
def extract_solution(solution_str):
"""从答案字符串中提取最终解"""
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution
def make_map_fn(split):
"""构建数据映射函数"""
def process_fn(example, idx):
question_raw = example.pop("question")
question = question_raw + " Let's think step by step and output the final answer after '####'."
answer_raw = example.pop("answer")
solution = extract_solution(answer_raw)
return {
"data_source": "openai/gsm8k",
"prompt": [{"role": "user", "content": question}],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {
"split": split,
"index": idx,
"answer": answer_raw,
"question": question_raw
}
}
return process_fn
Search-R1类数据集处理
对于需要工具调用的复杂任务,数据预处理需要额外的工具参数配置:
def process_single_row(row, current_split_name, row_index):
"""处理Search-R1格式数据"""
question = row.get("question", "")
# 构建工具调用参数
tools_kwargs = {
"search": {
"create_kwargs": {
"ground_truth": ground_truth,
"question": question,
"data_source": "searchR1_dataset"
}
}
}
return {
"data_source": "searchR1_tagged",
"prompt": prompt_structure,
"ability": row.get("ability"),
"reward_model": reward_model_data,
"extra_info": {
"index": row_index,
"need_tools_kwargs": True,
"tools_kwargs": tools_kwargs
}
}
数据质量清洗策略
1. 答案提取验证
使用正则表达式确保答案格式正确:
def validate_solution_extraction(solution_str):
"""验证答案提取的有效性"""
pattern = r"#### (\-?[0-9\.\,]+)"
match = re.search(pattern, solution_str)
if not match:
raise ValueError(f"Invalid solution format: {solution_str}")
return match.group(1).replace(",", "")
2. 数据完整性检查
确保所有必需字段都存在:
REQUIRED_FIELDS = ["data_source", "prompt", "ability", "reward_model"]
def validate_data_integrity(data_item):
"""验证数据完整性"""
missing_fields = [field for field in REQUIRED_FIELDS if field not in data_item]
if missing_fields:
raise ValueError(f"Missing required fields: {missing_fields}")
# 验证prompt格式
if not isinstance(data_item["prompt"], list) or len(data_item["prompt"]) == 0:
raise ValueError("Prompt must be a non-empty list")
3. 异常值处理
处理数值异常和格式问题:
def sanitize_numerical_value(value):
"""清理数值数据"""
try:
# 移除逗号并转换为浮点数
cleaned = float(str(value).replace(",", ""))
# 检查是否为有限数值
if not math.isfinite(cleaned):
return None
return cleaned
except (ValueError, TypeError):
return None
分布式存储支持
verl支持本地和HDFS分布式存储,提供统一的数据访问接口:
HDFS操作示例
from verl.utils.hdfs_io import copy, makedirs, exists
# 创建HDFS目录
makedirs("hdfs://path/to/dataset")
# 检查文件是否存在
if exists("hdfs://path/to/dataset/train.parquet"):
print("Dataset already exists in HDFS")
# 复制数据到HDFS
copy("~/data/gsm8k", "hdfs://path/to/dataset")
存储格式对比
| 存储方式 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| 本地存储 | 访问速度快,配置简单 | 容量有限,不易共享 | 单机开发测试 |
| HDFS存储 | 分布式,高可用,易扩展 | 配置复杂,网络依赖 | 生产环境,多节点训练 |
多数据集支持
verl目前支持的主流数据集预处理:
| 数据集 | 任务类型 | 预处理脚本 | 输出格式 |
|---|---|---|---|
| GSM8K | 数学推理 | gsm8k.py | Parquet |
| MATH | 数学推理 | math_dataset.py | Parquet |
| Hellaswag | 常识推理 | hellaswag.py | Parquet |
| Full_hh_rlhf | 对齐训练 | full_hh_rlhf.py | Parquet |
| Search-R1 | 工具调用 | preprocess_search_r1_dataset.py | Parquet |
最佳实践指南
1. 数据分区策略
# 按能力类型分区
ABILITY_PARTITIONS = {
"math": ["gsm8k", "math"],
"alignment": ["full_hh_rlhf"],
"commonsense": ["hellaswag"]
}
# 按数据来源分区
SOURCE_PARTITIONS = {
"openai": ["gsm8k"],
"meta": ["hellaswag"],
"huggingface": ["full_hh_rlhf"]
}
2. 内存优化处理
对于大规模数据集,使用迭代处理避免内存溢出:
def process_large_dataset(dataset, batch_size=1000):
"""分批处理大规模数据集"""
results = []
for i in range(0, len(dataset), batch_size):
batch = dataset.select(range(i, min(i + batch_size, len(dataset))))
processed_batch = batch.map(processing_function, batched=True)
results.append(processed_batch)
return concatenate_datasets(results)
3. 错误处理与重试机制
import time
from functools import wraps
def retry_on_failure(max_retries=3, delay=1):
"""失败重试装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
if attempt == max_retries - 1:
raise
time.sleep(delay * (2 ** attempt))
return func(*args, **kwargs)
return wrapper
return decorator
@retry_on_failure()
def safe_data_processing(data_item):
"""安全的数据处理函数"""
# 处理逻辑
pass
性能优化技巧
1. 并行处理
from multiprocessing import Pool
import datasets
def parallel_process_dataset(dataset, num_processes=4):
"""并行处理数据集"""
with Pool(num_processes) as pool:
results = pool.map(processing_function, dataset)
return datasets.Dataset.from_list(results)
2. 缓存机制
import diskcache
cache = diskcache.Cache('~/.verl_cache')
@cache.memoize()
def expensive_processing(data):
"""昂贵的处理操作,使用缓存"""
# 复杂处理逻辑
return processed_data
3. 增量处理
对于持续更新的数据集,支持增量处理:
def incremental_processing(existing_data, new_data):
"""增量数据处理"""
# 检查新数据是否已存在
existing_ids = {item['extra_info']['index'] for item in existing_data}
new_items = [item for item in new_data if item['extra_info']['index'] not in existing_ids]
return existing_data + new_items
总结
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



