FastAPI开发AI应用教程:新增用户历史消息

本教程通过前端会话 ID 管理、后端历史消息接口和流式对话上下文传递三个核心技术,实现了支持多助手切换和历史记录持久化的 AI 聊天应用。​

本文将深入介绍如何在 FastAPI AI 聊天应用中实现用户历史消息功能,当用户切换助手,刷新页面时,都可以保留当前会话历史消息。

图片

本项目已经开源至 Github,项目地址:https://github.com/wayn111/fastapi-ai-chat-demo

温馨提示:本文全文约一万字,看完约需 15 分钟。

文章概述

重点讲解每个助手区分 sessionid、获取历史消息接口以及发送消息时携带上下文信息的核心技术实现。通过本教程,你将掌握构建智能聊天应用中消息持久化和上下文管理的关键技术。

核心功能

  • 多助手会话隔离:每个 AI 助手(智能助手、AI 老师、编程专家)都有独立的会话历史
  • 智能会话管理:自动生成和管理 sessionid,确保会话的唯一性和持久性
  • 历史消息加载:快速加载和展示用户的历史对话记录
  • 上下文传递:发送消息时自动携带历史上下文,保持对话连贯性
  • 数据持久化:支持 Redis 和内存两种存储方式

技术栈

  • 后端框架:FastAPI(高性能异步 Web 框架)
  • 数据存储:Redis(主要)+ 内存存储(备用)
  • 前端技术:原生 JavaScript + HTML5 + CSS3
  • 数据格式:JSON(消息序列化和传输)
  • 会话管理:UUID + 时间戳(会话 ID 生成)

核心架构设计

🏗️ 数据模型设计

在实现历史消息功能之前,我们需要设计合理的数据模型来存储和管理消息数据:

