# mcp_server.py
from datetime import datetime
from mcp.server.fastmcp import FastMCP
import logging
import os
import asyncio
import hashlib
import json
import threading
import time
import numpy as np
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_community.llms import OpenAIChat
from langchain.chains import RetrievalQA
from ollama_embeding import CustomEmbeding
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.document_loaders import (
TextLoader,
PyPDFLoader,
Docx2txtLoader,
UnstructuredPowerPointLoader,
UnstructuredExcelLoader,
CSVLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredEmailLoader,
UnstructuredFileLoader
)
# 配置日志记录器
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 创建 FastMCP 实例
mcp = FastMCP("VectorService")
class VectorService:
def __init__(self):
self.embedding_function = CustomEmbeding('shaw/dmeta-embedding-zh')
self.docstore = InMemoryDocstore()
self.index = faiss.IndexFlatL2(768)
self.vector_store = None
self.existing_index_path = "E:/llm_rag/faiss_index/index.faiss"
self.existing_index_pkl_path = "E:/llm_rag/faiss_index/index.pkl"
self.is_processing = False
self.last_processed_count = 0
self.is_initialized = False # 添加初始化完成标志
self.load_or_init_vector_store() # 初始化向量存储
self.is_initialized = True # 初始化完成
def load_or_init_vector_store(self):
if self.vector_store is not None:
return self.vector_store # 已初始化
if os.path.exists(self.existing_index_path) and os.path.exists(self.existing_index_pkl_path):
vector_store = FAISS.load_local(
"E:/llm_rag/faiss_index",
embeddings=self.embedding_function,
allow_dangerous_deserialization=True
)
logger.info("Loaded existing vector store.")
self.vector_store = vector_store
return vector_store
else:
vector_store = FAISS(
embedding_function=self.embedding_function,
index=self.index,
docstore=self.docstore,
index_to_docstore_id={}
)
logger.info("Initialized new vector store.")
self.vector_store = vector_store
return vector_store
def get_id(self, file_path):
"""Generate file id"""
return hashlib.md5(file_path.encode()).hexdigest()
def load_document(self, file_path: str):
file_ext = file_path.split('.')[-1].lower()
logger.info(f"Loading document from {file_path}")
loader_map = {
'txt': TextLoader,
'pdf': PyPDFLoader,
'docx': Docx2txtLoader,
'pptx': UnstructuredPowerPointLoader,
'xlsx': UnstructuredExcelLoader,
'csv': CSVLoader,
'html': UnstructuredHTMLLoader,
'htm': UnstructuredHTMLLoader,
'md': UnstructuredMarkdownLoader,
'eml': UnstructuredEmailLoader,
'msg': UnstructuredEmailLoader
}
if file_ext not in loader_map:
logger.warning(f"Unsupported file type: {file_ext}")
return None
loader_class = loader_map.get(file_ext, UnstructuredFileLoader)
loader = loader_class(file_path)
try:
documents = loader.load()
logger.info(f"Loaded {len(documents)} documents from {file_path}")
return documents
except Exception as e:
logger.error(f"Error loading {file_path}: {str(e)}")
return None
def _add_vector_metadata(self, file_name, file_name_path):
"""
添加文件元数据
:return:
"""
docs = []
metadatas = []
try:
file_stats = os.stat(file_name_path)
file_size = file_stats.st_size
res = self.load_document(file_name_path)
if res:
# 生成文件唯一标识(使用文件路径的哈希值)
id = self.get_id(file_name_path)
for doc in res:
# 合并用户提供的元数据和文档自身的元数据
doc_metadata = doc.metadata.copy()
doc_metadata.update({
"source": file_name,
"file_path": file_name_path,
"id": id,
"upload_time": datetime.now().isoformat()
})
# docs.append(doc.page_content.strip())
# 将文件名融入内容(提高文件名的权重)
enhanced_content = f"文件名: {file_name}\n内容: {doc.page_content.strip()}"
docs.append(enhanced_content)
metadatas.append(doc_metadata)
logger.info(f"Processed {file_name} ({file_size / (1024 * 1024.0):.2f} MB)")
except Exception as e:
logger.error(f"Error processing {file_name_path}: {str(e)}")
return docs, metadatas
def process_documents(self, data_path: str):
"""把所有文件进行批量向量化,添加文件唯一标识"""
try:
self.is_processing = True
all_docs = []
all_metadatas = []
for root, dirs, files in os.walk(data_path):
for file_name in files:
file_name_path = os.path.join(root, file_name)
logger.info(f"Processing file: {file_name_path}")
# 调用 _add_vector_metadata 处理文件
docs, metadatas = self._add_vector_metadata(
file_name=file_name,
file_name_path=file_name_path
)
# 累积结果
all_docs.extend(docs)
all_metadatas.extend(metadatas)
# 保存所有文件的向量数据
self._save_data_vector(docs=all_docs, metadatas=all_metadatas)
self.last_processed_count = len(all_docs)
self.is_processing = False
return {
"status": "success",
"message": "Documents processed successfully",
"document_count": len(all_docs)
}
except Exception as e:
logger.error(f"Error processing documents: {str(e)}")
self.is_processing = False
return {"status": "error", "message": str(e)}
def _save_data_vector(self, docs, metadatas):
"""Save the data vector to faiss"""
self.vector_store = self.load_or_init_vector_store()
docs = [doc for doc in docs if doc]
try:
logger.info("Starting embedding process...")
self.vector_store.add_texts(texts=docs, metadatas=metadatas)
logger.info("Embedding process completed.")
except Exception as e:
logger.error(f"An error occurred during embedding: {str(e)}")
try:
logger.info("Saving updated vector store...")
self.vector_store.save_local("E:/llm_rag/faiss_index")
logger.info("Updated vector store saved to E:/llm_rag/faiss_index.")
except Exception as e:
logger.error(f"An error occurred during saving: {str(e)}")
return docs
def check_process_status(self):
"""检查处理状态"""
if self.is_processing:
return {
"status": "processing",
"message": "Documents are being processed"
}
else:
if os.path.exists(self.existing_index_path) and os.path.exists(self.existing_index_pkl_path):
if self.last_processed_count > 0:
return {
"status": "success",
"message": "Vector data has been updated",
"last_processed_count": self.last_processed_count
}
else:
return {
"status": "ready",
"message": "Vector store exists but no new data processed"
}
else:
return {
"status": "empty",
"message": "No vector store exists"
}
def add_vector(self, new_file_name_path: str, new_file_name: str):
"""添加单个文件的向量"""
try:
self.is_processing = True
docs, metadatas = self._add_vector_metadata(
file_name=new_file_name,
file_name_path=new_file_name_path
)
self._save_data_vector(docs=docs, metadatas=metadatas)
self.last_processed_count = len(docs)
self.is_processing = False
return {
"status": "success",
"message": "Vector added successfully"
}
except Exception as e:
logger.error(f"Error adding vector: {str(e)}")
self.is_processing = False
return {
"status": "error",
"message": str(e)
}
vector_service = VectorService()
@mcp.tool()
def process_documents(data_path: str):
"""处理指定路径下的所有文档并生成向量存储"""
logger.info(f"Starting to process documents in {data_path}")
return vector_service.process_documents(data_path)
@mcp.tool()
def check_process_status():
"""检查处理状态"""
logger.info("Checking process status")
return vector_service.check_process_status()
@mcp.tool()
def add_vector(new_file_name_path: str, new_file_name: str):
"""添加单个文件的向量"""
logger.info(f"Adding vector for file: {new_file_name_path}")
return vector_service.add_vector(new_file_name_path, new_file_name)
@mcp.tool(name="searchfile", description=f"根据关键词搜索文件并返回匹配的内容")
def search_answer(query: str):
"""
获取检索相关的文件
:param query: 用户问题
:return: 返回检索到的文档
"""
if not vector_service.is_initialized:
logger.info("Server is not initialized yet. Please wait.")
return {"status": "error", "message": "Server is not initialized yet. Please wait."}
logger.info(f"Searching for relevant documents: {query}")
try:
retriever = FAISS.load_local(
"E:/llm_rag/faiss_index",
CustomEmbeding('shaw/dmeta-embedding-zh'),
allow_dangerous_deserialization=True
).as_retriever(search_kwargs={"k": 10})
docs = retriever.get_relevant_documents(query)
logger.info(f"找到 {len(docs)} 个相关文档块")
logger.info(f"docs:{docs}")
# return docs
results = []
for doc in docs:
metadata = doc.metadata
file_path = metadata.get("file_path", "")
# 安全检查:确保文件在允许的目录内
allowed_dir = "E:\\llm_rag\\data\\"
if file_path and file_path.startswith(allowed_dir):
# 生成相对路径并构建下载URL
download_url = os.path.relpath(file_path, allowed_dir)
results.append({
"content": doc.page_content, # 文档内容
"download_url": download_url # 下载链接
})
return results
except Exception as e:
logger.error(f"搜索出错: {str(e)}")
return {"status": "error", "message": str(e)}
if __name__ == "__main__":
mcp.settings.port = 8880
logger.info("Starting mcp server through MCP")
mcp.run(transport="sse") # 使用标准输入输出通信 报了这个 + Exception Group Traceback (most recent call last):
| File "E:\llm_rag\.venv\lib\site-packages\uvicorn\protocols\http\h11_impl.py", line 403, in run_asgi
| result = await app( # type: ignore[func-returns-value]
| File "E:\llm_rag\.venv\lib\site-packages\uvicorn\middleware\proxy_headers.py", line 60, in __call__
| return await self.app(scope, receive, send)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\applications.py", line 112, in __call__
| await self.middleware_stack(scope, receive, send)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\middleware\errors.py", line 187, in __call__
| raise exc
| File "E:\llm_rag\.venv\lib\site-packages\starlette\middleware\errors.py", line 165, in __call__
| await self.app(scope, receive, _send)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\middleware\exceptions.py", line 62, in __call__
| await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\_exception_handler.py", line 53, in wrapped_app
| raise exc
| File "E:\llm_rag\.venv\lib\site-packages\starlette\_exception_handler.py", line 42, in wrapped_app
| await app(scope, receive, sender)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\routing.py", line 714, in __call__
| await self.middleware_stack(scope, receive, send)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\routing.py", line 734, in app
| await route.handle(scope, receive, send)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\routing.py", line 288, in handle
| await self.app(scope, receive, send)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\routing.py", line 76, in app
| await wrap_app_handling_exceptions(app, request)(scope, receive, send)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\_exception_handler.py", line 53, in wrapped_app
| raise exc
| File "E:\llm_rag\.venv\lib\site-packages\starlette\_exception_handler.py", line 42, in wrapped_app
| await app(scope, receive, sender)
| File "E:\llm_rag\.venv\lib\site-packages\starlette\routing.py", line 73, in app
| response = await f(request)
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\fastmcp\server.py", line 747, in sse_endpoint
| return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage]
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\fastmcp\server.py", line 680, in handle_sse
| async with sse.connect_sse(
| File "C:\Users\raywe\AppData\Local\Programs\Python\Python310\lib\contextlib.py", line 217, in __aexit__
| await self.gen.athrow(typ, value, traceback)
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\sse.py", line 146, in connect_sse
| async with anyio.create_task_group() as tg:
| File "E:\llm_rag\.venv\lib\site-packages\anyio\_backends\_asyncio.py", line 772, in __aexit__
| raise BaseExceptionGroup(
| exceptiongroup.ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
+-+---------------- 1 ----------------
| Exception Group Traceback (most recent call last):
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\sse.py", line 165, in connect_sse
| yield (read_stream, write_stream)
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\fastmcp\server.py", line 685, in handle_sse
| await self._mcp_server.run(
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\lowlevel\server.py", line 500, in run
| async with AsyncExitStack() as stack:
| File "C:\Users\raywe\AppData\Local\Programs\Python\Python310\lib\contextlib.py", line 714, in __aexit__
| raise exc_details[1]
| File "C:\Users\raywe\AppData\Local\Programs\Python\Python310\lib\contextlib.py", line 217, in __aexit__
| await self.gen.athrow(typ, value, traceback)
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\lowlevel\server.py", line 125, in lifespan
| yield {}
| File "C:\Users\raywe\AppData\Local\Programs\Python\Python310\lib\contextlib.py", line 697, in __aexit__
| cb_suppress = await cb(*exc_details)
| File "E:\llm_rag\.venv\lib\site-packages\mcp\shared\session.py", line 223, in __aexit__
| return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
| File "E:\llm_rag\.venv\lib\site-packages\anyio\_backends\_asyncio.py", line 772, in __aexit__
| raise BaseExceptionGroup(
| exceptiongroup.ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
+-+---------------- 1 ----------------
| Traceback (most recent call last):
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\session.py", line 147, in _receive_loop
| await super()._receive_loop()
| File "E:\llm_rag\.venv\lib\site-packages\mcp\shared\session.py", line 374, in _receive_loop
| await self._received_request(responder)
| File "E:\llm_rag\.venv\lib\site-packages\mcp\server\session.py", line 175, in _received_request
| raise RuntimeError(
| RuntimeError: Received request before initialization was complete 如何解决
最新发布