verl SGLang多轮代码走读:端到端实现解析

verl SGLang多轮代码走读:端到端实现解析

【免费下载链接】verl verl: Volcano Engine Reinforcement Learning for LLMs 【免费下载链接】verl 项目地址: https://gitcode.com/GitHub_Trending/ve/verl

引言

在大语言模型(LLM)的强化学习训练中,多轮对话(Multi-turn Conversation)是实现复杂推理和工具调用能力的关键技术。verl(Volcano Engine Reinforcement Learning)框架通过集成SGLang推理引擎,提供了强大的多轮对话rollout能力。本文将深入解析verl中SGLang多轮对话的端到端实现,帮助开发者理解其核心架构和工作原理。

整体架构概览

verl的SGLang多轮对话系统采用分层架构设计,主要包含以下核心组件:

mermaid

核心类关系

mermaid

多轮对话状态机

SGLang多轮对话的核心是一个精心设计的状态机,管理着对话的完整生命周期:

mermaid

核心代码实现解析

1. SGLangRollout初始化

SGLangRollout类的初始化过程负责建立分布式环境、验证配置、初始化推理引擎和工具系统:

def __init__(self, actor_module, config, processing_class, model_hf_config, **kwargs):
    super().__init__()
    self.config = config
    
    # 初始化工具系统
    (self._tool_schemas, self._tool_map, self._tool_call_parser_type, 
     self._sgl_tools, self._function_call_parser) = self._initialize_tools(config, processing_class)
    
    # 初始化交互系统
    self.interaction_map = self._initialize_interactions(config)
    
    # 建立分布式环境
    self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs)
    
    # 验证模型配置
    self._verify_config(model_hf_config=model_hf_config)
    
    # 初始化SGLang推理引擎
    self._init_inference_engine(trust_remote_code, actor_module, port)
    
    # 初始化采样参数
    self._init_sampling_params(**kwargs)

2. 多轮对话核心循环

_async_rollout_a_request方法是多轮对话的核心,实现了完整的状态机逻辑:

async def _async_rollout_a_request(self, req, do_sample=True, is_validate=False, **kwargs):
    current_turns = 0
    user_turns = 0
    user_turn_rewards = []
    
    while current_turns < self.config.multi_turn.max_assistant_turns:
        if req.state == AsyncRolloutRequestStateEnum.PENDING:
            await self._handle_pending_state(req)
            req.state = AsyncRolloutRequestStateEnum.RUNNING
            
        elif req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:
            # 处理工具调用
            parsed_tool_calls = req.messages[-1].tool_calls
            tool_call_results = await asyncio.gather(*[
                self._tool_map[tool_call.function.name].execute(
                    req.request_id,
                    tool_call.function.arguments,
                    **req.tools_kwargs.get(tool_call.function.name, {}).get("execute_kwargs", {})
                )
                for tool_call in parsed_tool_calls
            ])
            req.add_tool_response_messages(self.processing_class, [resp for resp, _, _ in tool_call_results])
            req.state = AsyncRolloutRequestStateEnum.RUNNING
            
        elif req.state == AsyncRolloutRequestStateEnum.RUNNING:
            # 生成模型响应
            output = await self._handle_engine_call(req, request_sampling_params)
            content = output["text"]
            finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"])
            
            if self._function_call_parser and self._function_call_parser.has_tool_call(content):
                # 检测到工具调用
                req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING
                normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)
                req.add_assistant_message(self.processing_class, normed_content, tool_calls=tool_calls)
            else:
                req.add_assistant_message(self.processing_class, content)
                
        elif req.state == AsyncRolloutRequestStateEnum.INTERACTING:
            # 处理用户交互
            user_turns += 1
            interaction_name = req.interaction_kwargs.get("name", "gsm8k")
            interaction = self.interaction_map[interaction_name]
            should_terminate_sequence, content, reward, metrics = await interaction.generate_response(
                req.request_id, messages, **req.interaction_kwargs
            )
            user_turn_rewards.append(reward)
            
            if should_terminate_sequence:
                break
            else:
                req.add_user_message(self.processing_class, content)
                req.state = AsyncRolloutRequestStateEnum.RUNNING

3. 工具系统实现

verl的工具系统基于抽象的BaseTool类,支持灵活的扩展:

class BaseTool:
    """工具基类,定义标准接口"""
    
    async def create(self, instance_id=None, **kwargs) -> tuple[str, ToolResponse]:
        """创建工具实例"""
        if instance_id is None:
            return str(uuid4()), ToolResponse()
        return instance_id, ToolResponse()

    async def execute(self, instance_id, parameters, **kwargs) -> tuple[ToolResponse, float, dict]:
        """执行工具调用"""
        return ToolResponse(text="Updated the tool state."), 0.0, {}

    async def calc_reward(self, instance_id, **kwargs) -> float:
        """计算奖励"""
        return 0.0

    async def release(self, instance_id, **kwargs) -> None:
        """释放工具实例"""
        pass

以GSM8K数学工具为例的具体实现:

