【RL】sglang 流式返回

流式的话chunk的投机推理metric 怎么更新的, 一个请求多个chunk, 最后chunk 的 记录怎么更新的, 请举例说明

流式生成中投机解码指标的更新机制

在流式生成过程中,投机解码指标会随着每个chunk逐步累积更新,最终在最后一个chunk中包含完整的统计信息。

指标更新的核心机制

1. spec_verify_ct 的实时更新

每次投机解码验证后,spec_verify_ct 会立即递增: 1

# 在 EAGLE 投机解码的验证过程中
req.spec_verify_ct += 1
req.spec_accepted_tokens += (sum(1 for idx in accept_index_row if idx != -1) - 1)
2. 流式响应中的指标传递

stream_output_generation 函数中,每个chunk都会包含当前累积的指标: 2

if not self.spec_algorithm.is_none():
    spec_verify_ct.append(req.spec_verify_ct)
    spec_accepted_tokens.append(req.spec_accepted_tokens)

具体更新示例

假设一个请求生成了3个chunks的完整流程:

Chunk 1 - 第一次验证后
chunk1 = {
    "text": "Hello",
    "meta_info": {
        "spec_verify_ct": 1,        # 第一次验证
        "spec_accepted_tokens": 2,  # 接受了2个token
        "completion_tokens": 2,
        "finish_reason": None
    }
}
Chunk 2 - 第二次验证后
chunk2 = {
    "text": "Hello, world",
    "meta_info": {
        "spec_verify_ct": 2,        # 第二次验证
        "spec_accepted_tokens": 5,  # 累积接受5个token
        "completion_tokens": 5,
        "finish_reason": None
    }
}
Chunk 3 - 最后一次验证(完成)
chunk3 = {
    "text": "Hello, world!",
    "meta_info": {
        "spec_verify_ct": 3,        # 最终验证次数
        "spec_accepted_tokens": 6,  # 最终接受token数
        "completion_tokens": 6,
        "finish_reason": "stop"
    }
}

spec_accept_length 的计算

spec_accept_length 在请求完成时通过 _calculate_spec_decoding_metrics 函数计算: 3

# 计算平均接受长度
meta_info["spec_accept_length"] = (
    recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
)

对于上面的例子:

  • 最终 spec_accept_length = 6 / 3 = 2.0

关键特点

  1. 累积性:每个chunk的指标都是累积值,不是增量
  2. 实时性spec_verify_ct 在每次验证后立即更新
  3. 最终计算spec_accept_length 只在请求完成时计算
  4. 一致性:所有并行采样的chunks都独立维护自己的指标 4

Notes

  • 投机解码指标仅在启用 --speculative-algorithm 时出现
  • 流式响应中的每个chunk都包含完整的当前指标状态
  • 最后一个chunk包含该请求的最终完整统计信息
  • 这些指标通过 SpeculativeDecodingMetricsMixin 在批处理输出中传递 5

Wiki pages you might want to explore:

Citations

File: python/sglang/srt/speculative/eagle_info.py (L398-401)

            req.spec_verify_ct += 1
            req.spec_accepted_tokens += (
                sum(1 for idx in accept_index_row if idx != -1) - 1
            )

File: python/sglang/srt/managers/scheduler_output_processor_mixin.py (L873-876)

                if not self.spec_algorithm.is_none():
                    spec_verify_ct.append(req.spec_verify_ct)
                    spec_accepted_tokens.append(req.spec_accepted_tokens)

