import time
import os
import requests
import json
from basereal import BaseReal
from logger import logger
from typing import Dict, List, Optional, Callable
import jwt
import hashlib
import hmac
import base64
from datetime import datetime, timedelta
import threading
class ConversationHistory:
"""对话历史管理类"""
def __init__(self, max_rounds=5):
self.max_rounds = max_rounds
self.history = []
def add_message(self, role: str, content: str):
"""添加消息到历史记录"""
self.history.append({"role": role, "content": content})
# 保持最多max_rounds轮对话(每轮包含用户和AI两条消息)
self.history = self.history[-(self.max_rounds*2):]
def get_messages(self) -> List[Dict]:
"""获取完整的历史消息列表"""
return self.history.copy()
def clear(self):
"""清空历史记录"""
self.history = []
class CozeClient:
def __init__(self, app_id: str, private_key: str, kid: str, bot_id: str):
"""
初始化Coze客户端
:param app_id: 应用ID
:param private_key: 私钥
:param kid: Key ID
:param bot_id: 机器人ID
"""
self.last_activity = {}
self.app_id = app_id
self.private_key = private_key
self.kid = kid
self.bot_id = bot_id
self.access_token = None
self.token_expire_time = None
self.lock = threading.Lock()
self.conversation_histories = {} # 用户对话历史存储
self.user_conversations = {} # {user_id: conversation_id}
self.session_file = "coze_sessions.json"
self._load_sessions() # 初始化时加载保存的会话
def _load_sessions(self):
try:
with open(self.session_file, 'r') as f:
data = json.load(f)
self.user_conversations = data.get("user_conversations", {})
self.conversation_histories = {
uid: ConversationHistory(max_rounds=5)
for uid in data.get("conversation_histories", {})
}
except (FileNotFoundError, json.JSONDecodeError):
self.user_conversations = {}
self.conversation_histories = {}
def _save_sessions(self):
data = {
"user_conversations": self.user_conversations,
"conversation_histories": {
uid: hist.get_messages()
for uid, hist in self.conversation_histories.items()
}
}
with open(self.session_file, 'w') as f:
json.dump(data, f)
# 在CozeClient类中添加
CONVERSATION_TIMEOUT = 1800 # 30分钟
def get_conversation_id(self, user_id: str) -> Optional[str]:
conv_id = self.user_conversations.get(user_id)
if conv_id:
self.last_activity[user_id] = time.time() # 更新活动时间
return conv_id
def _create_jwt_token(self, expire_seconds: int = 3600) -> str:
"""创建JWT Token用于获取Access Token"""
ts = int(time.time())
exp = ts + expire_seconds
header = {
"alg": "RS256",
"typ": "JMT",
"kid": self.kid
}
payload = {
"iss": self.app_id,
"aud": "api.coze.cn",
"iat": ts,
"exp": exp,
"jti": self._get_random_string(4)
}
header_b64 = self._base64_url_encode(json.dumps(header).encode('utf-8'))
payload_b64 = self._base64_url_encode(json.dumps(payload).encode('utf-8'))
h_and_p = f"{header_b64}.{payload_b64}"
signature = self._rsa_sign(self.private_key, h_and_p)
signature_b64 = self._base64_url_encode(signature)
return f"{h_and_p}.{signature_b64}"
def _get_access_token(self) -> str:
"""获取Access Token,优先使用缓存的token"""
if self.access_token and self.token_expire_time and time.time() < self.token_expire_time - 10:
return self.access_token
with self.lock:
if self.access_token and self.token_expire_time and time.time() < self.token_expire_time - 10:
return self.access_token
jwt_token = self._create_jwt_token(300)
data = {
"duration_seconds": 86399,
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer"
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {jwt_token}"
}
response = requests.post(
"https://api.coze.cn/api/permission/oauth2/token",
json=data,
headers=headers,
timeout=20
)
if response.status_code == 200:
result = response.json()
self.access_token = result.get("access_token")
expires_in = result.get("expires_in", 86399)
self.token_expire_time = time.time() + expires_in - 10
return self.access_token
else:
raise Exception(f"Failed to get access token: {response.text}")
def get_conversation_history(self, user_id: str, max_rounds: int = 5) -> List[Dict]:
"""获取用户对话历史"""
if user_id not in self.conversation_histories:
self.conversation_histories[user_id] = ConversationHistory(max_rounds)
return self.conversation_histories[user_id].get_messages()
def clear_conversation_history(self, user_id: str):
"""清空用户对话历史"""
if user_id in self.conversation_histories:
self.conversation_histories[user_id].clear()
def stream_chat(
self,
conversation_id: Optional[str],
user_id: str,
messages: List[Dict],
on_message: Callable[[str, Dict], None],
cancellation_token=None,
max_history_rounds: int = 5,
nerfreal: Optional[BaseReal] = None,
):
# 超时检查(添加在函数开头)
if (conversation_id and
time.time() - self.last_activity.get(user_id, 0) > self.CONVERSATION_TIMEOUT):
self.clear_conversation_history(user_id)
conversation_id = None
logger.info(f"Conversation timeout, new session started")
# 在函数开始时添加日志
logger.info(f"Starting chat - User: {user_id}, Existing Conversation ID: {conversation_id}")
"""
流式对话(支持历史记录)
:param max_history_rounds: 最大历史对话轮数
:param nerfreal: BaseReal实例,用于消息输出
"""
token = self._get_access_token()
# 获取历史记录并合并新消息
history = self.get_conversation_history(user_id, max_history_rounds)
all_messages = history + messages
url = "https://api.coze.cn/v3/chat"
if conversation_id:
url += f"?conversation_id={conversation_id}"
data = {
"bot_id": self.bot_id,
"user_id": user_id,
"additional_messages": all_messages,
"stream": True
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}"
}
# 定义包装回调以处理历史记录
full_response = ""
buffer = ""
def wrapped_callback(event: str, msg: dict):
nonlocal full_response, buffer, conversation_id # 添加conversation_id到nonlocal
#logger.debug(f"Received event: {event}, data: {json.dumps(msg, ensure_ascii=False)}")
if event == "conversation.message.delta":
# 从delta消息中获取会话ID(如果有)
if msg.get("conversation_id"):
conversation_id = msg["conversation_id"]
self.user_conversations[user_id] = conversation_id
self._save_sessions() # 立即保存
#logger.info(f"Updated conversation ID: {conversation_id}")
#logger.info(f"Delta message - ID: {msg.get('id')}, "
# f"Conversation ID: {msg.get('conversation_id')}, "
# f"Content: {msg.get('content')}")
if msg.get("type") == "answer" and msg.get("content_type") == "text":
content = msg.get("content", "")
reasoning_content = msg.get("reasoning_content", "")
if reasoning_content:
logger.info(f"Thinking content: {reasoning_content}")
if content:
buffer += content
sentence_enders = ['.', '!', '?']
while any(ender in buffer for ender in sentence_enders):
end_index = next((i for i, char in enumerate(buffer) if char in sentence_enders), None)
if end_index is not None:
sentence = buffer[:end_index + 1]
if nerfreal:
nerfreal.put_msg_txt(self._filter_urls(sentence))
# 优化历史记录处理 - 减少文件IO
if os.path.exists('systemReplyArray.txt'):
with open('systemReplyArray.txt', 'r+', encoding='utf-8') as f:
previous_lines = f.read().splitlines()[:3]
f.seek(0)
f.write(sentence + '\n' + '\n'.join(previous_lines))
f.truncate()
logger.info(f"Processed sentence: {sentence}")
full_response += sentence
buffer = buffer[end_index + 1:]
elif event == "conversation.message":
# 保存会话ID并记录完整消息
if msg.get("conversation_id") and user_id in self.user_conversations:
self.user_conversations[user_id] = msg["conversation_id"]
#logger.info(f"New conversation established - "
# f"Conversation ID: {msg['conversation_id']}, "
# f"Message ID: {msg.get('id')}")
elif event == "error":
error_msg = msg.get("error", "Unknown error")
logger.error(f"Chat error occurred: {error_msg}")
# if nerfreal:
# nerfreal.put_msg_txt(f"对话出错: {error_msg}")
# full_response += f"对话出错: {error_msg}"
elif event == "done":
if buffer:
if nerfreal:
nerfreal.put_msg_txt(self._filter_urls(buffer))
logger.info(f"Final buffer content: {buffer}")
full_response += buffer
buffer = ""
# 记录完整对话历史
logger.info(f"Completed conversation - " f"Final response: {full_response}")
# 将完整回复加入历史记录
if user_id in self.conversation_histories:
self.conversation_histories[user_id].add_message("assistant", full_response)
# 写入完整回复到文件
with open('systemReply.txt', 'w', encoding='utf-8') as f:
f.write(full_response)
# 调用原始回调
on_message(event, msg)
try:
with requests.post(url, json=data, headers=headers, stream=True, timeout=300) as response:
if response.status_code != 200:
wrapped_callback("error", {"error": response.text})
return
for line in response.iter_lines():
if cancellation_token and cancellation_token.is_cancelled():
wrapped_callback("user_cancel", None)
break
if not line:
continue
line = line.decode('utf-8')
if line.startswith("event:"):
event = line[6:].strip()
elif line.startswith("data:"):
data = line[5:].strip()
if data == "[DONE]" or data == "\"[DONE]\"":
wrapped_callback("done", None)
break
try:
msg = json.loads(data)
wrapped_callback(event, msg)
except json.JSONDecodeError:
logger.error(f"Failed to parse message: {data}")
except Exception as e:
wrapped_callback("error", {"error": str(e)})
@staticmethod
def _filter_urls(text: str) -> str:
"""过滤掉文本中的HTTP/HTTPS链接"""
import re
url_pattern = re.compile(r'https?://\S+')
return url_pattern.sub('', text)
@staticmethod
def _base64_url_encode(data: bytes) -> str:
"""Base64 URL安全编码"""
return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=')
@staticmethod
def _rsa_sign(private_key: str, message: str) -> bytes:
"""RSA签名"""
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
private_key_obj = serialization.load_pem_private_key(
private_key.encode(),
password=None
)
signature = private_key_obj.sign(
message.encode(),
padding.PKCS1v15(),
hashes.SHA256()
)
return signature
@staticmethod
def _get_random_string(length: int) -> str:
"""生成随机字符串"""
import random
import string
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
# 在模块级别创建单例客户端
_coze_client_instance = None
def llm_response(message, nerfreal: BaseReal):
start = time.perf_counter()
global _coze_client_instance
if _coze_client_instance is None:
_coze_client_instance = CozeClient(
app_id="1173801711558",
private_key="-----BEGIN PRIVATE KEY-----\n"
"MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQC2sewRl13lFCgB\n"
"s+ypHkWelDwS4NRiiUgDngisx8EV8awSS5mnj5GX632ZFwwHqLQbeQAvKys1/fiK\n"
"3nZduO6wSRKX3sqqxdWVhUcVJA2YURFLbVpkHozAJhY9s8wgug1Z7UjKgxOHupcx\n"
"1Ia1aW00cM03EE97Gq1u/qCoMjJtba8f/Igwgpctfg1YKtlzTKI+2NNo9OfQaBY+\n"
"WBilcTlkf8isw7zW/4g7f1/CpKPhAhsbjR5S4PPGZObi6m7Th49kUL8jbhcZGQir\n"
"JsVwhqoTmFU/rTwDJKMxdYNN3Xd6r1HwNilOzJxLBp+ayhQ2FzSzrDZBWNlfnL3A\n"
"eOqv2VDxAgMBAAECggEAPfIvLrnJ0wpWFFm7FY7XoVD225nTOcP9oIhdvaQPks66\n"
"fwuWQov6HG5zTEzVveUUiLoq91NmV+zQ8NlEfjvd5vUn8knPIz9oT8X8l6z9VRer\n"
"ywz9mLQJGn/vi6ViwfgD3emIhG4UWbHJYVKECJACQMU8t/52TLH7e4an19AJbDOx\n"
"GyurdgvhtMrsKblbfMUShvDrhice5oeh2N/NNMw4TfUwIlWei4vdGS7Nh9itDSor\n"
"Pxy8B9ezkkGFjskANIzwsfhvd8+c7TX6X5DnHmFVuL2M81AzcbVGPElx/GcbPkob\n"
"4GyVVmzI8ugBEFqlJmGHPKOl1e4UB5mXa+ylKGmdhwKBgQDpFtL0HIldLvbHlCGO\n"
"jdfSvzCPcV27koY8qYLp1ShAGx4VHlLWZOt/eV2Y0Fqq/X/QIIrCnkgzlDcWzISx\n"
"itZUkl3KtoxCrzloprcc7dI3gLdAhl3vsoeF9DL/J4iemxnsBi97gU7dDuBxt8fb\n"
"M0eYS43WsF8LATqCJLze0uGR3wKBgQDIpwpLGA22bfgh477y+51zdNuIFJ1Ll7qw\n"
"Fbz+speWbxxnwIjRX3lR1lh3uBdmIesFFsMTPc06IM4e+bj37AyucgRy/bch6z4g\n"
"L86kQCKKdGG0J+jMYdcFo5xKlYxN97hw8FQEl2JksbKkSf4fFAyWQjfLFcclT5cV\n"
"Eo+JUzuXLwKBgCY8r0iKceJOdP9Shpq7HB+fa5jscQL3S3wiFq7DYAH8MNgoDFDN\n"
"Z3CW+Uq7S1Rnl5MN85Vvn8qOUuczj8UMUJK5HBfIEIRT+Gf5iWp+fRDL1cQJBtnu\n"
"gJrx73e6BYh3Sy5T6XAqS0SqTxl4m5mS9Pi/1DnW3xCQGAgHfNBU6dojAoGALlN6\n"
"qenMyLDNGC332SvEp3J0eQ+hXWGTpbHvJ7LeEspmeYHXVNfBL+bYGBP1uwvbshoW\n"
"QewD5QbL8BTh4sOqDeCfLFltnbQtbMr836k7EFJceHa6Ze208kVbAVFTynCGMfUa\n"
"wNCe0/a+8vVuaYh8e3igXxARIYklraTSZPdFi9sCgYAh7RMUrRhv/AcFyJIPhfyA\n"
"y9KdUSPbbGT4/JoDKNE3TO61/v/h+q6WHSruYxpReZQyJTDPprL8inCBpfPmoPXX\n"
"vlDUehuRQ74xOkT4u+Xd9YYjR2V3zqthrydXrY+8aZxLi/ZumgBqRx18HxPyYkj/\n"
"2ASVX3TOEGLOC8dvq706AQ==\n"
"-----END PRIVATE KEY-----",
kid="AT0Q-GegCst7M3PcJz_icpUwDMrdXogc5q4k2SqTAXI",
bot_id="7546524863025463347"
)
coze_client = _coze_client_instance
user_id = "fixed_user_id_for_session"
def get_real_user_id() -> str:
# 示例1:从Web框架的Session获取(如Flask/Django)
# return session.get("user_id")
return "fixed_user_id_for_session"
# 示例2:生成临时会话ID(匿名用户)
#import uuid
#return str(uuid.uuid4())
# 在llm_response函数中:
#user_id = get_real_user_id()
if user_id not in coze_client.user_conversations:
coze_client.user_conversations[user_id] = None # 初始化
logger.info(f"New user session: {user_id}")
conversation_id = coze_client.get_conversation_id(user_id)
logger.info(f"Current conversation - User: {user_id}, ID: {conversation_id}")
# 添加用户消息到历史记录(自动处理)
user_message = {
"role": "user",
"content": message
}
# 发起流式对话(会自动处理历史记录)
coze_client.stream_chat(
conversation_id=conversation_id,
user_id=user_id,
messages=[user_message], # 只需要传入新消息,历史记录会自动添加
on_message=lambda event, msg: None, # 实际处理在wrapped_callback中完成
max_history_rounds=5, # 保持5轮对话历史
nerfreal=nerfreal # 传递nerfreal参数
)
end = time.perf_counter()
logger.info(f"llm总耗时: {end-start}s")
我在重构这个项目, 需要把coze换成ollama的, 本地部署就行不需要联网, 配置好