从本地脚本到生产级API:用FastAPI封装bce-embedding-base_v1,让你的AI能力触手可及

从本地脚本到生产级API:用FastAPI封装bce-embedding-base_v1,让你的AI能力触手可及

【免费下载链接】bce-embedding-base_v1 【免费下载链接】bce-embedding-base_v1 项目地址: https://ai.gitcode.com/mirrors/maidalun1020/bce-embedding-base_v1

引言:你还在为Embedding模型落地而烦恼吗?

你是否曾遇到这些问题:好不容易训练好的Embedding模型,却卡在了工程化落地的最后一公里?本地脚本运行得好好的,放到生产环境就问题百出?想给团队提供一个易用的API接口,却不知从何下手?本文将带你一步步解决这些痛点,用FastAPI将bce-embedding-base_v1模型封装成高性能、易扩展的生产级API服务。

读完本文,你将获得:

  • 一套完整的Embedding模型API化解决方案
  • 高性能异步API设计与实现
  • 生产环境必备的监控、日志和错误处理机制
  • 容器化部署与性能优化的最佳实践
  • 可直接复用的代码和配置示例

1. 项目背景与技术选型

1.1 bce-embedding-base_v1模型概述

bce-embedding-base_v1是由网易有道开发的一款高性能双语跨语种嵌入模型,基于XLMRoberta架构,具有以下特点:

  • 支持中英双语及跨语种语义表征
  • 768维向量输出,模型参数279M
  • 专为RAG场景优化,无需精心设计指令
  • 支持通过Sentence Transformers、Transformers等多种方式调用

1.2 为什么选择FastAPI?

框架性能易用性文档异步支持类型提示
FastAPI★★★★★★★★★★★★★★★原生支持强类型
Flask★★★☆☆★★★★☆★★★★☆扩展支持弱类型
Django★★★☆☆★★★☆☆★★★★★部分支持中类型
Tornado★★★★☆★★☆☆☆★★★☆☆原生支持中类型

FastAPI凭借其卓越的性能、自动生成的API文档、原生异步支持和类型提示等特性,成为构建高性能API的理想选择。

1.3 系统架构设计

mermaid

2. 环境准备与依赖安装

2.1 系统环境要求

  • Python 3.8+
  • CUDA 11.0+ (推荐,用于GPU加速)
  • 至少4GB内存 (模型加载需要)
  • 网络连接 (用于下载模型权重)

2.2 创建虚拟环境

# 创建虚拟环境
conda create --name bce-env python=3.10 -y
conda activate bce-env

# 或使用venv
python -m venv bce-env
source bce-env/bin/activate  # Linux/Mac
bce-env\Scripts\activate     # Windows

2.3 安装依赖包

# 基础依赖
pip install fastapi uvicorn python-multipart pydantic-settings

# 模型相关依赖
pip install torch transformers sentence-transformers BCEmbedding

# 性能优化依赖
pip install uvloop httptools

# 监控和日志依赖
pip install prometheus-client python-json-logger

# 缓存依赖
pip install redis

# 部署依赖
pip install gunicorn

2.4 模型下载与准备

# 克隆模型仓库
git clone https://gitcode.com/mirrors/maidalun1020/bce-embedding-base_v1
cd bce-embedding-base_v1

# 验证模型文件完整性
ls -l | grep -E "pytorch_model.bin|config.json|tokenizer.json"

3. API设计与实现

3.1 项目结构

bce-embedding-api/
├── app/
│   ├── __init__.py
│   ├── main.py             # FastAPI应用入口
│   ├── config.py           # 配置管理
│   ├── models/             # 数据模型定义
│   │   ├── __init__.py
│   │   └── request.py      # 请求响应模型
│   ├── api/                # API路由
│   │   ├── __init__.py
│   │   ├── v1/             # API v1版本
│   │   │   ├── __init__.py
│   │   │   └── endpoints/
│   │   │       ├── __init__.py
│   │   │       └── embedding.py  # Embedding相关接口
│   ├── core/               # 核心功能
│   │   ├── __init__.py
│   │   ├── model.py        # 模型加载与管理
│   │   ├── cache.py        # 缓存管理
│   │   ├── metrics.py      # 监控指标
│   │   └── logging.py      # 日志配置
│   └── utils/              # 工具函数
│       ├── __init__.py
│       └── helpers.py      # 辅助函数
├── tests/                  # 单元测试
├── Dockerfile              # Docker配置
├── docker-compose.yml      # Docker Compose配置
├── requirements.txt        # 依赖列表
└── .env                    # 环境变量配置

