user_service

user_service

这个模块主要处理用户的身份验证、权限管理以及用户相关的业务逻辑,包括创建用户、生成 JWT 令牌、获取用户权限等。

以下是代码的逐行解析:

import functools
import json
from base64 import b64decode
from typing import List

import rsa
from fastapi import Depends, HTTPException, Request
from fastapi_jwt_auth import AuthJWT

from bisheng.api.JWT import ACCESS_TOKEN_EXPIRE_TIME
from bisheng.api.errcode.base import UnAuthorizedError
from bisheng.api.errcode.user import (
    UserLoginOfflineError,
    UserNameAlreadyExistError,
    UserNeedGroupAndRoleError
)
from bisheng.api.utils import md5_hash
from bisheng.api.v1.schemas import CreateUserReq
from bisheng.cache.redis import redis_client
from bisheng.database.models.assistant import Assistant, AssistantDao
from bisheng.database.models.flow import Flow, FlowDao, FlowRead
from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead
from bisheng.database.models.role import AdminRole
from bisheng.database.models.role_access import AccessType, RoleAccessDao
from bisheng.database.models.user import User, UserDao
from bisheng.database.models.user_group import UserGroupDao
from bisheng.database.models.user_role import UserRoleDao
from bisheng.settings import settings
from bisheng.utils.constants import RSA_KEY, USER_CURRENT_SESSION

1. 导入模块

标准库导入

  • functools:提供了高阶函数和可调用对象的操作工具。
  • json:用于处理 JSON 数据的序列化和反序列化。
  • base64:用于处理 Base64 编码和解码。
  • typing.List:用于类型注解,表示列表类型。

第三方库导入

  • rsa:用于 RSA 加密和解密。

  • fastapi

    模块:

    • Depends:用于依赖注入。
    • HTTPException:用于抛出 HTTP 异常。
    • Request:表示请求对象。
  • fastapi_jwt_auth.AuthJWT:用于处理 JWT 认证。

项目内部模块导入

  • bisheng.api.JWT.ACCESS_TOKEN_EXPIRE_TIME:JWT 访问令牌的过期时间设置。

  • bisheng.api.errcode

    模块:

    • UnAuthorizedError:未授权错误。
    • UserLoginOfflineError:用户被迫下线错误。
    • UserNameAlreadyExistError:用户名已存在错误。
    • UserNeedGroupAndRoleError:用户需要分组和角色错误。
  • bisheng.api.utils.md5_hash:用于生成 MD5 哈希值。

  • bisheng.api.v1.schemas.CreateUserReq:用户创建请求的数据模型。

  • bisheng.cache.redis.redis_client:Redis 客户端实例。

  • bisheng.database.models

    模块:

    • assistant:助手相关的模型和 DAO。
    • flow:流程相关的模型和 DAO。
    • knowledge:知识库相关的模型和 DAO。
    • role:角色相关的模型。
    • role_access:角色权限相关的模型和 DAO。
    • user:用户相关的模型和 DAO。
    • user_group:用户组相关的 DAO。
    • user_role:用户角色相关的 DAO。
  • bisheng.settings.settings:项目的配置信息。

  • bisheng.utils.constants:定义了一些常量,如 RSA_KEYUSER_CURRENT_SESSION

2. UserPayload

class UserPayload:
    def __init__(self, **kwargs):
        self.user_id = kwargs.get('user_id')
        self.user_role = kwargs.get('role')
        if self.user_role != 'admin':
            # 非管理员用户,需要获取他的角色列表
            roles = UserRoleDao.get_user_roles(self.user_id)
            self.user_role = [one.role_id for one in roles]
        self.user_name = kwargs.get('user_name')

    def is_admin(self):
        if self.user_role == 'admin':
            return True
        if isinstance(self.user_role, list):
            for one in self.user_role:
                if one == AdminRole:
                    return True
        return False

    @staticmethod
    def wrapper_access_check(func):
        """
        权限检查的装饰器
        如果是 admin 用户则不执行后续具体的检查逻辑
        """
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if args[0].is_admin():
                return True
            return func(*args, **kwargs)
        return wrapper

    @wrapper_access_check
    def access_check(self, owner_user_id: int, target_id: str, access_type: AccessType) -> bool:
        """
        检查用户是否有某个资源的权限
        """
        # 判断是否属于本人资源
        if self.user_id == owner_user_id:
            return True
        # 判断授权
        if RoleAccessDao.judge_role_access(self.user_role, target_id, access_type):
            return True
        return False

    @wrapper_access_check
    def check_group_admin(self, group_id: int) -> bool:
        """
        检查用户是否是某个组的管理员
        """
        # 判断是否是用户组的管理员
        user_group = UserGroupDao.get_user_admin_group(self.user_id)
        if not user_group:
            return False
        for one in user_group:
            if one.group_id == group_id:
                return True
        return False

    @wrapper_access_check
    def check_groups_admin(self, group_ids: List[int]) -> bool:
        """
        检查用户是否是用户组列表中的管理员,有一个就是 True
        """
        user_groups = UserGroupDao.get_user_admin_group(self.user_id)
        for one in user_groups:
            if one.is_group_admin and one.group_id in group_ids:
                return True
        return False

功能概述

UserPayload 类用于封装用户的基本信息和权限校验方法,包括用户 ID、用户名、用户角色等。它还提供了权限检查的方法,判断用户是否有权访问某些资源。

属性

  • user_id:用户的唯一标识符。
  • user_role:用户的角色,可以是 'admin' 或角色 ID 列表。
  • user_name:用户名。

方法

__init__
  • 功能:初始化用户信息。
  • 逻辑:
    • 获取 user_idrole(用户角色)。
    • 如果角色不是 'admin',则通过 UserRoleDao.get_user_roles 获取用户的角色列表,赋值给 self.user_role
is_admin
  • 功能:判断用户是否为管理员。
  • 返回值True 表示是管理员,False 表示不是。
  • 逻辑:
    • 如果 self.user_role'admin',直接返回 True
    • 如果 self.user_role 是列表,遍历角色列表,如果存在 AdminRole,返回 True
wrapper_access_check
  • 功能:权限检查的装饰器。
  • 逻辑:
    • 如果用户是管理员,直接返回 True,跳过后续的权限检查。
    • 否则,执行被装饰的函数。
access_check
  • 功能:检查用户是否有权限访问某个资源。
  • 参数:
    • owner_user_id:资源所有者的用户 ID。
    • target_id:目标资源的 ID。
    • access_type:访问类型,枚举类型 AccessType
  • 逻辑:
    • 如果用户是资源的所有者,返回 True
    • 通过 RoleAccessDao.judge_role_access 检查用户角色是否有权限访问目标资源。
check_group_admin
  • 功能:检查用户是否是某个用户组的管理员。
  • 参数:
    • group_id:用户组的 ID。
  • 逻辑:
    • 获取用户管理的用户组列表。
    • 如果用户在指定的用户组中,并且是管理员,返回 True
check_groups_admin
  • 功能:检查用户是否是多个用户组中的管理员。
  • 参数:
    • group_ids:用户组 ID 的列表。
  • 逻辑:
    • 获取用户管理的用户组列表。
    • 如果用户是任意一个指定用户组的管理员,返回 True

总结

UserPayload 类通过封装用户信息和权限检查方法,方便在项目的其他地方进行用户权限的验证和操作。


3. UserService

class UserService:
    @classmethod
    def decrypt_md5_password(cls, password: str):
        if value := redis_client.get(RSA_KEY):
            private_key = value[1]
            password = md5_hash(rsa.decrypt(b64decode(password), private_key).decode('utf-8'))
        else:
            password = md5_hash(password)
        return password

    @classmethod
    def create_user(cls, request: Request, login_user: UserPayload, req_data: CreateUserReq):
        """
        创建用户
        """
        exists_user = UserDao.get_user_by_username(req_data.user_name)
        if exists_user:
            # 抛出异常
            raise UserNameAlreadyExistError.http_exception()
        user = User(
            user_name=req_data.user_name,
            password=cls.decrypt_md5_password(req_data.password),
        )
        group_ids = []
        role_ids = []
        for one in req_data.group_roles:
            group_ids.append(one.group_id)
            role_ids.extend(one.role_ids)
        if not group_ids or not role_ids:
            raise UserNeedGroupAndRoleError.http_exception()
        user = UserDao.add_user_with_groups_and_roles(user, group_ids, role_ids)
        return user