File: python/sglang/srt/managers/tokenizer_manager.py (L1837-1877)

    def _calculate_spec_decoding_metrics(
        self,
        meta_info: Dict[str, Any],
        recv_obj: Union[
            BatchStrOutput,
            BatchEmbeddingOutput,
            BatchMultimodalOutput,
            BatchTokenIDOutput,
        ],
        i: int,
    ) -> None:
        """Calculate speculative decoding metrics, such as acceptance rate and acceptance length metrics."""
        meta_info["spec_accept_rate"] = 0.0
        meta_info["spec_accept_length"] = 0
        meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]

        # The draft tokens per speculative step (excluding the target-sampled token).
        num_guess_tokens = self.server_args.speculative_num_draft_tokens - 1

        if (
            recv_obj.spec_verify_ct[i] > 0
            and num_guess_tokens is not None
            and not isinstance(recv_obj, BatchEmbeddingOutput)
            and hasattr(recv_obj, "spec_accepted_tokens")
            # Checks that `spec_accepted_tokens[i]` will exist.
            and len(recv_obj.spec_accepted_tokens) > i
        ):
            total_draft_tokens = recv_obj.spec_verify_ct[i] * num_guess_tokens
            accepted_tokens = recv_obj.spec_accepted_tokens[i]

            # Calculate per-request acceptance rate and average acceptance length.
            if total_draft_tokens > 0:
                # Calculate acceptance rate: accepted / (steps * lookahead)
                meta_info["spec_accept_rate"] = accepted_tokens / total_draft_tokens
                meta_info["spec_accept_length"] = (
                    recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
                )
                meta_info["spec_accept_token_num"] = accepted_tokens
                meta_info["spec_draft_token_num"] = total_draft_tokens
                meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]

File: python/sglang/srt/managers/io_struct.py (L97-111)

@dataclass
class SpeculativeDecodingMetricsMixin:
    """
    Mixin class containing speculative decoding metrics.

    This class consolidates speculative decoding metrics that are shared across
    batch output types that support speculative decoding to avoid code duplication.
    """

    # Verify count: number of verification forward passes
    spec_verify_ct: List[int]

    # Accepted tokens: Number of accepted tokens during speculative decoding
    spec_accepted_tokens: List[int]


# `postprocess_output_ids` 与 `process_sglang_output` 的调用关系

这两个函数通过异步回调机制紧密连接,形成完整的 SGLang 输出处理链。

## 核心调用链

