FastAPI在ChatGLM的使用代码(Open AI API格式)

#对于初学者FastAPI的使用有点晦涩难懂

首先导入模块

# coding=utf-8
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.


import time
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse

1. 导入和依赖


首先就是各种导包 注意这里
from pydantic import BaseModel, Field很重要 没有基础的记得去看一下他的基本概念 不然很容易懵
Pydantic 是一个用于数据验证和数据解析的 Python 库,它基于 Python 的类型注解(type hints),通过自动生成模型的方式,确保输入数据符合预期的类型和结构。Pydantic 主要用于快速、可靠地验证数据输入,尤其在 Web 开发和数据科学中非常有用。
基本示例:

from pydantic import BaseModel, Field
from typing import List, Optional

class Item(BaseModel):
    name: str
    description: Optional[str] = None
    price: float
    tags: List[str] = []

# 创建一个 Item 对象
item = Item(name="Apple", price=1.2, tags=["fruit", "fresh"])

# 输出对象
print(item)
# 输出: name='Apple' description=None price=1.2 tags=['fruit', 'fresh']

# 验证数据类型
try:
    invalid_item = Item(name="Banana", price="free", tags="fruit")
except ValueError as e:
    print(e)


#BaseModel:所有 Pydantic 模型需要继承 BaseModel,它提供了数据验证和解析的功能。

#Field:你可以使用 Field 来为字段指定默认值、额外的验证规则等。

#Optional:表示字段是可选的,可以为 None。

#List:表示字段是一个列表,包含多个元素。

 2. 生命周期管理

@asynccontextmanager
async def lifespan(app: FastAPI):
    yield
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
  • 作用: 定义 FastAPI 应用的生命周期。
  • yield 前: 可以放置启动时的初始化逻辑(这里为空)。
  • yield 后: 应用关闭时清理 GPU 内存,确保资源释放。

3. FastAPI 应用初始化 

app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
  • FastAPI(lifespan=lifespan): 创建 FastAPI 实例,并绑定生命周期函数。
  • CORSMiddleware: 添加跨域资源共享(CORS)中间件,允许所有来源、方法和头部的请求,便于前端调用。

4. 数据模型定义

这些模型使用 Pydantic 定义,用于验证和序列化请求/响应数据。

模型卡片和模型列表

class ModelCard(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "owner"
    root: Optional[str] = None
    parent: Optional[str] = None
    permission: Optional[list] = None

class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelCard] = []
  • ModelCard: 表示单个模型的元数据。
  • ModelList: 表示模型列表,用于 /v1/models 接口。

消息和请求

class ChatMessage(BaseModel):
    role: Literal["user", "assistant", "system"]
    content: str

class DeltaMessage(BaseModel):
    role: Optional[Literal["user", "assistant", "system"]] = None
    content: Optional[str] = None

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = None
    top_p: Optional[float] = None
    max_length: Optional[int] = None
    stream: Optional[bool] = False
  • ChatMessage: 表示单条消息,包含角色和内容。
  • DeltaMessage: 用于流式响应,表示增量更新。
  • ChatCompletionRequest: 用户请求的数据结构,包含模型名称、消息列表和可选参数。

响应  

class ChatCompletionResponseChoice(BaseModel):
    index: int
    message: ChatMessage
    finish_reason: Literal["stop", "length"]

class ChatCompletionResponseStreamChoice(BaseModel):
    index: int
    delta: DeltaMessage
    finish_reason: Optional[Literal["stop", "length"]]

class ChatCompletionResponse(BaseModel):
    model: str
    object: Literal["chat.completion", "chat.completion.chunk"]
    choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
  • ChatCompletionResponseChoice: 非流式响应的选项。
  • ChatCompletionResponseStreamChoice: 流式响应的选项。
  • ChatCompletionResponse: 统一的响应结构。

5. API 端点

获取模型列表

@app.get("/v1/models", response_model=ModelList)
async def list_models():
    global model_args
    model_card = ModelCard(id="gpt-3.5-turbo")
    return ModelList(data=[model_card])
  • 路径: /v1/models
  • 作用: 返回支持的模型列表(这里硬编码为 gpt-3.5-turbo)。
  • 问题: 代码中 model_args 未定义,可能是一个遗漏。

创建聊天完成

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
    global model, tokenizer

    if request.messages[-1].role != "user":
        raise HTTPException(status_code=400, detail="Invalid request")
    query = request.messages[-1].content

    prev_messages = request.messages[:-1]
    if len(prev_messages) > 0 and prev_messages[0].role == "system":
        query = prev_messages.pop(0).content + query

    history = []
    if len(prev_messages) % 2 == 0:
        for i in range(0, len(prev_messages), 2):
            if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
                history.append([prev_messages[i].content, prev_messages[i+1].content])

    if request.stream:
        generate = predict(query, history, request.model)
        return EventSourceResponse(generate, media_type="text/event-stream")

    response, _ = model.chat(tokenizer, query, history=history)
    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=ChatMessage(role="assistant", content=response),
        finish_reason="stop"
    )

    return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
  • 路径: /v1/chat/completions
  • 逻辑:
    1. 验证最后一条消息是否为用户消息。
    2. 提取用户查询(query)。
    3. 处理系统消息(若存在,拼接到 query)。
    4. 构建历史对话(成对的 user-assistant 消息)。
    5. 根据 stream 参数选择流式或非流式响应。
    • 非流式: 调用 model.chat 获取完整回答。
    • 流式: 调用 predict 函数,返回 SSE 流。