功能概述

UserService 类负责处理与用户相关的业务逻辑,例如创建用户、密码加密等。

方法

decrypt_md5_password
  • 功能:解密并加密用户密码。
  • 参数:
    • password:加密的密码字符串。
  • 逻辑:
    • 从 Redis 中获取 RSA 私钥(RSA_KEY)。
      • 如果存在私钥,使用 RSA 解密密码,然后计算 MD5 哈希值。
      • 如果没有私钥,直接计算密码的 MD5 哈希值。
  • 返回值:加密后的密码(MD5 哈希值)。
create_user
  • 功能:创建新用户。
  • 参数:
    • request:请求对象。
    • login_user:当前登录的用户信息。
    • req_data:创建用户的请求数据(CreateUserReq)。
  • 逻辑:
    • 检查用户名是否已存在,如果存在,抛出 UserNameAlreadyExistError 异常。
    • 创建 User 实例,设置用户名和加密后的密码。
    • 解析用户的分组和角色信息:
      • 遍历 req_data.group_roles,收集 group_idsrole_ids
    • 检查是否提供了分组和角色,如果没有,抛出 UserNeedGroupAndRoleError 异常。
    • 调用 UserDao.add_user_with_groups_and_roles,将用户信息、分组和角色保存到数据库。
  • 返回值:创建的用户对象。

4. 其他函数

sso_login

def sso_login():
    pass
  • 功能:单点登录的占位函数,目前未实现具体逻辑。

gen_user_role

def gen_user_role(db_user: User):
    # 查询用户的角色列表
    db_user_role = UserRoleDao.get_user_roles(db_user.user_id)
    role = ''
    role_ids = []
    for user_role in db_user_role:
        if user_role.role_id == 1:
            # 是管理员,忽略其他的角色
            role = 'admin'
        else:
            role_ids.append(user_role.role_id)
    if role != 'admin':
        # 判断是否是用户组管理员
        db_user_groups = UserGroupDao.get_user_admin_group(db_user.user_id)
        if len(db_user_groups) > 0:
            role = 'group_admin'
        else:
            role = role_ids
    # 获取用户的菜单栏权限列表
    web_menu = RoleAccessDao.get_role_access(role_ids, AccessType.WEB_MENU)
    web_menu = list(set([one.third_id for one in web_menu]))
    return role, web_menu
  • 功能:生成用户的角色和菜单权限列表。
  • 参数:
    • db_user:数据库中的用户对象。
  • 逻辑:
    • 获取用户的角色列表。
      • 如果角色 ID 为 1,表示是管理员,设置 role'admin',忽略其他角色。
      • 否则,将角色 ID 添加到 role_ids 列表中。
    • 如果不是管理员:
      • 检查用户是否是用户组管理员。
        • 如果是,设置 role'group_admin'
        • 否则,rolerole_ids 列表。
    • 获取用户的菜单栏权限列表。
      • 通过 RoleAccessDao.get_role_access 获取角色的菜单权限(AccessType.WEB_MENU)。
      • 去重后,得到 web_menu 列表。
  • 返回值role(用户角色)和 web_menu(菜单权限列表)。

gen_user_jwt

def gen_user_jwt(db_user: User):
    if 1 == db_user.delete:
        raise HTTPException(status_code=500, detail='该账号已被禁用,请联系管理员')
    # 查询角色
    role, web_menu = gen_user_role(db_user)
    # 生成 JWT 令牌
    payload = {'user_name': db_user.user_name, 'user_id': db_user.user_id, 'role': role}
    # Create the tokens and passing to set_access_cookies or set_refresh_cookies
    access_token = AuthJWT().create_access_token(subject=json.dumps(payload),
                                                 expires_time=ACCESS_TOKEN_EXPIRE_TIME)

    refresh_token = AuthJWT().create_refresh_token(subject=db_user.user_name)

    # Set the JWT cookies in the response
    return access_token, refresh_token, role, web_menu
  • 功能:生成用户的 JWT 访问令牌和刷新令牌。
  • 参数:
    • db_user:数据库中的用户对象。
  • 逻辑:
    • 检查用户是否被禁用(delete 字段为 1),如果被禁用,抛出 HTTP 异常。
    • 调用 gen_user_role 获取用户的角色和菜单权限列表。
    • 创建 JWT 负载(payload),包括用户名、用户 ID 和角色。
    • 使用 AuthJWT生成访问令牌和刷新令牌。
      • 访问令牌:包含用户信息,设置过期时间。
      • 刷新令牌:用于刷新访问令牌。
  • 返回值access_tokenrefresh_tokenroleweb_menu

get_knowledge_list_by_access

def get_knowledge_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Knowledge.name.like('%{}%'.format(name)))

    db_role_access = KnowledgeDao.get_knowledge_by_access(role_id, page_num, page_size)
    total_count = KnowledgeDao.get_count_by_filter(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [
            KnowledgeRead.validate({
                'name': access[0].name,
                'user_name': user_dict.get(access[0].user_id),
                'user_id': access[0].user_id,
                'update_time': access[0].update_time,
                'id': access[0].id
            }) for access in db_role_access
        ],
        'total':
            total_count
    }
  • 功能:根据角色 ID 获取用户有权限访问的知识库列表。
  • 参数:
    • role_id:角色 ID。
    • name:知识库名称,支持模糊查询。
    • page_num:分页页码。
    • page_size:每页显示的数量。
  • 逻辑:
    • 根据名称构建过滤条件。
    • 调用 KnowledgeDao.get_knowledge_by_access 获取用户有权限的知识库列表。
    • 获取总数,便于分页。
    • 获取相关的用户信息,补充用户名。
    • 构建返回的数据结构,包括知识库信息和总数。
  • 返回值:包含 data(知识库列表)和 total(总数)的字典。

get_flow_list_by_access

def get_flow_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Flow.name.like('%{}%'.format(name)))

    db_role_access = FlowDao.get_flow_by_access(role_id, name, page_num, page_size)
    total_count = FlowDao.get_count_by_filters(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [
            FlowRead.validate({
                'name': access[0].name,
                'user_name': user_dict.get(access[0].user_id),
                'user_id': access[0].user_id,
                'update_time': access[0].update_time,
                'id': access[0].id
            }) for access in db_role_access
        ],
        'total':
            total_count
    }
  • 功能:根据角色 ID 获取用户有权限访问的流程列表。
  • 逻辑:与 get_knowledge_list_by_access 类似,区别在于处理的对象是流程(Flow)。

get_assistant_list_by_access

def get_assistant_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Assistant.name.like('%{}%'.format(name)))

    db_role_access = AssistantDao.get_assistants_by_access(role_id, name, page_size, page_num)
    total_count = AssistantDao.get_count_by_filters(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [{
            'name': access[0].name,
            'user_name': user_dict.get(access[0].user_id),
            'user_id': access[0].user_id,
            'update_time': access[0].update_time,
            'id': access[0].id
        } for access in db_role_access],
        'total':
            total_count
    }
  • 功能:根据角色 ID 获取用户有权限访问的助手列表。
  • 逻辑:与前两个函数类似,处理的对象是助手(Assistant)。

5. 获取当前登录用户

get_login_user