3.2 配置文件设计

首先创建配置文件app/core/config.py

from pydantic_settings import BaseSettings
from typing import Optional, List
from pathlib import Path

class Settings(BaseSettings):
    # API配置
    API_V1_STR: str = "/api/v1"
    PROJECT_NAME: str = "bce-embedding-api"
    DESCRIPTION: str = "FastAPI wrapper for bce-embedding-base_v1 model"
    VERSION: str = "1.0.0"
    
    # 模型配置
    MODEL_PATH: str = "/data/web/disk1/git_repo/mirrors/maidalun1020/bce-embedding-base_v1"
    DEVICE: str = "cuda"  # "cuda" or "cpu"
    BATCH_SIZE: int = 32
    MAX_SEQ_LENGTH: int = 512
    NORMALIZE_EMBEDDINGS: bool = True
    
    # 缓存配置
    USE_CACHE: bool = True
    REDIS_URL: str = "redis://localhost:6379/0"
    CACHE_TTL: int = 3600  # 缓存过期时间(秒)
    
    # 并发配置
    WORKERS: int = 4
    MAX_TASKS_IN_PROGRESS: int = 100
    
    # 日志配置
    LOG_LEVEL: str = "INFO"
    
    class Config:
        case_sensitive = True
        env_file = ".env"

settings = Settings()

3.3 日志配置

创建日志配置文件app/core/logging.py

import logging
import sys
from typing import Dict, Any
from pythonjsonlogger import jsonlogger
from app.core.config import settings

def setup_logging() -> None:
    """配置日志系统"""
    logger = logging.getLogger()
    
    # 清除现有处理器
    if logger.hasHandlers():
        logger.handlers = []
    
    log_level = logging.getLevelName(settings.LOG_LEVEL.upper())
    
    # 创建JSON格式处理器
    handler = logging.StreamHandler(sys.stdout)
    formatter = jsonlogger.JsonFormatter(
        "%(asctime)s %(levelname)s %(name)s %(module)s %(funcName)s %(lineno)d %(message)s"
    )
    handler.setFormatter(formatter)
    
    logger.addHandler(handler)
    logger.setLevel(log_level)
    
    # 设置第三方库日志级别
    logging.getLogger("transformers").setLevel(logging.WARNING)
    logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
    logging.getLogger("fastapi").setLevel(logging.INFO)
    logging.getLogger("uvicorn").setLevel(logging.INFO)

3.4 模型加载与管理

创建模型管理文件app/core/model.py

import logging
from typing import List, Dict, Any, Optional
from pathlib import Path
import torch
from sentence_transformers import SentenceTransformer
from app.core.config import settings

logger = logging.getLogger(__name__)

class EmbeddingModel:
    """Embedding模型包装类"""
    _instance = None
    _model = None
    
    def __new__(cls):
        """单例模式"""
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        """初始化模型"""
        if self._model is None:
            self.load_model()
    
    def load_model(self) -> None:
        """加载模型"""
        logger.info(f"Loading model from {settings.MODEL_PATH}")
        try:
            self._model = SentenceTransformer(
                settings.MODEL_PATH,
                device=settings.DEVICE
            )
            logger.info(f"Model loaded successfully on {settings.DEVICE}")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise
    
    def encode(self, sentences: List[str], **kwargs) -> List[List[float]]:
        """
        将文本编码为向量
        
        Args:
            sentences: 文本列表
            **kwargs: 其他参数
        
        Returns:
            向量列表
        """
        if not self._model:
            self.load_model()
            
        params = {
            "normalize_embeddings": settings.NORMALIZE_EMBEDDINGS,
            "batch_size": settings.BATCH_SIZE,
            "max_seq_length": settings.MAX_SEQ_LENGTH,
            **kwargs
        }
        
        with torch.no_grad():
            embeddings = self._model.encode(sentences, **params)
        
        # 转换为列表格式
        return embeddings.tolist()

# 全局模型实例
embedding_model = EmbeddingModel()

3.5 请求响应模型

