verl SGLang多轮代码走读:端到端实现解析
引言
在大语言模型(LLM)的强化学习训练中,多轮对话(Multi-turn Conversation)是实现复杂推理和工具调用能力的关键技术。verl(Volcano Engine Reinforcement Learning)框架通过集成SGLang推理引擎,提供了强大的多轮对话rollout能力。本文将深入解析verl中SGLang多轮对话的端到端实现,帮助开发者理解其核心架构和工作原理。
整体架构概览
verl的SGLang多轮对话系统采用分层架构设计,主要包含以下核心组件:
核心类关系
多轮对话状态机
SGLang多轮对话的核心是一个精心设计的状态机,管理着对话的完整生命周期:
核心代码实现解析
1. SGLangRollout初始化
SGLangRollout类的初始化过程负责建立分布式环境、验证配置、初始化推理引擎和工具系统:
def __init__(self, actor_module, config, processing_class, model_hf_config, **kwargs):
super().__init__()
self.config = config
# 初始化工具系统
(self._tool_schemas, self._tool_map, self._tool_call_parser_type,
self._sgl_tools, self._function_call_parser) = self._initialize_tools(config, processing_class)
# 初始化交互系统
self.interaction_map = self._initialize_interactions(config)
# 建立分布式环境
self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs)
# 验证模型配置
self._verify_config(model_hf_config=model_hf_config)
# 初始化SGLang推理引擎
self._init_inference_engine(trust_remote_code, actor_module, port)
# 初始化采样参数
self._init_sampling_params(**kwargs)
2. 多轮对话核心循环
_async_rollout_a_request方法是多轮对话的核心,实现了完整的状态机逻辑:
async def _async_rollout_a_request(self, req, do_sample=True, is_validate=False, **kwargs):
current_turns = 0
user_turns = 0
user_turn_rewards = []
while current_turns < self.config.multi_turn.max_assistant_turns:
if req.state == AsyncRolloutRequestStateEnum.PENDING:
await self._handle_pending_state(req)
req.state = AsyncRolloutRequestStateEnum.RUNNING
elif req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:
# 处理工具调用
parsed_tool_calls = req.messages[-1].tool_calls
tool_call_results = await asyncio.gather(*[
self._tool_map[tool_call.function.name].execute(
req.request_id,
tool_call.function.arguments,
**req.tools_kwargs.get(tool_call.function.name, {}).get("execute_kwargs", {})
)
for tool_call in parsed_tool_calls
])
req.add_tool_response_messages(self.processing_class, [resp for resp, _, _ in tool_call_results])
req.state = AsyncRolloutRequestStateEnum.RUNNING
elif req.state == AsyncRolloutRequestStateEnum.RUNNING:
# 生成模型响应
output = await self._handle_engine_call(req, request_sampling_params)
content = output["text"]
finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"])
if self._function_call_parser and self._function_call_parser.has_tool_call(content):
# 检测到工具调用
req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING
normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)
req.add_assistant_message(self.processing_class, normed_content, tool_calls=tool_calls)
else:
req.add_assistant_message(self.processing_class, content)
elif req.state == AsyncRolloutRequestStateEnum.INTERACTING:
# 处理用户交互
user_turns += 1
interaction_name = req.interaction_kwargs.get("name", "gsm8k")
interaction = self.interaction_map[interaction_name]
should_terminate_sequence, content, reward, metrics = await interaction.generate_response(
req.request_id, messages, **req.interaction_kwargs
)
user_turn_rewards.append(reward)
if should_terminate_sequence:
break
else:
req.add_user_message(self.processing_class, content)
req.state = AsyncRolloutRequestStateEnum.RUNNING
3. 工具系统实现
verl的工具系统基于抽象的BaseTool类,支持灵活的扩展:
class BaseTool:
"""工具基类,定义标准接口"""
async def create(self, instance_id=None, **kwargs) -> tuple[str, ToolResponse]:
"""创建工具实例"""
if instance_id is None:
return str(uuid4()), ToolResponse()
return instance_id, ToolResponse()
async def execute(self, instance_id, parameters, **kwargs) -> tuple[ToolResponse, float, dict]:
"""执行工具调用"""
return ToolResponse(text="Updated the tool state."), 0.0, {}
async def calc_reward(self, instance_id, **kwargs) -> float:
"""计算奖励"""
return 0.0
async def release(self, instance_id, **kwargs) -> None:
"""释放工具实例"""
pass
以GSM8K数学工具为例的具体实现:
class Gsm8kTool(BaseTool):
"""GSM8K数学问题评估工具"""
async def execute(self, instance_id, parameters, **kwargs) -> tuple[ToolResponse, float, dict]:
answer = parameters.get("answer", "")
if answer.startswith("#### "):
self._instance_dict[instance_id]["response"] = answer
else:
self._instance_dict[instance_id]["response"] = "#### " + answer
reward = await self.calc_reward(instance_id)
tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05
self._instance_dict[instance_id]["reward"] = reward
return ToolResponse(text=f"Current parsed {answer=} {reward=}"), tool_reward, {}
async def calc_reward(self, instance_id, **kwargs) -> float:
return gsm8k.compute_score(
self._instance_dict[instance_id]["response"],
self._instance_dict[instance_id]["ground_truth"],
method="flexible",
format_score=0.0,
score=1.0,
)
4. 分布式处理与性能优化
verl在多轮对话中采用了多种性能优化策略:
请求级并行处理
def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
"""请求级并行生成,支持工具调用控制"""
if self._tp_rank == 0:
req_list = self._preprocess_prompt_to_async_rollout_requests(prompts)
# 训练模式:支持提前终止优化
total_requests = len(req_list)
target_completion = int(total_requests * (1 - self.config.get("over_sample_rate", 0.0)))
# 使用asyncio.gather并行处理请求
output_req_list = loop.run_until_complete(run_with_cancellation())
# 分布式广播结果
[sorted_output_req_list] = broadcast_pyobj(...)
内存优化策略
# 支持多阶段唤醒,减少内存占用
actor_rollout_ref.rollout.multi_stage_wake_up = True
# GPU内存利用率控制
actor_rollout_ref.rollout.gpu_memory_utilization = 0.85
# 权重更新内存优化
actor_rollout_ref.rollout.update_weights_bucket_megabytes = 512
配置示例与最佳实践
多轮对话配置
actor_rollout_ref:
rollout:
name: sglang
multi_turn:
enable: True
max_assistant_turns: 5
tokenization_sanity_check_mode: "ignore_strippable"
tool_kwargs:
tools_config_file: examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml
工具配置示例
tools:
- class_name: "verl.tools.gsm8k_tool.Gsm8kTool"
config:
type: native
tool_schema:
type: "function"
function:
name: "calc_gsm8k_reward"
description: "GSM8K数学问题评估工具"
parameters:
type: "object"
properties:
answer:
type: "string"
description: "模型对GSM8K数学问题的答案"
required: ["answer"]
运行脚本示例
# 8 GPU训练脚本
bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh
# 4 GPU训练脚本
bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh
关键技术挑战与解决方案
1. 分词一致性挑战
多轮对话中的分词一致性是一个重要挑战。verl采用基于差值的分词策略:
# 差值分词策略确保只对助手生成的内容计算损失
prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False)
curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False)
token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False)
loss_mask += [1] * len(token_ids) # 只掩码新的助手token
2. 多模态支持
verl支持图像和视频等多模态输入:
async def execute(self, instance_id, parameters, **kwargs) -> Tuple[str | Dict[str, Any], float, dict]:
from verl.utils.dataset.vision_utils import process_image, process_video
img1 = process_image(img1)
video1 = process_video(video1)
# 由于vllm中使用("image"|"video")而不是("images"|"videos")
return ToolResponse(image=[img1, ...], video=[video1, ...], text="..."), 0, {}
3. 分布式一致性
通过广播机制确保所有TP rank获得一致的结果:
# 广播处理结果
[sorted_output_req_list] = broadcast_pyobj(
data=[sorted_output_req_list],
rank=self._rank,
dist_group=self._device_mesh_cpu["tp"].get_group(),
src=self._device_mesh_cpu["tp"].mesh[0].item(),
force_cpu_device=False,
)
性能调优建议
内存优化配置
| 配置项 | 推荐值 | 说明 |
|---|---|---|
gpu_memory_utilization | 0.85 | GPU内存利用率 |
update_weights_bucket_megabytes | 512 | 权重更新内存桶大小 |
multi_stage_wake_up | True | 多阶段唤醒优化 |
批量处理优化
data:
train_batch_size: 256
max_prompt_length: 1024
max_response_length: 1024
filter_overlong_prompts: True
actor_rollout_ref:
actor:
ppo_mini_batch_size: 256
ppo_micro_batch_size_per_gpu: 32
rollout:
log_prob_micro_batch_size_per_gpu: 32
总结
verl的SGLang多轮对话系统通过精心设计的架构和实现,提供了强大的多轮对话和工具调用能力。其核心特点包括:
- 完整的状态机管理:支持PENDING、RUNNING、TOOL_CALLING、INTERACTING、COMPLETED等多种状态
- 灵活的工具系统:基于BaseTool抽象类,支持各种自定义工具
- 分布式优化:支持多节点训练和性能优化
- 多模态支持:完整支持图像、视频等多模态输入
- 性能调优:提供多种内存和计算优化策略
通过深入理解verl SGLang多轮对话的实现细节,开发者可以更好地利用这一强大框架进行大语言模型的强化学习训练,实现复杂的多轮推理和工具调用能力。
注意:本文基于verl最新代码分析,具体实现可能随版本更新而变化。建议参考官方文档和示例代码获取最新信息。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