```mermaid
sequenceDiagram
    participant SGLang as SGLang引擎
    participant Process as process_sglang_output
    participant Callback as request_complete_callback
    participant ActorWorker as ActorWorker
    participant Scheduler as DynamicSamplingScheduler
    participant PostProcess as postprocess_output_ids
    
    SGLang->>Process: 返回chunks
    Process->>Process: 处理chunks,提取指标
    Process->>Callback: request_complete_callback(data)
    Callback->>ActorWorker: request_complete(data)
    ActorWorker->>Scheduler: response_call_back_fn(data)
    Scheduler->>PostProcess: postprocess_output_ids(data)
    PostProcess->>Scheduler: 返回处理后的DataProto

详细调用过程

1. SGLang输出处理

process_sglang_output 函数接收 SGLang 返回的 chunks,提取关键指标 1

def process_sglang_output(chunks, meta_info):
    output_data = DataProto(meta_info=meta_info)
    # 提取各种指标
    output_token_ids = [chunk.get("output_ids", []) for chunk in chunks]
    output_data.meta_info["output_token_ids"] = output_token_ids
    output_data.meta_info["spec_accept_rate"] = [chunk["meta_info"].get("spec_accept_rate") for chunk in chunks]
    # 关键:通过回调传递数据
    request_complete_callback(data=output_data)

2. 回调函数传递

回调函数通过 ActorWorker.request_complete 方法执行 2

def request_complete(self, data: DataProto):
    # 添加token信息
    data.meta_info["eos_token_id"] = self.tokenizer.eos_token_id
    data.meta_info["pad_token_id"] = self.tokenizer.pad_token_id
    # 获取并执行回调函数
    response_call_back_fn = self.response_call_back_fns.pop(data.meta_info["request_id"])
    self.response_callback_refs.append(response_call_back_fn(data))

3. 调度器接收并处理

调度器的 report_response 方法接收数据并调用 postprocess_output_ids 3

@ray.method(concurrency_group="multi_thread")
def report_response(self, data: DataProto):
    request_id = data.meta_info["request_id"]
    # 关键:调用postprocess_output_ids处理SGLang返回的数据
    batch = self.postprocess_output_ids(data)
    # 后续处理...

4. 最终数据处理

postprocess_output_ids 将 SGLang 的输出转换为标准格式 4

def postprocess_output_ids(self, data: DataProto) -> DataProto:
    request_id = data.meta_info["request_id"]
    request: DataProto = self.requests_buffers.pop(request_id)
    
    # 从SGLang输出中提取数据
    output_token_ids = data.meta_info["output_token_ids"]
    output_logprobs = data.meta_info.get("output_logprobs", None)
    
    # 处理并返回标准格式的DataProto
    output: DataProto = postprocess_generate(...)
    return output

具体示例

假设一个 SGLang 请求完成:

  1. SGLang 返回 chunks

    chunks = [
        {"output_ids": [12, 34, 56], "meta_info": {"spec_accept_rate": 0.85, "finish_reason": "length"}},
        {"output_ids": [13, 35, 57], "meta_info": {"spec_accept_rate": 0.92, "finish_reason": "eos"}}
    ]
    
  2. process_sglang_output 处理

    # 提取指标
    output_data.meta_info["output_token_ids"] = [[12, 34, 56], [13, 35, 57]]
    output_data.meta_info["spec_accept_rate"] = [0.85, 0.92]
    # 调用回调
    request_complete_callback(data=output_data)
    
  3. postprocess_output_ids 接收并转换

    # 接收来自process_sglang_output的数据
    data.meta_info 包含: {"spec_accept_rate": [0.85, 0.92], ...}
    # 转换为标准格式并设置到non_tensor_batch
    output.non_tensor_batch["spec_accept_rate"] = np.array([0.85, 0.92], dtype=object)
    

关键连接点

  • 回调机制request_complete_callback 是连接两个函数的桥梁 5
  • 数据传递:通过 DataProto.meta_info 在函数间传递指标数据
  • 异步处理:整个调用链是异步的,不阻塞主训练流程

Notes

  • 在异步调度器中也有相同的调用模式 6
  • v046post4_patch 版本简化了这个流程,但回调机制保持一致 7
  • 你新增的 spec 指标需要确保在 postprocess_output_ids 中正确传递到 non_tensor_batch

Wiki pages you might want to explore:

Citations

File: roll/third_party/sglang/async_engine.py (L59-73)

        output_token_ids = [chunk.get("output_ids", []) for chunk in chunks]
        output_logprobs = [chunk["meta_info"].get("output_token_logprobs", None) for chunk in chunks]
        has_logprobs = any(logprobs is not None for logprobs in output_logprobs)
        if has_logprobs:
            lens = [min(len(ids), len(logprobs)) for ids, logprobs in zip(output_token_ids, output_logprobs)]
            output_token_ids = [ids[:l] for ids, l in zip(output_token_ids, lens)]
            output_logprobs = [logprobs[:l] for logprobs, l in zip(output_logprobs, lens)]
            output_logprobs = [[prob_info[0] for prob_info in logprobs] for logprobs in output_logprobs]
            output_data.meta_info["output_logprobs"] = output_logprobs
            assert all([len(ids) == len(logprobs) for ids, logprobs in zip(output_token_ids, output_logprobs)]), (
                "output_token_ids and output_logprobs length not match"
            )
        output_data.meta_info["output_token_ids"] = output_token_ids
        output_data.meta_info["finish_reasons"] = [chunk["meta_info"].get("finish_reason") for chunk in chunks]
        request_complete_callback(data=output_data)

File: roll/distributed/scheduler/generate_scheduler.py (L581-593)

    @ray.method(concurrency_group="multi_thread")
    def report_response(self, data: DataProto):
        """
        这里需要考虑多线程数据访问
        data 返回可能有多条的
        """
        try:
            request_id = data.meta_info["request_id"]
            prompt_id = self.request_id_2_prompt_id[request_id]
            num_return_sequences = self.generation_config["num_return_sequences"]

            batch = self.postprocess_output_ids(data)
            output_count = batch.batch.batch_size[0]

File: roll/distributed/scheduler/generate_scheduler.py (L711-739)

    def postprocess_output_ids(self, data: DataProto) -> DataProto:
        # postprocess_generate, input_ids, attention_mask, left pad
        request_id = data.meta_info["request_id"]
        request: DataProto = self.requests_buffers.pop(request_id)

        eos_token_id = data.meta_info["eos_token_id"]
        pad_token_id = data.meta_info["pad_token_id"]
        output_token_ids = data.meta_info["output_token_ids"]
        output_tokens = [torch.tensor(token_ids) for token_ids in output_token_ids]

        output_logprobs = data.meta_info.get("output_logprobs", None)

        output_tensor = pad_sequence(output_tokens, batch_first=True, padding_value=pad_token_id)
        output_tensor = concatenate_input_and_output(
            input_ids=request.batch["input_ids"], output_ids=output_tensor, num_return_sequences=len(output_tokens)
        )
        output: DataProto = postprocess_generate(
            prompts=request,
            output=output_tensor,
            num_return_sequences=len(output_tokens),
            sequence_length=self.sequence_length,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
            output_logprobs=output_logprobs,
        )
        request_repeat = request.repeat(repeat_times=len(output_tokens))
        output.non_tensor_batch = request_repeat.non_tensor_batch
        output.meta_info = request_repeat.meta_info
        return output

File: roll/distributed/scheduler/async_generate_scheduler.py (L649-649)

            batch = self.postprocess_output_ids(data) if is_finished else self.postprocess_paused_data(data)

File: roll/third_party/sglang/v046post4_patch/async_engine.py (L63-63)

        request_complete_callback(data=output_data)
# `request_complete_callback` 与 `report_response` 的关系

`request_complete_callback` 是连接 SGLang 异步引擎和调度器 `report_response` 方法的桥梁函数,构成了完整的异步请求处理链。

## 核心关系

**`request_complete_callback`** 是一个回调函数引用,实际指向调度器的 `report_response` 方法 [1](#18-0) 。

**`report_response`** 是调度器中处理响应完成的方法,接收并处理 SGLang 返回的数据 [2](#18-1) 。

## 完整调用链

```mermaid
sequenceDiagram
    participant SGLang as SGLang引擎
    participant Process as process_sglang_output
    participant ActorWorker as ActorWorker
    participant Scheduler as DynamicSamplingScheduler
    
    Note over SGLang: 异步生成完成
    SGLang->>Process: 返回chunks
    Process->>Process: 处理输出数据
    Process->>ActorWorker: request_complete_callback(data)
    ActorWorker->>ActorWorker: request_complete(data)
    ActorWorker->>Scheduler: response_call_back_fn(data)
    Scheduler->>Scheduler: report_response(data)