创建请求响应模型app/models/request.py

from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any

class EmbeddingRequest(BaseModel):
    """Embedding请求模型"""
    sentences: List[str] = Field(..., min_items=1, max_items=1000, description="要编码的文本列表")
    normalize: Optional[bool] = Field(None, description="是否归一化向量")
    batch_size: Optional[int] = Field(None, ge=1, le=1024, description="批处理大小")
    max_seq_length: Optional[int] = Field(None, ge=16, le=512, description="最大序列长度")
    
    @validator('sentences')
    def check_sentences_length(cls, v):
        """验证句子长度"""
        for i, s in enumerate(v):
            if len(s) > 10000:
                raise ValueError(f"Sentence at index {i} is too long (max 10000 characters)")
        return v

class EmbeddingResponse(BaseModel):
    """Embedding响应模型"""
    embeddings: List[List[float]] = Field(..., description="文本对应的向量列表")
    model: str = Field(..., description="使用的模型名称")
    dimensions: int = Field(..., description="向量维度")
    processing_time: float = Field(..., description="处理时间(秒)")
    normalized: bool = Field(..., description="是否归一化")

class HealthCheckResponse(BaseModel):
    """健康检查响应模型"""
    status: str = Field(..., description="服务状态")
    model_loaded: bool = Field(..., description="模型是否加载成功")
    version: str = Field(..., description="API版本")
    uptime: float = Field(..., description="服务运行时间(秒)")
    device: str = Field(..., description="模型运行设备")

3.6 API路由实现

创建API路由app/api/v1/endpoints/embedding.py

import logging
import time
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from typing import List, Dict, Any, Optional
from app.models.request import EmbeddingRequest, EmbeddingResponse, HealthCheckResponse
from app.core.model import embedding_model
from app.core.config import settings
from app.core.metrics import metrics, EMBEDDING_COUNT, EMBEDDING_LATENCY
from app.utils.helpers import get_uptime

logger = logging.getLogger(__name__)
router = APIRouter()

@router.post("/embeddings", response_model=EmbeddingResponse, summary="生成文本嵌入向量")
async def create_embeddings(request: EmbeddingRequest):
    """
    将输入文本列表编码为稠密向量
    
    - 支持批量处理多个文本
    - 可自定义归一化、批大小等参数
    - 返回与输入顺序对应的向量列表
    """
    start_time = time.time()
    logger.info(f"Received embedding request with {len(request.sentences)} sentences")
    
    # 更新指标
    EMBEDDING_COUNT.inc(len(request.sentences))
    
    try:
        # 准备参数
        params = {}
        if request.normalize is not None:
            params["normalize_embeddings"] = request.normalize
        if request.batch_size is not None:
            params["batch_size"] = request.batch_size
        if request.max_seq_length is not None:
            params["max_seq_length"] = request.max_seq_length
        
        # 调用模型
        embeddings = embedding_model.encode(request.sentences, **params)
        
        # 计算处理时间
        processing_time = time.time() - start_time
        
        # 更新延迟指标
        EMBEDDING_LATENCY.observe(processing_time)
        
        # 构建响应
        response = EmbeddingResponse(
            embeddings=embeddings,
            model="bce-embedding-base_v1",
            dimensions=len(embeddings[0]) if embeddings else 0,
            processing_time=round(processing_time, 4),
            normalized=request.normalize if request.normalize is not None else settings.NORMALIZE_EMBEDDINGS
        )
        
        logger.info(f"Successfully processed embedding request in {processing_time:.4f}s")
        return response
        
    except Exception as e:
        logger.error(f"Error processing embedding request: {str(e)}", exc_info=True)
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error generating embeddings: {str(e)}"
        )

@router.get("/health", response_model=HealthCheckResponse, summary="健康检查")
async def health_check():
    """
    检查服务健康状态
    
    返回服务状态、模型加载情况、版本和运行时间
    """
    try:
        # 检查模型是否加载成功
        model_loaded = hasattr(embedding_model, '_model') and embedding_model._model is not None
        
        # 构建响应
        return HealthCheckResponse(
            status="healthy",
            model_loaded=model_loaded,
            version=settings.VERSION,
            uptime=round(get_uptime(), 2),
            device=settings.DEVICE
        )
    except Exception as e:
        logger.error(f"Health check failed: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Service unhealthy: {str(e)}"
        )

