verl数据预处理:训练数据集构建与清洗

verl数据预处理:训练数据集构建与清洗

【免费下载链接】verl verl: Volcano Engine Reinforcement Learning for LLMs 【免费下载链接】verl 项目地址: https://gitcode.com/GitHub_Trending/ve/verl

概述

verl(Volcano Engine Reinforcement Learning)是一个专为大语言模型(LLM)设计的强化学习框架,数据预处理是其训练流程中的关键环节。本文将深入探讨verl的数据预处理机制,涵盖数据集构建、格式转换、质量清洗等核心内容,帮助开发者高效准备训练数据。

数据预处理架构

verl的数据预处理遵循统一的架构模式,主要包含以下核心组件:

mermaid

核心数据字段解析

1. data_source字段

标识数据集的来源,用于在RewardModel中索引对应的奖励函数。

data_source = "openai/gsm8k"  # 数据集标识

2. prompt字段

遵循HuggingFace chat_template格式,支持多轮对话结构:

prompt = [
    {
        "role": "user",
        "content": "数学问题内容 Let's think step by step and output the final answer after '####'."
    }
]

3. ability字段

定义任务类别,用于区分不同的能力维度:

能力类型描述适用数据集
math数学推理能力GSM8K, MATH
alignment对齐能力Full_hh_rlhf
commonsense常识推理Hellaswag

4. reward_model字段

包含奖励模型相关配置,目前主要使用ground_truth字段:

reward_model = {
    "style": "rule",  # 或 "model"
    "ground_truth": "42"  # 提取的正确答案
}

5. extra_info字段

记录额外的元数据信息:

extra_info = {
    "split": "train",  # 数据分割
    "index": 123,      # 数据索引
    "question": "原始问题",  # 原始问题文本
    "answer": "完整答案"    # 完整答案文本
}

数据预处理实战

GSM8K数据集处理

GSM8K是一个经典的数学推理数据集,verl提供了完整的预处理脚本:

import re
import os
import datasets
from verl.utils.hdfs_io import copy, makedirs

def extract_solution(solution_str):
    """从答案字符串中提取最终解"""
    solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
    assert solution is not None
    final_solution = solution.group(0)
    final_solution = final_solution.split("#### ")[1].replace(",", "")
    return final_solution

def make_map_fn(split):
    """构建数据映射函数"""
    def process_fn(example, idx):
        question_raw = example.pop("question")
        question = question_raw + " Let's think step by step and output the final answer after '####'."
        
        answer_raw = example.pop("answer")
        solution = extract_solution(answer_raw)
        
        return {
            "data_source": "openai/gsm8k",
            "prompt": [{"role": "user", "content": question}],
            "ability": "math",
            "reward_model": {"style": "rule", "ground_truth": solution},
            "extra_info": {
                "split": split,
                "index": idx,
                "answer": answer_raw,
                "question": question_raw
            }
        }
    return process_fn

Search-R1类数据集处理

对于需要工具调用的复杂任务,数据预处理需要额外的工具参数配置:

def process_single_row(row, current_split_name, row_index):
    """处理Search-R1格式数据"""
    question = row.get("question", "")
    
    # 构建工具调用参数
    tools_kwargs = {
        "search": {
            "create_kwargs": {
                "ground_truth": ground_truth,
                "question": question,
                "data_source": "searchR1_dataset"
            }
        }
    }
    
    return {
        "data_source": "searchR1_tagged",
        "prompt": prompt_structure,
        "ability": row.get("ability"),
        "reward_model": reward_model_data,
        "extra_info": {
            "index": row_index,
            "need_tools_kwargs": True,
            "tools_kwargs": tools_kwargs
        }
    }

数据质量清洗策略

1. 答案提取验证

使用正则表达式确保答案格式正确:

def validate_solution_extraction(solution_str):
    """验证答案提取的有效性"""
    pattern = r"#### (\-?[0-9\.\,]+)"
    match = re.search(pattern, solution_str)
    if not match:
        raise ValueError(f"Invalid solution format: {solution_str}")
    return match.group(1).replace(",", "")

2. 数据完整性检查

确保所有必需字段都存在:

REQUIRED_FIELDS = ["data_source", "prompt", "ability", "reward_model"]

