user_service
这个模块主要处理用户的身份验证、权限管理以及用户相关的业务逻辑,包括创建用户、生成 JWT 令牌、获取用户权限等。
以下是代码的逐行解析:
import functools
import json
from base64 import b64decode
from typing import List
import rsa
from fastapi import Depends, HTTPException, Request
from fastapi_jwt_auth import AuthJWT
from bisheng.api.JWT import ACCESS_TOKEN_EXPIRE_TIME
from bisheng.api.errcode.base import UnAuthorizedError
from bisheng.api.errcode.user import (
UserLoginOfflineError,
UserNameAlreadyExistError,
UserNeedGroupAndRoleError
)
from bisheng.api.utils import md5_hash
from bisheng.api.v1.schemas import CreateUserReq
from bisheng.cache.redis import redis_client
from bisheng.database.models.assistant import Assistant, AssistantDao
from bisheng.database.models.flow import Flow, FlowDao, FlowRead
from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead
from bisheng.database.models.role import AdminRole
from bisheng.database.models.role_access import AccessType, RoleAccessDao
from bisheng.database.models.user import User, UserDao
from bisheng.database.models.user_group import UserGroupDao
from bisheng.database.models.user_role import UserRoleDao
from bisheng.settings import settings
from bisheng.utils.constants import RSA_KEY, USER_CURRENT_SESSION
1. 导入模块
标准库导入
functools:提供了高阶函数和可调用对象的操作工具。json:用于处理 JSON 数据的序列化和反序列化。base64:用于处理 Base64 编码和解码。typing.List:用于类型注解,表示列表类型。
第三方库导入
-
rsa:用于 RSA 加密和解密。 -
fastapi模块:
Depends:用于依赖注入。HTTPException:用于抛出 HTTP 异常。Request:表示请求对象。
-
fastapi_jwt_auth.AuthJWT:用于处理 JWT 认证。
项目内部模块导入
-
bisheng.api.JWT.ACCESS_TOKEN_EXPIRE_TIME:JWT 访问令牌的过期时间设置。 -
bisheng.api.errcode模块:
UnAuthorizedError:未授权错误。UserLoginOfflineError:用户被迫下线错误。UserNameAlreadyExistError:用户名已存在错误。UserNeedGroupAndRoleError:用户需要分组和角色错误。
-
bisheng.api.utils.md5_hash:用于生成 MD5 哈希值。 -
bisheng.api.v1.schemas.CreateUserReq:用户创建请求的数据模型。 -
bisheng.cache.redis.redis_client:Redis 客户端实例。 -
bisheng.database.models模块:
assistant:助手相关的模型和 DAO。flow:流程相关的模型和 DAO。knowledge:知识库相关的模型和 DAO。role:角色相关的模型。role_access:角色权限相关的模型和 DAO。user:用户相关的模型和 DAO。user_group:用户组相关的 DAO。user_role:用户角色相关的 DAO。
-
bisheng.settings.settings:项目的配置信息。 -
bisheng.utils.constants:定义了一些常量,如RSA_KEY、USER_CURRENT_SESSION。
2. UserPayload 类
class UserPayload:
def __init__(self, **kwargs):
self.user_id = kwargs.get('user_id')
self.user_role = kwargs.get('role')
if self.user_role != 'admin':
# 非管理员用户,需要获取他的角色列表
roles = UserRoleDao.get_user_roles(self.user_id)
self.user_role = [one.role_id for one in roles]
self.user_name = kwargs.get('user_name')
def is_admin(self):
if self.user_role == 'admin':
return True
if isinstance(self.user_role, list):
for one in self.user_role:
if one == AdminRole:
return True
return False
@staticmethod
def wrapper_access_check(func):
"""
权限检查的装饰器
如果是 admin 用户则不执行后续具体的检查逻辑
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if args[0].is_admin():
return True
return func(*args, **kwargs)
return wrapper
@wrapper_access_check
def access_check(self, owner_user_id: int, target_id: str, access_type: AccessType) -> bool:
"""
检查用户是否有某个资源的权限
"""
# 判断是否属于本人资源
if self.user_id == owner_user_id:
return True
# 判断授权
if RoleAccessDao.judge_role_access(self.user_role, target_id, access_type):
return True
return False
@wrapper_access_check
def check_group_admin(self, group_id: int) -> bool:
"""
检查用户是否是某个组的管理员
"""
# 判断是否是用户组的管理员
user_group = UserGroupDao.get_user_admin_group(self.user_id)
if not user_group:
return False
for one in user_group:
if one.group_id == group_id:
return True
return False
@wrapper_access_check
def check_groups_admin(self, group_ids: List[int]) -> bool:
"""
检查用户是否是用户组列表中的管理员,有一个就是 True
"""
user_groups = UserGroupDao.get_user_admin_group(self.user_id)
for one in user_groups:
if one.is_group_admin and one.group_id in group_ids:
return True
return False
功能概述
UserPayload 类用于封装用户的基本信息和权限校验方法,包括用户 ID、用户名、用户角色等。它还提供了权限检查的方法,判断用户是否有权访问某些资源。
属性
user_id:用户的唯一标识符。user_role:用户的角色,可以是'admin'或角色 ID 列表。user_name:用户名。
方法
__init__
- 功能:初始化用户信息。
- 逻辑:
- 获取
user_id和role(用户角色)。 - 如果角色不是
'admin',则通过UserRoleDao.get_user_roles获取用户的角色列表,赋值给self.user_role。
- 获取
is_admin
- 功能:判断用户是否为管理员。
- 返回值:
True表示是管理员,False表示不是。 - 逻辑:
- 如果
self.user_role是'admin',直接返回True。 - 如果
self.user_role是列表,遍历角色列表,如果存在AdminRole,返回True。
- 如果
wrapper_access_check
- 功能:权限检查的装饰器。
- 逻辑:
- 如果用户是管理员,直接返回
True,跳过后续的权限检查。 - 否则,执行被装饰的函数。
- 如果用户是管理员,直接返回
access_check
- 功能:检查用户是否有权限访问某个资源。
- 参数:
owner_user_id:资源所有者的用户 ID。target_id:目标资源的 ID。access_type:访问类型,枚举类型AccessType。
- 逻辑:
- 如果用户是资源的所有者,返回
True。 - 通过
RoleAccessDao.judge_role_access检查用户角色是否有权限访问目标资源。
- 如果用户是资源的所有者,返回
check_group_admin
- 功能:检查用户是否是某个用户组的管理员。
- 参数:
group_id:用户组的 ID。
- 逻辑:
- 获取用户管理的用户组列表。
- 如果用户在指定的用户组中,并且是管理员,返回
True。
check_groups_admin
- 功能:检查用户是否是多个用户组中的管理员。
- 参数:
group_ids:用户组 ID 的列表。
- 逻辑:
- 获取用户管理的用户组列表。
- 如果用户是任意一个指定用户组的管理员,返回
True。
总结
UserPayload 类通过封装用户信息和权限检查方法,方便在项目的其他地方进行用户权限的验证和操作。
3. UserService 类
class UserService:
@classmethod
def decrypt_md5_password(cls, password: str):
if value := redis_client.get(RSA_KEY):
private_key = value[1]
password = md5_hash(rsa.decrypt(b64decode(password), private_key).decode('utf-8'))
else:
password = md5_hash(password)
return password
@classmethod
def create_user(cls, request: Request, login_user: UserPayload, req_data: CreateUserReq):
"""
创建用户
"""
exists_user = UserDao.get_user_by_username(req_data.user_name)
if exists_user:
# 抛出异常
raise UserNameAlreadyExistError.http_exception()
user = User(
user_name=req_data.user_name,
password=cls.decrypt_md5_password(req_data.password),
)
group_ids = []
role_ids = []
for one in req_data.group_roles:
group_ids.append(one.group_id)
role_ids.extend(one.role_ids)
if not group_ids or not role_ids:
raise UserNeedGroupAndRoleError.http_exception()
user = UserDao.add_user_with_groups_and_roles(user, group_ids, role_ids)
return user
功能概述
UserService 类负责处理与用户相关的业务逻辑,例如创建用户、密码加密等。
方法
decrypt_md5_password
- 功能:解密并加密用户密码。
- 参数:
password:加密的密码字符串。
- 逻辑:
- 从 Redis 中获取 RSA 私钥(RSA_KEY)。
- 如果存在私钥,使用 RSA 解密密码,然后计算 MD5 哈希值。
- 如果没有私钥,直接计算密码的 MD5 哈希值。
- 从 Redis 中获取 RSA 私钥(RSA_KEY)。
- 返回值:加密后的密码(MD5 哈希值)。
create_user
- 功能:创建新用户。
- 参数:
request:请求对象。login_user:当前登录的用户信息。req_data:创建用户的请求数据(CreateUserReq)。
- 逻辑:
- 检查用户名是否已存在,如果存在,抛出
UserNameAlreadyExistError异常。 - 创建
User实例,设置用户名和加密后的密码。 - 解析用户的分组和角色信息:
- 遍历
req_data.group_roles,收集group_ids和role_ids。
- 遍历
- 检查是否提供了分组和角色,如果没有,抛出
UserNeedGroupAndRoleError异常。 - 调用
UserDao.add_user_with_groups_and_roles,将用户信息、分组和角色保存到数据库。
- 检查用户名是否已存在,如果存在,抛出
- 返回值:创建的用户对象。
4. 其他函数
sso_login
def sso_login():
pass
- 功能:单点登录的占位函数,目前未实现具体逻辑。
gen_user_role
def gen_user_role(db_user: User):
# 查询用户的角色列表
db_user_role = UserRoleDao.get_user_roles(db_user.user_id)
role = ''
role_ids = []
for user_role in db_user_role:
if user_role.role_id == 1:
# 是管理员,忽略其他的角色
role = 'admin'
else:
role_ids.append(user_role.role_id)
if role != 'admin':
# 判断是否是用户组管理员
db_user_groups = UserGroupDao.get_user_admin_group(db_user.user_id)
if len(db_user_groups) > 0:
role = 'group_admin'
else:
role = role_ids
# 获取用户的菜单栏权限列表
web_menu = RoleAccessDao.get_role_access(role_ids, AccessType.WEB_MENU)
web_menu = list(set([one.third_id for one in web_menu]))
return role, web_menu
- 功能:生成用户的角色和菜单权限列表。
- 参数:
db_user:数据库中的用户对象。
- 逻辑:
- 获取用户的角色列表。
- 如果角色 ID 为
1,表示是管理员,设置role为'admin',忽略其他角色。 - 否则,将角色 ID 添加到
role_ids列表中。
- 如果角色 ID 为
- 如果不是管理员:
- 检查用户是否是用户组管理员。
- 如果是,设置
role为'group_admin'。 - 否则,
role为role_ids列表。
- 如果是,设置
- 检查用户是否是用户组管理员。
- 获取用户的菜单栏权限列表。
- 通过
RoleAccessDao.get_role_access获取角色的菜单权限(AccessType.WEB_MENU)。 - 去重后,得到
web_menu列表。
- 通过
- 获取用户的角色列表。
- 返回值:
role(用户角色)和web_menu(菜单权限列表)。
gen_user_jwt
def gen_user_jwt(db_user: User):
if 1 == db_user.delete:
raise HTTPException(status_code=500, detail='该账号已被禁用,请联系管理员')
# 查询角色
role, web_menu = gen_user_role(db_user)
# 生成 JWT 令牌
payload = {'user_name': db_user.user_name, 'user_id': db_user.user_id, 'role': role}
# Create the tokens and passing to set_access_cookies or set_refresh_cookies
access_token = AuthJWT().create_access_token(subject=json.dumps(payload),
expires_time=ACCESS_TOKEN_EXPIRE_TIME)
refresh_token = AuthJWT().create_refresh_token(subject=db_user.user_name)
# Set the JWT cookies in the response
return access_token, refresh_token, role, web_menu
- 功能:生成用户的 JWT 访问令牌和刷新令牌。
- 参数:
db_user:数据库中的用户对象。
- 逻辑:
- 检查用户是否被禁用(
delete字段为1),如果被禁用,抛出 HTTP 异常。 - 调用
gen_user_role获取用户的角色和菜单权限列表。 - 创建 JWT 负载(
payload),包括用户名、用户 ID 和角色。 - 使用 AuthJWT生成访问令牌和刷新令牌。
- 访问令牌:包含用户信息,设置过期时间。
- 刷新令牌:用于刷新访问令牌。
- 检查用户是否被禁用(
- 返回值:
access_token、refresh_token、role、web_menu。
get_knowledge_list_by_access
def get_knowledge_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Knowledge.name.like('%{}%'.format(name)))
db_role_access = KnowledgeDao.get_knowledge_by_access(role_id, page_num, page_size)
total_count = KnowledgeDao.get_count_by_filter(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [
KnowledgeRead.validate({
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
}) for access in db_role_access
],
'total':
total_count
}
- 功能:根据角色 ID 获取用户有权限访问的知识库列表。
- 参数:
role_id:角色 ID。name:知识库名称,支持模糊查询。page_num:分页页码。page_size:每页显示的数量。
- 逻辑:
- 根据名称构建过滤条件。
- 调用
KnowledgeDao.get_knowledge_by_access获取用户有权限的知识库列表。 - 获取总数,便于分页。
- 获取相关的用户信息,补充用户名。
- 构建返回的数据结构,包括知识库信息和总数。
- 返回值:包含
data(知识库列表)和total(总数)的字典。
get_flow_list_by_access
def get_flow_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Flow.name.like('%{}%'.format(name)))
db_role_access = FlowDao.get_flow_by_access(role_id, name, page_num, page_size)
total_count = FlowDao.get_count_by_filters(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [
FlowRead.validate({
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
}) for access in db_role_access
],
'total':
total_count
}
- 功能:根据角色 ID 获取用户有权限访问的流程列表。
- 逻辑:与
get_knowledge_list_by_access类似,区别在于处理的对象是流程(Flow)。
get_assistant_list_by_access
def get_assistant_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Assistant.name.like('%{}%'.format(name)))
db_role_access = AssistantDao.get_assistants_by_access(role_id, name, page_size, page_num)
total_count = AssistantDao.get_count_by_filters(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [{
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
} for access in db_role_access],
'total':
total_count
}
- 功能:根据角色 ID 获取用户有权限访问的助手列表。
- 逻辑:与前两个函数类似,处理的对象是助手(Assistant)。
5. 获取当前登录用户
get_login_user
async def get_login_user(authorize: AuthJWT = Depends()) -> UserPayload:
"""
获取当前登录的用户
"""
# 校验是否过期,过期则直接返回 http 状态码的 401
authorize.jwt_required()
current_user = json.loads(authorize.get_jwt_subject())
user = UserPayload(**current_user)
# 判断是否允许多点登录
if not settings.get_system_login_method().allow_multi_login:
# 获取 access_token
current_token = redis_client.get(USER_CURRENT_SESSION.format(user.user_id))
# 登录被挤下线了,http 状态码是 200, status_code 是特殊 code
if current_token != authorize._token:
raise UserLoginOfflineError.http_exception()
return user
- 功能:获取当前登录的用户信息,进行权限验证和多点登录检查。
- 参数:
authorize:依赖注入的AuthJWT实例,用于处理 JWT 认证。
- 逻辑:
- 调用
authorize.jwt_required()验证 JWT 是否存在和有效,若无效则返回 401 状态码。 - 从 JWT 中获取用户信息,创建
UserPayload实例。 - 检查系统设置是否允许多点登录。
- 如果不允许,检查当前用户的会话是否与 Redis 中存储的令牌一致。
- 如果不一致,说明用户被挤下线,抛出
UserLoginOfflineError异常。
- 调用
- 返回值:
UserPayload实例,包含当前用户的信息。
get_admin_user
async def get_admin_user(authorize: AuthJWT = Depends()) -> UserPayload:
"""
获取超级管理员账号,非超级管理员用户,抛出异常
"""
login_user = await get_login_user(authorize)
if not login_user.is_admin():
raise UnAuthorizedError.http_exception()
return login_user
- 功能:获取当前登录的超级管理员用户,如果不是管理员,抛出未授权异常。
- 逻辑:
- 调用
get_login_user获取当前登录用户。 - 检查用户是否为管理员。
- 如果不是,抛出
UnAuthorizedError异常。
- 如果不是,抛出
- 调用
- 返回值:
UserPayload实例,包含管理员用户的信息。
6. 总结
这个 user_service.py 模块主要处理了以下功能:
- 用户权限管理:通过
UserPayload类和相关的权限检查方法,管理用户的权限和角色。 - 用户创建:提供创建用户的功能,包括密码加密、用户名重复检查、分组和角色的分配等。
- JWT 认证:生成 JWT 令牌,验证用户的身份,支持多点登录的控制。
- 资源访问:根据用户的角色,获取用户有权限访问的知识库、流程和助手列表。
- 异常处理:在权限不足、用户被禁用、用户名重复等情况下,抛出相应的异常,提供友好的错误信息。
manager
好的,我来为您详细讲解 manager.py 这个模块的代码。这个模块主要涉及到观察者模式的实现和一个缓存管理器的设计。以下是对代码的详细解析:
导入模块
from contextlib import contextmanager
from typing import Any, Awaitable, Callable, List, Optional
import pandas as pd
from bisheng.utils.util import get_cache_key
from PIL import Image
-
contextlib.contextmanager:用于创建上下文管理器的装饰器,方便使用with语句。 -
typing模块:提供类型注解支持。
Any:任意类型。Awaitable:可等待对象,通常用于异步函数。Callable:可调用对象类型。List:列表类型。Optional:可选类型,可能为None。
-
pandas as pd:用于数据处理的库,简称为pd。 -
bisheng.utils.util.get_cache_key:项目内部的工具函数,用于生成缓存键值。 -
PIL.Image:Python 图像库,处理图像相关操作。
1. Subject 类
class Subject:
"""实现观察者模式的基础类。"""
def __init__(self):
self.observers: List[Callable[[], None]] = []
def attach(self, observer: Callable[[], None]):
"""向主题添加一个观察者。"""
self.observers.append(observer)
def detach(self, observer: Callable[[], None]):
"""从主题中移除一个观察者。"""
self.observers.remove(observer)
def notify(self):
"""通知所有观察者一个事件发生。"""
for observer in self.observers:
if observer is None:
continue
observer()
功能概述
Subject类是实现观察者模式的基础类。- 它维护了一个观察者列表,当主题发生变化时,通知所有的观察者。
方法解析
__init__- 初始化方法,创建一个空的观察者列表
self.observers。
- 初始化方法,创建一个空的观察者列表
attach- 添加一个观察者到观察者列表。
- 参数:
observer: 一个不接受任何参数且返回None的可调用对象(函数)。
detach- 从观察者列表中移除一个观察者。
notify- 通知所有观察者,调用它们的回调函数。
- 遍历
self.observers列表,逐个调用每个观察者。 - 如果观察者为
None,则跳过。
2. AsyncSubject 类
class AsyncSubject:
"""实现异步观察者模式的基础类。"""
def __init__(self):
self.observers: List[Callable[[], Awaitable]] = []
def attach(self, observer: Callable[[], Awaitable]):
"""向主题添加一个异步观察者。"""
self.observers.append(observer)
def detach(self, observer: Callable[[], Awaitable]):
"""从主题中移除一个异步观察者。"""
self.observers.remove(observer)
async def notify(self):
"""异步通知所有观察者一个事件发生。"""
for observer in self.observers:
if observer is None:
continue
await observer()
功能概述
AsyncSubject类是针对异步场景的观察者模式实现。- 它允许观察者是异步函数,支持
await调用。
方法解析
__init__- 初始化方法,创建一个空的异步观察者列表
self.observers。
- 初始化方法,创建一个空的异步观察者列表
attach- 添加一个异步观察者到观察者列表。
- 参数:
observer: 一个不接受任何参数且返回Awaitable的可调用对象(异步函数)。
detach- 从观察者列表中移除一个异步观察者。
notify- 异步通知所有观察者。
- 使用
await调用每个异步观察者。
3. CacheManager 类
class CacheManager(Subject):
"""管理不同客户端的缓存,并在发生变化时通知观察者。"""
def __init__(self):
super().__init__()
self._cache = {}
self.current_client_id = None
self.current_chat_id = None
self.current_cache = {}
@contextmanager
def set_client_id(self, client_id: str, chat_id: str):
"""
上下文管理器,用于设置当前的 client_id 和关联的缓存。
参数:
client_id (str): 客户端标识符。
chat_id (str): 聊天标识符。
"""
previous_client_id = self.current_client_id
previous_chat_id = self.current_chat_id
self.current_client_id = client_id
self.current_chat_id = chat_id
self.current_cache = self._cache.setdefault(get_cache_key(client_id, chat_id), {})
try:
yield
finally:
self.current_client_id = previous_client_id
self.current_chat_id = previous_chat_id
self.current_cache = self._cache.get(get_cache_key(self.current_client_id, self.current_chat_id), {})
def add(self, name: str, obj: Any, obj_type: str, extension: Optional[str] = None):
"""
向当前客户端的缓存中添加一个对象。
参数:
name (str): 缓存的键名。
obj (Any): 要缓存的对象。
obj_type (str): 对象的类型。
extension (Optional[str]): 文件扩展名,默认为 None。
"""
object_extensions = {
'image': 'png',
'pandas': 'csv',
}
if obj_type in object_extensions:
_extension = object_extensions[obj_type]
else:
_extension = type(obj).__name__.lower()
self.current_cache[name] = {
'obj': obj,
'type': obj_type,
'extension': extension or _extension,
}
self.notify()
def add_pandas(self, name: str, obj: Any):
"""
向当前客户端的缓存中添加一个 pandas DataFrame 或 Series。
参数:
name (str): 缓存的键名。
obj (Any): pandas DataFrame 或 Series 对象。
"""
if isinstance(obj, (pd.DataFrame, pd.Series)):
self.add(name, obj.to_csv(), 'pandas', extension='csv')
else:
raise ValueError('Object is not a pandas DataFrame or Series')
def add_image(self, name: str, obj: Any, extension: str = 'png'):
"""
向当前客户端的缓存中添加一个 PIL Image。
参数:
name (str): 缓存的键名。
obj (Any): PIL Image 对象。
extension (str): 文件扩展名,默认为 'png'。
"""
if isinstance(obj, Image.Image):
self.add(name, obj, 'image', extension=extension)
else:
raise ValueError('Object is not a PIL Image')
def get(self, name: str):
"""
从当前客户端的缓存中获取一个对象。
参数:
name (str): 缓存的键名。
返回:
与给定键名关联的缓存对象。
"""
return self.current_cache[name]
def get_last(self):
"""
获取当前客户端缓存中最后添加的对象。
返回:
缓存中最后添加的对象。
"""
return list(self.current_cache.values())[-1]
功能概述
CacheManager类继承自Subject,因此具备观察者模式的功能。- 该类管理不同客户端(client)的缓存数据,并在缓存发生变化时通知观察者。
- 支持对缓存的添加、获取操作,以及对特定类型对象的专门处理(如 pandas 数据和图像)。
属性
_cache:存储所有客户端的缓存数据,结构为{client_cache_key: client_cache_dict}。current_client_id:当前操作的客户端 ID。current_chat_id:当前操作的聊天 ID。current_cache:当前客户端的缓存字典。
方法解析
__init__
- 初始化方法,调用父类的初始化方法。
- 初始化缓存相关的属性。
set_client_id
- 上下文管理器,用于设置当前的客户端 ID 和聊天 ID,以及关联的缓存。
- 参数:
client_id:客户端标识符。chat_id:聊天标识符。
- 功能:
- 在进入上下文时,保存之前的
client_id和chat_id。 - 设置新的
current_client_id和current_chat_id。 - 获取或创建对应的
current_cache。 - 在退出上下文时,恢复之前的
client_id、chat_id和current_cache。
- 在进入上下文时,保存之前的
使用示例:
with cache_manager.set_client_id('client1', 'chat1'):
# 在这个上下文中,current_client_id 和 current_chat_id 被设置为 'client1' 和 'chat1'
# 可以进行缓存操作
cache_manager.add('key1', 'value1', 'string')
add
- 向当前客户端的缓存中添加一个对象。
- 参数:
name:缓存的键名。obj:要缓存的对象。obj_type:对象的类型,字符串形式。extension:文件扩展名,可选。
- 功能:
- 根据
obj_type,确定对象的扩展名。如果未提供extension,则使用预定义的或对象类型的名称。 - 将对象存储在
current_cache中,键为name,值为一个包含对象信息的字典。 - 调用
self.notify(),通知所有观察者缓存已更新。
- 根据
add_pandas
- 专门处理 pandas 数据的添加方法。
- 参数:
name:缓存的键名。obj:pandas DataFrame 或 Series 对象。
- 功能:
- 检查对象是否为 pandas DataFrame 或 Series。
- 将对象转换为 CSV 格式的字符串。
- 调用
self.add方法,将其添加到缓存中,类型设为'pandas',扩展名为'csv'。
- 异常处理:
- 如果对象不是 pandas 数据类型,抛出
ValueError。
- 如果对象不是 pandas 数据类型,抛出
add_image
- 专门处理图像数据的添加方法。
- 参数:
name:缓存的键名。obj:PIL Image 对象。extension:文件扩展名,默认为'png'。
- 功能:
- 检查对象是否为 PIL Image 类型。
- 调用
self.add方法,将其添加到缓存中,类型设为'image'。
- 异常处理:
- 如果对象不是 PIL Image 类型,抛出
ValueError。
- 如果对象不是 PIL Image 类型,抛出
get
- 从当前客户端的缓存中获取指定键名的对象。
- 参数:
name:缓存的键名。
- 返回:
- 缓存中对应键名的对象信息字典。
get_last
- 获取当前客户端缓存中最后添加的对象。
- 返回:
- 缓存中最后一个添加的对象信息字典。
4. 模块级实例
cache_manager = CacheManager()
- 在模块级别创建了一个
CacheManager的实例cache_manager,供全局使用。
5. 总结
- 观察者模式:
Subject和AsyncSubject实现了同步和异步的观察者模式基础类。- 提供了
attach、detach和notify方法,用于管理观察者和通知事件。
- 缓存管理器:
CacheManager继承自Subject,实现了对不同客户端缓存的管理。- 支持上下文管理器,通过
set_client_id方法设置当前操作的客户端和聊天会话。 - 提供了
add、add_pandas、add_image、get和get_last方法,方便缓存对象的添加和获取。 - 在添加对象时,会通知所有已注册的观察者,便于在缓存更新时触发相应的处理。
使用场景
- 缓存管理:
- 在一个多用户、多会话的环境中,需要对不同的客户端和聊天会话进行缓存管理。
CacheManager可以根据client_id和chat_id进行区分,确保缓存的隔离性。
- 观察者通知:
- 当缓存发生变化时,可能需要更新界面、日志记录或触发其他操作。
- 通过观察者模式,
CacheManager可以通知所有的观察者,执行相应的回调函数。
实际应用示例
# 定义一个观察者函数
def cache_updated():
print("Cache has been updated.")
# 将观察者函数附加到缓存管理器
cache_manager.attach(cache_updated)
# 使用缓存管理器添加对象
with cache_manager.set_client_id('user123', 'chat456'):
# 添加一个字符串对象
cache_manager.add('greeting', 'Hello, World!', 'string')
# 添加一个 pandas DataFrame
df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
cache_manager.add_pandas('dataframe', df)
# 添加一个图像对象
image = Image.new('RGB', (100, 100), color='red')
cache_manager.add_image('red_square', image)
# 由于缓存发生了更新,观察者函数会被调用,输出:
# "Cache has been updated."
注意事项
- 线程安全性:
- 目前的实现未考虑线程安全性。如果在多线程环境中使用,需要添加线程锁来保护共享数据。
- 错误处理:
- 在获取缓存对象时,如果键名不存在,会抛出
KeyError。在实际使用中,应该添加相应的异常处理。
- 在获取缓存对象时,如果键名不存在,会抛出
- 缓存清理:
- 目前未提供缓存清理或过期机制。如果缓存数据较多,可能会占用较多内存。
llm
导入模块
import json
from typing import List, Optional
from fastapi import Request
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel
from loguru import logger
from bisheng.api.errcode.base import NotFoundError
from bisheng.api.errcode.llm import (
ServerExistError,
ModelNameRepeatError,
ServerAddError,
ServerAddAllError
)
from bisheng.api.services.user_service import UserPayload
from bisheng.api.v1.schemas import (
LLMServerInfo,
LLMModelInfo,
KnowledgeLLMConfig,
AssistantLLMConfig,
EvaluationLLMConfig,
AssistantLLMItem,
LLMServerCreateReq
)
from bisheng.database.models.config import ConfigDao, ConfigKeyEnum, Config
from bisheng.database.models.llm_server import (
LLMDao,
LLMServer,
LLMModel,
LLMModelType
)
from bisheng.interface.importing import import_by_type
from bisheng.interface.initialize.loading import instantiate_llm, instantiate_embedding
标准库导入
json:用于处理 JSON 数据的序列化和反序列化。typing模块:提供类型注解支持。List:列表类型注解。Optional:可选类型,可能为None。
第三方库导入
fastapi.Request:FastAPI 的请求对象。langchain_core.embeddings.Embeddings:LangChain 中的嵌入向量模型类。langchain_core.language_models.BaseChatModel:LangChain 中的基础聊天模型类。loguru.logger:用于日志记录的库。
项目内部模块导入
- 错误码模块:
NotFoundError:未找到错误。ServerExistError:服务已存在错误。ModelNameRepeatError:模型名称重复错误。ServerAddError:添加服务错误。ServerAddAllError:添加服务全部失败错误。
- 用户服务模块:
UserPayload:用户信息载体类。
- 数据模型和请求体定义:
LLMServerInfo:LLM 服务信息的数据模型。LLMModelInfo:LLM 模型信息的数据模型。KnowledgeLLMConfig:知识库 LLM 配置的数据模型。AssistantLLMConfig:助手 LLM 配置的数据模型。EvaluationLLMConfig:评估 LLM 配置的数据模型。AssistantLLMItem:助手 LLM 项目的数据模型。LLMServerCreateReq:创建 LLM 服务的请求体数据模型。
- 配置相关的 DAO 和枚举:
ConfigDao:配置数据访问对象。ConfigKeyEnum:配置键的枚举类型。Config:配置的数据模型。
- LLM 服务相关的 DAO 和模型:
LLMDao:LLM 数据访问对象。LLMServer:LLM 服务的数据模型。LLMModel:LLM 模型的数据模型。LLMModelType:LLM 模型类型的枚举。
- 动态导入和实例化模块:
import_by_type:根据类型导入模块的函数。instantiate_llm:实例化 LLM 模型的函数。instantiate_embedding:实例化嵌入模型的函数。
LLMService 类
class LLMService:
...
功能概述
LLMService 类是一个用于管理大型语言模型(LLM)服务的类,提供了对 LLM 服务和模型的增删改查,以及配置默认模型等功能。
方法列表
get_all_llm:获取所有的 LLM 服务信息。get_one_llm:获取单个 LLM 服务的详细信息。add_llm_server:添加一个新的 LLM 服务。add_llm_server_hook:添加 LLM 服务后的后续处理。delete_llm_server:删除一个 LLM 服务。update_llm_server:更新 LLM 服务信息。update_model_online:更新模型的上线状态。get_knowledge_llm:获取知识库相关的默认模型配置。get_knowledge_source_llm:获取知识库溯源的默认模型实例。get_knowledge_similar_llm:获取知识库相似问的默认模型实例。update_knowledge_llm:更新知识库的默认模型配置。get_assistant_llm:获取助手相关的默认模型配置。update_assistant_llm:更新助手的默认模型配置。get_evaluation_llm:获取评测功能的默认模型配置。get_evaluation_llm_object:获取评测功能的默认模型实例。get_bisheng_llm:获取必胜的 LLM 模型实例。get_bisheng_embedding:获取必胜的嵌入模型实例。update_evaluation_llm:更新评测功能的默认模型配置。get_assistant_llm_list:获取助手可选的模型列表。set_default_model:设置默认的模型配置。
下面我们逐一解析其中的重要方法。
1. get_all_llm
@classmethod
def get_all_llm(cls, request: Request, login_user: UserPayload) -> List[LLMServerInfo]:
""" 获取所有的模型数据,不包含 key 等敏感信息 """
llm_servers = LLMDao.get_all_server()
ret = []
server_ids = []
for one in llm_servers:
server_ids.append(one.id)
ret.append(LLMServerInfo(**one.model_dump(exclude={'config'})))
llm_models = LLMDao.get_model_by_server_ids(server_ids)
server_dicts = {}
for one in llm_models:
if one.server_id not in server_dicts:
server_dicts[one.server_id] = []
server_dicts[one.server_id].append(one.model_dump(exclude={'config'}))
for one in ret:
one.models = server_dicts.get(one.id, [])
return ret
功能
- 获取所有的 LLM 服务信息,包括每个服务下的模型列表。
- 不包含敏感的配置信息,如密钥(
config字段)。
逻辑
- 获取所有 LLM 服务:
- 调用
LLMDao.get_all_server()获取所有的 LLM 服务记录。
- 调用
- 构建服务信息列表:
- 初始化返回列表
ret和服务 ID 列表server_ids。 - 遍历所有的 LLM 服务,提取每个服务的 ID,并使用
LLMServerInfo数据模型(排除config字段)构建服务信息对象,添加到ret列表。
- 初始化返回列表
- 获取每个服务下的模型列表:
- 调用
LLMDao.get_model_by_server_ids(server_ids)获取所有服务对应的模型列表。 - 初始化一个字典
server_dicts,以服务 ID 为键,模型列表为值。 - 遍历模型列表,将模型信息添加到对应的服务下(排除
config字段)。
- 调用
- 组装最终返回结果:
- 遍历服务信息列表
ret,为每个服务对象添加models属性,对应其模型列表。
- 遍历服务信息列表
- 返回结果:
- 返回包含所有服务和其模型列表的
LLMServerInfo对象列表。
- 返回包含所有服务和其模型列表的
2. get_one_llm
@classmethod
def get_one_llm(cls, request: Request, login_user: UserPayload, server_id: int) -> LLMServerInfo:
""" 获取一个服务提供方的详细信息,包含了 key 等敏感的配置信息 """
llm = LLMDao.get_server_by_id(server_id)
if not llm:
raise NotFoundError.http_exception()
models = LLMDao.get_model_by_server_ids([server_id])
models = [LLMModelInfo(**one.model_dump()) for one in models]
return LLMServerInfo(**llm.model_dump(), models=models)
功能
- 获取指定的 LLM 服务的详细信息,包括敏感配置信息和模型列表。
逻辑
- 获取指定的 LLM 服务:
- 调用
LLMDao.get_server_by_id(server_id)获取服务信息。 - 如果服务不存在,抛出
NotFoundError异常。
- 调用
- 获取服务下的模型列表:
- 调用
LLMDao.get_model_by_server_ids([server_id])获取该服务下的模型列表。 - 使用
LLMModelInfo数据模型构建模型信息列表。
- 调用
- 返回结果:
- 构建
LLMServerInfo对象,包括服务信息和模型列表,返回给调用者。
- 构建
3. add_llm_server
@classmethod
def add_llm_server(cls, request: Request, login_user: UserPayload, server: LLMServerCreateReq) -> LLMServerInfo:
""" 添加一个服务提供方 """
exist_server = LLMDao.get_server_by_name(server.name)
if exist_server:
raise ServerExistError.http_exception()
# 尝试实例化下对应的模型组件是否可以成功初始化
model_dict = {}
for one in server.models:
if one.model_name not in model_dict:
model_dict[one.model_name] = LLMModel(**one.dict(), user_id=login_user.user_id)
else:
raise ModelNameRepeatError.http_exception()
db_server = LLMServer(**server.dict(exclude={'models'}))
db_server.user_id = login_user.user_id
db_server = LLMDao.insert_server_with_models(db_server, list(model_dict.values()))
ret = cls.get_one_llm(request, login_user, db_server.id)
success_models = []
success_msg = ''
failed_models = []
failed_msg = ''
# 尝试实例化对应的模型,有报错的话删除
for one in ret.models:
try:
if one.model_type == LLMModelType.LLM.value:
cls.get_bisheng_llm(model_id=one.id, ignore_online=True)
elif one.model_type == LLMModelType.EMBEDDING.value:
cls.get_bisheng_embedding(model_id=one.id, ignore_online=True)
success_msg += f'{one.model_name},'
success_models.append(one)
except Exception as e:
logger.exception("init_model_error")
# 模型初始化失败的话,不添加到模型列表里
failed_msg += f'<{one.model_name}>添加失败,失败原因:{str(e)}\n'
failed_models.append(one)
# 说明模型全部添加失败了
if len(success_models) == 0 and failed_msg:
LLMDao.delete_server_by_id(ret.id)
raise ServerAddAllError.http_exception(failed_msg)
elif len(success_models) > 0 and failed_msg:
# 部分模型添加成功了, 删除失败的模型信息
ret.models = success_models
LLMDao.delete_model_by_ids(model_ids=[one.id for one in failed_models])
cls.add_llm_server_hook(request, login_user, ret)
raise ServerAddError.http_exception(f"<{success_msg.rstrip(',')}>添加成功,{failed_msg}")
cls.add_llm_server_hook(request, login_user, ret)
return ret
功能
- 添加一个新的 LLM 服务和其模型列表。
- 尝试初始化模型,如果模型初始化失败,进行相应的错误处理。
逻辑
- 检查服务名称是否已存在:
- 调用
LLMDao.get_server_by_name(server.name)检查服务名称是否已存在。 - 如果存在,抛出
ServerExistError异常。
- 调用
- 构建模型字典:
- 初始化
model_dict,用于存储模型信息,键为模型名称。 - 遍历
server.models,检查模型名称是否重复,如果重复,抛出ModelNameRepeatError异常。 - 创建
LLMModel实例,添加到model_dict。
- 初始化
- 创建 LLM 服务:
- 创建
LLMServer实例(排除models字段),设置user_id。 - 调用
LLMDao.insert_server_with_models,将服务和模型列表插入数据库。
- 创建
- 获取新添加的服务信息:
- 调用
cls.get_one_llm获取新添加的服务和模型信息。
- 调用
- 尝试初始化模型:
- 遍历服务的模型列表,尝试实例化每个模型。
- 如果模型类型是
LLM,调用cls.get_bisheng_llm实例化。 - 如果模型类型是
EMBEDDING,调用cls.get_bisheng_embedding实例化。
- 如果模型类型是
- 如果实例化成功,记录成功信息。
- 如果实例化失败,记录失败信息,将模型添加到
failed_models列表。
- 遍历服务的模型列表,尝试实例化每个模型。
- 处理初始化结果:
- 如果所有模型都初始化失败,删除刚添加的服务,抛出
ServerAddAllError异常。 - 如果部分模型初始化失败,删除失败的模型记录,抛出
ServerAddError异常,返回成功和失败的模型信息。 - 如果所有模型初始化成功,调用
cls.add_llm_server_hook进行后续处理,返回服务信息。
- 如果所有模型都初始化失败,删除刚添加的服务,抛出
4. add_llm_server_hook
@classmethod
def add_llm_server_hook(cls, request: Request, login_user: UserPayload, server: LLMServerInfo) -> bool:
""" 添加一个服务提供方后的后续动作 """
handle_types = []
for one in server.models:
if one.model_type in handle_types:
continue
handle_types.append(one.model_type)
model_info = LLMDao.get_model_by_type(LLMModelType(one.model_type))
# 判断是否是首个 LLM 或 Embedding 模型
if model_info.id == one.id:
cls.set_default_model(request, login_user, model_info)
return True
功能
- 添加 LLM 服务后,检查是否需要设置默认的模型配置。
- 如果新添加的模型是系统中的首个同类型模型(如首个 LLM 或首个 Embedding 模型),则设置为默认模型。
逻辑
- 初始化处理类型列表:
- 初始化
handle_types,用于记录已处理的模型类型。
- 初始化
- 遍历服务的模型列表:
- 遍历 server.models,对于每个模型:
- 如果模型类型已处理,跳过。
- 将模型类型添加到
handle_types。
- 遍历 server.models,对于每个模型:
- 检查是否是首个模型:
- 调用
LLMDao.get_model_by_type,获取数据库中指定类型的模型。 - 如果模型的 ID 与新添加的模型 ID 相同,说明这是首个同类型的模型。
- 调用
- 设置默认模型:
- 调用
cls.set_default_model,将新模型设置为默认模型配置。
- 调用
5. set_default_model
@classmethod
def set_default_model(cls, request: Request, login_user: UserPayload, model: LLMModel):
""" 设置默认的模型配置 """
if model.model_type == LLMModelType.LLM.value:
# 设置知识库的默认模型配置
knowledge_llm = cls.get_knowledge_llm()
knowledge_change = False
if not knowledge_llm.extract_title_model_id:
knowledge_llm.extract_title_model_id = model.id
knowledge_change = True
if not knowledge_llm.source_model_id:
knowledge_llm.source_model_id = model.id
knowledge_change = True
if not knowledge_llm.qa_similar_model_id:
knowledge_llm.qa_similar_model_id = model.id
knowledge_change = True
if knowledge_change:
cls.update_knowledge_llm(request, login_user, knowledge_llm)
# 设置评测的默认模型配置
evaluation_llm = cls.get_evaluation_llm()
if not evaluation_llm.model_id:
evaluation_llm.model_id = model.id
cls.update_evaluation_llm(request, login_user, evaluation_llm)
# 设置助手的默认模型配置
assistant_llm = cls.get_assistant_llm()
assistant_change = False
if not assistant_llm.auto_llm:
assistant_llm.auto_llm = AssistantLLMItem(model_id=model.id)
assistant_change = True
if not assistant_llm.llm_list:
assistant_change = True
assistant_llm.llm_list = [
AssistantLLMItem(model_id=model.id, default=True)
]
if assistant_change:
cls.update_assistant_llm(request, login_user, assistant_llm)
elif model.model_type == LLMModelType.EMBEDDING.value:
knowledge_llm = cls.get_knowledge_llm()
if not knowledge_llm.embedding_model_id:
knowledge_llm.embedding_model_id = model.id
cls.update_knowledge_llm(request, login_user, knowledge_llm)
功能
- 设置默认的模型配置,包括知识库、评测、助手等模块。
逻辑
-
处理 LLM 模型类型:
- 如果模型类型是
LLM,执行以下操作:- 知识库模块:
- 获取当前的知识库 LLM 配置
knowledge_llm。 - 初始化标志
knowledge_change为False。 - 检查知识库配置中的模型 ID,如果为空,则设置为当前模型的 ID,设置
knowledge_change为True。 - 如果配置有变化,调用
cls.update_knowledge_llm更新配置。
- 获取当前的知识库 LLM 配置
- 评测模块:
- 获取当前的评测 LLM 配置
evaluation_llm。 - 如果模型 ID 为空,设置为当前模型的 ID。
- 调用
cls.update_evaluation_llm更新配置。
- 获取当前的评测 LLM 配置
- 助手模块:
- 获取当前的助手 LLM 配置
assistant_llm。 - 初始化标志
assistant_change为False。 - 检查助手配置中的自动 LLM,如果为空,设置为当前模型。
- 检查助手配置中的 LLM 列表,如果为空,添加当前模型为默认模型。
- 如果配置有变化,调用
cls.update_assistant_llm更新配置。
- 获取当前的助手 LLM 配置
- 知识库模块:
- 如果模型类型是
-
处理 Embedding 模型类型:
-
如果模型类型是
EMBEDDING,执行以下操作:-
知识库模块
:
- 获取当前的知识库 LLM 配置
knowledge_llm。 - 如果嵌入模型 ID 为空,设置为当前模型的 ID。
- 调用
cls.update_knowledge_llm更新配置。
- 获取当前的知识库 LLM 配置
-
-
6. get_bisheng_llm 和 get_bisheng_embedding
@classmethod
def get_bisheng_llm(cls, **kwargs) -> BaseChatModel:
""" 获取必胜的 LLM 模型实例 """
class_object = import_by_type(_type='llms', name='BishengLLM')
return instantiate_llm('BishengLLM', class_object, kwargs)
@classmethod
def get_bisheng_embedding(cls, **kwargs) -> Embeddings:
""" 获取必胜的嵌入模型实例 """
class_object = import_by_type(_type='embeddings', name='BishengEmbedding')
return instantiate_embedding(class_object, kwargs)
功能
- 动态导入和实例化必胜的 LLM 模型和嵌入模型。
逻辑
get_bisheng_llm:- 调用
import_by_type,根据类型'llms'和名称'BishengLLM'导入 LLM 类。 - 调用
instantiate_llm,传入类名'BishengLLM'、类对象和参数kwargs,实例化 LLM 模型。
- 调用
get_bisheng_embedding:- 调用
import_by_type,根据类型'embeddings'和名称'BishengEmbedding'导入嵌入类。 - 调用
instantiate_embedding,传入类对象和参数kwargs,实例化嵌入模型。
- 调用
7. get_assistant_llm_list
@classmethod
def get_assistant_llm_list(cls, request: Request, login_user: UserPayload) -> List[LLMServerInfo]:
""" 获取助手可选的模型列表 """
assistant_llm = cls.get_assistant_llm()
if not assistant_llm.llm_list:
return []
model_list = LLMDao.get_model_by_ids([one.model_id for one in assistant_llm.llm_list])
if not model_list:
return []
model_dict = {}
for one in model_list:
if one.server_id not in model_dict:
model_dict[one.server_id] = []
model_dict[one.server_id].append(LLMModelInfo(**one.dict(exclude={'config'})))
server_list = LLMDao.get_server_by_ids(list(model_dict.keys()))
ret = []
for one in server_list:
ret.append(LLMServerInfo(**one.dict(exclude={'config'}), models=model_dict[one.id]))
return ret
功能
- 获取助手模块可选的模型列表,供用户选择。
逻辑
- 获取助手 LLM 配置:
- 调用
cls.get_assistant_llm()获取助手的 LLM 配置assistant_llm。
- 调用
- 检查是否有可用的模型列表:
- 如果
assistant_llm.llm_list为空,返回空列表。
- 如果
- 获取模型列表:
- 根据模型 ID 列表,调用
LLMDao.get_model_by_ids获取模型列表model_list。 - 如果模型列表为空,返回空列表。
- 根据模型 ID 列表,调用
- 构建模型字典:
- 初始化
model_dict,以服务 ID 为键,模型信息列表为值。 - 遍历模型列表,将模型信息添加到对应的服务下(排除
config字段)。
- 初始化
- 获取服务列表:
- 根据服务 ID 列表,调用
LLMDao.get_server_by_ids获取服务列表。
- 根据服务 ID 列表,调用
- 组装返回结果:
- 遍历服务列表,构建
LLMServerInfo对象,包含服务信息和模型列表。 - 返回服务信息列表。
- 遍历服务列表,构建
总结
LLMService 类提供了对 LLM 服务和模型的管理功能,包括添加、删除、更新服务和模型,获取默认模型配置,设置默认模型,以及获取可选的模型列表等。
- 模型初始化和错误处理:在添加模型时,尝试实例化模型,处理可能的初始化失败情况,保证系统的稳定性。
- 默认模型配置:当添加新的模型时,如果是系统中的首个同类型模型,自动设置为默认模型,方便后续模块使用。
- 动态实例化模型:通过动态导入和实例化机制,支持不同类型的模型,增强系统的灵活性。
message
导入模块
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from uuid import UUID
from sqlalchemy.sql import not_
from sqlalchemy import JSON, Column, DateTime, String, Text, case, func, or_, text, update
from sqlmodel import Field, delete, select
from pydantic import BaseModel
from loguru import logger
from bisheng.database.base import session_getter
from bisheng.database.models.base import SQLModelSerializable
标准库导入
datetime:用于处理日期和时间。typing:提供类型提示支持。Dict:字典类型。List:列表类型。Optional:可选类型,可能为None。Tuple:元组类型。
uuid.UUID:用于处理 UUID 类型的数据。
第三方库导入
-
sqlalchemy模块:用于与数据库进行交互。func:SQL 函数,如count、max等。select、update、delete、or_、not_:构建 SQL 查询语句的函数。Column:定义数据库列的属性。DateTime、String、Text、JSON:定义数据库列的数据类型。text:用于编写原生 SQL 语句。case:SQL 中的 CASE 表达式,用于条件判断。
-
sqlmodel:基于 SQLAlchemy 的 ORM 库,简化数据库操作。
Field:用于定义模型字段。
-
pydantic.BaseModel:用于数据验证和序列化的基类。 -
loguru.logger:用于日志记录。
项目内部模块导入
bisheng.database.base.session_getter:用于获取数据库会话的函数。bisheng.database.models.base.SQLModelSerializable:自定义的基类,继承自SQLModel,用于序列化。
1. MessageBase 类
class MessageBase(SQLModelSerializable):
is_bot: bool = Field(index=False, description='聊天角色')
source: Optional[int] = Field(index=False, description='是否支持溯源')
mark_status: Optional[int] = Field(index=False, default=1, description='标记状态')
mark_user: Optional[int] = Field(index=False, description='标记用户')
mark_user_name: Optional[str] = Field(index=False, description='标记用户')
message: Optional[str] = Field(sa_column=Column(Text), description='聊天消息')
extra: Optional[str] = Field(sa_column=Column(String(length=4096)), description='连接信息等')
type: str = Field(index=False, description='消息类型')
category: str = Field(index=False, description='消息类别, question等')
flow_id: UUID = Field(index=True, description='对应的技能id')
chat_id: Optional[str] = Field(index=True, description='chat_id, 前端生成')
user_id: Optional[str] = Field(index=True, description='用户id')
liked: Optional[int] = Field(index=False, default=0, description='用户是否喜欢 0未评价/1 喜欢/2 不喜欢')
solved: Optional[int] = Field(index=False, default=0, description='用户是否喜欢 0未评价/1 解决/2 未解决')
copied: Optional[int] = Field(index=False, default=0, description='用户是否复制 0:未复制 1:已复制')
sender: Optional[str] = Field(index=False, default='', description='autogen 的发送方')
receiver: Optional[Dict] = Field(index=False, default=None, description='autogen 的接收方')
intermediate_steps: Optional[str] = Field(sa_column=Column(Text), description='过程日志')
files: Optional[str] = Field(sa_column=Column(String(length=4096)), description='上传的文件等')
remark: Optional[str] = Field(sa_column=Column(String(length=4096)),
description='备注。break_answer: 中断的回复不作为history传给模型')
create_time: Optional[datetime] = Field(
sa_column=Column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP')))
update_time: Optional[datetime] = Field(
sa_column=Column(DateTime,
nullable=False,
index=True,
server_default=text('CURRENT_TIMESTAMP'),
onupdate=text('CURRENT_TIMESTAMP')))
功能概述
MessageBase类是所有消息类的基类,继承自SQLModelSerializable。- 定义了与聊天消息相关的公共字段和属性。
- 使用
Field和Column来定义数据库字段的属性和类型。
字段解析
is_bot:bool类型,表示消息是否来自机器人。source:Optional[int],表示是否支持溯源。mark_status:Optional[int],标记状态,默认值为1。mark_user:Optional[int],标记用户的 ID。mark_user_name:Optional[str],标记用户的名称。message:Optional[str],聊天消息内容,使用Text类型存储较长的文本。extra:Optional[str],额外信息,如连接信息等,限制长度为 4096 个字符。type:str,消息类型。category:str,消息类别,如question等。flow_id:UUID,对应的技能 ID。chat_id:Optional[str],聊天会话的 ID,由前端生成。user_id:Optional[str],用户的 ID。liked:Optional[int],用户对消息的喜欢状态,0未评价,1喜欢,2不喜欢。solved:Optional[int],问题是否解决,0未评价,1解决,2未解决。copied:Optional[int],消息是否被复制,0未复制,1已复制。sender:Optional[str],发送方(针对自动生成的消息)。receiver:Optional[Dict],接收方(针对自动生成的消息)。intermediate_steps:Optional[str],过程日志,存储消息生成过程的详细信息。files:Optional[str],上传的文件信息。remark:Optional[str],备注信息,如break_answer表示中断的回复不作为历史记录传给模型。create_time:Optional[datetime],消息的创建时间,默认值为当前时间。update_time:Optional[datetime],消息的更新时间,默认值为当前时间,并在更新时自动更新。
2. ChatMessage 类
class ChatMessage(MessageBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
receiver: Optional[Dict] = Field(default=None, sa_column=Column(JSON))
功能概述
ChatMessage类继承自MessageBase,并且指定了table=True,表示这是一个数据库表。- 增加了主键
id和重新定义了receiver字段。
字段解析
id:Optional[int],消息的主键,自增的整数。receiver:Optional[Dict],接收方信息,使用 JSON 类型存储。
3. 其他模型类
ChatMessageRead
class ChatMessageRead(MessageBase):
id: Optional[int]
- 用于读取消息时的数据模型,包含了
id字段。
ChatMessageQuery
class ChatMessageQuery(BaseModel):
id: Optional[int]
flow_id: str
chat_id: str
- 用于查询消息的请求体数据模型,包含
id、flow_id、chat_id。
ChatMessageCreate
class ChatMessageCreate(MessageBase):
pass
- 用于创建消息的请求体数据模型,继承自
MessageBase。
4. MessageDao 类
class MessageDao(MessageBase):
@classmethod
def static_msg_liked(cls, liked: int, flow_id: str, create_time_begin: datetime,
create_time_end: datetime):
base_condition = select(func.count(ChatMessage.id)).where(ChatMessage.liked == liked)
if flow_id:
base_condition = base_condition.where(ChatMessage.flow_id == flow_id)
if create_time_begin and create_time_end:
base_condition = base_condition.where(ChatMessage.create_time > create_time_begin,
ChatMessage.create_time < create_time_end)
with session_getter() as session:
return session.scalar(base_condition)
功能概述
MessageDao类提供了对消息数据的统计功能。static_msg_liked方法用于统计指定条件下的消息数量。
方法解析
static_msg_liked
- 参数:
liked:用户对消息的喜欢状态(0、1、2)。flow_id:技能 ID,用于过滤特定技能的消息。create_time_begin和create_time_end:开始和结束时间,用于时间范围过滤。
- 逻辑:
- 构建基础查询条件,统计满足
liked状态的消息数量。 - 如果提供了
flow_id,则添加flow_id的过滤条件。 - 如果提供了时间范围,添加
create_time的过滤条件。 - 使用数据库会话执行查询,返回统计结果。
- 构建基础查询条件,统计满足
5. ChatMessageDao 类
class ChatMessageDao(MessageBase):
...
功能概述
ChatMessageDao类提供了对ChatMessage表的常用数据库操作方法,包括查询、插入、更新、删除等。
方法列表
get_latest_message_by_chatid:获取指定聊天会话的最新一条消息。get_latest_message_by_chat_ids:获取多个聊天会话的最新消息。get_messages_by_chat_id:获取指定聊天会话的消息列表。get_last_msg_by_flow_id:获取指定技能 ID 的最后一条消息。get_msg_by_chat_id:获取指定聊天会话的所有消息。get_msg_by_flow:获取指定技能 ID 的消息列表。get_msg_by_flows:获取多个技能 ID 的消息列表。delete_by_user_chat_id:根据用户 ID 和聊天会话 ID 删除消息。delete_by_message_id:根据用户 ID 和消息 ID 删除消息。insert_one:插入一条新的消息。insert_batch:批量插入消息。get_message_by_id:根据消息 ID 获取消息。update_message:更新指定消息的内容。update_message_model:更新消息对象。update_message_copied:更新消息的复制状态。update_message_mark:更新消息的标记状态。
详细方法解析
get_latest_message_by_chatid
@classmethod
def get_latest_message_by_chatid(cls, chat_id: str):
with session_getter() as session:
res = session.exec(
select(ChatMessage).where(ChatMessage.chat_id == chat_id).limit(1)).all()
if res:
return res[0]
else:
return None
- 功能:获取指定聊天会话的最新一条消息。
- 逻辑:
- 使用
select查询指定chat_id的消息,限制结果数量为1。 - 如果查询结果存在,返回第一条消息,否则返回
None。
- 使用
get_latest_message_by_chat_ids
@classmethod
def get_latest_message_by_chat_ids(cls, chat_ids: list[str], category: str = None):
statement = select(ChatMessage.chat_id,
func.max(ChatMessage.id)).where(ChatMessage.chat_id.in_(chat_ids))
if category:
statement = statement.where(ChatMessage.category == category)
statement = statement.group_by(ChatMessage.chat_id)
with session_getter() as session:
res = session.exec(statement).all()
ids = [one[1] for one in res]
statement = select(ChatMessage).where(ChatMessage.id.in_(ids))
return session.exec(statement).all()
- 功能:获取多个聊天会话的最新消息。
- 参数:
chat_ids:聊天会话 ID 列表。category:可选,消息类别过滤。
- 逻辑:
- 构建查询,获取每个
chat_id的最大id,即最新的消息 ID。 - 如果提供了
category,添加过滤条件。 - 执行查询,得到最新消息的 ID 列表。
- 根据消息 ID 列表,查询消息的具体内容并返回。
- 构建查询,获取每个
get_messages_by_chat_id
@classmethod
def get_messages_by_chat_id(cls, chat_id: str, category_list: list = None, limit: int = 10):
with session_getter() as session:
statement = select(ChatMessage).where(ChatMessage.chat_id == chat_id)
if category_list:
statement = statement.where(ChatMessage.category.in_(category_list))
statement = statement.limit(limit).order_by(ChatMessage.create_time.asc())
return session.exec(statement).all()
- 功能:获取指定聊天会话的消息列表。
- 参数:
chat_id:聊天会话 ID。category_list:可选,消息类别列表。limit:限制返回的消息数量,默认为10。
- 逻辑:
- 构建查询,筛选指定
chat_id的消息。 - 如果提供了
category_list,添加类别过滤条件。 - 限制返回的消息数量,按照创建时间升序排序。
- 执行查询并返回结果。
- 构建查询,筛选指定
insert_one
@classmethod
def insert_one(cls, message: ChatMessage) -> ChatMessage:
with session_getter() as session:
session.add(message)
session.commit()
session.refresh(message)
return message
- 功能:插入一条新的消息记录。
- 逻辑:
- 使用数据库会话添加消息对象。
- 提交事务。
- 刷新消息对象,获取数据库生成的字段(如自增的
id)。 - 返回消息对象。
update_message
@classmethod
def update_message(cls, message_id: int, user_id: int, message: str):
with session_getter() as session:
statement = update(ChatMessage).where(ChatMessage.id == message_id).where(
ChatMessage.user_id == user_id).values(message=message)
session.exec(statement)
session.commit()
- 功能:更新指定消息的内容。
- 参数:
message_id:消息 ID。user_id:用户 ID,确保只有消息的所有者才能更新消息。message:新的消息内容。
- 逻辑:
- 构建更新语句,筛选
id和user_id匹配的消息。 - 设置新的消息内容。
- 执行更新语句并提交事务。
- 构建更新语句,筛选

被折叠的 条评论
为什么被折叠?