3.7 主应用入口

创建主应用入口app/main.py

import logging
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_swagger_ui_html
from prometheus_fastapi_instrumentator import Instrumentator

from app.core.config import settings
from app.core.logging import setup_logging
from app.api.v1.endpoints import embedding
from app.api.v1 import api_router as api_v1_router

# 配置日志
setup_logging()
logger = logging.getLogger(__name__)

# 创建FastAPI应用
app = FastAPI(
    title=settings.PROJECT_NAME,
    description=settings.DESCRIPTION,
    version=settings.VERSION,
    docs_url=None,  # 禁用默认文档,使用自定义文档
    redoc_url=None
)

# 添加CORS中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 生产环境应限制具体域名
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 添加GZip压缩中间件
app.add_middleware(
    GZipMiddleware,
    minimum_size=1000,
    compresslevel=5
)

# 设置自定义Swagger UI
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
    return get_swagger_ui_html(
        openapi_url=f"{settings.API_V1_STR}/openapi.json",
        title=f"{settings.PROJECT_NAME} - API Docs",
        oauth2_redirect_url="/docs/oauth2-redirect",
        swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.15.5/swagger-ui-bundle.js",
        swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.15.5/swagger-ui.css",
    )

# 注册API路由
app.include_router(api_v1_router, prefix=settings.API_V1_STR)

# 设置Prometheus监控
instrumentator = Instrumentator().instrument(app)

@app.on_event("startup")
async def startup_event():
    """启动事件"""
    logger.info(f"Starting {settings.PROJECT_NAME} v{settings.VERSION}")
    logger.info(f"Model path: {settings.MODEL_PATH}")
    logger.info(f"Using device: {settings.DEVICE}")
    
    # 启动时加载模型
    try:
        if embedding_model._model is None:
            embedding_model.load_model()
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model during startup: {str(e)}")
    
    # 启动监控
    instrumentator.expose(app, endpoint_url="/metrics")

@app.on_event("shutdown")
async def shutdown_event():
    """关闭事件"""
    logger.info(f"Shutting down {settings.PROJECT_NAME}")

4. 缓存策略与性能优化

4.1 Redis缓存实现

创建缓存管理app/core/cache.py

import logging
import hashlib
import json
from typing import Optional, List, Any
from redis import Redis, RedisError
from app.core.config import settings

logger = logging.getLogger(__name__)

class CacheManager:
    """缓存管理器"""
    def __init__(self):
        """初始化缓存管理器"""
        self.enabled = settings.USE_CACHE
        self.redis = None
        
        if self.enabled:
            try:
                self.redis = Redis.from_url(settings.REDIS_URL)
                # 测试连接
                self.redis.ping()
                logger.info("Redis cache connected successfully")
            except RedisError as e:
                logger.error(f"Failed to connect to Redis: {str(e)}")
                self.enabled = False
    
    def _generate_key(self, text: str, **params) -> str:
        """生成缓存键"""
        # 合并参数并排序
        sorted_params = sorted(params.items())
        # 生成唯一键
        key_data = f"{text}||{json.dumps(sorted_params, sort_keys=True)}"
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def get_embeddings(self, sentences: List[str], **params) -> List[Optional[List[float]]]:
        """
        从缓存获取多个句子的embedding
        
        Args:
            sentences: 句子列表
            **params: 其他参数
            
        Returns:
            缓存结果列表,None表示未命中
        """
        if not self.enabled or not self.redis:
            return [None] * len(sentences)
            
        try:
            # 生成所有键
            keys = [self._generate_key(s, **params) for s in sentences]
            
            # 批量获取
            results = self.redis.mget(keys)
            
            # 解析结果
            embeddings = []
            for result in results:
                if result:
                    embeddings.append(json.loads(result))
                else:
                    embeddings.append(None)
                    
            return embeddings
            
        except RedisError as e:
            logger.warning(f"Redis error during get: {str(e)}")
            return [None] * len(sentences)
    
    def set_embeddings(self, sentences: List[str], embeddings: List[List[float]], **params) -> None:
        """
        将多个句子的embedding存入缓存
        
        Args:
            sentences: 句子列表
            embeddings: 对应的embedding列表
            **params: 其他参数
        """
        if not self.enabled or not self.redis or len(sentences) != len(embeddings):
            return
            
        try:
            # 使用管道批量操作
            pipe = self.redis.pipeline()
            
            for s, emb in zip(sentences, embeddings):
                key = self._generate_key(s, **params)
                pipe.setex(
                    key, 
                    settings.CACHE_TTL, 
                    json.dumps(emb)
                )
                
            # 执行管道命令
            pipe.execute()
            
        except RedisError as e:
            logger.warning(f"Redis error during set: {str(e)}")