async def get_login_user(authorize: AuthJWT = Depends()) -> UserPayload:
    """
    获取当前登录的用户
    """
    # 校验是否过期,过期则直接返回 http 状态码的 401
    authorize.jwt_required()

    current_user = json.loads(authorize.get_jwt_subject())
    user = UserPayload(**current_user)

    # 判断是否允许多点登录
    if not settings.get_system_login_method().allow_multi_login:
        # 获取 access_token
        current_token = redis_client.get(USER_CURRENT_SESSION.format(user.user_id))
        # 登录被挤下线了,http 状态码是 200, status_code 是特殊 code
        if current_token != authorize._token:
            raise UserLoginOfflineError.http_exception()
    return user
  • 功能:获取当前登录的用户信息,进行权限验证和多点登录检查。
  • 参数:
    • authorize:依赖注入的 AuthJWT 实例,用于处理 JWT 认证。
  • 逻辑:
    • 调用 authorize.jwt_required() 验证 JWT 是否存在和有效,若无效则返回 401 状态码。
    • 从 JWT 中获取用户信息,创建 UserPayload 实例。
    • 检查系统设置是否允许多点登录。
      • 如果不允许,检查当前用户的会话是否与 Redis 中存储的令牌一致。
      • 如果不一致,说明用户被挤下线,抛出 UserLoginOfflineError 异常。
  • 返回值UserPayload 实例,包含当前用户的信息。

get_admin_user

async def get_admin_user(authorize: AuthJWT = Depends()) -> UserPayload:
    """
    获取超级管理员账号,非超级管理员用户,抛出异常
    """
    login_user = await get_login_user(authorize)
    if not login_user.is_admin():
        raise UnAuthorizedError.http_exception()
    return login_user
  • 功能:获取当前登录的超级管理员用户,如果不是管理员,抛出未授权异常。
  • 逻辑:
    • 调用 get_login_user 获取当前登录用户。
    • 检查用户是否为管理员。
      • 如果不是,抛出 UnAuthorizedError 异常。
  • 返回值UserPayload 实例,包含管理员用户的信息。

6. 总结

这个 user_service.py 模块主要处理了以下功能:

  • 用户权限管理:通过 UserPayload 类和相关的权限检查方法,管理用户的权限和角色。
  • 用户创建:提供创建用户的功能,包括密码加密、用户名重复检查、分组和角色的分配等。
  • JWT 认证:生成 JWT 令牌,验证用户的身份,支持多点登录的控制。
  • 资源访问:根据用户的角色,获取用户有权限访问的知识库、流程和助手列表。
  • 异常处理:在权限不足、用户被禁用、用户名重复等情况下,抛出相应的异常,提供友好的错误信息。

manager

好的,我来为您详细讲解 manager.py 这个模块的代码。这个模块主要涉及到观察者模式的实现和一个缓存管理器的设计。以下是对代码的详细解析:

导入模块

from contextlib import contextmanager
from typing import Any, Awaitable, Callable, List, Optional

import pandas as pd
from bisheng.utils.util import get_cache_key
from PIL import Image
  • contextlib.contextmanager:用于创建上下文管理器的装饰器,方便使用 with 语句。

  • typing 模块

    :提供类型注解支持。

    • Any:任意类型。
    • Awaitable:可等待对象,通常用于异步函数。
    • Callable:可调用对象类型。
    • List:列表类型。
    • Optional:可选类型,可能为 None
  • pandas as pd:用于数据处理的库,简称为 pd

  • bisheng.utils.util.get_cache_key:项目内部的工具函数,用于生成缓存键值。

  • PIL.Image:Python 图像库,处理图像相关操作。


1. Subject

class Subject:
    """实现观察者模式的基础类。"""

    def __init__(self):
        self.observers: List[Callable[[], None]] = []

    def attach(self, observer: Callable[[], None]):
        """向主题添加一个观察者。"""
        self.observers.append(observer)

    def detach(self, observer: Callable[[], None]):
        """从主题中移除一个观察者。"""
        self.observers.remove(observer)

    def notify(self):
        """通知所有观察者一个事件发生。"""
        for observer in self.observers:
            if observer is None:
                continue
            observer()

功能概述

  • Subject 类是实现观察者模式的基础类。
  • 它维护了一个观察者列表,当主题发生变化时,通知所有的观察者。

方法解析

  • __init__
    • 初始化方法,创建一个空的观察者列表 self.observers
  • attach
    • 添加一个观察者到观察者列表。
    • 参数:
      • observer: 一个不接受任何参数且返回 None 的可调用对象(函数)。
  • detach
    • 从观察者列表中移除一个观察者。
  • notify
    • 通知所有观察者,调用它们的回调函数。
    • 遍历 self.observers 列表,逐个调用每个观察者。
    • 如果观察者为 None,则跳过。

2. AsyncSubject

class AsyncSubject:
    """实现异步观察者模式的基础类。"""

    def __init__(self):
        self.observers: List[Callable[[], Awaitable]] = []

    def attach(self, observer: Callable[[], Awaitable]):
        """向主题添加一个异步观察者。"""
        self.observers.append(observer)

    def detach(self, observer: Callable[[], Awaitable]):
        """从主题中移除一个异步观察者。"""
        self.observers.remove(observer)

    async def notify(self):
        """异步通知所有观察者一个事件发生。"""
        for observer in self.observers:
            if observer is None:
                continue
            await observer()

功能概述

  • AsyncSubject 类是针对异步场景的观察者模式实现。
  • 它允许观察者是异步函数,支持 await 调用。

方法解析

  • __init__
    • 初始化方法,创建一个空的异步观察者列表 self.observers
  • attach
    • 添加一个异步观察者到观察者列表。
    • 参数:
      • observer: 一个不接受任何参数且返回 Awaitable 的可调用对象(异步函数)。
  • detach
    • 从观察者列表中移除一个异步观察者。
  • notify
    • 异步通知所有观察者。
    • 使用 await 调用每个异步观察者。

3. CacheManager

class CacheManager(Subject):
    """管理不同客户端的缓存,并在发生变化时通知观察者。"""

    def __init__(self):
        super().__init__()
        self._cache = {}
        self.current_client_id = None
        self.current_chat_id = None
        self.current_cache = {}

    @contextmanager
    def set_client_id(self, client_id: str, chat_id: str):
        """
        上下文管理器,用于设置当前的 client_id 和关联的缓存。

        参数:
            client_id (str): 客户端标识符。
            chat_id (str): 聊天标识符。
        """
        previous_client_id = self.current_client_id
        previous_chat_id = self.current_chat_id
        self.current_client_id = client_id
        self.current_chat_id = chat_id
        self.current_cache = self._cache.setdefault(get_cache_key(client_id, chat_id), {})
        try:
            yield
        finally:
            self.current_client_id = previous_client_id
            self.current_chat_id = previous_chat_id
            self.current_cache = self._cache.get(get_cache_key(self.current_client_id, self.current_chat_id), {})
    
    def add(self, name: str, obj: Any, obj_type: str, extension: Optional[str] = None):
        """
        向当前客户端的缓存中添加一个对象。

        参数:
            name (str): 缓存的键名。
            obj (Any): 要缓存的对象。
            obj_type (str): 对象的类型。
            extension (Optional[str]): 文件扩展名,默认为 None。
        """
        object_extensions = {
            'image': 'png',
            'pandas': 'csv',
        }
        if obj_type in object_extensions:
            _extension = object_extensions[obj_type]
        else:
            _extension = type(obj).__name__.lower()
        self.current_cache[name] = {
            'obj': obj,
            'type': obj_type,
            'extension': extension or _extension,
        }
        self.notify()
    
    def add_pandas(self, name: str, obj: Any):
        """
        向当前客户端的缓存中添加一个 pandas DataFrame 或 Series。

        参数:
            name (str): 缓存的键名。
            obj (Any): pandas DataFrame 或 Series 对象。
        """
        if isinstance(obj, (pd.DataFrame, pd.Series)):
            self.add(name, obj.to_csv(), 'pandas', extension='csv')
        else:
            raise ValueError('Object is not a pandas DataFrame or Series')
    
    def add_image(self, name: str, obj: Any, extension: str = 'png'):
        """
        向当前客户端的缓存中添加一个 PIL Image。

        参数:
            name (str): 缓存的键名。
            obj (Any): PIL Image 对象。
            extension (str): 文件扩展名,默认为 'png'。
        """
        if isinstance(obj, Image.Image):
            self.add(name, obj, 'image', extension=extension)
        else:
            raise ValueError('Object is not a PIL Image')
    
    def get(self, name: str):
        """
        从当前客户端的缓存中获取一个对象。

        参数:
            name (str): 缓存的键名。

        返回:
            与给定键名关联的缓存对象。
        """
        return self.current_cache[name]
    
    def get_last(self):
        """
        获取当前客户端缓存中最后添加的对象。

        返回:
            缓存中最后添加的对象。
        """
        return list(self.current_cache.values())[-1]