6. 流式预测函数

async def predict(query: str, history: List[List[str]], model_id: str):
    global model, tokenizer

    choice_data = ChatCompletionResponseStreamChoice(
        index=0,
        delta=DeltaMessage(role="assistant"),
        finish_reason=None
    )
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))

    current_length = 0
    for new_response, _ in model.stream_chat(tokenizer, query, history):
        if len(new_response) == current_length:
            continue
        new_text = new_response[current_length:]
        current_length = len(new_response)
        choice_data = ChatCompletionResponseStreamChoice(
            index=0,
            delta=DeltaMessage(content=new_text),
            finish_reason=None
        )
        chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
        yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))

    choice_data = ChatCompletionResponseStreamChoice(
        index=0,
        delta=DeltaMessage(),
        finish_reason="stop"
    )
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
    yield '[DONE]'
  • 作用: 实现流式响应,逐步生成并返回模型输出。
  • 逻辑:
    1. 发送初始 chunk(角色信息)。
    2. 使用 model.stream_chat 迭代生成回答,每次只发送新增部分(new_text)。
    3. 结束时发送 finish_reason="stop" 和 [DONE]。

7.主程序

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
    model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
    model.eval()
    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
  • 加载模型: 从 Hugging Face 加载 ChatGLM2-6B,并移至 GPU。
  • 运行服务: 在 0.0.0.0:8000 启动 API。

完整示例

运行代码
  1. 确保安装依赖:
    pip install fastapi uvicorn torch transformers sse-starlette pydantic

    保存代码为 openai_api.py,然后运行:

  2. ​​​​​​​python openai_api.py

  3. 服务启动后,访问 http://localhost:8000/docs 查看 Swagger UI。
非流式请求

使用 curl 发送请求:

​​​​​​​curl -X POST "http://localhost:8000/v1/chat/completions" \ 
-H "Content-Type: application/json" \ 
-d '{ "model": "gpt-3.5-turbo", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"} ], "stream": false }'

响应(假设模型返回 "I'm fine, thanks!"):

​​​​​​​{ "model": "gpt-3.5-turbo", "object": "chat.completion", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "I'm fine, thanks!" }, "finish_reason": "stop" } ], "created": 1711468800 }
流式请求
​​​​​​​curl -X POST "http://localhost:8000/v1/chat/completions" \
 -H "Content-Type: application/json" \
 -d '{ "model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a story"}], "stream": true }'

响应

{"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}],"object":"chat.completion.chunk"} {"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Once"},"finish_reason":null}],"object":"chat.completion.chunk"} {"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":" upon"},"finish_reason":null}],"object":"chat.completion.chunk"} {"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":" a"},"finish_reason":null}],"object":"chat.completion.chunk"} {"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":" time"},"finish_reason":null}],"object":"chat.completion.chunk"} {"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"object":"chat.completion.chunk"} [DONE]


细节补充

  1. 历史对话处理: 只支持成对的 user-assistant 消息,奇数条历史消息会被忽略。
  2. 参数限制: temperature、top_p 等参数定义了但未使用,需在 model.chat 中实现。
  3. 模型加载: 单 GPU 加载,多 GPU 需启用注释部分的代码。
  4. 错误处理: 仅验证了最后一条消息的角色,其他异常未全面处理。

问题:

1.最后的 yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))已经没有信息了 为什么还需要返回呢 返回的是什么呢?

这个 yield 的作用是通知客户端流式响应的结束。在流式传输(SSE)中:

  • 数据是分块发送的,每块代表一部分内容(比如文本的增量)。
  • 客户端需要知道什么时候整个响应完成了,而不是简单地停止接收数据。
  • 通过发送一个带有 finish_reason="stop" 的 chunk,客户端可以明确知道模型已经生成了所有内容,不会再有后续数据。

如果没有这个结束信号:

  • 客户端可能会认为连接中断或服务未完成,等待更多数据。

2.'[DONE]' 是一个特殊符号吗 我可以写成'[OK]'吗 ?

  • '[DONE]' 不是编程语言中的特殊符号,而是开发者定义的一个约定信号
  • 在 SSE(Server-Sent Events)中,服务端可以通过发送特定字符串来标记流的结束。
  • 在 OpenAI 的流式响应实现中,'[DONE]' 被用作流的结束标志(参考 OpenAI 的流式响应示例)。客户端(比如前端代码或工具)会识别这个字符串,知道数据流已完成。
  • 它本质上是一个自定义的标记,告诉客户端“所有数据都发送完了,可以关闭连接或停止处理”。
  • 虽然可以更改 但是不建议 因为这是OpenAI 的规范编写的。

3.for new_response, _ in model.stream_chat(tokenizer, query, history): if len(new_response) == current_length: continue 这里的continue是什么意思 是进入下一个段循环吗 这样的话不会进入死循环吗 为什么不是break

不会进入死循环,原因如下:

 model.stream_chat 是有限的生成器

model.stream_chat 会逐步生成文本,直到完成整个响应(比如 "Hello world" 生成完毕)。

如果生成结束,for 循环会自然退出,不会无限循环。

代码引用自:https://github.com/THUDM/ChatGLM2-6B/blob/main/openai_api.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值