# 全局缓存实例
cache_manager = CacheManager()

4.2 集成缓存到API

修改app/core/model.py中的encode方法:

def encode(self, sentences: List[str], **kwargs) -> List[List[float]]:
    """
    将文本编码为向量
    
    Args:
        sentences: 文本列表
        **kwargs: 其他参数
    
    Returns:
        向量列表
    """
    if not self._model:
        self.load_model()
        
    # 准备参数
    params = {
        "normalize_embeddings": settings.NORMALIZE_EMBEDDINGS,
        "batch_size": settings.BATCH_SIZE,
        "max_seq_length": settings.MAX_SEQ_LENGTH,
        **kwargs
    }
    
    # 从缓存获取
    cache_results = cache_manager.get_embeddings(sentences, **params)
    
    # 找出未命中的索引和句子
    miss_indices = []
    miss_sentences = []
    
    for i, (sentence, result) in enumerate(zip(sentences, cache_results)):
        if result is None:
            miss_indices.append(i)
            miss_sentences.append(sentence)
    
    # 如果全部命中,直接返回
    if not miss_indices:
        return cache_results
    
    # 处理未命中的句子
    with torch.no_grad():
        miss_embeddings = self._model.encode(miss_sentences, **params)
    
    # 转换为列表格式
    miss_embeddings = miss_embeddings.tolist()
    
    # 存入缓存
    cache_manager.set_embeddings(miss_sentences, miss_embeddings, **params)
    
    # 合并结果
    final_embeddings = []
    miss_idx = 0
    
    for i, result in enumerate(cache_results):
        if result is not None:
            final_embeddings.append(result)
        else:
            final_embeddings.append(miss_embeddings[miss_idx])
            miss_idx += 1
    
    return final_embeddings

4.3 性能优化策略

1.** 异步处理 :利用FastAPI的异步特性处理并发请求 2. 批处理优化 :合理设置批处理大小,充分利用GPU 3. 缓存热点数据 :使用Redis缓存频繁请求的文本向量 4. 模型优化 **:

  • 使用半精度推理(FP16)
  • 启用模型并行(适用于大型模型)
  • 考虑使用TensorRT等优化工具 5.** API优化 **:
  • 启用GZip压缩
  • 合理设置超时时间
  • 实现请求限流

5. 监控与日志

5.1 监控指标实现

创建监控指标app/core/metrics.py

from prometheus_client import Counter, Histogram, Gauge
import time
from typing import Dict, Any

# 请求计数
REQUEST_COUNT = Counter(
    "api_requests_total", 
    "Total number of API requests",
    ["endpoint", "method", "status_code"]
)

# 请求延迟
REQUEST_LATENCY = Histogram(
    "api_request_latency_seconds", 
    "API request latency in seconds",
    ["endpoint", "method"]
)

# Embedding计数
EMBEDDING_COUNT = Counter(
    "embedding_total", 
    "Total number of embeddings generated",
)

# Embedding延迟
EMBEDDING_LATENCY = Histogram(
    "embedding_latency_seconds", 
    "Embedding generation latency in seconds",
)

# 缓存命中率
CACHE_HIT_RATE = Gauge(
    "embedding_cache_hit_rate", 
    "Embedding cache hit rate",
)

# 服务启动时间
start_time = time.time()

def get_uptime() -> float:
    """获取服务运行时间"""
    return time.time() - start_time

class MetricsMiddleware:
    """指标中间件"""
    async def __call__(self, request, call_next):
        """处理请求并记录指标"""
        # 记录开始时间
        start_time = time.time()
        
        # 获取端点和方法
        endpoint = request.url.path
        method = request.method
        
        # 处理请求
        response = await call_next(request)
        
        # 记录指标
        REQUEST_COUNT.labels(
            endpoint=endpoint,
            method=method,
            status_code=response.status_code
        ).inc()
        
        # 记录延迟
        latency = time.time() - start_time
        REQUEST_LATENCY.labels(
            endpoint=endpoint,
            method=method
        ).observe(latency)
        
        return response