功能概述

  • CacheManager 类继承自 Subject,因此具备观察者模式的功能。
  • 该类管理不同客户端(client)的缓存数据,并在缓存发生变化时通知观察者。
  • 支持对缓存的添加、获取操作,以及对特定类型对象的专门处理(如 pandas 数据和图像)。

属性

  • _cache:存储所有客户端的缓存数据,结构为 {client_cache_key: client_cache_dict}
  • current_client_id:当前操作的客户端 ID。
  • current_chat_id:当前操作的聊天 ID。
  • current_cache:当前客户端的缓存字典。

方法解析

__init__
  • 初始化方法,调用父类的初始化方法。
  • 初始化缓存相关的属性。
set_client_id
  • 上下文管理器,用于设置当前的客户端 ID 和聊天 ID,以及关联的缓存。
  • 参数:
    • client_id:客户端标识符。
    • chat_id:聊天标识符。
  • 功能:
    • 在进入上下文时,保存之前的 client_idchat_id
    • 设置新的 current_client_idcurrent_chat_id
    • 获取或创建对应的 current_cache
    • 在退出上下文时,恢复之前的 client_idchat_idcurrent_cache

使用示例

with cache_manager.set_client_id('client1', 'chat1'):
    # 在这个上下文中,current_client_id 和 current_chat_id 被设置为 'client1' 和 'chat1'
    # 可以进行缓存操作
    cache_manager.add('key1', 'value1', 'string')
add
  • 向当前客户端的缓存中添加一个对象。
  • 参数:
    • name:缓存的键名。
    • obj:要缓存的对象。
    • obj_type:对象的类型,字符串形式。
    • extension:文件扩展名,可选。
  • 功能:
    • 根据 obj_type,确定对象的扩展名。如果未提供 extension,则使用预定义的或对象类型的名称。
    • 将对象存储在 current_cache 中,键为 name,值为一个包含对象信息的字典。
    • 调用 self.notify(),通知所有观察者缓存已更新。
add_pandas
  • 专门处理 pandas 数据的添加方法。
  • 参数:
    • name:缓存的键名。
    • obj:pandas DataFrame 或 Series 对象。
  • 功能:
    • 检查对象是否为 pandas DataFrame 或 Series。
    • 将对象转换为 CSV 格式的字符串。
    • 调用 self.add 方法,将其添加到缓存中,类型设为 'pandas',扩展名为 'csv'
  • 异常处理:
    • 如果对象不是 pandas 数据类型,抛出 ValueError
add_image
  • 专门处理图像数据的添加方法。
  • 参数:
    • name:缓存的键名。
    • obj:PIL Image 对象。
    • extension:文件扩展名,默认为 'png'
  • 功能:
    • 检查对象是否为 PIL Image 类型。
    • 调用 self.add 方法,将其添加到缓存中,类型设为 'image'
  • 异常处理:
    • 如果对象不是 PIL Image 类型,抛出 ValueError
get
  • 从当前客户端的缓存中获取指定键名的对象。
  • 参数:
    • name:缓存的键名。
  • 返回:
    • 缓存中对应键名的对象信息字典。
get_last
  • 获取当前客户端缓存中最后添加的对象。
  • 返回:
    • 缓存中最后一个添加的对象信息字典。

4. 模块级实例

cache_manager = CacheManager()
  • 在模块级别创建了一个 CacheManager 的实例 cache_manager,供全局使用。

5. 总结

  • 观察者模式:
    • SubjectAsyncSubject 实现了同步和异步的观察者模式基础类。
    • 提供了 attachdetachnotify 方法,用于管理观察者和通知事件。
  • 缓存管理器:
    • CacheManager 继承自 Subject,实现了对不同客户端缓存的管理。
    • 支持上下文管理器,通过 set_client_id 方法设置当前操作的客户端和聊天会话。
    • 提供了 addadd_pandasadd_imagegetget_last 方法,方便缓存对象的添加和获取。
    • 在添加对象时,会通知所有已注册的观察者,便于在缓存更新时触发相应的处理。

使用场景

  • 缓存管理:
    • 在一个多用户、多会话的环境中,需要对不同的客户端和聊天会话进行缓存管理。
    • CacheManager 可以根据 client_idchat_id 进行区分,确保缓存的隔离性。
  • 观察者通知:
    • 当缓存发生变化时,可能需要更新界面、日志记录或触发其他操作。
    • 通过观察者模式,CacheManager 可以通知所有的观察者,执行相应的回调函数。

实际应用示例

# 定义一个观察者函数
def cache_updated():
    print("Cache has been updated.")

# 将观察者函数附加到缓存管理器
cache_manager.attach(cache_updated)

# 使用缓存管理器添加对象
with cache_manager.set_client_id('user123', 'chat456'):
    # 添加一个字符串对象
    cache_manager.add('greeting', 'Hello, World!', 'string')
    
    # 添加一个 pandas DataFrame
    df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
    cache_manager.add_pandas('dataframe', df)
    
    # 添加一个图像对象
    image = Image.new('RGB', (100, 100), color='red')
    cache_manager.add_image('red_square', image)

# 由于缓存发生了更新,观察者函数会被调用,输出:
# "Cache has been updated."

注意事项

  • 线程安全性:
    • 目前的实现未考虑线程安全性。如果在多线程环境中使用,需要添加线程锁来保护共享数据。
  • 错误处理:
    • 在获取缓存对象时,如果键名不存在,会抛出 KeyError。在实际使用中,应该添加相应的异常处理。
  • 缓存清理:
    • 目前未提供缓存清理或过期机制。如果缓存数据较多,可能会占用较多内存。

llm

导入模块

import json
from typing import List, Optional

from fastapi import Request
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel
from loguru import logger

from bisheng.api.errcode.base import NotFoundError
from bisheng.api.errcode.llm import (
    ServerExistError,
    ModelNameRepeatError,
    ServerAddError,
    ServerAddAllError
)
from bisheng.api.services.user_service import UserPayload
from bisheng.api.v1.schemas import (
    LLMServerInfo,
    LLMModelInfo,
    KnowledgeLLMConfig,
    AssistantLLMConfig,
    EvaluationLLMConfig,
    AssistantLLMItem,
    LLMServerCreateReq
)
from bisheng.database.models.config import ConfigDao, ConfigKeyEnum, Config
from bisheng.database.models.llm_server import (
    LLMDao,
    LLMServer,
    LLMModel,
    LLMModelType
)
from bisheng.interface.importing import import_by_type
from bisheng.interface.initialize.loading import instantiate_llm, instantiate_embedding

标准库导入

  • json:用于处理 JSON 数据的序列化和反序列化。
  • typing 模块:提供类型注解支持。
    • List:列表类型注解。
    • Optional:可选类型,可能为 None

第三方库导入

  • fastapi.Request:FastAPI 的请求对象。
  • langchain_core.embeddings.Embeddings:LangChain 中的嵌入向量模型类。
  • langchain_core.language_models.BaseChatModel:LangChain 中的基础聊天模型类。
  • loguru.logger:用于日志记录的库。