关键连接点

1. 回调函数注册

在请求发送时,将 report_response 作为回调存储 3

self.response_call_back_fns[data.meta_info["request_id"]] = data.meta_info.pop("response_callback_fn")

2. SGLang 输出触发回调

process_sglang_output 完成数据处理后调用回调 4

request_complete_callback(data=output_data)

3. ActorWorker 中转执行

ActorWorker.request_complete 获取并执行回调函数 5

response_call_back_fn = self.response_call_back_fns.pop(data.meta_info["request_id"])
self.response_callback_refs.append(response_call_back_fn(data))

4. 调度器最终处理

report_response 接收数据并进行后续处理 6

batch = self.postprocess_output_ids(data)
# 计算奖励、过滤等后续处理

设计模式

这种设计实现了生产者-消费者模式

  • 生产者: SGLang 异步生成结果
  • 消费者: 调度器处理生成结果
  • 桥梁: request_complete_callback 连接两者

Notes

  • 整个过程是异步非阻塞的,不阻塞主训练流程
  • 回调机制支持多个并发请求的独立处理
  • 在异步调度器 AsyncDynamicSamplingScheduler 中也有相同的模式 7
  • 你新增的 spec 指标通过这个机制从 SGLang 传递到调度器的 report_response 方法中

Wiki pages you might want to explore:

Citations