5.2 集成监控到主应用

修改app/main.py,添加监控中间件:

# 在导入部分添加
from app.core.metrics import MetricsMiddleware

# 在创建app后添加
app.add_middleware(MetricsMiddleware)

6. 部署与运维

6.1 Docker容器化

创建Dockerfile

FROM python:3.10-slim

# 设置工作目录
WORKDIR /app

# 设置环境变量
ENV PYTHONDONTWRITEBYTECODE=1 \
    PYTHONUNBUFFERED=1 \
    PIP_NO_CACHE_DIR=off \
    PIP_DISABLE_PIP_VERSION_CHECK=on

# 安装系统依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
    build-essential \
    git \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

# 复制依赖文件
COPY requirements.txt .

# 安装Python依赖
RUN pip install --upgrade pip && \
    pip install -r requirements.txt

# 复制项目文件
COPY . .

# 创建非root用户
RUN useradd -m appuser && chown -R appuser:appuser /app
USER appuser

# 暴露端口
EXPOSE 8000

# 启动命令
CMD ["gunicorn", "app.main:app", "--workers", "4", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000"]

创建docker-compose.yml

version: '3.8'

services:
  api:
    build: .
    restart: always
    ports:
      - "8000:8000"
    environment:
      - MODEL_PATH=/models/bce-embedding-base_v1
      - DEVICE=cuda
      - BATCH_SIZE=32
      - USE_CACHE=True
      - REDIS_URL=redis://redis:6379/0
      - LOG_LEVEL=INFO
    volumes:
      - ./models:/models
    depends_on:
      - redis
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]

  redis:
    image: redis:7-alpine
    restart: always
    ports:
      - "6379:6379"
    volumes:
      - redis-data:/data
    command: redis-server --maxmemory 4gb --maxmemory-policy allkeys-lru

volumes:
  redis-data:

6.2 启动与停止服务

# 构建镜像
docker-compose build

# 启动服务
docker-compose up -d

# 查看日志
docker-compose logs -f

# 停止服务
docker-compose down

# 停止服务并删除数据卷
docker-compose down -v

6.3 性能测试

使用locust进行性能测试:

# locustfile.py
from locust import HttpUser, task, between
import random
import json

class EmbeddingUser(HttpUser):
    wait_time = between(0.1, 0.5)
    
    sentences = [
        "这是一个测试句子",
        "FastAPI是一个高性能的API框架",
        "bce-embedding-base-v1是一个优秀的嵌入模型",
        "性能测试对于API开发非常重要",
        "Redis可以用来缓存频繁访问的数据",
        # 添加更多测试句子...
    ]
    
    @task(1)
    def single_embedding(self):
        """测试单个句子嵌入"""
        sentence = random.choice(self.sentences)
        self.client.post(
            "/api/v1/embeddings",
            json={"sentences": [sentence]}
        )
    
    @task(2)
    def batch_embedding(self):
        """测试批量句子嵌入"""
        batch_size = random.randint(2, 10)
        batch_sentences = random.sample(self.sentences, batch_size)
        self.client.post(
            "/api/v1/embeddings",
            json={"sentences": batch_sentences}
        )
    
    @task(1)
    def health_check(self):
        """测试健康检查接口"""
        self.client.get("/api/v1/health")

启动性能测试:

locust -f locustfile.py --host http://localhost:8000

7. 错误处理与故障恢复

7.1 错误处理机制

app/main.py中添加全局异常处理:

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
    """全局异常处理"""
    logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
    return JSONResponse(
        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
        content={"detail": "An unexpected error occurred"}
    )

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
    """HTTP异常处理"""
    logger.warning(f"HTTP exception: {exc.status_code} - {exc.detail}")
    return JSONResponse(
        status_code=exc.status_code,
        content={"detail": exc.detail}
    )

7.2 故障恢复策略

1.** 模型自动重载 **:在app/core/model.py中添加自动重载机制