项目内部模块导入

  • 错误码模块:
    • NotFoundError:未找到错误。
    • ServerExistError:服务已存在错误。
    • ModelNameRepeatError:模型名称重复错误。
    • ServerAddError:添加服务错误。
    • ServerAddAllError:添加服务全部失败错误。
  • 用户服务模块:
    • UserPayload:用户信息载体类。
  • 数据模型和请求体定义:
    • LLMServerInfo:LLM 服务信息的数据模型。
    • LLMModelInfo:LLM 模型信息的数据模型。
    • KnowledgeLLMConfig:知识库 LLM 配置的数据模型。
    • AssistantLLMConfig:助手 LLM 配置的数据模型。
    • EvaluationLLMConfig:评估 LLM 配置的数据模型。
    • AssistantLLMItem:助手 LLM 项目的数据模型。
    • LLMServerCreateReq:创建 LLM 服务的请求体数据模型。
  • 配置相关的 DAO 和枚举:
    • ConfigDao:配置数据访问对象。
    • ConfigKeyEnum:配置键的枚举类型。
    • Config:配置的数据模型。
  • LLM 服务相关的 DAO 和模型:
    • LLMDao:LLM 数据访问对象。
    • LLMServer:LLM 服务的数据模型。
    • LLMModel:LLM 模型的数据模型。
    • LLMModelType:LLM 模型类型的枚举。
  • 动态导入和实例化模块:
    • import_by_type:根据类型导入模块的函数。
    • instantiate_llm:实例化 LLM 模型的函数。
    • instantiate_embedding:实例化嵌入模型的函数。

LLMService

class LLMService:
    ...

功能概述

LLMService 类是一个用于管理大型语言模型(LLM)服务的类,提供了对 LLM 服务和模型的增删改查,以及配置默认模型等功能。

方法列表

  1. get_all_llm:获取所有的 LLM 服务信息。
  2. get_one_llm:获取单个 LLM 服务的详细信息。
  3. add_llm_server:添加一个新的 LLM 服务。
  4. add_llm_server_hook:添加 LLM 服务后的后续处理。
  5. delete_llm_server:删除一个 LLM 服务。
  6. update_llm_server:更新 LLM 服务信息。
  7. update_model_online:更新模型的上线状态。
  8. get_knowledge_llm:获取知识库相关的默认模型配置。
  9. get_knowledge_source_llm:获取知识库溯源的默认模型实例。
  10. get_knowledge_similar_llm:获取知识库相似问的默认模型实例。
  11. update_knowledge_llm:更新知识库的默认模型配置。
  12. get_assistant_llm:获取助手相关的默认模型配置。
  13. update_assistant_llm:更新助手的默认模型配置。
  14. get_evaluation_llm:获取评测功能的默认模型配置。
  15. get_evaluation_llm_object:获取评测功能的默认模型实例。
  16. get_bisheng_llm:获取必胜的 LLM 模型实例。
  17. get_bisheng_embedding:获取必胜的嵌入模型实例。
  18. update_evaluation_llm:更新评测功能的默认模型配置。
  19. get_assistant_llm_list:获取助手可选的模型列表。
  20. set_default_model:设置默认的模型配置。

下面我们逐一解析其中的重要方法。


1. get_all_llm

@classmethod
def get_all_llm(cls, request: Request, login_user: UserPayload) -> List[LLMServerInfo]:
    """ 获取所有的模型数据,不包含 key 等敏感信息 """
    llm_servers = LLMDao.get_all_server()
    ret = []
    server_ids = []
    for one in llm_servers:
        server_ids.append(one.id)
        ret.append(LLMServerInfo(**one.model_dump(exclude={'config'})))

    llm_models = LLMDao.get_model_by_server_ids(server_ids)
    server_dicts = {}
    for one in llm_models:
        if one.server_id not in server_dicts:
            server_dicts[one.server_id] = []
        server_dicts[one.server_id].append(one.model_dump(exclude={'config'}))

    for one in ret:
        one.models = server_dicts.get(one.id, [])
    return ret
功能
  • 获取所有的 LLM 服务信息,包括每个服务下的模型列表。
  • 不包含敏感的配置信息,如密钥(config 字段)。
逻辑
  1. 获取所有 LLM 服务
    • 调用 LLMDao.get_all_server() 获取所有的 LLM 服务记录。
  2. 构建服务信息列表
    • 初始化返回列表 ret 和服务 ID 列表 server_ids
    • 遍历所有的 LLM 服务,提取每个服务的 ID,并使用 LLMServerInfo 数据模型(排除 config 字段)构建服务信息对象,添加到 ret 列表。
  3. 获取每个服务下的模型列表
    • 调用 LLMDao.get_model_by_server_ids(server_ids) 获取所有服务对应的模型列表。
    • 初始化一个字典 server_dicts,以服务 ID 为键,模型列表为值。
    • 遍历模型列表,将模型信息添加到对应的服务下(排除 config 字段)。
  4. 组装最终返回结果
    • 遍历服务信息列表 ret,为每个服务对象添加 models 属性,对应其模型列表。
  5. 返回结果
    • 返回包含所有服务和其模型列表的 LLMServerInfo 对象列表。

2. get_one_llm

@classmethod
def get_one_llm(cls, request: Request, login_user: UserPayload, server_id: int) -> LLMServerInfo:
    """ 获取一个服务提供方的详细信息,包含了 key 等敏感的配置信息 """
    llm = LLMDao.get_server_by_id(server_id)
    if not llm:
        raise NotFoundError.http_exception()

    models = LLMDao.get_model_by_server_ids([server_id])
    models = [LLMModelInfo(**one.model_dump()) for one in models]
    return LLMServerInfo(**llm.model_dump(), models=models)
功能
  • 获取指定的 LLM 服务的详细信息,包括敏感配置信息和模型列表。
逻辑
  1. 获取指定的 LLM 服务
    • 调用 LLMDao.get_server_by_id(server_id) 获取服务信息。
    • 如果服务不存在,抛出 NotFoundError 异常。
  2. 获取服务下的模型列表
    • 调用 LLMDao.get_model_by_server_ids([server_id]) 获取该服务下的模型列表。
    • 使用 LLMModelInfo 数据模型构建模型信息列表。
  3. 返回结果
    • 构建 LLMServerInfo 对象,包括服务信息和模型列表,返回给调用者。

3. add_llm_server

@classmethod
def add_llm_server(cls, request: Request, login_user: UserPayload, server: LLMServerCreateReq) -> LLMServerInfo:
    """ 添加一个服务提供方 """
    exist_server = LLMDao.get_server_by_name(server.name)
    if exist_server:
        raise ServerExistError.http_exception()

    # 尝试实例化下对应的模型组件是否可以成功初始化
    model_dict = {}
    for one in server.models:
        if one.model_name not in model_dict:
            model_dict[one.model_name] = LLMModel(**one.dict(), user_id=login_user.user_id)
        else:
            raise ModelNameRepeatError.http_exception()

    db_server = LLMServer(**server.dict(exclude={'models'}))
    db_server.user_id = login_user.user_id

    db_server = LLMDao.insert_server_with_models(db_server, list(model_dict.values()))

    ret = cls.get_one_llm(request, login_user, db_server.id)
    success_models = []
    success_msg = ''
    failed_models = []
    failed_msg = ''
    # 尝试实例化对应的模型,有报错的话删除
    for one in ret.models:
        try:
            if one.model_type == LLMModelType.LLM.value:
                cls.get_bisheng_llm(model_id=one.id, ignore_online=True)
            elif one.model_type == LLMModelType.EMBEDDING.value:
                cls.get_bisheng_embedding(model_id=one.id, ignore_online=True)
            success_msg += f'{one.model_name},'
            success_models.append(one)
        except Exception as e:
            logger.exception("init_model_error")
            # 模型初始化失败的话,不添加到模型列表里
            failed_msg += f'<{one.model_name}>添加失败,失败原因:{str(e)}\n'
            failed_models.append(one)

    # 说明模型全部添加失败了
    if len(success_models) == 0 and failed_msg:
        LLMDao.delete_server_by_id(ret.id)
        raise ServerAddAllError.http_exception(failed_msg)
    elif len(success_models) > 0 and failed_msg:
        # 部分模型添加成功了, 删除失败的模型信息
        ret.models = success_models
        LLMDao.delete_model_by_ids(model_ids=[one.id for one in failed_models])
        cls.add_llm_server_hook(request, login_user, ret)
        raise ServerAddError.http_exception(f"<{success_msg.rstrip(',')}>添加成功,{failed_msg}")

    cls.add_llm_server_hook(request, login_user, ret)
    return ret