def validate_data_integrity(data_item):
    """验证数据完整性"""
    missing_fields = [field for field in REQUIRED_FIELDS if field not in data_item]
    if missing_fields:
        raise ValueError(f"Missing required fields: {missing_fields}")
    
    # 验证prompt格式
    if not isinstance(data_item["prompt"], list) or len(data_item["prompt"]) == 0:
        raise ValueError("Prompt must be a non-empty list")

3. 异常值处理

处理数值异常和格式问题:

def sanitize_numerical_value(value):
    """清理数值数据"""
    try:
        # 移除逗号并转换为浮点数
        cleaned = float(str(value).replace(",", ""))
        # 检查是否为有限数值
        if not math.isfinite(cleaned):
            return None
        return cleaned
    except (ValueError, TypeError):
        return None

分布式存储支持

verl支持本地和HDFS分布式存储,提供统一的数据访问接口:

HDFS操作示例

from verl.utils.hdfs_io import copy, makedirs, exists

# 创建HDFS目录
makedirs("hdfs://path/to/dataset")

# 检查文件是否存在
if exists("hdfs://path/to/dataset/train.parquet"):
    print("Dataset already exists in HDFS")

# 复制数据到HDFS
copy("~/data/gsm8k", "hdfs://path/to/dataset")

存储格式对比

存储方式优点缺点适用场景
本地存储访问速度快,配置简单容量有限,不易共享单机开发测试
HDFS存储分布式,高可用,易扩展配置复杂,网络依赖生产环境,多节点训练

多数据集支持

verl目前支持的主流数据集预处理:

数据集任务类型预处理脚本输出格式
GSM8K数学推理gsm8k.pyParquet
MATH数学推理math_dataset.pyParquet
Hellaswag常识推理hellaswag.pyParquet
Full_hh_rlhf对齐训练full_hh_rlhf.pyParquet
Search-R1工具调用preprocess_search_r1_dataset.pyParquet

最佳实践指南

1. 数据分区策略

# 按能力类型分区
ABILITY_PARTITIONS = {
    "math": ["gsm8k", "math"],
    "alignment": ["full_hh_rlhf"],
    "commonsense": ["hellaswag"]
}

# 按数据来源分区
SOURCE_PARTITIONS = {
    "openai": ["gsm8k"],
    "meta": ["hellaswag"],
    "huggingface": ["full_hh_rlhf"]
}

2. 内存优化处理

对于大规模数据集,使用迭代处理避免内存溢出:

def process_large_dataset(dataset, batch_size=1000):
    """分批处理大规模数据集"""
    results = []
    for i in range(0, len(dataset), batch_size):
        batch = dataset.select(range(i, min(i + batch_size, len(dataset))))
        processed_batch = batch.map(processing_function, batched=True)
        results.append(processed_batch)
    return concatenate_datasets(results)

3. 错误处理与重试机制

import time
from functools import wraps

def retry_on_failure(max_retries=3, delay=1):
    """失败重试装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for attempt in range(max_retries):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    if attempt == max_retries - 1:
                        raise
                    time.sleep(delay * (2 ** attempt))
            return func(*args, **kwargs)
        return wrapper
    return decorator

@retry_on_failure()
def safe_data_processing(data_item):
    """安全的数据处理函数"""
    # 处理逻辑
    pass

性能优化技巧

1. 并行处理

from multiprocessing import Pool
import datasets

def parallel_process_dataset(dataset, num_processes=4):
    """并行处理数据集"""
    with Pool(num_processes) as pool:
        results = pool.map(processing_function, dataset)
    return datasets.Dataset.from_list(results)

2. 缓存机制

import diskcache

cache = diskcache.Cache('~/.verl_cache')

@cache.memoize()
def expensive_processing(data):
    """昂贵的处理操作,使用缓存"""
    # 复杂处理逻辑
    return processed_data

3. 增量处理

对于持续更新的数据集,支持增量处理:

def incremental_processing(existing_data, new_data):
    """增量数据处理"""
    # 检查新数据是否已存在
    existing_ids = {item['extra_info']['index'] for item in existing_data}
    new_items = [item for item in new_data if item['extra_info']['index'] not in existing_ids]
    
    return existing_data + new_items

总结

【免费下载链接】verl verl: Volcano Engine Reinforcement Learning for LLMs 【免费下载链接】verl 项目地址: https://gitcode.com/GitHub_Trending/ve/verl

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值