@dataclass
class AIMessage:
    """AI消息数据类"""
    role: str
    content: str
    timestamp: float
    image_data: Optional[str] = None  # Base64编码的图片数据
    image_type: Optional[str] = None  # 图片类型 (jpeg, png, gif)

    这个数据类定义了消息的基本结构,包含角色、内容、时间戳和可选的图片数据字段。

    🔑 会话 ID 管理策略

    会话 ID 是整个历史消息系统的核心,我们采用了前端生成、后端接收的管理策略:

    前端会话 ID 生成逻辑:

    // 前端生成会话ID的核心逻辑
    if (sessionId) {
        // 复用已存在的会话ID
        currentSessionId = sessionId;
    } else {
        // 生成新的会话ID:时间戳 + 随机数
        const timestamp = Date.now();
        const randomNum = Math.floor(Math.random() * 10000);
        sessionId = `session_${timestamp}_${randomNum}`;
        currentSessionId = sessionId;
        localStorage.setItem(sessionKey, sessionId);
    }

    后端键名管理:

    def get_conversation_key(user_id: str, session_id: str) -> str:
        """获取对话在Redis中的键名"""
        return f"conversation:{user_id}:{session_id}"
    
    def get_user_sessions_key(user_id: str) -> str:
        """获取用户会话列表在Redis中的键名"""
        return f"user_sessions:{user_id}"

      前端生成唯一的会话 ID 并传递给后端,后端使用这个 ID 构建 Redis 键名来存储对话数据。

      核心功能实现

      🎯 功能一:每个助手区分 sessionid

      前端实现:智能会话管理

      在前端,我们为每个助手类型维护独立的 sessionid,实现真正的会话隔离:

      /**
       * 选择智能助手类型
       * @param {string} assistantType - 助手类型
       */
      function selectAssistant(assistantType) {
          // 更新当前助手类型
          currentAssistantType = assistantType;
      
          // 移除所有助手项的active类
          document.querySelectorAll('.assistant-item').forEach(item => {
              item.classList.remove('active');
          });
      
          // 为当前选中的助手添加active类
          event.target.closest('.assistant-item').classList.add('active');
      
          // 更新所有现有的assistant消息头像
          updateAssistantAvatars(assistantType);
      
          // 从全局配置中获取角色信息
          const roleConfig = aiRolesConfig[assistantType];
          if (!roleConfig) {
              console.error('未找到角色配置:', assistantType);
              return;
          }
      
          // 更新选中模型信息显示
          updateSelectedModelInfo(assistantType);
      
          // 切换助手时处理sessionId
          const sessionKey = `${assistantType}_sessionId`;
          let sessionId = localStorage.getItem(sessionKey);
      
          if (sessionId) {
              // 如果该助手已有sessionId,使用之前的
              currentSessionId = sessionId;
          } else {
              // 如果没有sessionId,生成新的
              const timestamp = Date.now();
              const randomNum = Math.floor(Math.random() * 10000);
              sessionId = `session_${timestamp}_${randomNum}`;
              currentSessionId = sessionId;
              localStorage.setItem(sessionKey, sessionId);
          }
      
          // 根据当前助手的sessionId重新调用history接口
          loadAssistantHistory(assistantType);
      }

      这个函数负责切换助手时的会话管理,为每个助手类型维护独立的 sessionId,并从 localStorage 中获取或生成新的会话 ID。

      后端实现:接收会话 ID 并管理数据

      后端接收前端传来的会话 ID,通过 Redis 实现会话数据的持久化存储:

      async def save_message_to_redis(user_id: str, session_id: str, message: ChatMessage):
          """将消息保存到Redis或内存"""
          try:
              message_data = {
                  "role": message.role,
                  "content": message.content,
                  "timestamp": message.timestamp,
                  "image_data": getattr(message, 'image_data', None),
                  "image_type": getattr(message, 'image_type', None)
              }
      
              if REDIS_AVAILABLE and redis_client:
                  # Redis存储:高性能,支持数据过期
                  conversation_key = get_conversation_key(user_id, session_id)
                  redis_client.lpush(conversation_key, json.dumps(message_data))
                  redis_client.ltrim(conversation_key, 0, 19)  # 只保留最近20条消息
                  redis_client.expire(conversation_key, 86400 * 7)  # 7天过期
      
                  # 更新会话信息
                  sessions_key = get_user_sessions_key(user_id)
                  session_info = {
                      "session_id": session_id,
                      "last_message": message.content[:50] + "..."if len(message.content) > 50else message.content,
                      "last_timestamp": message.timestamp
                  }
                  redis_client.hset(sessions_key, session_id, json.dumps(session_info))
                  redis_client.expire(sessions_key, 86400 * 30)  # 30天过期
      
                  logger.info(f"消息已保存到Redis - 用户: {user_id}, 会话: {session_id[:8]}..., 角色: {message.role}")
              else:
                  # 内存存储:备用方案
                  if user_id notin MEMORY_STORAGE["conversations"]:
                      MEMORY_STORAGE["conversations"][user_id] = {}
                  if session_id notin MEMORY_STORAGE["conversations"][user_id]:
                      MEMORY_STORAGE["conversations"][user_id][session_id] = []
      
                  MEMORY_STORAGE["conversations"][user_id][session_id].append(message_data)
      
                  # 限制内存中的消息数量
                  if len(MEMORY_STORAGE["conversations"][user_id][session_id]) > 20:
                      MEMORY_STORAGE["conversations"][user_id][session_id] = \
                          MEMORY_STORAGE["conversations"][user_id][session_id][-20:]
      
                  logger.info(f"消息已保存到内存 - 用户: {user_id}, 会话: {session_id[:8]}..., 角色: {message.role}")
      
          except Exception as e:
              logger.error(f"保存消息失败 - 用户: {user_id}, 会话: {session_id[:8]}..., 错误: {e}")
              raise

      这个函数将消息保存到 Redis 或内存中,支持双重存储策略,并设置了消息数量限制和过期时间。

      🔍 功能二:获取历史消息接口

      图片

      后端 API 设计

      我们设计了一个高效的历史消息获取接口:

      @app.get("/chat/history")
      asyncdef get_chat_history(
          user_id: str = Query(..., descriptinotallow="用户ID"),
          session_id: str = Query(..., descriptinotallow="会话ID")
      ):
          """获取聊天历史"""
          logger.info(f"获取聊天历史 - 用户: {user_id}, 会话: {session_id[:8]}...")
      
          try:
              history = await get_conversation_history(user_id, session_id)
              logger.info(f"聊天历史获取成功 - 用户: {user_id}, 会话: {session_id[:8]}..., 消息数: {len(history)}")
              return {
                  "session_id": session_id,
                  "messages": history,
                  "total": len(history)
              }
          except Exception as e:
              logger.error(f"获取聊天历史失败 - 用户: {user_id}, 会话: {session_id[:8]}..., 错误: {e}")
              raise HTTPException(status_code=500, detail="获取聊天历史失败")
      
      asyncdef get_conversation_history(user_id: str, session_id: str) -> List[Dict[str, Any]]:
          """从Redis或内存获取对话历史"""
          try:
              if REDIS_AVAILABLE and redis_client:
                  # 从Redis获取
                  conversation_key = get_conversation_key(user_id, session_id)
                  messages = redis_client.lrange(conversation_key, 0, -1)
      
                  # 反转消息顺序(Redis中是倒序存储的)
                  messages.reverse()
      
                  history = [json.loads(msg) for msg in messages]
                  logger.info(f"从Redis获取对话历史 - 用户: {user_id}, 会话: {session_id[:8]}..., 消息数量: {len(history)}")
                  return history
              else:
                  # 从内存获取
                  if (user_id in MEMORY_STORAGE["conversations"] and
                      session_id in MEMORY_STORAGE["conversations"][user_id]):
                      history = MEMORY_STORAGE["conversations"][user_id][session_id]
                      logger.info(f"从内存获取对话历史 - 用户: {user_id}, 会话: {session_id[:8]}..., 消息数量: {len(history)}")
                      return history
                  else:
                      logger.info(f"未找到对话历史 - 用户: {user_id}, 会话: {session_id[:8]}...")
                      return []
      
          except Exception as e:
              logger.error(f"获取对话历史失败 - 用户: {user_id}, 会话: {session_id[:8]}..., 错误: {e}")
              return []
      前端历史消息加载

      前端通过异步请求加载和渲染历史消息:

      /**
       * 加载指定助手的历史消息
       * @param {string} assistantType - 助手类型
       */
      asyncfunction loadAssistantHistory(assistantType) {
          try {
              // 获取该助手的sessionId
              const sessionId = localStorage.getItem(`${assistantType}_sessionId`);
              if (!sessionId) {
                  // 如果没有sessionId,显示欢迎消息
                  showWelcomeMessage(assistantType);
                  return;
              }
      
              // 更新当前会话ID
              currentSessionId = sessionId;
      
              // 清空当前聊天消息
              const chatMessages = document.getElementById('chatMessages');
              chatMessages.innerHTML = '';
      
              // 显示加载提示
              const loadingMessage = document.createElement('div');
              loadingMessage.className = 'message assistant';
              loadingMessage.innerHTML = `
                  <div class="message-avatar">🤖</div>
                  <div class="message-content-wrapper">
                      正在加载历史消息...
                  </div>
              `;
              chatMessages.appendChild(loadingMessage);
      
              // 从后端获取历史消息
              const response = await fetch(`/chat/history?session_id=${sessionId}&user_id=${userId}`);
              if (response.ok) {
                  const data = await response.json();
      
                  // 清空加载提示
                  chatMessages.innerHTML = '';
      
                  // 渲染历史消息
                  if (data.messages && data.messages.length > 0) {
                      data.messages.forEach(message => {
                          renderHistoryMessage(message);
                      });
                      console.log(`加载了 ${data.messages.length}条历史消息`);
                  } else {
                      // 如果没有历史消息,显示欢迎消息
                      showWelcomeMessage(assistantType);
                  }
      
                  // 滚动到底部
                  scrollToBottom();
              } else {
                  console.error('加载历史消息失败:', response.statusText);
                  showWelcomeMessage(assistantType);
              }
          } catch (error) {
              console.error('加载助手历史失败:', error);
              showWelcomeMessage(assistantType);
          }
      }
      
      /**
       * 渲染历史消息
       * @param {Object} message - 消息对象
       */
      function renderHistoryMessage(message) {
          const chatMessages = document.getElementById('chatMessages');
          const messageDiv = document.createElement('div');
          messageDiv.className = `message ${message.role}`;
      
          // 创建头像
          const avatarDiv = document.createElement('div');
          avatarDiv.className = 'message-avatar';
      
          // 如果是assistant消息,设置助手图标
          if (message.role === 'assistant') {
              const icon = getAssistantIcon(currentAssistantType);
              avatarDiv.setAttribute('data-icon', icon);
          }
      
          const contentDiv = document.createElement('div');
          contentDiv.className = 'message-content-wrapper';
      
          // 处理消息内容
          if (message.role === 'assistant') {
              // 对于AI回复,使用Markdown渲染
              renderMarkdownContent(message.content, contentDiv);
          } else {
              // 对于用户消息,检查是否包含图片
              if (message.image_data) {
                  // 创建图片元素
                  const imageDiv = document.createElement('div');
                  imageDiv.className = 'message-image';
                  const img = document.createElement('img');
                  img.src = `data:${message.image_type};base64,${message.image_data}`;
                  img.alt = '用户上传的图片';
                  img.style.maxWidth = '300px';
                  img.style.borderRadius = '8px';
                  imageDiv.appendChild(img);
                  contentDiv.appendChild(imageDiv);
              }
      
              // 添加文本内容
              if (message.content && message.content.trim()) {
                  const textDiv = document.createElement('div');
                  textDiv.textContent = message.content;
                  contentDiv.appendChild(textDiv);
              }
          }
      
          messageDiv.appendChild(avatarDiv);
          messageDiv.appendChild(contentDiv);
          chatMessages.appendChild(messageDiv);
      }

      这个函数从后端获取指定助手的历史消息,并在前端进行渲染显示,支持文本和图片消息的完整展示。

      💬 功能三:发送消息时携带上下文信息

      后端流式对话实现

      发送消息时,我们需要获取历史上下文并传递给 AI 模型:

      1. 流式聊天接口
      @app.post("/chat/stream")
      asyncdef chat_stream(request: ChatRequest):
          """流式聊天接口"""
          # 设置默认值
          role = "assistant"
          provider = request.provider
          model = getattr(request, 'model', None)
      
          logger.info(f"流式聊天请求 - 用户: {request.user_id}, 会话: {request.session_id[:8]}..., 角色: {role}, 消息长度: {len(request.message)}, 提供商: {provider}")
      
          if role notin AI_ROLES:
              logger.warning(f"不支持的AI角色: {role}")
              raise HTTPException(status_code=400, detail="不支持的AI角色")
      
          return StreamingResponse(
              generate_streaming_response(request.user_id, request.session_id, request.message, role, provider, model, request.image_data, request.image_type),
              media_type="text/event-stream",
              headers={
                  "Cache-Control": "no-cache",
                  "Connection": "keep-alive",
                  "Access-Control-Allow-Origin": "*"
              }
          )

      这个接口是流式聊天的入口点:

      • 接收前端发送的 ChatRequest 对象,包含用户 ID、会话 ID、消息内容等
      • 设置默认的 AI 角色为 "assistant",从请求中获取 AI 提供商和模型信息
      • 验证 AI 角色是否在支持的角色列表中
      • 返回 StreamingResponse 对象,设置 SSE(Server-Sent Events)相关的响应头
      • 调用 generate_streaming_response 函数处理具体的流式响应逻辑
      2. 流式响应生成函数
      async def generate_streaming_response(user_id: str, session_id: str, user_message: str, role: str = "assistant", provider: Optional[str] = None, model: Optional[str] = None, image_data: Optional[str] = None, image_type: Optional[str] = None):
          """生成流式响应"""
          logger.info(f"开始流式响应 - 用户: {user_id}, 会话: {session_id[:8]}..., 角色: {role}, 消息长度: {len(user_message)}, 提供商: {provider}")
      
          try:
              # 1. 保存用户消息到Redis
              from ai_providers.base import AIMessage
              user_msg = AIMessage(
                  role="user",
                  cnotallow=user_message,
                  timestamp=time.time(),
                  image_data=image_data,
                  image_type=image_type
              )
              await save_message_to_redis(user_id, session_id, user_msg)
      
              # 2. 获取对话历史记录
              history = await get_conversation_history(user_id, session_id)
      
              # 3. 构建系统提示词
              system_prompt = AI_ROLES.get(role, AI_ROLES["assistant"])["prompt"]
      
              # 4. 构建AI消息对象列表
              ai_messages = []
      
              # 5. 添加历史消息(限制数量避免上下文过长)
              recent_messages = history[-config.MAX_HISTORY_MESSAGES:] if len(history) > config.MAX_HISTORY_MESSAGES else history
              for msg in recent_messages:
                  if msg["role"] in ["user", "assistant"]:
                      ai_messages.append(AIMessage(
                          role=msg["role"],
                          cnotallow=msg["content"],
                          timestamp=msg.get("timestamp", time.time()),
                          image_data=msg.get("image_data"),
                          image_type=msg.get("image_type")
                      ))
      
              # 6. 调用AI提供商的流式API
              logger.info(f"调用AI流式API - 消息数: {len(ai_messages)}, 提供商: {provider or '默认'}, 模型: {model or '默认'}")
      
              full_response = ""
              content_only_response = ""# 只保存 type: 'content' 的内容
              chunk_count = 0
      
              # 7. 处理流式响应
              asyncfor chunk in ai_manager.generate_streaming_response(
                  messages=ai_messages,
                  provider=provider,
                  model=model,
                  system_prompt=system_prompt
              ):
                  if chunk:
                      full_response += chunk
                      chunk_count += 1
      
                      # 8. 解析chunk数据,过滤出纯文本内容
                      try:
                          if chunk.startswith("data: "):
                              json_str = chunk[6:].strip()  # 移除 "data: " 前缀
                              if json_str:
                                  chunk_data = json.loads(json_str)
                                  # 只累积 type 为 'content' 的内容用于保存到Redis
                                  if chunk_data.get('type') == 'content'and'content'in chunk_data:
                                      content_only_response += chunk_data['content']
                      except (json.JSONDecodeError, KeyError) as e:
                          # 如果解析失败,按原来的方式处理(向后兼容)
                          logger.debug(f"解析chunk数据失败,使用原始内容: {e}")
                          content_only_response += chunk
      
                      # 9. 实时推送数据到前端
                      yield chunk
      
              logger.info(f"流式响应完成 - 用户: {user_id}, 会话: {session_id[:8]}..., 块数: {chunk_count}, 总长度: {len(full_response)}, 内容长度: {len(content_only_response)}")
      
              # 10. 保存AI响应到Redis(只保存纯文本内容)
              ai_msg = ChatMessage(
                  role="assistant",
                  cnotallow=content_only_response,  # 使用过滤后的内容
                  timestamp=time.time()
              )
              await save_message_to_redis(user_id, session_id, ai_msg)
      
              # 11. 发送结束信号
              yieldf"data: {json.dumps({'type': 'end', 'session_id': session_id})}\n\n"
      
          except Exception as e:
              logger.error(f"流式响应错误 - 用户: {user_id}, 会话: {session_id[:8]}..., 错误: {e}")
              error_msg = f"抱歉,服务出现错误:{str(e)}"
              yieldf"data: {json.dumps({'content': error_msg, 'type': 'error'})}\n\n"

      这个函数是流式响应的核心实现,主要包含以下步骤:

      1. 保存用户消息:将用户发送的消息(包括文本和图片)保存到 Redis 中
      2. 获取历史记录:根据用户 ID 和会话 ID 从 Redis 中获取完整的对话历史
      3. 构建系统提示:根据 AI 角色获取对应的系统提示词
      4. 构建消息列表:将历史消息转换为 AI 模型需要的格式
      5. 限制历史长度:只取最近的 N 条消息,避免上下文过长影响性能
      6. 调用 AI API:使用 AI 管理器调用指定提供商的流式 API
      7. 处理流式数据:逐块接收 AI 响应,实时推送给前端
      8. 数据过滤:从流式数据中提取纯文本内容,用于保存到数据库
      9. 实时推送:使用 yield 将数据块实时发送给前端
      10. 保存 AI 响应:将完整的 AI 回复保存到 Redis 中
      11. 发送结束信号:通知前端流式响应已完成

      通过这种设计,实现了带有完整上下文的流式对话功能,用户可以看到 AI 的实时回复,同时所有对话记录都会被持久化保存。

      总结

      本教程通过前端会话 ID 管理、后端历史消息接口和流式对话上下文传递三个核心技术,实现了支持多助手切换和历史记录持久化的 AI 聊天应用。

      AI大模型学习福利

      作为一名热心肠的互联网老兵,我决定把宝贵的AI知识分享给大家。 至于能学习到多少就看你的学习毅力和能力了 。我已将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。

      一、全套AGI大模型学习路线

      AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

      因篇幅有限,仅展示部分资料,需要点击文章最下方名片即可前往获取

      二、640套AI大模型报告合集

      这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

      因篇幅有限,仅展示部分资料,需要点击文章最下方名片即可前往获

      三、AI大模型经典PDF籍

      随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。


      因篇幅有限,仅展示部分资料,需要点击文章最下方名片即可前往获

      四、AI大模型商业化落地方案

      因篇幅有限,仅展示部分资料,需要点击文章最下方名片即可前往获

      作为普通人,入局大模型时代需要持续学习和实践,不断提高自己的技能和认知水平,同时也需要有责任感和伦理意识,为人工智能的健康发展贡献力量

      INFO: 127.0.0.1:58952 - "POST /query HTTP/1.1" 422 Unprocessable Entity 代码: from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional import asyncio import logging import os import time # 添加time模块导入 import uuid # 添加uuid模块导入 from dotenv import load_dotenv import uvicorn from contextlib import asynccontextmanager # 在导入其他模块之前先加载环境变量 load_dotenv() # 导入RAG系统 from rag_main import RAGSystem, RAGConfig # 导入TTS服务 from text_to_speech import tts_service # 添加TTS服务导入 # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 全局RAG系统实例 rag_system: Optional[RAGSystem] = None # Pydantic模型定义 class ImageData(BaseModel): """图片数据模型""" name: str = Field(..., description="图片文件名") data: str = Field(..., description="图片base64数据") class QueryRequest(BaseModel): """查询请求模型""" question: str = Field(..., description="用户问题", min_length=1, max_length=1000) include_history: bool = Field(True, description="是否包含对话历史") images: List[ImageData] = Field(default=[], description="上传的图片列表") user_id: Optional[str] = Field(None, description="用户ID") conversation_history: List[Dict[str, Any]] = Field(default=[], description="对话历史") class BatchQueryRequest(BaseModel): """批量查询请求模型""" questions: List[str] = Field(..., description="问题列表", min_items=1, max_items=10) class QueryResponse(BaseModel): """查询响应模型""" success: bool question: Optional[str] = None answer: Optional[str] = None sources: List[Dict[str, Any]] = [] context_used: Optional[int] = None retrieval_results: Optional[int] = None timing: Optional[Dict[str, float]] = None model_info: Optional[Dict[str, Any]] = None error: Optional[str] = None images: Optional[list] = [] # 新增此行,支持图片多模态返回 class SystemStatus(BaseModel): """系统状态模型""" initialized: bool chat_history_length: int files_exist: Dict[str, bool] config: Dict[str, Any] class ConfigUpdateRequest(BaseModel): """配置更新请求模型""" chunk_size: Optional[int] = Field(None, ge=100, le=2048) chunk_overlap: Optional[int] = Field(None, ge=0, le=500) retrieval_top_k: Optional[int] = Field(None, ge=1, le=50) final_top_k: Optional[int] = Field(None, ge=1, le=20) temperature: Optional[float] = Field(None, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(None, ge=100, le=4096) enable_reranking: Optional[bool] = None enable_safety_check: Optional[bool] = None # 应用生命周期管理 @asynccontextmanager async def lifespan(app: FastAPI): """应用启动和关闭时的生命周期管理""" # 启动时初始化RAG系统 global rag_system try: logger.info("🚀 正在初始化RAG系统...") # 环境变量已经在模块级别加载了,这里不需要再次加载 # load_dotenv() # 删除这行 # 获取API密钥 siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY") volcengine_api_key = os.getenv("VOLCENGINE_API_KEY") if not siliconflow_api_key: raise ValueError("未找到SILICONFLOW_API_KEY环境变量") if not volcengine_api_key: raise ValueError("未找到VOLCENGINE_API_KEY环境变量") # 创建配置 config = RAGConfig( api_key=siliconflow_api_key, volcengine_api_key=volcengine_api_key, document_path="招股说明书1.pdf", output_dir="./", chunk_size=384, chunk_overlap=48, retrieval_top_k=20, final_top_k=8, enable_safety_check=False ) # 初始化RAG系统 rag_system = RAGSystem(config) if not await rag_system.initialize(): raise RuntimeError("RAG系统初始化失败") # 处理文档 if not await rag_system.process_document(): raise RuntimeError("文档处理失败") logger.info("✅ RAG系统初始化成功") except Exception as e: logger.error(f"❌ RAG系统初始化失败: {e}") rag_system = None raise yield # 关闭时清理资源 logger.info("🔄 正在关闭RAG系统...") rag_system = None # 创建FastAPI应用 app = FastAPI( title="RAG智能问答API", description="基于检索增强生成(RAG)的智能问答系统API", version="1.0.0", lifespan=lifespan ) # 添加CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 生产环境中应该限制具体域名 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 辅助函数 def check_system_ready(): """检查系统是否就绪""" if rag_system is None or not rag_system.is_initialized: raise HTTPException( status_code=503, detail="RAG系统未初始化或不可用,请稍后重试" ) # API路由定义 @app.get("/", summary="根路径", description="API根路径,返回欢迎信息") async def root(): """根路径""" return { "message": "欢迎使用RAG智能问答API", "version": "1.0.0", "docs": "/docs", "status": "/status" } @app.get("/health", summary="健康检查", description="检查API服务健康状态") async def health_check(): """健康检查""" return { "status": "healthy", "rag_system_ready": rag_system is not None and rag_system.is_initialized } @app.get("/status", response_model=SystemStatus, summary="系统状态", description="获取RAG系统详细状态") async def get_system_status(): """获取系统状态""" check_system_ready() status = rag_system.get_system_status() return SystemStatus(**status) # 在现有的Pydantic模型定义后添加 class ImageDescriptionRequest(BaseModel): """图片描述请求模型""" images: List[ImageData] = Field(..., description="要描述的图片列表", min_items=1) detailed: bool = Field(True, description="是否生成详细描述") class ImageDescriptionResponse(BaseModel): """图片描述响应模型""" success: bool descriptions: List[Dict[str, str]] = [] total_images: int processing_time: Optional[float] = None error: Optional[str] = None # 在现有的API端点后添加新的图片描述端点 @app.post("/describe_image", response_model=ImageDescriptionResponse, summary="图片描述", description="上传图片并获取AI生成的详细描述") async def describe_images(request: ImageDescriptionRequest): """图片描述接口 - 专门用于处理用户上传的图片""" import time start_time = time.time() try: from image_captioning import describe_image import base64 descriptions = [] for img in request.images: try: # 提取base64数据 if img.data.startswith('data:image/'): base64_data = img.data.split(',')[1] else: base64_data = img.data # 获取图片格式 img_format = img.name.split('.')[-1].lower() if '.' in img.name else 'png' # 生成图片描述 description = describe_image(base64_data, img_format) descriptions.append({ "image_name": img.name, "description": description, "status": "success" }) logger.info(f"✓ 成功描述图片: {img.name}") except Exception as e: logger.error(f"处理图片 {img.name} 失败: {e}") descriptions.append({ "image_name": img.name, "description": f"无法生成描述: {str(e)}", "status": "error" }) processing_time = time.time() - start_time return ImageDescriptionResponse( success=True, descriptions=descriptions, total_images=len(request.images), processing_time=processing_time ) except Exception as e: logger.error(f"图片描述处理异常: {e}") return ImageDescriptionResponse( success=False, descriptions=[], total_images=len(request.images), error=str(e) ) # 修改现有的query_rag函数,增强图片处理逻辑 @app.post("/query", response_model=QueryResponse, summary="智能问答", description="提交问题并获取AI回答") async def query_rag(request: QueryRequest): """智能问答接口""" check_system_ready() try: # 处理图片描述 image_descriptions = [] if request.images: from image_captioning import describe_image import base64 for img in request.images: try: # 添加日志以检查 img.data 和 img.name 的值 logger.info(f"处理图片: 名称={img.name}, 数据前100字符={img.data[:100]}") # 提取base64数据 if img.data.startswith('data:image/'): base64_data = img.data.split(',')[1] else: base64_data = img.data # 获取图片格式 img_format = img.name.split('.')[-1].lower() if '.' in img.name else 'png' # 生成图片描述 description = describe_image(base64_data, img_format) image_descriptions.append(f"图片 {img.name}: {description}") logger.info(f"✓ 成功处理图片: {img.name}") except Exception as e: logger.error(f"处理图片 {img.name} 失败: {e}") image_descriptions.append(f"图片 {img.name}: 无法生成描述") # 判断是否为纯图片查询(没有文字问题或问题很简单) is_image_only_query = ( request.images and (not request.question.strip() or request.question.strip().lower() in ['这是什么', '描述一下', '看看这个', '这个图片', '图片内容']) ) if is_image_only_query and image_descriptions: # 纯图片查询,直接返回图片描述 answer = "根据您上传的图片,我看到:\n\n" + "\n\n".join(image_descriptions) # 在 query_rag 方法中,构建 images 列表 images = [] for idx, img in enumerate(request.images): images.append({ "url": f"/images/{img.name}", "description": image_descriptions[idx] if idx < len(image_descriptions) else "" }) # 在 QueryResponse 返回时加入 images 字段 return QueryResponse( success=True, question=request.question, answer=answer, sources=[], images=images, context_used=0, retrieval_results=0, timing={"total_time": 0.1}, model_info={"type": "image_description_only"} ) # 构建完整问题(包含图片描述) full_question = request.question if image_descriptions: full_question += "\n\n上传的图片信息:\n" + "\n".join(image_descriptions) # 调用RAG系统 result = await rag_system.query( user_question=full_question, include_history=request.include_history ) # 保存图片描述到对话历史 if image_descriptions: result['image_descriptions'] = image_descriptions return QueryResponse(**result) except Exception as e: logger.error(f"查询处理异常: {e}") raise HTTPException( status_code=500, detail=f"查询处理失败: {str(e)}" ) @app.post("/batch_query", response_model=List[QueryResponse], summary="批量问答", description="批量提交问题并获取AI回答") async def batch_query_rag(request: BatchQueryRequest): """批量查询接口""" check_system_ready() try: results = await rag_system.batch_query(request.questions) return [QueryResponse(**result) for result in results] except Exception as e: logger.error(f"批量查询处理异常: {e}") raise HTTPException( status_code=500, detail=f"批量查询处理失败: {str(e)}" ) @app.post("/clear_history", summary="清空历史", description="清空对话历史记录") async def clear_chat_history(): """清空对话历史""" check_system_ready() try: rag_system.clear_history() return {"success": True, "message": "对话历史已清空"} except Exception as e: logger.error(f"清空历史异常: {e}") raise HTTPException( status_code=500, detail=f"清空历史失败: {str(e)}" ) @app.post("/reprocess_document", summary="重新处理文档", description="强制重新处理文档") async def reprocess_document(background_tasks: BackgroundTasks): """重新处理文档""" check_system_ready() async def process_task(): try: await rag_system.process_document(force_reprocess=True) logger.info("文档重新处理完成") except Exception as e: logger.error(f"文档重新处理失败: {e}") background_tasks.add_task(process_task) return {"success": True, "message": "文档重新处理任务已启动,请稍后查看状态"} @app.get("/config", summary="获取配置", description="获取当前系统配置") async def get_config(): """获取当前配置""" check_system_ready() config = rag_system.config return { "embedding_model": config.embedding_model, "chunk_size": config.chunk_size, "chunk_overlap": config.chunk_overlap, "retrieval_top_k": config.retrieval_top_k, "final_top_k": config.final_top_k, "generation_model": config.generation_model, "temperature": config.temperature, "max_tokens": config.max_tokens, "enable_reranking": config.enable_reranking, "enable_safety_check": config.enable_safety_check } @app.put("/config", summary="更新配置", description="更新系统配置参数") async def update_config(request: ConfigUpdateRequest): """更新配置""" check_system_ready() try: config = rag_system.config updated_fields = [] # 更新配置字段 if request.chunk_size is not None: config.chunk_size = request.chunk_size updated_fields.append("chunk_size") if request.chunk_overlap is not None: config.chunk_overlap = request.chunk_overlap updated_fields.append("chunk_overlap") if request.retrieval_top_k is not None: config.retrieval_top_k = request.retrieval_top_k updated_fields.append("retrieval_top_k") if request.final_top_k is not None: config.final_top_k = request.final_top_k updated_fields.append("final_top_k") if request.temperature is not None: config.temperature = request.temperature updated_fields.append("temperature") if request.max_tokens is not None: config.max_tokens = request.max_tokens updated_fields.append("max_tokens") if request.enable_reranking is not None: config.enable_reranking = request.enable_reranking updated_fields.append("enable_reranking") if request.enable_safety_check is not None: config.enable_safety_check = request.enable_safety_check updated_fields.append("enable_safety_check") return { "success": True, "message": f"配置已更新: {', '.join(updated_fields)}", "updated_fields": updated_fields } except Exception as e: logger.error(f"配置更新异常: {e}") raise HTTPException( status_code=500, detail=f"配置更新失败: {str(e)}" ) # 错误处理 @app.exception_handler(Exception) async def global_exception_handler(request, exc): """全局异常处理""" logger.error(f"未处理的异常: {exc}") if isinstance(exc, HTTPException): return JSONResponse(status_code=exc.status_code, content={ "success": False, "error": exc.detail }) return { "success": False, "error": "服务器内部错误", "detail": str(exc) } # 主函数 import os if __name__ == "__main__": # 开发环境启用重载,生产环境禁用 reload_mode = os.getenv("UVICORN_RELOAD", "true").lower() == "true" uvicorn.run( "rag_api:app", host="0.0.0.0", port=8000, reload=reload_mode, log_level="info" ) from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles # 添加静态文件服务 app.mount("/images", StaticFiles(directory="images"), name="images") @app.get("/image/{image_name}") async def get_image(image_name: str): """获取图片文件""" image_path = os.path.join("images", image_name) if os.path.exists(image_path): return FileResponse(image_path) else: raise HTTPException(status_code=404, detail="图片未找到") from speech_to_text import process_speech_to_text from fastapi import UploadFile, File # 添加语音转文字接口 @app.post("/speech-to-text") async def speech_to_text(audio_file: UploadFile = File(...)): """语音转文字接口""" try: # 验证文件类型 if not audio_file.content_type.startswith('audio/'): raise HTTPException(status_code=400, detail="请上传音频文件") # 处理语音转文字 result = await process_speech_to_text(audio_file) return JSONResponse(content=result) except Exception as e: logger.error(f"语音转文字API错误: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/voice-query") async def voice_query(audio_file: UploadFile = File(...)): """语音查询接口:语音转文字 + RAG查询""" try: # 1. 语音转文字 speech_result = await process_speech_to_text(audio_file) if not speech_result["success"]: return JSONResponse(content=speech_result) transcribed_text = speech_result["transcribed_text"] # 添加调试日志 logger.info(f"原始转录结果: {speech_result['transcribed_text']}") logger.info(f"转录文本类型: {type(transcribed_text)}") # 如果转录结果是列表格式,需要提取文本 if isinstance(transcribed_text, list) and len(transcribed_text) > 0: if isinstance(transcribed_text[0], dict) and 'text' in transcribed_text[0]: transcribed_text = transcribed_text[0]['text'] elif isinstance(transcribed_text[0], str): transcribed_text = transcribed_text[0] logger.info(f"处理后的查询文本: {transcribed_text}") # 2. 使用转录文本进行RAG查询 if hasattr(rag_system, 'query_with_images'): query_result = await rag_system.query_with_images(transcribed_text) else: query_result = await rag_system.query(transcribed_text) # 3. 返回完整结果 return JSONResponse(content={ "success": True, "transcribed_text": transcribed_text, "audio_file_path": speech_result["audio_file_path"], "answer": query_result["answer"], "sources": query_result["sources"], "images": query_result.get("images", []), "processing_time": query_result.get("timing", {}).get("total_time", 0) }) except Exception as e: logger.error(f"语音查询API错误: {e}") @app.post("/query-with-tts") async def query_with_tts(request: QueryRequest): """ 智能问答并返回TTS音频 """ check_system_ready() try: start_time = time.time() # 执行RAG查询 result = await rag_system.query( user_question=request.question, include_history=request.include_history ) if not result.get('success', False): raise HTTPException(status_code=500, detail=result.get('error', '查询失败')) answer_text = result.get('answer', '') # 尝试生成TTS音频(增强错误处理) try: # 检查方法是否存在 if hasattr(tts_service, 'text_to_speech'): audio_data = tts_service.text_to_speech(answer_text) elif hasattr(tts_service, 'text_to_speech_stream'): # 使用流式方法作为备选 audio_stream = tts_service.text_to_speech_stream(answer_text) if audio_stream: # 处理流式数据 audio_data = b''.join(audio_stream) else: audio_data = None else: logger.error("No available TTS method found") audio_data = None except Exception as e: logger.error(f"TTS service error: {str(e)}") audio_data = None if audio_data: # 创建临时文件保存音频 audio_filename = f"audio_{uuid.uuid4().hex}.wav" audio_path = os.path.join("audio_files", audio_filename) # 确保audio_files目录存在 os.makedirs("audio_files", exist_ok=True) # 保存音频文件 if tts_service.save_audio_to_file(audio_data, audio_path): result['audio_file'] = audio_filename result['audio_url'] = f"/audio/{audio_filename}" logger.info(f"✓ TTS音频生成成功: {audio_filename}") else: logger.warning("音频文件保存失败") result['tts_error'] = "音频文件保存失败" else: logger.warning("TTS转换失败") result['tts_error'] = "TTS转换失败" processing_time = time.time() - start_time result['processing_time'] = processing_time return QueryResponse(**result) except HTTPException: raise except Exception as e: logger.error(f"Query with TTS error: {str(e)}") raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") # 添加音频文件服务端点 @app.get("/audio/{audio_filename}") async def get_audio(audio_filename: str): """ 获取生成的音频文件 """ audio_path = os.path.join("audio_files", audio_filename) if not os.path.exists(audio_path): raise HTTPException(status_code=404, detail="音频文件不存在") return FileResponse( audio_path, media_type="audio/wav", filename=audio_filename ) # 添加TTS独立端点 @app.post("/text-to-speech") async def text_to_speech_endpoint(text: str, voice: str = "Cherry"): """ 独立的文字转语音端点 """ try: audio_data = tts_service.text_to_speech_stream(text, voice) if not audio_data: raise HTTPException(status_code=500, detail="TTS转换失败") # 创建临时文件 audio_filename = f"tts_{uuid.uuid4().hex}.wav" audio_path = os.path.join("audio_files", audio_filename) # 确保目录存在 os.makedirs("audio_files", exist_ok=True) # 保存音频文件 if tts_service.save_audio_to_file(audio_data, audio_path): return { "success": True, "audio_file": audio_filename, "audio_url": f"/audio/{audio_filename}" } else: raise HTTPException(status_code=500, detail="音频文件保存失败") except Exception as e: logger.error(f"TTS endpoint error: {str(e)}") raise HTTPException(status_code=500, detail=f"TTS转换失败: {str(e)}") def build_multimodal_answer(answer_text, images): multimodal_part = "" for idx, img in enumerate(images, 1): multimodal_part += f"如图{idx}所示:{img['description']}\n【图片{idx}:{img['url']}, {img['description']}】\n" return answer_text + "\n" + multimodal_part if images else answer_text 检查代码错误
      08-22
      你的 FastAPI 项目整体结构清晰,功能完整,涵盖了 RAG 查询、图像描述、语音识别、TTS、配置管理、文件服务等多个模块。不过你遇到的错误: ``` INFO: 127.0.0.1:58952 - "POST /query HTTP/1.1" 422 Unprocessable Entity ``` 是一个 **请求体验证失败** 的错误,通常是由于请求体与定义的 Pydantic 模型不匹配导致的。我们来逐步分析这个问题。 --- ### 🔍 错误解析 `422 Unprocessable Entity` 是 FastAPI 返回的一种验证错误,表示客户端提交的数据格式不正确,**无法通过 Pydantic 模型的验证规则**。 --- ### ✅ 问题定位 你定义了 `QueryRequest` 模型如下: ```python class QueryRequest(BaseModel): question: str = Field(..., description="用户问题", min_length=1, max_length=1000) include_history: bool = Field(True, description="是否包含对话历史") images: List[ImageData] = Field(default=[], description="上传的图片列表") user_id: Optional[str] = Field(None, description="用户ID") conversation_history: List[Dict[str, Any]] = Field(default=[], description="对话历史") ``` FastAPI 会严格按照这个模型来验证请求体。如果请求体中字段名不匹配、类型错误、缺少必填字段(如 `question`)、字段值不满足约束(如长度超过 1000),就会导致 422 错误。 --- ### 🧪 常见错误原因 1. **字段名不一致** - 请求体中的字段名与模型定义不一致,例如 `question` 写成 `Question` 或 `query`。 2. **字段类型错误** - `include_history` 应该是布尔值,如果传入 `"true"`(字符串)或 `1`(整数),就会验证失败。 3. **缺少必填字段** - `question` 是必填项,如果请求体中没有提供,就会报错。 4. **字段值不符合约束** - `question` 长度超过 `max_length=1000`,也会导致验证失败。 5. **嵌套字段格式错误** - `images` 字段是 `List[ImageData]` 类型,每个元素必须是 `ImageData` 模型实例,如果格式不对(如缺少 `name` 或 `data` 字段)也会失败。 --- ### 🛠 示例错误请求体 ```json { "Question": "这是一道问题", // 错误:字段名应该是 question "include_history": "true", // 错误:应该是布尔值 "images": [ { "name": "image1.png", "data": "base64_data" // 如果 base64 数据不正确,也可能出错 } ], "conversation_history": [ { "role": "user", "content": "你好" // content 应该是 Dict 类型?还是 str?需要看实际定义 } ] } ``` --- ### ✅ 解决方案 #### ✅ 1. 确保请求体字段名一致 ```json { "question": "这是一个问题", "include_history": true, "images": [ { "name": "image1.png", "data": "base64_encoded_string" } ], "user_id": "user_123", "conversation_history": [ { "role": "user", "content": "你好" }, { "role": "assistant", "content": "你好!有什么我可以帮助你的吗?" } ] } ``` #### ✅ 2. 使用 `curl` 或 Postman 验证请求 ```bash curl -X POST "http://localhost:8000/query" \ -H "Content-Type: application/json" \ -d '{ "question": "这是一个问题", "include_history": true, "images": [ { "name": "image1.png", "data": "base64_encoded_string" } ], "user_id": "user_123", "conversation_history": [] }' ``` #### ✅ 3. 使用 Swagger UI 测试 访问 `http://localhost:8000/docs`,找到 `/query` 接口,使用提供的表单测试接口,确保输入格式正确。 --- ### ✅ 优化建议 - **添加详细的日志记录**:在 `/query` 接口中添加请求体日志输出,方便调试。 - **使用 `ValidationError` 捕获**:可以在全局异常处理器中捕获 `ValidationError`,返回更清晰的错误信息。 - **使用 `@app.exception_handler(RequestValidationError)`**:自定义请求验证错误响应。 ```python from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): logger.error(f"请求验证失败: {exc}") return JSONResponse( status_code=422, content={ "success": False, "error": "请求格式错误", "detail": exc.errors() } ) ``` --- ###
      评论
      成就一亿技术人!
      拼手气红包6.0元
      还能输入1000个字符
       
      红包 添加红包
      表情包 插入表情
       条评论被折叠 查看
      添加红包

      请填写红包祝福语或标题

      红包个数最小为10个

      红包金额最低5元

      当前余额3.43前往充值 >
      需支付:10.00
      成就一亿技术人!
      领取后你会自动成为博主和红包主的粉丝 规则
      hope_wisdom
      发出的红包
      实付
      使用余额支付
      点击重新获取
      扫码支付
      钱包余额 0

      抵扣说明:

      1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
      2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

      余额充值