功能
  • 添加一个新的 LLM 服务和其模型列表。
  • 尝试初始化模型,如果模型初始化失败,进行相应的错误处理。
逻辑
  1. 检查服务名称是否已存在
    • 调用 LLMDao.get_server_by_name(server.name) 检查服务名称是否已存在。
    • 如果存在,抛出 ServerExistError 异常。
  2. 构建模型字典
    • 初始化 model_dict,用于存储模型信息,键为模型名称。
    • 遍历 server.models,检查模型名称是否重复,如果重复,抛出 ModelNameRepeatError 异常。
    • 创建 LLMModel 实例,添加到 model_dict
  3. 创建 LLM 服务
    • 创建 LLMServer 实例(排除 models 字段),设置 user_id
    • 调用 LLMDao.insert_server_with_models,将服务和模型列表插入数据库。
  4. 获取新添加的服务信息
    • 调用 cls.get_one_llm 获取新添加的服务和模型信息。
  5. 尝试初始化模型
    • 遍历服务的模型列表,尝试实例化每个模型。
      • 如果模型类型是 LLM,调用 cls.get_bisheng_llm 实例化。
      • 如果模型类型是 EMBEDDING,调用 cls.get_bisheng_embedding 实例化。
    • 如果实例化成功,记录成功信息。
    • 如果实例化失败,记录失败信息,将模型添加到 failed_models 列表。
  6. 处理初始化结果
    • 如果所有模型都初始化失败,删除刚添加的服务,抛出 ServerAddAllError 异常。
    • 如果部分模型初始化失败,删除失败的模型记录,抛出 ServerAddError 异常,返回成功和失败的模型信息。
    • 如果所有模型初始化成功,调用 cls.add_llm_server_hook 进行后续处理,返回服务信息。

4. add_llm_server_hook

@classmethod
def add_llm_server_hook(cls, request: Request, login_user: UserPayload, server: LLMServerInfo) -> bool:
    """ 添加一个服务提供方后的后续动作 """
    handle_types = []
    for one in server.models:
        if one.model_type in handle_types:
            continue
        handle_types.append(one.model_type)
        model_info = LLMDao.get_model_by_type(LLMModelType(one.model_type))
        # 判断是否是首个 LLM 或 Embedding 模型
        if model_info.id == one.id:
            cls.set_default_model(request, login_user, model_info)
    return True
功能
  • 添加 LLM 服务后,检查是否需要设置默认的模型配置。
  • 如果新添加的模型是系统中的首个同类型模型(如首个 LLM 或首个 Embedding 模型),则设置为默认模型。
逻辑
  1. 初始化处理类型列表
    • 初始化 handle_types,用于记录已处理的模型类型。
  2. 遍历服务的模型列表
    • 遍历 server.models,对于每个模型:
      • 如果模型类型已处理,跳过。
      • 将模型类型添加到 handle_types
  3. 检查是否是首个模型
    • 调用 LLMDao.get_model_by_type,获取数据库中指定类型的模型。
    • 如果模型的 ID 与新添加的模型 ID 相同,说明这是首个同类型的模型。
  4. 设置默认模型
    • 调用 cls.set_default_model,将新模型设置为默认模型配置。

5. set_default_model

@classmethod
def set_default_model(cls, request: Request, login_user: UserPayload, model: LLMModel):
    """ 设置默认的模型配置 """
    if model.model_type == LLMModelType.LLM.value:
        # 设置知识库的默认模型配置
        knowledge_llm = cls.get_knowledge_llm()
        knowledge_change = False
        if not knowledge_llm.extract_title_model_id:
            knowledge_llm.extract_title_model_id = model.id
            knowledge_change = True
        if not knowledge_llm.source_model_id:
            knowledge_llm.source_model_id = model.id
            knowledge_change = True
        if not knowledge_llm.qa_similar_model_id:
            knowledge_llm.qa_similar_model_id = model.id
            knowledge_change = True
        if knowledge_change:
            cls.update_knowledge_llm(request, login_user, knowledge_llm)

        # 设置评测的默认模型配置
        evaluation_llm = cls.get_evaluation_llm()
        if not evaluation_llm.model_id:
            evaluation_llm.model_id = model.id
            cls.update_evaluation_llm(request, login_user, evaluation_llm)

        # 设置助手的默认模型配置
        assistant_llm = cls.get_assistant_llm()
        assistant_change = False
        if not assistant_llm.auto_llm:
            assistant_llm.auto_llm = AssistantLLMItem(model_id=model.id)
            assistant_change = True
        if not assistant_llm.llm_list:
            assistant_change = True
            assistant_llm.llm_list = [
                AssistantLLMItem(model_id=model.id, default=True)
            ]
        if assistant_change:
            cls.update_assistant_llm(request, login_user, assistant_llm)

    elif model.model_type == LLMModelType.EMBEDDING.value:
        knowledge_llm = cls.get_knowledge_llm()
        if not knowledge_llm.embedding_model_id:
            knowledge_llm.embedding_model_id = model.id
            cls.update_knowledge_llm(request, login_user, knowledge_llm)
功能
  • 设置默认的模型配置,包括知识库、评测、助手等模块。
逻辑
  1. 处理 LLM 模型类型

    • 如果模型类型是 LLM,执行以下操作:
      • 知识库模块
        • 获取当前的知识库 LLM 配置 knowledge_llm
        • 初始化标志 knowledge_changeFalse
        • 检查知识库配置中的模型 ID,如果为空,则设置为当前模型的 ID,设置 knowledge_changeTrue
        • 如果配置有变化,调用 cls.update_knowledge_llm 更新配置。
      • 评测模块
        • 获取当前的评测 LLM 配置 evaluation_llm
        • 如果模型 ID 为空,设置为当前模型的 ID。
        • 调用 cls.update_evaluation_llm 更新配置。
      • 助手模块
        • 获取当前的助手 LLM 配置 assistant_llm
        • 初始化标志 assistant_changeFalse
        • 检查助手配置中的自动 LLM,如果为空,设置为当前模型。
        • 检查助手配置中的 LLM 列表,如果为空,添加当前模型为默认模型。
        • 如果配置有变化,调用 cls.update_assistant_llm 更新配置。
  2. 处理 Embedding 模型类型

    • 如果模型类型是 EMBEDDING,执行以下操作:

      • 知识库模块

        • 获取当前的知识库 LLM 配置 knowledge_llm
        • 如果嵌入模型 ID 为空,设置为当前模型的 ID。
        • 调用 cls.update_knowledge_llm 更新配置。

6. get_bisheng_llmget_bisheng_embedding

@classmethod
def get_bisheng_llm(cls, **kwargs) -> BaseChatModel:
    """ 获取必胜的 LLM 模型实例 """
    class_object = import_by_type(_type='llms', name='BishengLLM')
    return instantiate_llm('BishengLLM', class_object, kwargs)

@classmethod
def get_bisheng_embedding(cls, **kwargs) -> Embeddings:
    """ 获取必胜的嵌入模型实例 """
    class_object = import_by_type(_type='embeddings', name='BishengEmbedding')
    return instantiate_embedding(class_object, kwargs)
功能
  • 动态导入和实例化必胜的 LLM 模型和嵌入模型。
