import time
import os
import requests
import json
import re
from basereal import BaseReal
from logger import logger
from typing import Dict, List, Optional, Callable
import threading
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OllamaEmbeddings
class LocalKnowledgeBase:
"""本地知识库管理类 - 使用新路径"""
def __init__(self, model_name="qwen:1.7b"):
# 关键修改点:更新为新知识库路径
self.knowledge_base_path = r"D:\coze\知识库(新)"
self.embedding_model = OllamaEmbeddings(model=model_name)
self.vector_store = self._initialize_vector_store()
def _initialize_vector_store(self):
"""初始化本地知识库向量存储"""
loader = DirectoryLoader(
path=self.knowledge_base_path,
glob="**/*.*", # 支持多种文件格式
recursive=True, # 递归读取子目录
show_progress=True,
use_multithreading=True
)
documents = loader.load()
# 创建本地向量存储
return FAISS.from_documents(
documents=documents,
embedding=self.embedding_model
)
def query(self, question: str, max_results: int = 3):
"""查询本地知识库"""
return self.vector_store.similarity_search(
query=question,
k=max_results,
filter={"source": self.knowledge_base_path}
)
class ConversationHistory:
"""对话历史管理类"""
def __init__(self, max_rounds=5):
self.max_rounds = max_rounds
self.history = []
def add_message(self, role: str, content: str):
"""添加消息到历史记录"""
self.history.append({"role": role, "content": content})
# 保持最多max_rounds轮对话(每轮包含用户和AI两条消息)
self.history = self.history[-(self.max_rounds*2):]
def get_messages(self) -> List[Dict]:
"""获取完整的历史消息列表"""
return self.history.copy()
def clear(self):
"""清空历史记录"""
self.history = []
class OllamaClient:
def __init__(self, model_name: str = "qwen3:1.7b", base_url: str = "http://localhost:11434"):
"""
初始化Ollama客户端
:param model_name: 本地模型名称
:param base_url: Ollama服务地址
"""
self.last_activity = {}
self.model_name = model_name
self.base_url = base_url
self.lock = threading.Lock()
self.conversation_histories = {} # 用户对话历史存储
self.user_conversations = {} # {user_id: conversation_id}
self.session_file = "ollama_sessions.json"
self.knowledge_base = LocalKnowledgeBase() # 关键修改点:使用新知识库
self._load_sessions() # 初始化时加载保存的会话
self.CONVERSATION_TIMEOUT = 1800 # 30分钟超时
def _load_sessions(self):
try:
with open(self.session_file, 'r') as f:
data = json.load(f)
self.user_conversations = data.get("user_conversations", {})
self.conversation_histories = {
uid: ConversationHistory(max_rounds=5)
for uid in data.get("conversation_histories", {})
}
except (FileNotFoundError, json.JSONDecodeError):
self.user_conversations = {}
self.conversation_histories = {}
def _save_sessions(self):
data = {
"user_conversations": self.user_conversations,
"conversation_histories": {
uid: hist.get_messages()
for uid, hist in self.conversation_histories.items()
},
}
with open(self.session_file, 'w') as f:
json.dump(data, f)
def _update_thinking_mode(self, user_id: str, content: str) -> str:
"""处理思考模式指令 - 添加知识库查询"""
if content.startswith("/think"):
# 查询知识库获取相关上下文
query = content.replace("/think", "").strip()
knowledge_context = self._query_knowledge_base(query)
# 合并知识库结果到用户指令
return f"/think {query}\n\n相关知识库内容:\n{knowledge_context}"
return content
def _query_knowledge_base(self, query: str) -> str:
"""查询本地知识库获取相关内容"""
try:
results = self.knowledge_base.query(query, max_results=3)
context = "\n".join([f"- {doc.page_content[:200]}..." for doc in results])
return context if context else "知识库中未找到相关信息"
except Exception as e:
logger.error(f"知识库查询错误: {str(e)}")
return "知识库查询失败"
def get_conversation_id(self, user_id: str) -> Optional[str]:
conv_id = self.user_conversations.get(user_id)
if conv_id:
self.last_activity[user_id] = time.time() # 更新活动时间
return conv_id
def get_conversation_history(self, user_id: str, max_rounds: int = 5) -> List[Dict]:
"""获取用户对话历史"""
if user_id not in self.conversation_histories:
self.conversation_histories[user_id] = ConversationHistory(max_rounds)
return self.conversation_histories[user_id].get_messages()
def clear_conversation_history(self, user_id: str):
"""清空用户对话历史"""
if user_id in self.conversation_histories:
self.conversation_histories[user_id].clear()
def stream_chat(
self,
conversation_id: Optional[str],
user_id: str,
messages: List[Dict],
on_message: Callable[[str, Dict], None],
cancellation_token=None,
max_history_rounds: int = 5,
nerfreal: Optional[BaseReal] = None,
):
"""使用Ollama本地模型进行流式对话"""
# 超时检查
if (conversation_id and
time.time() - self.last_activity.get(user_id, 0) > self.CONVERSATION_TIMEOUT):
self.clear_conversation_history(user_id)
conversation_id = None
logger.info(f"会话超时,启动新会话")
logger.info(f"开始Ollama对话 - 用户: {user_id}, 模型: {self.model_name}")
# 获取历史记录并合并新消息
history = self.get_conversation_history(user_id, max_history_rounds)
# 处理思考模式指令(动态控制模型行为)
messages[-1]["content"] = self._update_thinking_mode(user_id, messages[-1]["content"])
all_messages = history + messages
# Ollama API请求参数
url = f"{self.base_url}/api/chat"
data = {
"model": self.model_name,
"messages": all_messages,
"stream": True,
"options": {
"temperature": 0.7, # 控制生成随机性
"num_ctx": 4096, # 上下文长度
"top_p": 0.9, # 核采样概率
"repeat_penalty": 1.2 # 重复惩罚
}
}
headers = {
"Content-Type": "application/json"
}
# 回调处理
full_response = ""
buffer = ""
def wrapped_callback(event: str, msg: dict):
nonlocal full_response, buffer, conversation_id
if event == "message_chunk":
content = msg.get("message", {}).get("content", "")
if content:
buffer += content
sentence_enders = ['.', '!', '?', '。', '!', '?']
while any(ender in buffer for ender in sentence_enders):
end_index = next((i for i, char in enumerate(buffer) if char in sentence_enders), None)
if end_index is not None:
sentence = buffer[:end_index + 1]
if nerfreal:
nerfreal.put_msg_txt(self._filter_urls(sentence))
# 优化历史记录处理 - 减少文件IO
if os.path.exists('systemReplyArray.txt'):
with open('systemReplyArray.txt', 'r+', encoding='utf-8') as f:
previous_lines = f.read().splitlines()[:3]
f.seek(0)
f.write(sentence + '\n' + '\n'.join(previous_lines))
f.truncate()
logger.info(f"处理语句: {sentence}")
full_response += sentence
buffer = buffer[end_index + 1:]
elif event == "done":
if buffer:
if nerfreal:
nerfreal.put_msg_txt(self._filter_urls(buffer))
logger.info(f"最终缓存内容: {buffer}")
full_response += buffer
buffer = ""
# 记录完整对话历史
logger.info(f"完成对话 - 最终回复: {full_response}")
# 将完整回复加入历史记录
if user_id in self.conversation_histories:
self.conversation_histories[user_id].add_message("assistant", full_response)
# 写入完整回复到文件
with open('systemReply.txt', 'w', encoding='utf-8') as f:
f.write(full_response)
elif event == "error":
error_msg = msg.get("error", "未知错误")
logger.error(f"对话发生错误: {error_msg}")
# 调用原始回调
on_message(event, msg)
try:
with requests.post(url, json=data, headers=headers, stream=True, timeout=300) as response:
if response.status_code != 200:
wrapped_callback("error", {"error": response.text})
return
for line in response.iter_lines():
if cancellation_token and cancellation_token.is_cancelled():
wrapped_callback("user_cancel", None)
break
if not line:
continue
line = line.decode('utf-8')
try:
msg = json.loads(line)
if msg.get("done", False):
wrapped_callback("done", {"message": {"content": ""}})
break
else:
wrapped_callback("message_chunk", msg)
except json.JSONDecodeError:
logger.error(f"解析Ollama消息失败: {line}")
except Exception as e:
wrapped_callback("error", {"error": str(e)})
@staticmethod
def _filter_urls(text: str) -> str:
"""过滤掉文本中的HTTP/HTTPS链接"""
import re
url_pattern = re.compile(r'https?://\S+')
return url_pattern.sub('', text)
# 在模块级别创建单例客户端
_ollama_client_instance = None
def llm_response(message, nerfreal: BaseReal):
start = time.perf_counter()
global _ollama_client_instance
if _ollama_client_instance is None:
_ollama_client_instance = OllamaClient(
model_name="qwen3:1.7b", # Ollama模型名称
base_url="http://localhost:11434" # Ollama服务地址
)
ollama_client = _ollama_client_instance
user_id = "fixed_user_id_for_session"
# 获取真实用户ID的逻辑(可根据需要实现)
def get_real_user_id() -> str:
return "fixed_user_id_for_session"
# 在llm_response函数中:
# user_id = get_real_user_id()
if user_id not in ollama_client.user_conversations:
ollama_client.user_conversations[user_id] = None # 初始化
logger.info(f"新用户会话: {user_id}")
conversation_id = ollama_client.get_conversation_id(user_id)
logger.info(f"当前会话 - 用户: {user_id}, ID: {conversation_id}")
# 过滤掉多余的回答
filtered_message = re.sub(r'<lora:[^>]+>', '', message)
# 添加用户消息
user_message = {
"role": "user",
"content": f"/no_think {filtered_message}"
}
# 发起流式对话
ollama_client.stream_chat(
conversation_id=conversation_id,
user_id=user_id,
messages=[user_message],
on_message=lambda event, msg: None,
max_history_rounds=5,
nerfreal=nerfreal
)
end = time.perf_counter()
logger.info(f"Ollama总耗时: {end-start}s")
把这个代码改回之前的,