class Gsm8kTool(BaseTool):
    """GSM8K数学问题评估工具"""
    
    async def execute(self, instance_id, parameters, **kwargs) -> tuple[ToolResponse, float, dict]:
        answer = parameters.get("answer", "")
        if answer.startswith("#### "):
            self._instance_dict[instance_id]["response"] = answer
        else:
            self._instance_dict[instance_id]["response"] = "#### " + answer
        
        reward = await self.calc_reward(instance_id)
        tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05
        self._instance_dict[instance_id]["reward"] = reward
        
        return ToolResponse(text=f"Current parsed {answer=} {reward=}"), tool_reward, {}

    async def calc_reward(self, instance_id, **kwargs) -> float:
        return gsm8k.compute_score(
            self._instance_dict[instance_id]["response"],
            self._instance_dict[instance_id]["ground_truth"],
            method="flexible",
            format_score=0.0,
            score=1.0,
        )

4. 分布式处理与性能优化

verl在多轮对话中采用了多种性能优化策略:

请求级并行处理
def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
    """请求级并行生成,支持工具调用控制"""
    if self._tp_rank == 0:
        req_list = self._preprocess_prompt_to_async_rollout_requests(prompts)
        
        # 训练模式:支持提前终止优化
        total_requests = len(req_list)
        target_completion = int(total_requests * (1 - self.config.get("over_sample_rate", 0.0)))
        
        # 使用asyncio.gather并行处理请求
        output_req_list = loop.run_until_complete(run_with_cancellation())
    
    # 分布式广播结果
    [sorted_output_req_list] = broadcast_pyobj(...)
内存优化策略
# 支持多阶段唤醒,减少内存占用
actor_rollout_ref.rollout.multi_stage_wake_up = True

# GPU内存利用率控制
actor_rollout_ref.rollout.gpu_memory_utilization = 0.85

# 权重更新内存优化
actor_rollout_ref.rollout.update_weights_bucket_megabytes = 512

配置示例与最佳实践

多轮对话配置

actor_rollout_ref:
  rollout:
    name: sglang
    multi_turn:
      enable: True
      max_assistant_turns: 5
      tokenization_sanity_check_mode: "ignore_strippable"
    tool_kwargs:
      tools_config_file: examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml

工具配置示例

tools:
  - class_name: "verl.tools.gsm8k_tool.Gsm8kTool"
    config: 
      type: native
    tool_schema:
      type: "function"
      function:
        name: "calc_gsm8k_reward"
        description: "GSM8K数学问题评估工具"
        parameters:
          type: "object"
          properties:
            answer:
              type: "string"
              description: "模型对GSM8K数学问题的答案"
          required: ["answer"]

运行脚本示例

# 8 GPU训练脚本
bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh

# 4 GPU训练脚本  
bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh

关键技术挑战与解决方案

1. 分词一致性挑战

多轮对话中的分词一致性是一个重要挑战。verl采用基于差值的分词策略

# 差值分词策略确保只对助手生成的内容计算损失
prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False)
curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False)
token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False)
loss_mask += [1] * len(token_ids)  # 只掩码新的助手token

2. 多模态支持

verl支持图像和视频等多模态输入:

async def execute(self, instance_id, parameters, **kwargs) -> Tuple[str | Dict[str, Any], float, dict]:
    from verl.utils.dataset.vision_utils import process_image, process_video
    
    img1 = process_image(img1)
    video1 = process_video(video1)
    
    # 由于vllm中使用("image"|"video")而不是("images"|"videos")
    return ToolResponse(image=[img1, ...], video=[video1, ...], text="..."), 0, {}

3. 分布式一致性

通过广播机制确保所有TP rank获得一致的结果:

# 广播处理结果
[sorted_output_req_list] = broadcast_pyobj(
    data=[sorted_output_req_list],
    rank=self._rank,
    dist_group=self._device_mesh_cpu["tp"].get_group(),
    src=self._device_mesh_cpu["tp"].mesh[0].item(),
    force_cpu_device=False,
)

性能调优建议

内存优化配置

配置项推荐值说明
gpu_memory_utilization0.85GPU内存利用率
update_weights_bucket_megabytes512权重更新内存桶大小
multi_stage_wake_upTrue多阶段唤醒优化

批量处理优化

data:
  train_batch_size: 256
  max_prompt_length: 1024
  max_response_length: 1024
  filter_overlong_prompts: True

actor_rollout_ref:
  actor:
    ppo_mini_batch_size: 256
    ppo_micro_batch_size_per_gpu: 32
  rollout:
    log_prob_micro_batch_size_per_gpu: 32

总结

verl的SGLang多轮对话系统通过精心设计的架构和实现,提供了强大的多轮对话和工具调用能力。其核心特点包括:

  1. 完整的状态机管理:支持PENDING、RUNNING、TOOL_CALLING、INTERACTING、COMPLETED等多种状态
  2. 灵活的工具系统:基于BaseTool抽象类,支持各种自定义工具
  3. 分布式优化:支持多节点训练和性能优化
  4. 多模态支持:完整支持图像、视频等多模态输入
  5. 性能调优:提供多种内存和计算优化策略

通过深入理解verl SGLang多轮对话的实现细节,开发者可以更好地利用这一强大框架进行大语言模型的强化学习训练,实现复杂的多轮推理和工具调用能力。

注意:本文基于verl最新代码分析,具体实现可能随版本更新而变化。建议参考官方文档和示例代码获取最新信息。

【免费下载链接】verl verl: Volcano Engine Reinforcement Learning for LLMs 【免费下载链接】verl 项目地址: https://gitcode.com/GitHub_Trending/ve/verl

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值