逻辑
  1. get_bisheng_llm
    • 调用 import_by_type,根据类型 'llms' 和名称 'BishengLLM' 导入 LLM 类。
    • 调用 instantiate_llm,传入类名 'BishengLLM'、类对象和参数 kwargs,实例化 LLM 模型。
  2. get_bisheng_embedding
    • 调用 import_by_type,根据类型 'embeddings' 和名称 'BishengEmbedding' 导入嵌入类。
    • 调用 instantiate_embedding,传入类对象和参数 kwargs,实例化嵌入模型。

7. get_assistant_llm_list

@classmethod
def get_assistant_llm_list(cls, request: Request, login_user: UserPayload) -> List[LLMServerInfo]:
    """ 获取助手可选的模型列表 """
    assistant_llm = cls.get_assistant_llm()
    if not assistant_llm.llm_list:
        return []
    model_list = LLMDao.get_model_by_ids([one.model_id for one in assistant_llm.llm_list])
    if not model_list:
        return []

    model_dict = {}
    for one in model_list:
        if one.server_id not in model_dict:
            model_dict[one.server_id] = []
        model_dict[one.server_id].append(LLMModelInfo(**one.dict(exclude={'config'})))
    server_list = LLMDao.get_server_by_ids(list(model_dict.keys()))

    ret = []
    for one in server_list:
        ret.append(LLMServerInfo(**one.dict(exclude={'config'}), models=model_dict[one.id]))

    return ret
功能
  • 获取助手模块可选的模型列表,供用户选择。
逻辑
  1. 获取助手 LLM 配置
    • 调用 cls.get_assistant_llm() 获取助手的 LLM 配置 assistant_llm
  2. 检查是否有可用的模型列表
    • 如果 assistant_llm.llm_list 为空,返回空列表。
  3. 获取模型列表
    • 根据模型 ID 列表,调用 LLMDao.get_model_by_ids 获取模型列表 model_list
    • 如果模型列表为空,返回空列表。
  4. 构建模型字典
    • 初始化 model_dict,以服务 ID 为键,模型信息列表为值。
    • 遍历模型列表,将模型信息添加到对应的服务下(排除 config 字段)。
  5. 获取服务列表
    • 根据服务 ID 列表,调用 LLMDao.get_server_by_ids 获取服务列表。
  6. 组装返回结果
    • 遍历服务列表,构建 LLMServerInfo 对象,包含服务信息和模型列表。
    • 返回服务信息列表。

总结

LLMService 类提供了对 LLM 服务和模型的管理功能,包括添加、删除、更新服务和模型,获取默认模型配置,设置默认模型,以及获取可选的模型列表等。

  • 模型初始化和错误处理:在添加模型时,尝试实例化模型,处理可能的初始化失败情况,保证系统的稳定性。
  • 默认模型配置:当添加新的模型时,如果是系统中的首个同类型模型,自动设置为默认模型,方便后续模块使用。
  • 动态实例化模型:通过动态导入和实例化机制,支持不同类型的模型,增强系统的灵活性。

message

导入模块

from datetime import datetime
from typing import Dict, List, Optional, Tuple
from uuid import UUID

from sqlalchemy.sql import not_
from sqlalchemy import JSON, Column, DateTime, String, Text, case, func, or_, text, update
from sqlmodel import Field, delete, select
from pydantic import BaseModel
from loguru import logger

from bisheng.database.base import session_getter
from bisheng.database.models.base import SQLModelSerializable

标准库导入

  • datetime:用于处理日期和时间。
  • typing:提供类型提示支持。
    • Dict:字典类型。
    • List:列表类型。
    • Optional:可选类型,可能为 None
    • Tuple:元组类型。
  • uuid.UUID:用于处理 UUID 类型的数据。

第三方库导入

  • sqlalchemy 模块:用于与数据库进行交互。

    • func:SQL 函数,如 countmax 等。
    • selectupdatedeleteor_not_:构建 SQL 查询语句的函数。
    • Column:定义数据库列的属性。
    • DateTimeStringTextJSON:定义数据库列的数据类型。
    • text:用于编写原生 SQL 语句。
    • case:SQL 中的 CASE 表达式,用于条件判断。
  • sqlmodel

    :基于 SQLAlchemy 的 ORM 库,简化数据库操作。

    • Field:用于定义模型字段。
  • pydantic.BaseModel:用于数据验证和序列化的基类。

  • loguru.logger:用于日志记录。

项目内部模块导入

  • bisheng.database.base.session_getter:用于获取数据库会话的函数。
  • bisheng.database.models.base.SQLModelSerializable:自定义的基类,继承自 SQLModel,用于序列化。

1. MessageBase