File: roll/distributed/scheduler/generate_scheduler.py (L581-675)

    @ray.method(concurrency_group="multi_thread")
    def report_response(self, data: DataProto):
        """
        这里需要考虑多线程数据访问
        data 返回可能有多条的
        """
        try:
            request_id = data.meta_info["request_id"]
            prompt_id = self.request_id_2_prompt_id[request_id]
            num_return_sequences = self.generation_config["num_return_sequences"]

            batch = self.postprocess_output_ids(data)
            output_count = batch.batch.batch_size[0]
            with self.lock:
                self.load_balance_coordinator[self.request_id_2_dp_rank[request_id]] -= 1
                self.prompt_id_2_request_ids[prompt_id].remove(request_id)
                domain = "default"
                if "domain" in batch.non_tensor_batch.keys():
                    domain = batch.non_tensor_batch["domain"][0]
                reward_worker = next(self.reward_worker_iters[domain])

            if not self.running:
                return

            # call reward
            # reward worker得能支持单条数据计算, dynamic sampling对需要batch计算reward的需要注意...
            # 多域的时候,llm as judge, 需要单独为reward worker分配gpu
            rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch))
            batch.union(rewards)

            response_buffers: List[DataProto] = []
            batch_expanded = [batch[[idx]] for idx in range(output_count)]

            # response_filter, 不太需要response filter
            for batch_item in batch_expanded:
                if self.response_filter_fn(batch_item, self.pipeline_config):
                    response_buffers.append(batch_item)
                else:
                    self.response_filter_count += 1

            with self.lock:
                self.response_cache[domain].extend(batch_expanded)

                if len(response_buffers) == 0:
                    if len(self.prompt_id_2_request_ids[prompt_id]) == 0:
                        self.running_prompts -= 1
                    return

                if len(self.completed_buffers[prompt_id]) > 0:
                    return

                # expand batch to response
                self.query_group_buffers[prompt_id].extend(response_buffers)

                # query_filter, query has n responses
                if len(self.query_group_buffers[prompt_id]) >= num_return_sequences:
                    if not self.query_filter_fn(self.query_group_buffers[prompt_id], self.pipeline_config):
                        self.query_filter_count += 1
                        del self.query_group_buffers[prompt_id]
                        self.abort_requests(self.prompt_id_2_request_ids[prompt_id])
                        return

                    assert len(self.query_group_buffers[prompt_id]) >= num_return_sequences, (
                        f"expect to generate {num_return_sequences} results from one prompt, "
                        f"but get {len(self.query_group_buffers[prompt_id])}."
                    )
                    self.completed_buffers[prompt_id] = self.query_group_buffers[prompt_id][:num_return_sequences]
                    self.progress_bar.update()

                    if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")):
                        # report response level rewards
                        response_level_rewards = [data.batch["response_level_rewards"] for data in self.query_group_buffers[prompt_id]]
                        response_rewards = torch.cat(response_level_rewards, dim=0).long().cpu().tolist()
                        prompt_hash = self.prompt_id_2_hash_str.pop(prompt_id)
                        prompt_response_proto = DataProto.concat(self.query_group_buffers[prompt_id][:num_return_sequences])
                        # report response level lengths
                        response_lengths = torch.sum(prompt_response_proto.batch["response_mask"], dim=1).cpu().tolist()

                        lengths_and_rewards = {
                            'domain': domain,
                            'prompt_hash': prompt_hash,
                            'response_lengths': response_lengths,
                            'response_rewards': response_rewards
                        }
                        length_dir = os.path.join(self.pipeline_config.length_profiler_dir, "length")
                        os.makedirs(length_dir, exist_ok=True)
                        filename = f"response-length-and-rewards-{domain}-ep{self.dataset_epoch}.jsonl"
                        length_file_path = os.path.join(length_dir, filename)
                        with open(length_file_path, "a") as f:
                            f.write(json.dumps(lengths_and_rewards) + "\n")

                    # abort uncompleted request
                    self.abort_requests(self.prompt_id_2_request_ids[prompt_id])
        except Exception as e:
            self.exception_queue.put(e)

File: roll/pipeline/base_worker.py (L382-382)

            self.response_call_back_fns[data.meta_info["request_id"]] = data.meta_info.pop("response_callback_fn")

File: roll/pipeline/base_worker.py (L386-390)

    def request_complete(self, data: DataProto):
        data.meta_info["eos_token_id"] = self.tokenizer.eos_token_id
        data.meta_info["pad_token_id"] = self.tokenizer.pad_token_id
        response_call_back_fn = self.response_call_back_fns.pop(data.meta_info["request_id"])
        self.response_callback_refs.append(response_call_back_fn(data))

File: roll/third_party/sglang/async_engine.py (L73-73)

        request_complete_callback(data=output_data)

一个请求多个采样的投机解码指标计算

当启用并行采样(n > 1)时,每个采样都被视为独立的请求,投机解码指标是分别计算的。

关键机制

1. 独立的请求ID和状态

并行采样时,每个采样生成独立的请求ID 1

def _normalize_rid(self, num):
    if self.rid is None:
        self.rid = [uuid.uuid4().hex for _ in range(num)]
    elif isinstance(self.rid, str):
        new_rids = [f"{self.rid}_{i}" for i in range(num)]
        self.rid = new_rids
2. 独立的指标跟踪

每个采样(请求)在 Req 对象中独立跟踪投机解码指标 2

