# stream_multi_qa_async.py
# RAG 系统主程序 | GPU 安全模式 | 支持并发控制与流式输出
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
import os
from typing import List, Tuple, AsyncGenerator
from langchain_core.documents import Document
from FlagEmbedding import FlagReranker
from langchain_community.embeddings import HuggingFaceEmbeddings
import asyncio
from asyncio import Semaphore
import psutil
import torch
from openai import AsyncOpenAI
# =================== 模型路径配置 ===================
BGE_EMBEDDING_MODEL_PATH = "/data/models/modelscope/models/BAAI/bge-large-zh-v1___5"
BGE_RERANKER_MODEL_PATH = "/data/models/modelscope/models/BAAI/bge-reranker-large"
QA_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
# =================== 资源管理器(全局单例)===================
class RAGResourceManager:
def __init__(self):
self.cpu_cores = psutil.cpu_count(logical=True)
print(f"🖥️ CPU 核心数: {self.cpu_cores}")
# 🔐 控制最大并发请求数(建议 ≤2)
self.total_concurrent_semaphore = Semaphore(2)
# ❗❗ 关键:所有 GPU 推理任务共用一个信号量
self.gpu_task_semaphore = Semaphore(1) # 同一时间只允许一个 GPU 推理任务
# 单例模型(避免重复加载)
self._embeddings_for_load = None # 用于加载 FAISS(CPU dummy)
self._reranker = None # 用于重排序(GPU)
@property
def embeddings_for_loading(self):
"""返回一个“虚拟”嵌入模型,仅用于加载 FAISS,不执行真实 encode"""
if self._embeddings_for_load is None:
print("🔄 初始化 FAISS 加载专用嵌入模型(虚拟编码,不耗 GPU)...")
# 使用 CPU,但关键:我们手动禁用其 encode 功能
self._embeddings_for_load = HuggingFaceEmbeddings(
model_name=BGE_EMBEDDING_MODEL_PATH,
model_kwargs={'device': 'cpu'},
encode_kwargs={'batch_size': 1}
)
# 🛑 手动覆盖方法,防止 FAISS 在查询时调用 encode(否则会 OOM)
def dummy_embed_documents(texts):
return [[0.] * 1024 for _ in texts] # 返回零向量占位
def dummy_embed_query(text):
return [0.] * 1024
self._embeddings_for_load.embed_documents = dummy_embed_documents
self._embeddings_for_load.embed_query = dummy_embed_query
print("✅ FAISS 加载模型准备就绪(使用虚拟编码)")
return self._embeddings_for_load
async def get_reranker(self):
"""延迟加载重排序模型到 GPU,全局唯一实例"""
if self._reranker is None:
print("🔄 正在加载 BGE 重排序模型到 GPU...")
loop = asyncio.get_event_loop()
def load_reranker():
return FlagReranker(
BGE_RERANKER_MODEL_PATH,
use_fp16=True, # 启用半精度,节省显存
device='cuda' # 使用 GPU
)
try:
self._reranker = await loop.run_in_executor(None, load_reranker)
print("✅ 重排序模型加载完成 (GPU)")
except Exception as e:
print(f"❌ 重排序模型加载失败: {e}")
raise RuntimeError(f"Failed to load reranker: {e}")
return self._reranker
# 全局资源管理器(单例)
rag_resource_manager = RAGResourceManager()
def print_gpu_memory(step=""):
"""打印当前 GPU 显存使用情况"""
if torch.cuda.is_available():
mem = torch.cuda.memory_allocated(0) / 1024 ** 3
res = torch.cuda.memory_reserved(0) / 1024 ** 3
print(f"📊 [{step}] GPU 显存: 已分配={mem:.2f}GB, 已保留={res:.2f}GB")
async def _search_single_db(query: str, db_dir: str, db_path: str, k1: int) -> List[Tuple[Document, float]]:
"""异步检索单个 FAISS 库"""
full_path = os.path.join(db_dir, db_path)
faiss_file = os.path.join(full_path, "index.faiss")
pkl_file = os.path.join(full_path, "index.pkl")
if not (os.path.exists(faiss_file) and os.path.exists(pkl_file)):
print(f"⚠️ 缺少文件: {full_path},跳过...")
return []
try:
loop = asyncio.get_event_loop()
embeddings = rag_resource_manager.embeddings_for_loading
def load_faiss():
return FAISS.load_local(
full_path,
embeddings=embeddings,
allow_dangerous_deserialization=True
)
db = await loop.run_in_executor(None, load_faiss)
docs_with_scores = await loop.run_in_executor(None, db.similarity_search_with_score, query, k1)
return [(doc, score) for doc, score in docs_with_scores]
except Exception as e:
print(f"❌ 查询 {db_path} 失败: {str(e)}")
return []
async def query_vectors_async(
query: str,
db_dir: str,
db_paths: List[str],
k1: int = 15,
k2: int = 80,
k3: int = 15
) -> List[Tuple[Document, float]]:
"""
主检索流程:多库并行检索 → 去重 → 重排序
"""
print(f"🔍 开始检索 {len(db_paths)} 个向量库...")
tasks = [_search_single_db(query, db_dir, path, k1) for path in db_paths]
results_list = await asyncio.gather(*tasks, return_exceptions=True)
all_results = []
for i, result in enumerate(results_list):
if isinstance(result, Exception):
print(f"⚠️ 查询 {db_paths[i]} 出错: {result}")
elif result:
all_results.extend(result)
print(f"🔍 共检索到 {len(all_results)} 个原始文档片段")
# 去重
seen = set()
unique_results = []
for doc, score in all_results:
if doc.page_content not in seen:
seen.add(doc.page_content)
unique_results.append((doc, score))
print(f"✅ 去重后剩余 {len(unique_results)} 个唯一文档")
if not unique_results:
return []
# 取 top-k2 进入重排序
top_k2_docs = sorted(unique_results, key=lambda x: x[1], reverse=True)[:k2]
candidate_docs = [doc for doc, _ in top_k2_docs]
original_scores = {doc.page_content: score for doc, score in top_k2_docs}
pairs = [[query, doc.page_content] for doc in candidate_docs]
print(f"🔁 准备重排序: {len(pairs)} 个候选")
# ❗❗ 获取 GPU 访问权(关键!防并发 OOM)
async with rag_resource_manager.gpu_task_semaphore:
print_gpu_memory("进入重排序前")
reranker = await rag_resource_manager.get_reranker()
loop = asyncio.get_event_loop()
batch_size = 8 # 控制每批大小,防小内存分配失败
rerank_scores = []
for i in range(0, len(pairs), batch_size):
batch = pairs[i:i + batch_size]
def compute(b):
return reranker.compute_score(b, batch_size=len(b))
try:
scores = await loop.run_in_executor(None, compute, batch)
if isinstance(scores, float):
scores = [scores]
rerank_scores.extend(scores)
except Exception as e:
print(f"❌ 批次 {i // batch_size} 重排序失败: {e}")
# 补全缺失值
rerank_scores.extend([0.0] * (len(batch) - len(scores)))
# 构造最终结果
final_results = []
for doc, r_score in zip(candidate_docs, rerank_scores):
new_meta = doc.metadata.copy()
new_meta["relevance_score"] = r_score
new_doc = Document(page_content=doc.page_content, metadata=new_meta)
vector_score = original_scores[doc.page_content]
final_results.append((new_doc, vector_score))
# 按重排序分数降序排列
final_results.sort(key=lambda x: x[0].metadata["relevance_score"], reverse=True)
return final_results[:k3]
def get_related_content(docs_with_scores: List[Tuple[Document, float]]) -> str:
"""提取上下文文本"""
contents = []
for doc, _ in docs_with_scores:
cleaned = doc.page_content.strip().replace("\n\n", "\n")
if cleaned:
contents.append(cleaned)
return "\n".join(contents)
def build_rag_prompt(docs: List[Tuple[Document, float]], question: str) -> str:
"""构建 RAG 提示词"""
context = get_related_content(docs)
PROMPT_TEMPLATE = """
你是一个专业的知识库问答AI助手,请严格按照以下规则回答问题:
【回答规则】
1. 知识库优先原则
- 若上下文存在明确答案:
* 回答格式:
「根据知识库有关内容可知:[具体回答内容]」
* 流程类问题分步骤说明(1. 第一步... 2. 第二步...)
- 若上下文无答案但属于以下情况:
* 数学公式/物理定律等基础科学知识
* 广泛认可的常识(如国家首都等)
* 回答格式:
「在已知知识库中未找到相关信息,根据通用知识回答如下:[回答内容]」
- 完全无法回答时:
* 回答格式:
「抱歉,未能找到相关信息,无法做出回答」
2. 严格性要求
- 禁止对知识库内容进行推测或扩展
- 非知识库内容必须明确声明
【上下文】
{context}
【问题】
{question}
"""
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
return prompt.format(context=context, question=question)
async def get_vllm_stream_response_async(
model_name: str,
docs: List[Tuple[Document, float]],
question: str
) -> AsyncGenerator[str, None]:
"""调用 vLLM 获取流式回答"""
try:
client = AsyncOpenAI(base_url="http://localhost:5000/v1", api_key="token-abc123")
prompt = build_rag_prompt(docs, question)
print(f"📝 提示词长度: {len(prompt)} 字符")
response = await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
max_tokens=4096,
stream=True
)
async for chunk in response:
if chunk.choices and (delta := chunk.choices[0].delta.content):
yield delta
except Exception as e:
error_msg = f"⚠️ vLLM 流式响应出错: {str(e)}"
print(error_msg)
yield error_msg
async def process_answer_with_stream_async(
model_name: str,
query: str,
db_dir: str,
db_paths: List[str],
k1: int = 15,
k2: int = 80,
k3: int = 15,
final_k: int = 18
) -> Tuple[AsyncGenerator[str, None], List[str]]:
"""
主接口:输入问题 → 返回流式生成器 + 引用文件列表
"""
async with rag_resource_manager.total_concurrent_semaphore:
print_gpu_memory("请求开始")
try:
# 1. 检索 + 重排序
results = await query_vectors_async(query, db_dir, db_paths, k1, k2, k3)
if not results:
print("📭 未检索到任何相关文档")
async def fallback():
yield "抱歉,未能从知识库中找到与问题相关的内容,无法做出准确回答。"
return fallback(), []
# 2. 过滤低相关性结果
filtered = [
(doc, vscore) for doc, vscore in results
if doc.metadata.get("relevance_score", 0) > 0
]
filtered.sort(key=lambda x: x[0].metadata["relevance_score"], reverse=True)
final_results = filtered[:final_k]
if not final_results:
print("📭 过滤后无有效文档")
async def fallback():
yield "根据相关性分析,未找到足够匹配的知识内容,无法回答该问题。"
return fallback(), []
# 3. 收集引用来源
use_files = list({
doc.metadata.get("source", "unknown")
for doc, _ in final_results
if "source" in doc.metadata
})
# 4. 返回流式生成器
stream_generator = get_vllm_stream_response_async(model_name, final_results, query)
return stream_generator, use_files
except Exception as e:
print(f"💥 处理过程发生错误: {str(e)}")
async def error_gen():
yield f"系统错误: {str(e)}"
return error_gen(), []
# =================== 同步兼容接口 ===================
def answer(model_name, query, db_dir, db_paths, k1=15, k2=80, k3=15, final_k=18):
"""同步调用接口(供旧代码兼容)"""
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
process_answer_with_stream_async(model_name, query, db_dir, db_paths, k1, k2, k3, final_k)
)
finally:
loop.close()
# =================== 异步主函数示例 ===================
async def main_async():
db_dir = "../Faiss"
k1, k2, k3, final_k = 15, 80, 15, 18
if not os.path.exists(db_dir):
print(f"❌ 向量库目录不存在: {db_dir}")
return
db_paths = [f for f in os.listdir(db_dir) if os.path.isdir(os.path.join(db_dir, f))]
print(f"📁 发现 {len(db_paths)} 个向量库: {db_paths}")
while True:
print("\n" + "=" * 60)
query = input("请输入你想提问的问题(输入'q'退出): ").strip()
if query.lower() in ['q', 'quit', 'exit']:
print("👋 程序结束,再见!")
break
if not query:
print("❗ 问题不能为空,请重新输入!")
continue
print_gpu_memory("请求前")
stream_gen, files = await process_answer_with_stream_async(
QA_MODEL_NAME, query, db_dir, db_paths, k1, k2, k3, final_k
)
print(f"📎 引用文件: {files}")
print("💬 回答内容:")
full_response = ""
async for chunk in stream_gen:
print(chunk, end="", flush=True)
full_response += chunk
print(f"\n📦 完整回答: {full_response}")
print_gpu_memory("请求完成后")
if __name__ == "__main__":
# 建议启动前设置环境变量防碎片
print_gpu_memory("程序启动")
asyncio.run(main_async())
这个代码报错 self._embeddings_for_load = HuggingFaceEmbeddings(
❌ 查询 0002 失败: "HuggingFaceEmbeddings" object has no field "embed_documents"