class MessageBase(SQLModelSerializable):
    is_bot: bool = Field(index=False, description='聊天角色')
    source: Optional[int] = Field(index=False, description='是否支持溯源')
    mark_status: Optional[int] = Field(index=False, default=1, description='标记状态')
    mark_user: Optional[int] = Field(index=False, description='标记用户')
    mark_user_name: Optional[str] = Field(index=False, description='标记用户')
    message: Optional[str] = Field(sa_column=Column(Text), description='聊天消息')
    extra: Optional[str] = Field(sa_column=Column(String(length=4096)), description='连接信息等')
    type: str = Field(index=False, description='消息类型')
    category: str = Field(index=False, description='消息类别, question等')
    flow_id: UUID = Field(index=True, description='对应的技能id')
    chat_id: Optional[str] = Field(index=True, description='chat_id, 前端生成')
    user_id: Optional[str] = Field(index=True, description='用户id')
    liked: Optional[int] = Field(index=False, default=0, description='用户是否喜欢 0未评价/1 喜欢/2 不喜欢')
    solved: Optional[int] = Field(index=False, default=0, description='用户是否喜欢 0未评价/1 解决/2 未解决')
    copied: Optional[int] = Field(index=False, default=0, description='用户是否复制 0:未复制 1:已复制')
    sender: Optional[str] = Field(index=False, default='', description='autogen 的发送方')
    receiver: Optional[Dict] = Field(index=False, default=None, description='autogen 的接收方')
    intermediate_steps: Optional[str] = Field(sa_column=Column(Text), description='过程日志')
    files: Optional[str] = Field(sa_column=Column(String(length=4096)), description='上传的文件等')
    remark: Optional[str] = Field(sa_column=Column(String(length=4096)),
                                  description='备注。break_answer: 中断的回复不作为history传给模型')
    create_time: Optional[datetime] = Field(
        sa_column=Column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP')))
    update_time: Optional[datetime] = Field(
        sa_column=Column(DateTime,
                         nullable=False,
                         index=True,
                         server_default=text('CURRENT_TIMESTAMP'),
                         onupdate=text('CURRENT_TIMESTAMP')))

功能概述

  • MessageBase 类是所有消息类的基类,继承自 SQLModelSerializable
  • 定义了与聊天消息相关的公共字段和属性。
  • 使用 FieldColumn 来定义数据库字段的属性和类型。

字段解析

  • is_botbool 类型,表示消息是否来自机器人。
  • sourceOptional[int],表示是否支持溯源。
  • mark_statusOptional[int],标记状态,默认值为 1
  • mark_userOptional[int],标记用户的 ID。
  • mark_user_nameOptional[str],标记用户的名称。
  • messageOptional[str],聊天消息内容,使用 Text 类型存储较长的文本。
  • extraOptional[str],额外信息,如连接信息等,限制长度为 4096 个字符。
  • typestr,消息类型。
  • categorystr,消息类别,如 question 等。
  • flow_idUUID,对应的技能 ID。
  • chat_idOptional[str],聊天会话的 ID,由前端生成。
  • user_idOptional[str],用户的 ID。
  • likedOptional[int],用户对消息的喜欢状态,0 未评价,1 喜欢,2 不喜欢。
  • solvedOptional[int],问题是否解决,0 未评价,1 解决,2 未解决。
  • copiedOptional[int],消息是否被复制,0 未复制,1 已复制。
  • senderOptional[str],发送方(针对自动生成的消息)。
  • receiverOptional[Dict],接收方(针对自动生成的消息)。
  • intermediate_stepsOptional[str],过程日志,存储消息生成过程的详细信息。
  • filesOptional[str],上传的文件信息。
  • remarkOptional[str],备注信息,如 break_answer 表示中断的回复不作为历史记录传给模型。
  • create_timeOptional[datetime],消息的创建时间,默认值为当前时间。
  • update_timeOptional[datetime],消息的更新时间,默认值为当前时间,并在更新时自动更新。

2. ChatMessage

class ChatMessage(MessageBase, table=True):
    id: Optional[int] = Field(default=None, primary_key=True)
    receiver: Optional[Dict] = Field(default=None, sa_column=Column(JSON))

功能概述

  • ChatMessage 类继承自 MessageBase,并且指定了 table=True,表示这是一个数据库表。
  • 增加了主键 id 和重新定义了 receiver 字段。

字段解析

  • idOptional[int],消息的主键,自增的整数。
  • receiverOptional[Dict],接收方信息,使用 JSON 类型存储。

3. 其他模型类

ChatMessageRead

class ChatMessageRead(MessageBase):
    id: Optional[int]
  • 用于读取消息时的数据模型,包含了 id 字段。

ChatMessageQuery

class ChatMessageQuery(BaseModel):
    id: Optional[int]
    flow_id: str
    chat_id: str
  • 用于查询消息的请求体数据模型,包含 idflow_idchat_id

ChatMessageCreate

class ChatMessageCreate(MessageBase):
    pass
  • 用于创建消息的请求体数据模型,继承自 MessageBase

4. MessageDao

class MessageDao(MessageBase):
    @classmethod
    def static_msg_liked(cls, liked: int, flow_id: str, create_time_begin: datetime,
                         create_time_end: datetime):
        base_condition = select(func.count(ChatMessage.id)).where(ChatMessage.liked == liked)

        if flow_id:
            base_condition = base_condition.where(ChatMessage.flow_id == flow_id)

        if create_time_begin and create_time_end:
            base_condition = base_condition.where(ChatMessage.create_time > create_time_begin,
                                                  ChatMessage.create_time < create_time_end)

        with session_getter() as session:
            return session.scalar(base_condition)

功能概述

  • MessageDao 类提供了对消息数据的统计功能。
  • static_msg_liked 方法用于统计指定条件下的消息数量。

方法解析

static_msg_liked
  • 参数
    • liked:用户对消息的喜欢状态(0、1、2)。
    • flow_id:技能 ID,用于过滤特定技能的消息。
    • create_time_begincreate_time_end:开始和结束时间,用于时间范围过滤。
  • 逻辑
    1. 构建基础查询条件,统计满足 liked 状态的消息数量。
    2. 如果提供了 flow_id,则添加 flow_id 的过滤条件。
    3. 如果提供了时间范围,添加 create_time 的过滤条件。
    4. 使用数据库会话执行查询,返回统计结果。

5. ChatMessageDao

class ChatMessageDao(MessageBase):
    ...

功能概述

  • ChatMessageDao 类提供了对 ChatMessage 表的常用数据库操作方法,包括查询、插入、更新、删除等。

方法列表

  1. get_latest_message_by_chatid:获取指定聊天会话的最新一条消息。
  2. get_latest_message_by_chat_ids:获取多个聊天会话的最新消息。
  3. get_messages_by_chat_id:获取指定聊天会话的消息列表。
  4. get_last_msg_by_flow_id:获取指定技能 ID 的最后一条消息。
  5. get_msg_by_chat_id:获取指定聊天会话的所有消息。
  6. get_msg_by_flow:获取指定技能 ID 的消息列表。
  7. get_msg_by_flows:获取多个技能 ID 的消息列表。
  8. delete_by_user_chat_id:根据用户 ID 和聊天会话 ID 删除消息。
  9. delete_by_message_id:根据用户 ID 和消息 ID 删除消息。
  10. insert_one:插入一条新的消息。
  11. insert_batch:批量插入消息。
  12. get_message_by_id:根据消息 ID 获取消息。
  13. update_message:更新指定消息的内容。
  14. update_message_model:更新消息对象。
  15. update_message_copied:更新消息的复制状态。
  16. update_message_mark:更新消息的标记状态。

详细方法解析

get_latest_message_by_chatid
@classmethod
def get_latest_message_by_chatid(cls, chat_id: str):
    with session_getter() as session:
        res = session.exec(
            select(ChatMessage).where(ChatMessage.chat_id == chat_id).limit(1)).all()
        if res:
            return res[0]
        else:
            return None
  • 功能:获取指定聊天会话的最新一条消息。
  • 逻辑:
    • 使用 select 查询指定 chat_id 的消息,限制结果数量为 1
    • 如果查询结果存在,返回第一条消息,否则返回 None
get_latest_message_by_chat_ids
@classmethod
def get_latest_message_by_chat_ids(cls, chat_ids: list[str], category: str = None):
    statement = select(ChatMessage.chat_id,
                       func.max(ChatMessage.id)).where(ChatMessage.chat_id.in_(chat_ids))
    if category:
        statement = statement.where(ChatMessage.category == category)
    statement = statement.group_by(ChatMessage.chat_id)
    with session_getter() as session:
        res = session.exec(statement).all()
        ids = [one[1] for one in res]
        statement = select(ChatMessage).where(ChatMessage.id.in_(ids))
        return session.exec(statement).all()
  • 功能:获取多个聊天会话的最新消息。
  • 参数:
    • chat_ids:聊天会话 ID 列表。
    • category:可选,消息类别过滤。
  • 逻辑:
    1. 构建查询,获取每个 chat_id 的最大 id,即最新的消息 ID。
    2. 如果提供了 category,添加过滤条件。
    3. 执行查询,得到最新消息的 ID 列表。
    4. 根据消息 ID 列表,查询消息的具体内容并返回。
get_messages_by_chat_id
@classmethod
def get_messages_by_chat_id(cls, chat_id: str, category_list: list = None, limit: int = 10):
    with session_getter() as session:
        statement = select(ChatMessage).where(ChatMessage.chat_id == chat_id)
        if category_list:
            statement = statement.where(ChatMessage.category.in_(category_list))
        statement = statement.limit(limit).order_by(ChatMessage.create_time.asc())
        return session.exec(statement).all()
  • 功能:获取指定聊天会话的消息列表。
  • 参数:
    • chat_id:聊天会话 ID。
    • category_list:可选,消息类别列表。
    • limit:限制返回的消息数量,默认为 10
  • 逻辑:
    1. 构建查询,筛选指定 chat_id 的消息。
    2. 如果提供了 category_list,添加类别过滤条件。
    3. 限制返回的消息数量,按照创建时间升序排序。
    4. 执行查询并返回结果。
insert_one
@classmethod
def insert_one(cls, message: ChatMessage) -> ChatMessage:
    with session_getter() as session:
        session.add(message)
        session.commit()
        session.refresh(message)
    return message
  • 功能:插入一条新的消息记录。
  • 逻辑:
    1. 使用数据库会话添加消息对象。
    2. 提交事务。
    3. 刷新消息对象,获取数据库生成的字段(如自增的 id)。
    4. 返回消息对象。
update_message
@classmethod
def update_message(cls, message_id: int, user_id: int, message: str):
    with session_getter() as session:
        statement = update(ChatMessage).where(ChatMessage.id == message_id).where(
            ChatMessage.user_id == user_id).values(message=message)
        session.exec(statement)
        session.commit()
  • 功能:更新指定消息的内容。
  • 参数:
    • message_id:消息 ID。
    • user_id:用户 ID,确保只有消息的所有者才能更新消息。
    • message:新的消息内容。
  • 逻辑:
    1. 构建更新语句,筛选 iduser_id 匹配的消息。
    2. 设置新的消息内容。
    3. 执行更新语句并提交事务。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值