def encode(self, sentences: List[str], **kwargs) -> List[List[float]]:
    """将文本编码为向量"""
    try:
        # 尝试编码
        # ... 现有代码 ...
    except RuntimeError as e:
        if "out of memory" in str(e):
            logger.error(f"GPU out of memory error: {str(e)}")
            # 尝试清理缓存
            torch.cuda.empty_cache()
            
            # 重新加载模型
            logger.info("Reloading model due to OOM error")
            self.load_model()
            
            # 重试编码
            return self.encode(sentences, **kwargs)
        else:
            raise

2.** 服务降级策略 **:当系统负载过高时,自动调整参数

# 在app/core/config.py中添加
MAX_QUEUE_SIZE: int = 1000  # 最大请求队列大小
HIGH_LOAD_THRESHOLD: float = 0.8  # 高负载阈值(CPU利用率)

# 在API处理函数中添加
import psutil

@app.post("/embeddings")
async def create_embeddings(request: EmbeddingRequest):
    # 检查系统负载
    cpu_usage = psutil.cpu_percent(interval=0.1)
    
    if cpu_usage > settings.HIGH_LOAD_THRESHOLD * 100:
        logger.warning(f"High system load detected: {cpu_usage}%")
        # 自动调整批大小
        if request.batch_size is None:
            request.batch_size = max(1, settings.BATCH_SIZE // 2)
            logger.info(f"Reduced batch size to {request.batch_size} due to high load")
    
    # ... 现有代码 ...

8. 总结与展望

8.1 项目成果总结

本文详细介绍了如何使用FastAPI将bce-embedding-base_v1模型封装为高性能的生产级API服务,主要成果包括:

1.** 完整的API服务实现 :从模型加载到API接口设计的全流程实现 2. 企业级特性 :添加缓存、监控、日志和错误处理等生产环境必备功能 3. 容器化部署 :使用Docker和Docker Compose实现便捷部署 4. 性能优化 **:通过批处理、缓存和异步处理提升系统性能

8.2 后续优化方向

1.** 多模型支持 :支持同时加载多个版本或类型的模型 2. A/B测试框架 :实现模型A/B测试功能,方便评估新模型效果 3. 动态扩缩容 :结合Kubernetes实现基于负载的自动扩缩容 4. 模型量化 :使用INT8量化进一步降低模型内存占用 5. 分布式推理 **:实现多节点分布式推理,提高吞吐量

8.3 结语

通过本文的方案,我们成功将bce-embedding-base_v1模型从本地脚本转变为企业级API服务,不仅解决了模型落地的工程化难题,还通过一系列优化措施确保了服务的高性能和可靠性。这套方案不仅适用于Embedding模型,也可推广到其他NLP模型的API化部署。

附录:完整代码清单

requirements.txt

fastapi==0.104.1
uvicorn==0.24.0
python-multipart==0.0.6
pydantic-settings==2.0.3
torch==2.1.0
transformers==4.35.0
sentence-transformers==2.2.2
BCEmbedding==0.1.1
redis==4.6.0
prometheus-client==0.17.1
python-json-logger==2.0.7
psutil==5.9.6
gunicorn==21.2.0
httptools==0.6.0
uvloop==0.19.0

完整目录结构

bce-embedding-api/
├── app/
│   ├── __init__.py
│   ├── main.py
│   ├── config.py
│   ├── models/
│   │   ├── __init__.py
│   │   └── request.py
│   ├── api/
│   │   ├── __init__.py
│   │   └── v1/
│   │       ├── __init__.py
│   │       └── endpoints/
│   │           ├── __init__.py
│   │           └── embedding.py
│   ├── core/
│   │   ├── __init__.py
│   │   ├── model.py
│   │   ├── cache.py
│   │   ├── metrics.py
│   │   └── logging.py
│   └── utils/
│       ├── __init__.py
│       └── helpers.py
├── tests/
├── Dockerfile
├── docker-compose.yml
├── requirements.txt
└── .env

通过以上步骤,我们成功构建了一个功能完善、性能优异的Embedding模型API服务,为企业级应用提供了强大的语义表征能力支持。无论是构建RAG系统、语义搜索还是推荐系统,这个API服务都能提供可靠高效的向量生成能力。

【免费下载链接】bce-embedding-base_v1 【免费下载链接】bce-embedding-base_v1 项目地址: https://ai.gitcode.com/mirrors/maidalun1020/bce-embedding-base_v1

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值