# 每个请求独立的投机解码指标
self.spec_verify_ct = 0
self.spec_accepted_tokens = 0
3. 分别计算指标

_handle_batch_output 中,遍历所有请求(采样),为每个完成的请求单独计算指标 3

for i, rid in enumerate(recv_obj.rids):
    # ...
    state.finished = recv_obj.finished_reasons[i] is not None
    if state.finished:
        if self.server_args.speculative_algorithm:
            self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)

执行流程

客户端 TokenizerManager Scheduler 请求: n=3 个采样 generate_request(n=3) 创建3个独立请求 rid_0, rid_1, rid_2 采样0执行 采样1执行 采样2执行 par [并行执行3个- 采样] 采样1先完成 返回采样1结果 finished_reasons[1]="stop" _calculate_spec_decoding_metrics(i=1) 采样0和2继续执行 采样0完成 返回采样0结果 finished_reasons[0]="stop" _calculate_spec_decoding_metrics(i=0) 采样2完成 返回采样2结果 finished_reasons[2]="stop" _calculate_spec_decoding_metrics(i=2) 客户端 TokenizerManager Scheduler

指标计算示例

假设一个请求有3个采样,投机解码执行情况:

采样spec_verify_ctspec_accepted_tokens完成时间计算的指标
采样0512第2个完成spec_accept_rate = 12/(5×4) = 0.60
采样1310第1个完成spec_accept_rate = 10/(3×4) = 0.83
采样2714第3个完成spec_accept_rate = 14/(7×4) = 0.50

每个采样的指标独立计算,互不影响。

Notes

  • 独立性:每个采样有独立的 Req 对象和投机解码指标
  • 异步完成:采样可以不同时间完成,完成时立即计算指标
  • 无共享:采样之间不共享投机解码状态,各自统计
  • 流式输出:在流式模式下,先完成的采样会先返回包含投机解码指标的 chunk

Wiki pages you might want to explore:

Wiki pages you might want to explore:

Citations

File: python/sglang/srt/managers/io_struct.py (L460-474)

    def _normalize_rid(self, num):
        """Normalize request IDs for batch processing."""
        if self.rid is None:
            self.rid = [uuid.uuid4().hex for _ in range(num)]
        elif isinstance(self.rid, str):
            new_rids = [f"{self.rid}_{i}" for i in range(num)]
            self.rid = new_rids
        elif isinstance(self.rid, list):
            # Note: the length of rid shall be the same as the batch_size,
            # as the rid would be expanded for parallel sampling in tokenizer_manager
            if len(self.rid) != self.batch_size:
                raise ValueError(
                    "The specified rids length mismatch with the batch_size for batch processing."
                )
        else:

File: python/sglang/srt/managers/schedule_batch.py (L646-652)

        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0

        # The number of accepted tokens in speculative decoding for this request.
        # This is used to compute the acceptance rate and average acceptance length per request.
        self.spec_accepted_tokens = 0

File: python/sglang/srt/managers/tokenizer_manager.py (L1575-1664)

        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
                continue

            # Build meta_info and return value
            meta_info = {
                "id": rid,
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
                "weight_version": self.server_args.weight_version,
                "total_retractions": recv_obj.retraction_counts[i],
            }

            if self.enable_metrics:
                self._add_metric_if_present(recv_obj, "queue_time", meta_info, i)
                self._add_metric_if_present(
                    recv_obj, "prefill_launch_delay", meta_info, i
                )
                self._add_metric_if_present(
                    recv_obj, "prefill_launch_latency", meta_info, i
                )

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
                    state,
                    state.obj.top_logprobs_num,
                    state.obj.token_ids_logprob,
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOutput):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

            if getattr(recv_obj, "output_hidden_states", None):
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOutput):
                state.text += recv_obj.output_strs[i]
                if state.obj.stream:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
                else:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids.copy()

                out_dict = {
                    "text": state.text,
                    "output_ids": output_token_ids,
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOutput):
                if self.server_args.stream_output and state.obj.stream:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
                else:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids.copy()

                out_dict = {
                    "output_ids": output_token_ids,
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchMultimodalOutput):
                raise NotImplementedError("BatchMultimodalOut not implemented")
            else:
                assert isinstance(recv_obj, BatchEmbeddingOutput)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
            if state.finished:
                if self.server_args.speculative_algorithm:
                    self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值