流式的话chunk的投机推理metric 怎么更新的, 一个请求多个chunk, 最后chunk 的 记录怎么更新的, 请举例说明
流式生成中投机解码指标的更新机制
在流式生成过程中,投机解码指标会随着每个chunk逐步累积更新,最终在最后一个chunk中包含完整的统计信息。
指标更新的核心机制
1. spec_verify_ct 的实时更新
每次投机解码验证后,spec_verify_ct 会立即递增: 1
# 在 EAGLE 投机解码的验证过程中
req.spec_verify_ct += 1
req.spec_accepted_tokens += (sum(1 for idx in accept_index_row if idx != -1) - 1)
2. 流式响应中的指标传递
在 stream_output_generation 函数中,每个chunk都会包含当前累积的指标: 2
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
spec_accepted_tokens.append(req.spec_accepted_tokens)
具体更新示例
假设一个请求生成了3个chunks的完整流程:
Chunk 1 - 第一次验证后
chunk1 = {
"text": "Hello",
"meta_info": {
"spec_verify_ct": 1, # 第一次验证
"spec_accepted_tokens": 2, # 接受了2个token
"completion_tokens": 2,
"finish_reason": None
}
}
Chunk 2 - 第二次验证后
chunk2 = {
"text": "Hello, world",
"meta_info": {
"spec_verify_ct": 2, # 第二次验证
"spec_accepted_tokens": 5, # 累积接受5个token
"completion_tokens": 5,
"finish_reason": None
}
}
Chunk 3 - 最后一次验证(完成)
chunk3 = {
"text": "Hello, world!",
"meta_info": {
"spec_verify_ct": 3, # 最终验证次数
"spec_accepted_tokens": 6, # 最终接受token数
"completion_tokens": 6,
"finish_reason": "stop"
}
}
spec_accept_length 的计算
spec_accept_length 在请求完成时通过 _calculate_spec_decoding_metrics 函数计算: 3
# 计算平均接受长度
meta_info["spec_accept_length"] = (
recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
)
对于上面的例子:
- 最终
spec_accept_length= 6 / 3 = 2.0
关键特点
- 累积性:每个chunk的指标都是累积值,不是增量
- 实时性:
spec_verify_ct在每次验证后立即更新 - 最终计算:
spec_accept_length只在请求完成时计算 - 一致性:所有并行采样的chunks都独立维护自己的指标 4
Notes
- 投机解码指标仅在启用
--speculative-algorithm时出现 - 流式响应中的每个chunk都包含完整的当前指标状态
- 最后一个chunk包含该请求的最终完整统计信息
- 这些指标通过
SpeculativeDecodingMetricsMixin在批处理输出中传递 5
Wiki pages you might want to explore:
Citations
File: python/sglang/srt/speculative/eagle_info.py (L398-401)
req.spec_verify_ct += 1
req.spec_accepted_tokens += (
sum(1 for idx in accept_index_row if idx != -1) - 1
)
File: python/sglang/srt/managers/scheduler_output_processor_mixin.py (L873-876)
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
spec_accepted_tokens.append(req.spec_accepted_tokens)
File: python/sglang/srt/managers/tokenizer_manager.py (L1837-1877)
def _calculate_spec_decoding_metrics(
self,
meta_info: Dict[str, Any],
recv_obj: Union[
BatchStrOutput,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchTokenIDOutput,
],
i: int,
) -> None:
"""Calculate speculative decoding metrics, such as acceptance rate and acceptance length metrics."""
meta_info["spec_accept_rate"] = 0.0
meta_info["spec_accept_length"] = 0
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
# The draft tokens per speculative step (excluding the target-sampled token).
num_guess_tokens = self.server_args.speculative_num_draft_tokens - 1
if (
recv_obj.spec_verify_ct[i] > 0
and num_guess_tokens is not None
and not isinstance(recv_obj, BatchEmbeddingOutput)
and hasattr(recv_obj, "spec_accepted_tokens")
# Checks that `spec_accepted_tokens[i]` will exist.
and len(recv_obj.spec_accepted_tokens) > i
):
total_draft_tokens = recv_obj.spec_verify_ct[i] * num_guess_tokens
accepted_tokens = recv_obj.spec_accepted_tokens[i]
# Calculate per-request acceptance rate and average acceptance length.
if total_draft_tokens > 0:
# Calculate acceptance rate: accepted / (steps * lookahead)
meta_info["spec_accept_rate"] = accepted_tokens / total_draft_tokens
meta_info["spec_accept_length"] = (
recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
)
meta_info["spec_accept_token_num"] = accepted_tokens
meta_info["spec_draft_token_num"] = total_draft_tokens
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
File: python/sglang/srt/managers/io_struct.py (L97-111)
@dataclass
class SpeculativeDecodingMetricsMixin:
"""
Mixin class containing speculative decoding metrics.
This class consolidates speculative decoding metrics that are shared across
batch output types that support speculative decoding to avoid code duplication.
"""
# Verify count: number of verification forward passes
spec_verify_ct: List[int]
# Accepted tokens: Number of accepted tokens during speculative decoding
spec_accepted_tokens: List[int]
# `postprocess_output_ids` 与 `process_sglang_output` 的调用关系
这两个函数通过异步回调机制紧密连接,形成完整的 SGLang 输出处理链。
## 核心调用链
```mermaid
sequenceDiagram
participant SGLang as SGLang引擎
participant Process as process_sglang_output
participant Callback as request_complete_callback
participant ActorWorker as ActorWorker
participant Scheduler as DynamicSamplingScheduler
participant PostProcess as postprocess_output_ids
SGLang->>Process: 返回chunks
Process->>Process: 处理chunks,提取指标
Process->>Callback: request_complete_callback(data)
Callback->>ActorWorker: request_complete(data)
ActorWorker->>Scheduler: response_call_back_fn(data)
Scheduler->>PostProcess: postprocess_output_ids(data)
PostProcess->>Scheduler: 返回处理后的DataProto
详细调用过程
1. SGLang输出处理
process_sglang_output 函数接收 SGLang 返回的 chunks,提取关键指标 1 :
def process_sglang_output(chunks, meta_info):
output_data = DataProto(meta_info=meta_info)
# 提取各种指标
output_token_ids = [chunk.get("output_ids", []) for chunk in chunks]
output_data.meta_info["output_token_ids"] = output_token_ids
output_data.meta_info["spec_accept_rate"] = [chunk["meta_info"].get("spec_accept_rate") for chunk in chunks]
# 关键:通过回调传递数据
request_complete_callback(data=output_data)
2. 回调函数传递
回调函数通过 ActorWorker.request_complete 方法执行 2 :
def request_complete(self, data: DataProto):
# 添加token信息
data.meta_info["eos_token_id"] = self.tokenizer.eos_token_id
data.meta_info["pad_token_id"] = self.tokenizer.pad_token_id
# 获取并执行回调函数
response_call_back_fn = self.response_call_back_fns.pop(data.meta_info["request_id"])
self.response_callback_refs.append(response_call_back_fn(data))
3. 调度器接收并处理
调度器的 report_response 方法接收数据并调用 postprocess_output_ids 3 :
@ray.method(concurrency_group="multi_thread")
def report_response(self, data: DataProto):
request_id = data.meta_info["request_id"]
# 关键:调用postprocess_output_ids处理SGLang返回的数据
batch = self.postprocess_output_ids(data)
# 后续处理...
4. 最终数据处理
postprocess_output_ids 将 SGLang 的输出转换为标准格式 4 :
def postprocess_output_ids(self, data: DataProto) -> DataProto:
request_id = data.meta_info["request_id"]
request: DataProto = self.requests_buffers.pop(request_id)
# 从SGLang输出中提取数据
output_token_ids = data.meta_info["output_token_ids"]
output_logprobs = data.meta_info.get("output_logprobs", None)
# 处理并返回标准格式的DataProto
output: DataProto = postprocess_generate(...)
return output
具体示例
假设一个 SGLang 请求完成:
-
SGLang 返回 chunks:
chunks = [ {"output_ids": [12, 34, 56], "meta_info": {"spec_accept_rate": 0.85, "finish_reason": "length"}}, {"output_ids": [13, 35, 57], "meta_info": {"spec_accept_rate": 0.92, "finish_reason": "eos"}} ] -
process_sglang_output 处理:
# 提取指标 output_data.meta_info["output_token_ids"] = [[12, 34, 56], [13, 35, 57]] output_data.meta_info["spec_accept_rate"] = [0.85, 0.92] # 调用回调 request_complete_callback(data=output_data) -
postprocess_output_ids 接收并转换:
# 接收来自process_sglang_output的数据 data.meta_info 包含: {"spec_accept_rate": [0.85, 0.92], ...} # 转换为标准格式并设置到non_tensor_batch output.non_tensor_batch["spec_accept_rate"] = np.array([0.85, 0.92], dtype=object)
关键连接点
- 回调机制:
request_complete_callback是连接两个函数的桥梁 5 - 数据传递:通过
DataProto.meta_info在函数间传递指标数据 - 异步处理:整个调用链是异步的,不阻塞主训练流程
Notes
- 在异步调度器中也有相同的调用模式 6
- v046post4_patch 版本简化了这个流程,但回调机制保持一致 7
- 你新增的 spec 指标需要确保在
postprocess_output_ids中正确传递到non_tensor_batch
Wiki pages you might want to explore:
Citations
File: roll/third_party/sglang/async_engine.py (L59-73)
output_token_ids = [chunk.get("output_ids", []) for chunk in chunks]
output_logprobs = [chunk["meta_info"].get("output_token_logprobs", None) for chunk in chunks]
has_logprobs = any(logprobs is not None for logprobs in output_logprobs)
if has_logprobs:
lens = [min(len(ids), len(logprobs)) for ids, logprobs in zip(output_token_ids, output_logprobs)]
output_token_ids = [ids[:l] for ids, l in zip(output_token_ids, lens)]
output_logprobs = [logprobs[:l] for logprobs, l in zip(output_logprobs, lens)]
output_logprobs = [[prob_info[0] for prob_info in logprobs] for logprobs in output_logprobs]
output_data.meta_info["output_logprobs"] = output_logprobs
assert all([len(ids) == len(logprobs) for ids, logprobs in zip(output_token_ids, output_logprobs)]), (
"output_token_ids and output_logprobs length not match"
)
output_data.meta_info["output_token_ids"] = output_token_ids
output_data.meta_info["finish_reasons"] = [chunk["meta_info"].get("finish_reason") for chunk in chunks]
request_complete_callback(data=output_data)
File: roll/distributed/scheduler/generate_scheduler.py (L581-593)
@ray.method(concurrency_group="multi_thread")
def report_response(self, data: DataProto):
"""
这里需要考虑多线程数据访问
data 返回可能有多条的
"""
try:
request_id = data.meta_info["request_id"]
prompt_id = self.request_id_2_prompt_id[request_id]
num_return_sequences = self.generation_config["num_return_sequences"]
batch = self.postprocess_output_ids(data)
output_count = batch.batch.batch_size[0]
File: roll/distributed/scheduler/generate_scheduler.py (L711-739)
def postprocess_output_ids(self, data: DataProto) -> DataProto:
# postprocess_generate, input_ids, attention_mask, left pad
request_id = data.meta_info["request_id"]
request: DataProto = self.requests_buffers.pop(request_id)
eos_token_id = data.meta_info["eos_token_id"]
pad_token_id = data.meta_info["pad_token_id"]
output_token_ids = data.meta_info["output_token_ids"]
output_tokens = [torch.tensor(token_ids) for token_ids in output_token_ids]
output_logprobs = data.meta_info.get("output_logprobs", None)
output_tensor = pad_sequence(output_tokens, batch_first=True, padding_value=pad_token_id)
output_tensor = concatenate_input_and_output(
input_ids=request.batch["input_ids"], output_ids=output_tensor, num_return_sequences=len(output_tokens)
)
output: DataProto = postprocess_generate(
prompts=request,
output=output_tensor,
num_return_sequences=len(output_tokens),
sequence_length=self.sequence_length,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
output_logprobs=output_logprobs,
)
request_repeat = request.repeat(repeat_times=len(output_tokens))
output.non_tensor_batch = request_repeat.non_tensor_batch
output.meta_info = request_repeat.meta_info
return output
File: roll/distributed/scheduler/async_generate_scheduler.py (L649-649)
batch = self.postprocess_output_ids(data) if is_finished else self.postprocess_paused_data(data)
File: roll/third_party/sglang/v046post4_patch/async_engine.py (L63-63)
request_complete_callback(data=output_data)
# `request_complete_callback` 与 `report_response` 的关系
`request_complete_callback` 是连接 SGLang 异步引擎和调度器 `report_response` 方法的桥梁函数,构成了完整的异步请求处理链。
## 核心关系
**`request_complete_callback`** 是一个回调函数引用,实际指向调度器的 `report_response` 方法 [1](#18-0) 。
**`report_response`** 是调度器中处理响应完成的方法,接收并处理 SGLang 返回的数据 [2](#18-1) 。
## 完整调用链
```mermaid
sequenceDiagram
participant SGLang as SGLang引擎
participant Process as process_sglang_output
participant ActorWorker as ActorWorker
participant Scheduler as DynamicSamplingScheduler
Note over SGLang: 异步生成完成
SGLang->>Process: 返回chunks
Process->>Process: 处理输出数据
Process->>ActorWorker: request_complete_callback(data)
ActorWorker->>ActorWorker: request_complete(data)
ActorWorker->>Scheduler: response_call_back_fn(data)
Scheduler->>Scheduler: report_response(data)
关键连接点
1. 回调函数注册
在请求发送时,将 report_response 作为回调存储 3 :
self.response_call_back_fns[data.meta_info["request_id"]] = data.meta_info.pop("response_callback_fn")
2. SGLang 输出触发回调
process_sglang_output 完成数据处理后调用回调 4 :
request_complete_callback(data=output_data)
3. ActorWorker 中转执行
ActorWorker.request_complete 获取并执行回调函数 5 :
response_call_back_fn = self.response_call_back_fns.pop(data.meta_info["request_id"])
self.response_callback_refs.append(response_call_back_fn(data))
4. 调度器最终处理
report_response 接收数据并进行后续处理 6 :
batch = self.postprocess_output_ids(data)
# 计算奖励、过滤等后续处理
设计模式
这种设计实现了生产者-消费者模式:
- 生产者: SGLang 异步生成结果
- 消费者: 调度器处理生成结果
- 桥梁:
request_complete_callback连接两者
Notes
- 整个过程是异步非阻塞的,不阻塞主训练流程
- 回调机制支持多个并发请求的独立处理
- 在异步调度器
AsyncDynamicSamplingScheduler中也有相同的模式 7 - 你新增的 spec 指标通过这个机制从 SGLang 传递到调度器的
report_response方法中
Wiki pages you might want to explore:
Citations
File: roll/distributed/scheduler/generate_scheduler.py (L581-675)
@ray.method(concurrency_group="multi_thread")
def report_response(self, data: DataProto):
"""
这里需要考虑多线程数据访问
data 返回可能有多条的
"""
try:
request_id = data.meta_info["request_id"]
prompt_id = self.request_id_2_prompt_id[request_id]
num_return_sequences = self.generation_config["num_return_sequences"]
batch = self.postprocess_output_ids(data)
output_count = batch.batch.batch_size[0]
with self.lock:
self.load_balance_coordinator[self.request_id_2_dp_rank[request_id]] -= 1
self.prompt_id_2_request_ids[prompt_id].remove(request_id)
domain = "default"
if "domain" in batch.non_tensor_batch.keys():
domain = batch.non_tensor_batch["domain"][0]
reward_worker = next(self.reward_worker_iters[domain])
if not self.running:
return
# call reward
# reward worker得能支持单条数据计算, dynamic sampling对需要batch计算reward的需要注意...
# 多域的时候,llm as judge, 需要单独为reward worker分配gpu
rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch))
batch.union(rewards)
response_buffers: List[DataProto] = []
batch_expanded = [batch[[idx]] for idx in range(output_count)]
# response_filter, 不太需要response filter
for batch_item in batch_expanded:
if self.response_filter_fn(batch_item, self.pipeline_config):
response_buffers.append(batch_item)
else:
self.response_filter_count += 1
with self.lock:
self.response_cache[domain].extend(batch_expanded)
if len(response_buffers) == 0:
if len(self.prompt_id_2_request_ids[prompt_id]) == 0:
self.running_prompts -= 1
return
if len(self.completed_buffers[prompt_id]) > 0:
return
# expand batch to response
self.query_group_buffers[prompt_id].extend(response_buffers)
# query_filter, query has n responses
if len(self.query_group_buffers[prompt_id]) >= num_return_sequences:
if not self.query_filter_fn(self.query_group_buffers[prompt_id], self.pipeline_config):
self.query_filter_count += 1
del self.query_group_buffers[prompt_id]
self.abort_requests(self.prompt_id_2_request_ids[prompt_id])
return
assert len(self.query_group_buffers[prompt_id]) >= num_return_sequences, (
f"expect to generate {num_return_sequences} results from one prompt, "
f"but get {len(self.query_group_buffers[prompt_id])}."
)
self.completed_buffers[prompt_id] = self.query_group_buffers[prompt_id][:num_return_sequences]
self.progress_bar.update()
if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")):
# report response level rewards
response_level_rewards = [data.batch["response_level_rewards"] for data in self.query_group_buffers[prompt_id]]
response_rewards = torch.cat(response_level_rewards, dim=0).long().cpu().tolist()
prompt_hash = self.prompt_id_2_hash_str.pop(prompt_id)
prompt_response_proto = DataProto.concat(self.query_group_buffers[prompt_id][:num_return_sequences])
# report response level lengths
response_lengths = torch.sum(prompt_response_proto.batch["response_mask"], dim=1).cpu().tolist()
lengths_and_rewards = {
'domain': domain,
'prompt_hash': prompt_hash,
'response_lengths': response_lengths,
'response_rewards': response_rewards
}
length_dir = os.path.join(self.pipeline_config.length_profiler_dir, "length")
os.makedirs(length_dir, exist_ok=True)
filename = f"response-length-and-rewards-{domain}-ep{self.dataset_epoch}.jsonl"
length_file_path = os.path.join(length_dir, filename)
with open(length_file_path, "a") as f:
f.write(json.dumps(lengths_and_rewards) + "\n")
# abort uncompleted request
self.abort_requests(self.prompt_id_2_request_ids[prompt_id])
except Exception as e:
self.exception_queue.put(e)
File: roll/pipeline/base_worker.py (L382-382)
self.response_call_back_fns[data.meta_info["request_id"]] = data.meta_info.pop("response_callback_fn")
File: roll/pipeline/base_worker.py (L386-390)
def request_complete(self, data: DataProto):
data.meta_info["eos_token_id"] = self.tokenizer.eos_token_id
data.meta_info["pad_token_id"] = self.tokenizer.pad_token_id
response_call_back_fn = self.response_call_back_fns.pop(data.meta_info["request_id"])
self.response_callback_refs.append(response_call_back_fn(data))
File: roll/third_party/sglang/async_engine.py (L73-73)
request_complete_callback(data=output_data)
一个请求多个采样的投机解码指标计算
当启用并行采样(n > 1)时,每个采样都被视为独立的请求,投机解码指标是分别计算的。
关键机制
1. 独立的请求ID和状态
并行采样时,每个采样生成独立的请求ID 1 :
def _normalize_rid(self, num):
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
elif isinstance(self.rid, str):
new_rids = [f"{self.rid}_{i}" for i in range(num)]
self.rid = new_rids
2. 独立的指标跟踪
每个采样(请求)在 Req 对象中独立跟踪投机解码指标 2 :
# 每个请求独立的投机解码指标
self.spec_verify_ct = 0
self.spec_accepted_tokens = 0
3. 分别计算指标
在 _handle_batch_output 中,遍历所有请求(采样),为每个完成的请求单独计算指标 3 :
for i, rid in enumerate(recv_obj.rids):
# ...
state.finished = recv_obj.finished_reasons[i] is not None
if state.finished:
if self.server_args.speculative_algorithm:
self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
执行流程
指标计算示例
假设一个请求有3个采样,投机解码执行情况:
| 采样 | spec_verify_ct | spec_accepted_tokens | 完成时间 | 计算的指标 |
|---|---|---|---|---|
| 采样0 | 5 | 12 | 第2个完成 | spec_accept_rate = 12/(5×4) = 0.60 |
| 采样1 | 3 | 10 | 第1个完成 | spec_accept_rate = 10/(3×4) = 0.83 |
| 采样2 | 7 | 14 | 第3个完成 | spec_accept_rate = 14/(7×4) = 0.50 |
每个采样的指标独立计算,互不影响。
Notes
- 独立性:每个采样有独立的
Req对象和投机解码指标 - 异步完成:采样可以不同时间完成,完成时立即计算指标
- 无共享:采样之间不共享投机解码状态,各自统计
- 流式输出:在流式模式下,先完成的采样会先返回包含投机解码指标的 chunk
Wiki pages you might want to explore:
Wiki pages you might want to explore:
Citations
File: python/sglang/srt/managers/io_struct.py (L460-474)
def _normalize_rid(self, num):
"""Normalize request IDs for batch processing."""
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
elif isinstance(self.rid, str):
new_rids = [f"{self.rid}_{i}" for i in range(num)]
self.rid = new_rids
elif isinstance(self.rid, list):
# Note: the length of rid shall be the same as the batch_size,
# as the rid would be expanded for parallel sampling in tokenizer_manager
if len(self.rid) != self.batch_size:
raise ValueError(
"The specified rids length mismatch with the batch_size for batch processing."
)
else:
File: python/sglang/srt/managers/schedule_batch.py (L646-652)
# The number of verification forward passes in the speculative decoding.
# This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0
# The number of accepted tokens in speculative decoding for this request.
# This is used to compute the acceptance rate and average acceptance length per request.
self.spec_accepted_tokens = 0
File: python/sglang/srt/managers/tokenizer_manager.py (L1575-1664)
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
logger.error(
f"Received output for {rid=} but the state was deleted in TokenizerManager."
)
continue
# Build meta_info and return value
meta_info = {
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version,
"total_retractions": recv_obj.retraction_counts[i],
}
if self.enable_metrics:
self._add_metric_if_present(recv_obj, "queue_time", meta_info, i)
self._add_metric_if_present(
recv_obj, "prefill_launch_delay", meta_info, i
)
self._add_metric_if_present(
recv_obj, "prefill_launch_latency", meta_info, i
)
if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style(
meta_info,
state,
state.obj.top_logprobs_num,
state.obj.token_ids_logprob,
state.obj.return_text_in_logprobs
and not self.server_args.skip_tokenizer_init,
recv_obj,
i,
)
if not isinstance(recv_obj, BatchEmbeddingOutput):
meta_info.update(
{
"completion_tokens": recv_obj.completion_tokens[i],
"cached_tokens": recv_obj.cached_tokens[i],
}
)
if getattr(recv_obj, "output_hidden_states", None):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
if isinstance(recv_obj, BatchStrOutput):
state.text += recv_obj.output_strs[i]
if state.obj.stream:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids[state.last_output_offset :]
state.last_output_offset = len(state.output_ids)
else:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids.copy()
out_dict = {
"text": state.text,
"output_ids": output_token_ids,
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchTokenIDOutput):
if self.server_args.stream_output and state.obj.stream:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids[state.last_output_offset :]
state.last_output_offset = len(state.output_ids)
else:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids.copy()
out_dict = {
"output_ids": output_token_ids,
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchMultimodalOutput):
raise NotImplementedError("BatchMultimodalOut not implemented")
else:
assert isinstance(recv_obj, BatchEmbeddingOutput)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": meta_info,
}
state.finished = recv_obj.finished_reasons[i] is not None
if state.finished:
if self.server_args.speculative_algorithm:
self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
1071

被折叠的 条评论
为什么被折叠?



