from pathlib import Path
from fastapi import APIRouter, Depends, Query, HTTPException, Request
from fastapi.responses import JSONResponse
from datetime import datetime, timezone
import logging
from backend.jwt_handler import TokenData, get_current_user_token_data
from backend import database
router = APIRouter(prefix="/api/search", tags=["搜索推荐"])
logger = logging.getLogger(__name__)
@router.get("")
async def search(
request: Request,
keyword: str = Query(..., min_length=1, max_length=100),
type: str = Query("all", pattern="^(all|share|room|user)$"),
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
current_user: TokenData = Depends(get_current_user_token_data)
):
"""搜索功能(支持分享/聊天室/用户多类型搜索)"""
client_ip = request.client.host
logger.info(
f"🔍 用户发起搜索 | 用户ID:{current_user.user_id} | 关键词:{keyword} | "
f"类型:{type} | 分页:第{page}页/每页{size}条 | IP:{client_ip}"
)
try:
# 记录搜索关键词(带时间戳)
await database.add_search_record(
keyword=keyword,
user_id=current_user.user_id,
search_time=datetime.now(timezone.utc)
)
# 根据类型搜索不同内容
share_results = (0, [])
if type == "all" or type == "share":
# 分享搜索:公开+本人私有可见
share_results = await database.search_shares(
keyword=keyword,
is_public=True,
author_id=current_user.user_id,
page=page,
size=size
)
room_results = (0, [])
if type == "all" or type == "room":
# 聊天室搜索:公开+本人所属院系+已加入的
room_results = await database.search_rooms(
keyword=keyword,
user_id=current_user.user_id,
department_id=current_user.department_id,
is_admin=current_user.role == "admin",
page=page,
size=size
)
user_results = (0, [])
if type == "all" or type == "user":
# 用户搜索:支持按账号/姓名匹配,隐藏敏感信息
user_results = await database.search_users(
keyword=keyword,
current_user_id=current_user.user_id,
is_admin=current_user.role == "admin",
page=page,
size=size
)
logger.debug(
f"✅ 搜索完成 | 用户ID:{current_user.user_id} | 分享结果:{share_results[0]}条 | "
f"聊天室结果:{room_results[0]}条 | 用户结果:{user_results[0]}条"
)
return JSONResponse({
"success": True,
"data": {
"keyword": keyword,
"results": {
"shares": {
"total": share_results[0],
"items": share_results[1]
},
"rooms": {
"total": room_results[0],
"items": room_results[1]
},
"users": {
"total": user_results[0],
"items": user_results[1]
}
}
}
})
except Exception as e:
logger.error(
f"💥 搜索失败 | 用户ID:{current_user.user_id} | 关键词:{keyword} | "
f"错误:{str(e)}",
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"搜索失败:服务器内部错误({str(e)})"
)
@router.get("/hot")
async def get_hot_keywords(
request: Request,
limit: int = Query(10, ge=1, le=50),
current_user: TokenData = Depends(get_current_user_token_data)
):
"""获取热搜词(按近7天搜索量排序)"""
client_ip = request.client.host
logger.info(
f"🔥 用户获取热搜 | 用户ID:{current_user.user_id} | 数量限制:{limit} | IP:{client_ip}"
)
try:
# 获取近7天的热搜词
days_ago = datetime.now(timezone.utc) - timezone.timedelta(days=7)
hot_keywords = await database.get_hot_search_keywords(
start_time=days_ago,
limit=limit
)
# 格式化结果(添加排名)
formatted_hot = [
{
"rank": idx + 1,
"keyword": item["keyword"],
"search_count": item["search_count"],
"trend": item.get("trend", "stable") # 趋势:up/down/stable
}
for idx, item in enumerate(hot_keywords)
]
logger.debug(f"✅ 获取热搜成功 | 数量:{len(formatted_hot)} | IP:{client_ip}")
return JSONResponse({
"success": True,
"data": formatted_hot
})
except Exception as e:
logger.error(
f"💥 获取热搜失败 | 用户ID:{current_user.user_id} | 错误:{str(e)}",
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"获取热搜失败:服务器内部错误({str(e)})"
)
@router.get("/recommend/shares")
async def recommend_shares(
request: Request,
limit: int = Query(5, ge=1, le=20),
current_user: TokenData = Depends(get_current_user_token_data)
):
"""推荐分享(基于用户搜索历史和院系热门内容)"""
client_ip = request.client.host
logger.info(
f"📊 用户获取分享推荐 | 用户ID:{current_user.user_id} | 数量限制:{limit} | IP:{client_ip}"
)
try:
# 多维度推荐策略
recommendations = []
# 1. 基于用户搜索历史推荐
user_history = await database.get_user_search_history(
user_id=current_user.user_id,
limit=5,
days=14 # 只取近14天历史
)
history_recommendations = []
if user_history:
history_recommendations = await database.recommend_shares_by_keywords(
keywords=[h["keyword"] for h in user_history],
limit=limit // 2, # 分配一半额度
exclude_user_id=current_user.user_id,
department_id=current_user.department_id
)
recommendations.extend(history_recommendations)
# 2. 补充院系热门内容(当历史推荐不足时)
remaining = limit - len(recommendations)
if remaining > 0:
dept_hot = await database.get_dept_hot_shares(
dept_id=current_user.department_id,
limit=remaining,
exclude_user_id=current_user.user_id
)
recommendations.extend(dept_hot)
# 3. 最终补充全局热门(仍不足时)
remaining = limit - len(recommendations)
if remaining > 0:
global_hot = await database.get_global_hot_shares(
limit=remaining,
exclude_user_id=current_user.user_id
)
recommendations.extend(global_hot)
# 去重(避免多渠道推荐重复内容)
seen_ids = set()
unique_recommendations = []
for rec in recommendations:
if rec["id"] not in seen_ids:
seen_ids.add(rec["id"])
unique_recommendations.append(rec)
if len(unique_recommendations) >= limit:
break
logger.debug(
f"✅ 分享推荐完成 | 用户ID:{current_user.user_id} | 推荐数量:{len(unique_recommendations)} | "
f"历史推荐:{len(history_recommendations)} | IP:{client_ip}"
)
return JSONResponse({
"success": True,
"data": unique_recommendations
})
except Exception as e:
logger.error(
f"💥 分享推荐失败 | 用户ID:{current_user.user_id} | 错误:{str(e)}",
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"获取推荐失败:服务器内部错误({str(e)})"
)
@router.get("/history")
async def get_search_history(
request: Request,
limit: int = Query(10, ge=1, le=50),
current_user: TokenData = Depends(get_current_user_token_data)
):
"""获取用户搜索历史"""
client_ip = request.client.host
logger.info(
f"📜 用户获取搜索历史 | 用户ID:{current_user.user_id} | 数量限制:{limit} | IP:{client_ip}"
)
try:
history = await database.get_user_search_history(
user_id=current_user.user_id,
limit=limit,
days=30 # 保留30天内历史
)
# 格式化时间
formatted_history = [
{
"keyword": item["keyword"],
"search_time": item["search_time"].strftime("%Y-%m-%d %H:%M:%S"),
"id": item["id"]
}
for item in history
]
logger.debug(f"✅ 获取搜索历史成功 | 数量:{len(formatted_history)} | IP:{client_ip}")
return JSONResponse({
"success": True,
"data": formatted_history
})
except Exception as e:
logger.error(
f"💥 获取搜索历史失败 | 用户ID:{current_user.user_id} | 错误:{str(e)}",
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"获取搜索历史失败:服务器内部错误({str(e)})"
)
@router.delete("/history/{history_id}")
async def delete_search_history(
request: Request,
history_id: int = Path(..., ge=1),
current_user: TokenData = Depends(get_current_user_token_data)
):
"""删除单条搜索历史"""
client_ip = request.client.host
logger.info(
f"🗑️ 用户删除搜索历史 | 用户ID:{current_user.user_id} | 历史ID:{history_id} | IP:{client_ip}"
)
try:
# 校验所有权
history_item = await database.get_search_history_by_id(history_id)
if not history_item:
raise HTTPException(status_code=404, detail="搜索历史记录不存在")
if history_item["user_id"] != current_user.user_id:
raise HTTPException(status_code=403, detail="无权限删除该记录")
# 执行删除
await database.delete_search_history(history_id)
logger.debug(f"✅ 删除搜索历史成功 | 历史ID:{history_id} | IP:{client_ip}")
return JSONResponse({
"success": True,
"message": "搜索历史已删除"
})
except HTTPException:
raise
except Exception as e:
logger.error(
f"💥 删除搜索历史失败 | 用户ID:{current_user.user_id} | 历史ID:{history_id} | "
f"错误:{str(e)}",
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"删除搜索历史失败:服务器内部错误({str(e)})"
)
@router.delete("/history")
async def clear_search_history(
request: Request,
current_user: TokenData = Depends(get_current_user_token_data)
):
"""清空用户所有搜索历史"""
client_ip = request.client.host
logger.info(
f"🗑️ 用户清空搜索历史 | 用户ID:{current_user.user_id} | IP:{client_ip}"
)
try:
await database.clear_user_search_history(current_user.user_id)
logger.debug(f"✅ 清空搜索历史成功 | 用户ID:{current_user.user_id} | IP:{client_ip}")
return JSONResponse({
"success": True,
"message": "所有搜索历史已清空"
})
except Exception as e:
logger.error(
f"💥 清空搜索历史失败 | 用户ID:{current_user.user_id} | 错误:{str(e)}",
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"清空搜索历史失败:服务器内部错误({str(e)})"
)
search接口有多个数据库操作与现有的数据库函数不匹配,帮我找出来并分析哪些可以用现有的函数替代,给我完整代码
import re
from typing import Optional, Dict, List, Tuple, AsyncGenerator
from datetime import date, datetime
from fastapi import Depends
from fastapi.concurrency import asynccontextmanager
from sqlalchemy.ext.asyncio import create_async_engine,AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy import text
# -------------------------- 数据库核心配置 --------------------------
# 建议从 .env 加载(原配置硬编码,此处优化为环境变量读取)
import os
from dotenv import load_dotenv
load_dotenv()
DATABASE_URL = os.getenv(
"DATABASE_URL",
"mysql+asyncmy://root:123456@localhost/ai_roleplay?charset=utf8mb4" # 默认 fallback
)
# 异步引擎配置(优化连接池参数)
engine = create_async_engine(
DATABASE_URL,
echo=False, # 生产环境设为 False,避免日志冗余
pool_pre_ping=True, # 连接前校验,防止失效连接
pool_size=10, # 常驻连接数
max_overflow=20, # 最大临时连接数
pool_recycle=3600 # 连接超时回收(1小时)
)
# 异步 Session 工厂(线程安全)
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False, # 提交后不失效对象
autoflush=False # 关闭自动刷新,减少不必要 SQL
)
# -------------------------- 通用工具函数 --------------------------
def get_conversation_table_name(user_id: str) -> str:
"""生成用户专属对话表名(防 SQL 注入)"""
safe_id = "".join(c for c in str(user_id) if c.isalnum() or c == "_")
return f"conversations_{safe_id}"
def is_valid_table_name(table_name: str) -> bool:
"""校验表名合法性(仅允许 conversations_xxx 格式)"""
return re.match(r'^conversations_[a-zA-Z0-9_]+$', table_name) is not None
@asynccontextmanager
async def get_default_db() -> AsyncGenerator[AsyncSession,None]:
"""
自动创建默认数据库会话(上下文管理器,自动管理生命周期)
用于数据库函数的默认db参数,避免手动传参
"""
async with AsyncSessionLocal() as db: # 基于全局工厂创建独立会话
try:
yield db # 提供会话给函数使用
await db.commit() # 函数无异常则提交
except Exception as e:
await db.rollback() # 异常则回滚
raise e # 重新抛出异常,让上层处理
finally:
await db.close() # 无论成败都关闭会话
async def get_default_db_instance() -> AsyncSession:
"""
获取默认db实例(供函数默认参数使用)
本质是触发 get_default_db() 上下文管理器,返回会话对象
"""
return await anext(get_default_db()) # anext() 用于异步上下文管理器
# -------------------------- 用户基础操作 --------------------------
async def get_user_by_account(account: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""通过账号查询用户(用于注册时判重、登录校验)"""
result = await db.execute(
text("""
SELECT id, account, password, role, department_id, created_at
FROM users
WHERE account = :account
"""),
{"account": account}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_user_by_id(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""通过用户ID查询用户(用于角色修改、权限校验)"""
result = await db.execute(
text("""
SELECT id, account, role, department_id, created_at
FROM users
WHERE id = :user_id
"""),
{"user_id": user_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_user_detail(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""获取用户详情(含院系名称,用于个人中心)"""
result = await db.execute(
text("""
SELECT u.id, u.account, u.role, u.department_id, u.created_at, d.name AS dept_name
FROM users u
LEFT JOIN departments d ON u.department_id = d.id
WHERE u.id = :user_id
"""),
{"user_id": user_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
# -------------------------- 用户创建与更新 --------------------------
async def create_user(
account: str,
password: str,
role: str = "user",
department_id: Optional[int] = None,
db: AsyncSession = Depends(get_default_db_instance)
) -> Dict:
"""创建用户(注册专用),并自动创建专属对话表"""
# 1. 插入用户记录
result = await db.execute(
text("""
INSERT INTO users (account, password, role, department_id, created_at)
VALUES (:account, :password, :role, :dept_id, NOW())
"""),
{
"account": account,
"password": password,
"role": role,
"dept_id": department_id
}
)
user_id = result.lastrowid # 获取自增ID
# 2. 创建用户专属对话表(关联 AI 角色)
table_name = get_conversation_table_name(user_id)
if not is_valid_table_name(table_name):
raise ValueError(f"Invalid user ID for conversation table: {user_id}")
await db.execute(text(f"""
CREATE TABLE IF NOT EXISTS `{table_name}` (
id INT AUTO_INCREMENT PRIMARY KEY,
character_id INT NOT NULL, # 关联 AI 角色表
user_message TEXT NOT NULL, # 用户消息
ai_message TEXT NOT NULL, # AI 回复
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (character_id) REFERENCES characters(id) ON DELETE CASCADE
) ENGINE=InnoDB CHARSET=utf8mb4;
"""))
# 3. 返回创建的用户信息
return await get_user_by_id(db, user_id)
async def update_user(
user_id: str,
update_params: Dict, # 支持更新:password、role、department_id
db: AsyncSession = Depends(get_default_db_instance)
) -> None:
"""更新用户信息(动态拼接 SQL,避免冗余)"""
if not update_params:
return # 无参数则不执行
# 动态生成更新字段(防注入:仅允许指定字段)
allowed_fields = ["password", "role", "department_id"]
set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields])
if not set_clause:
return
# 补充用户ID参数
params = {**update_params, "user_id": user_id}
await db.execute(
text(f"UPDATE users SET {set_clause} WHERE id = :user_id"),
params
)
# -------------------------- 用户列表与统计 --------------------------
async def get_users_list(
page: int = 1,
size: int = 10,
role: Optional[str] = None,
dept_id: Optional[int] = None,
db: AsyncSession = Depends(get_default_db_instance)
) -> Tuple[int, List[Dict]]:
"""分页查询用户列表(管理员专用,支持角色/院系筛选)"""
# 1. 构建筛选条件
where_clause = []
params = {"offset": (page - 1) * size, "limit": size}
if role:
where_clause.append("role = :role")
params["role"] = role
if dept_id is not None:
where_clause.append("department_id = :dept_id")
params["dept_id"] = dept_id
where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else ""
# 2. 查询总数(用于分页)
total_result = await db.execute(
text(f"SELECT COUNT(*) AS total FROM users {where_sql}"),
params
)
total = total_result.scalar()
# 3. 查询分页数据
data_result = await db.execute(
text(f"""
SELECT id, account, role, department_id, created_at
FROM users
{where_sql}
ORDER BY created_at DESC
LIMIT :offset, :limit
"""),
params
)
users = [dict(row._mapping) for row in data_result.fetchall()]
return total, users
async def get_user_count_by_dept(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> int:
"""统计指定院系的用户数(删除院系前校验)"""
result = await db.execute(
text("SELECT COUNT(*) FROM users WHERE department_id = :dept_id"),
{"dept_id": dept_id}
)
return result.scalar()
# -------------------------- 原登录校验修正 --------------------------
async def check_users(account: str, password: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Tuple[str, str]]:
"""仅校验用户账号密码(原逻辑拆分,插入用户移至 create_user)"""
result = await db.execute(
text("SELECT id, password FROM users WHERE account = :account"),
{"account": account}
)
row = result.fetchone()
return (str(row.id), row.password) if row else None
# -------------------------- 院系基础操作 --------------------------
async def create_department(
name: str,
description: Optional[str] = None,
db: AsyncSession = Depends(get_default_db_instance)
) -> Dict:
"""创建院系(管理员专用)"""
result = await db.execute(
text("""
INSERT INTO departments (name, description, created_at)
VALUES (:name, :desc, NOW())
"""),
{"name": name, "desc": description}
)
dept_id = result.lastrowid
return await get_department_by_id(db, dept_id)
async def get_department_by_id(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""通过ID查询院系"""
result = await db.execute(
text("SELECT id, name, description, created_at FROM departments WHERE id = :dept_id"),
{"dept_id": dept_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_department_by_name(name: str, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""通过名称查询院系(创建时判重)"""
result = await db.execute(
text("SELECT id, name FROM departments WHERE name = :name"),
{"name": name}
)
row = result.fetchone()
return dict(row._mapping) if row else None
# -------------------------- 院系列表与统计 --------------------------
async def get_departments_with_user_count(db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]:
"""获取所有院系(含用户数统计)"""
result = await db.execute(
text("""
SELECT
d.id, d.name, d.description, d.created_at,
COUNT(u.id) AS user_count
FROM departments d
LEFT JOIN users u ON d.id = u.department_id
GROUP BY d.id
ORDER BY d.created_at DESC
""")
)
return [dict(row._mapping) for row in result.fetchall()]
# -------------------------- 院系更新与删除 --------------------------
async def update_department(
dept_id: int,
update_params: Dict, # 支持更新:name、description
db: AsyncSession = Depends(get_default_db_instance)
) -> None:
"""更新院系信息(管理员专用)"""
allowed_fields = ["name", "description"]
set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields])
if not set_clause:
return
params = {**update_params, "dept_id": dept_id}
await db.execute(
text(f"UPDATE departments SET {set_clause} WHERE id = :dept_id"),
params
)
async def delete_department(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None:
"""删除院系(需先确保无用户关联)"""
await db.execute(
text("DELETE FROM departments WHERE id = :dept_id"),
{"dept_id": dept_id}
)
# -------------------------- 院系专属资源查询 --------------------------
async def get_dept_exclusive_rooms(dept_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]:
"""获取院系专属聊天室(仅本院系用户可见)"""
result = await db.execute(
text("""
SELECT id, name, description, creator_id, created_at
FROM rooms
WHERE type = 'dept' AND dept_id = :dept_id
ORDER BY created_at DESC
"""),
{"dept_id": dept_id}
)
return [dict(row._mapping) for row in result.fetchall()]
async def get_dept_exclusive_shares(
dept_id: int,
page: int = 1,
size: int = 10,
db: AsyncSession = Depends(get_default_db_instance)
) -> Tuple[int, List[Dict]]:
"""获取院系专属分享(分页,仅本院系用户可见)"""
# 1. 统计总数
total_result = await db.execute(
text("""
SELECT COUNT(*) AS total
FROM shares s
JOIN users u ON s.author_id = u.id
WHERE s.type = 'dept' AND u.department_id = :dept_id
"""),
{"dept_id": dept_id}
)
total = total_result.scalar()
# 2. 查询分页数据
data_result = await db.execute(
text("""
SELECT s.*, u.account AS author_account
FROM shares s
JOIN users u ON s.author_id = u.id
WHERE s.type = 'dept' AND u.department_id = :dept_id
ORDER BY s.created_at DESC
LIMIT :offset, :limit
"""),
{
"dept_id": dept_id,
"offset": (page - 1) * size,
"limit": size
}
)
shares = [dict(row._mapping) for row in data_result.fetchall()]
return total, shares
# -------------------------- AI角色操作 --------------------------
async def get_all_characters(db: AsyncSession = Depends(get_default_db_instance)) -> List[Dict]:
"""获取所有AI角色(用于聊天页面角色选择)"""
result = await db.execute(
text("SELECT id, name, trait, avatar_url FROM characters ORDER BY name ASC")
)
return [dict(row._mapping) for row in result.fetchall()]
async def get_character_by_id(character_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""通过ID查询AI角色(聊天时获取角色设定)"""
result = await db.execute(
text("SELECT id, name, trait FROM characters WHERE id = :character_id"),
{"character_id": character_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
# -------------------------- 对话历史操作 --------------------------
async def save_conversation(
user_id: int,
character_id: int,
user_message: str,
ai_message: str,
db: AsyncSession = Depends(get_default_db_instance)
) -> None:
"""保存用户与AI的对话(聊天后存储)"""
table_name = get_conversation_table_name(user_id)
if not is_valid_table_name(table_name):
raise ValueError(f"Invalid user ID: {user_id}")
await db.execute(
text(f"""
INSERT INTO `{table_name}`
(character_id, user_message, ai_message)
VALUES (:char_id, :user_msg, :ai_msg)
"""),
{
"char_id": character_id,
"user_msg": user_message,
"ai_msg": ai_message
}
)
async def load_conversation_history(
user_id: str,
character_id: Optional[int] = None,
max_count: int = 10,
db: AsyncSession = Depends(get_default_db_instance)
) -> List[Dict]:
"""加载用户对话历史(支持按AI角色筛选)"""
table_name = get_conversation_table_name(user_id)
if not is_valid_table_name(table_name):
return [] # 表名无效则返回空历史
# 构建筛选条件(可选按角色筛选)
where_clause = "WHERE character_id = :char_id" if character_id else ""
params = {"limit": max_count}
if character_id:
params["char_id"] = character_id
# 查询最近的 max_count 条历史(时间正序)
result = await db.execute(
text(f"""
SELECT user_message, ai_message, timestamp
FROM `{table_name}`
{where_clause}
ORDER BY timestamp DESC
LIMIT :limit
"""),
params
)
rows = result.fetchall()
# 转换为 [{user: ..., ai: ...}, ...] 格式,按时间正序排列
history = [
{
"user": row.user_message,
"ai": row.ai_message,
"time": row.timestamp.strftime("%Y-%m-%d %H:%M:%S")
}
for row in reversed(rows) # 反转后变为时间正序
]
return history
# -------------------------- 用户个性化设定 --------------------------
async def get_user_profile(user_id: str,db:AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""获取用户个性化设定(如自定义角色配置)"""
result = await db.execute(
text("""
SELECT personality, role_setting
FROM user_profiles
WHERE user_id = :user_id
"""),
{"user_id": user_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def create_or_update_user_profile(
user_id: str,
personality: str,
role_setting: str,
db: AsyncSession = Depends(get_default_db_instance)
) -> bool:
"""创建/更新用户个性化设定(存在则更新,不存在则创建)"""
await db.execute(
text("""
INSERT INTO user_profiles (user_id, personality, role_setting, updated_at)
VALUES (:user_id, :personality, :role_setting, NOW())
ON DUPLICATE KEY UPDATE
personality = VALUES(personality),
role_setting = VALUES(role_setting),
updated_at = NOW()
"""),
{
"user_id": user_id,
"personality": personality.strip(),
"role_setting": role_setting.strip()
}
)
return True
# -------------------------- 聊天室基础操作 --------------------------
async def create_room(
name: str,
type: str, # 类型:public(公开)、dept(院系)、ai(AI专属)
creator_id: str,
dept_id: Optional[int] = None,
ai_character_id: Optional[int] = None,
description: Optional[str] = None,
db: AsyncSession = Depends(get_default_db_instance)
) -> Dict:
"""创建聊天室(支持三种类型)"""
result = await db.execute(
text("""
INSERT INTO rooms (
name, type, dept_id, ai_character_id,
description, creator_id, created_at
)
VALUES (:name, :type, :dept_id, :ai_char_id, :desc, :creator_id, NOW())
"""),
{
"name": name,
"type": type,
"dept_id": dept_id,
"ai_char_id": ai_character_id,
"desc": description,
"creator_id": creator_id
}
)
room_id = result.lastrowid
return await get_room_by_id(db, room_id)
async def get_room_by_id(room_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""通过ID查询聊天室详情"""
result = await db.execute(
text("""
SELECT
r.*, u.account AS creator_account,
c.name AS ai_char_name # 关联AI角色名称(若为AI专属)
FROM rooms r
JOIN users u ON r.creator_id = u.id
LEFT JOIN characters c ON r.ai_character_id = c.id
WHERE r.id = :room_id
"""),
{"room_id": room_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
# -------------------------- 聊天室列表查询 --------------------------
async def get_rooms(
type: Optional[str] = None,
dept_id: Optional[int] = None,
page: int = 1,
size: int = 10,
db: AsyncSession = Depends(get_default_db_instance)
) -> Tuple[int, List[Dict]]:
"""分页查询聊天室(支持按类型、院系筛选)"""
# 1. 构建筛选条件
where_clause = []
params = {"offset": (page - 1) * size, "limit": size}
if type:
where_clause.append("r.type = :type")
params["type"] = type
if dept_id is not None:
where_clause.append("r.dept_id = :dept_id")
params["dept_id"] = dept_id
where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else ""
# 2. 统计总数
total_result = await db.execute(
text(f"""
SELECT COUNT(*) AS total
FROM rooms r
{where_sql}
"""),
params
)
total = total_result.scalar()
# 3. 查询分页数据
data_result = await db.execute(
text(f"""
SELECT
r.id, r.name, r.type, r.description,
r.creator_id, u.account AS creator_account,
r.created_at, COUNT(rm.user_id) AS member_count
FROM rooms r
JOIN users u ON r.creator_id = u.id
LEFT JOIN room_members rm ON r.id = rm.room_id
{where_sql}
GROUP BY r.id
ORDER BY r.created_at DESC
LIMIT :offset, :limit
"""),
params
)
rooms = [dict(row._mapping) for row in data_result.fetchall()]
return total, rooms
# -------------------------- 聊天室成员管理 --------------------------
async def check_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> bool:
"""校验用户是否为聊天室成员"""
result = await db.execute(
text("""
SELECT 1 FROM room_members
WHERE room_id = :room_id AND user_id = :user_id
"""),
{"room_id": room_id, "user_id": user_id}
)
return result.scalar() is not None
async def add_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None:
"""添加用户到聊天室成员(加入聊天室)"""
await db.execute(
text("""
INSERT IGNORE INTO room_members (room_id, user_id, joined_at)
VALUES (:room_id, :user_id, NOW())
"""),
{"room_id": room_id, "user_id": user_id}
)
async def remove_room_member(room_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None:
"""从聊天室移除用户(离开聊天室)"""
await db.execute(
text("""
DELETE FROM room_members
WHERE room_id = :room_id AND user_id = :user_id
"""),
{"room_id": room_id, "user_id": user_id}
)
# -------------------------- 聊天室消息操作 --------------------------
async def create_room_message(
room_id: int,
sender_id: str,
content: str,
sent_at: Optional[datetime] = None,
db: AsyncSession = Depends(get_default_db_instance)
) -> Dict:
"""发送聊天室消息(成员专用)"""
sent_at = sent_at or datetime.now()
result = await db.execute(
text("""
INSERT INTO room_messages (
room_id, sender_id, content, sent_at
)
VALUES (:room_id, :sender_id, :content, :sent_at)
"""),
{
"room_id": room_id,
"sender_id": sender_id,
"content": content,
"sent_at": sent_at
}
)
msg_id = result.lastrowid
# 返回消息详情(含发送者账号)
msg_result = await db.execute(
text("""
SELECT
rm.id, rm.content, rm.sent_at,
u.account AS sender_account
FROM room_messages rm
JOIN users u ON rm.sender_id = u.id
WHERE rm.id = :msg_id
"""),
{"msg_id": msg_id}
)
return dict(msg_result.fetchone()._mapping)
async def get_room_messages(
room_id: int,
page: int = 1,
size: int = 20,
order_by: str = "sent_at DESC",
db: AsyncSession = Depends(get_default_db_instance)
) -> Tuple[int, List[Dict]]:
"""分页获取聊天室历史消息(支持排序)"""
# 1. 统计总数
total_result = await db.execute(
text("SELECT COUNT(*) AS total FROM room_messages WHERE room_id = :room_id"),
{"room_id": room_id}
)
total = total_result.scalar()
# 2. 查询分页数据(防排序注入:仅允许指定排序字段)
valid_order = ["sent_at ASC", "sent_at DESC"]
order_sql = order_by if order_by in valid_order else "sent_at DESC"
data_result = await db.execute(
text(f"""
SELECT
rm.id, rm.content, rm.sent_at,
u.id AS sender_id, u.account AS sender_account
FROM room_messages rm
JOIN users u ON rm.sender_id = u.id
WHERE rm.room_id = :room_id
ORDER BY {order_sql}
LIMIT :offset, :limit
"""),
{
"room_id": room_id,
"offset": (page - 1) * size,
"limit": size
}
)
messages = [dict(row._mapping) for row in data_result.fetchall()]
return total, messages
# -------------------------- 聊天室删除 --------------------------
async def delete_room(room_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None:
"""删除聊天室(级联删除成员和消息)"""
# 1. 删除成员关联
await db.execute(
text("DELETE FROM room_members WHERE room_id = :room_id"),
{"room_id": room_id}
)
# 2. 删除消息
await db.execute(
text("DELETE FROM room_messages WHERE room_id = :room_id"),
{"room_id": room_id}
)
# 3. 删除聊天室本身
await db.execute(
text("DELETE FROM rooms WHERE id = :room_id"),
{"room_id": room_id}
)
# -------------------------- 分享基础操作 --------------------------
async def create_share(
title: str,
content: str,
author_id: str,
is_public: bool = True,
type: str = "public", # 类型:public(公开)、private(私有)、dept(院系)
ai_character_id: Optional[int] = None,
created_at: Optional[datetime] = None,
db: AsyncSession = Depends(get_default_db_instance)
) -> Dict:
"""发布分享(支持三种类型)"""
created_at = created_at or datetime.now()
result = await db.execute(
text("""
INSERT INTO shares (
title, content, author_id, is_public, type,
ai_character_id, view_count, like_count, comment_count,
created_at
)
VALUES (
:title, :content, :author_id, :is_public, :type,
:ai_char_id, 0, 0, 0, :created_at
)
"""),
{
"title": title,
"content": content,
"author_id": author_id,
"is_public": is_public,
"type": type,
"ai_char_id": ai_character_id,
"created_at": created_at
}
)
share_id = result.lastrowid
return await get_share_by_id(db, share_id)
async def get_share_by_id(share_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""通过ID查询分享详情(含作者信息)"""
result = await db.execute(
text("""
SELECT
s.*, u.account AS author_account, u.department_id,
c.name AS ai_char_name # 关联AI角色名称
FROM shares s
JOIN users u ON s.author_id = u.id
LEFT JOIN characters c ON s.ai_character_id = c.id
WHERE s.id = :share_id
"""),
{"share_id": share_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
# -------------------------- 分享列表查询 --------------------------
async def get_shares(
is_public: Optional[bool] = None,
author_id: Optional[str] = None,
type: Optional[str] = None,
order_by: str = "created_at DESC",
page: int = 1,
size: int = 10,
db: AsyncSession = Depends(get_default_db_instance)
) -> Tuple[int, List[Dict]]:
"""分页查询分享(支持按公开性、作者、类型筛选)"""
# 1. 构建筛选条件
where_clause = []
params = {"offset": (page - 1) * size, "limit": size}
if is_public is not None:
where_clause.append("s.is_public = :is_public")
params["is_public"] = is_public
if author_id:
where_clause.append("s.author_id = :author_id")
params["author_id"] = author_id
if type:
where_clause.append("s.type = :type")
params["type"] = type
where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else ""
# 2. 统计总数
total_result = await db.execute(
text(f"SELECT COUNT(*) AS total FROM shares s {where_sql}"),
params
)
total = total_result.scalar()
# 3. 查询分页数据(防排序注入)
valid_order = ["created_at DESC", "created_at ASC", "like_count DESC", "view_count DESC"]
order_sql = order_by if order_by in valid_order else "created_at DESC"
data_result = await db.execute(
text(f"""
SELECT
s.*, u.account AS author_account,
c.name AS ai_char_name
FROM shares s
JOIN users u ON s.author_id = u.id
LEFT JOIN characters c ON s.ai_character_id = c.id
{where_sql}
ORDER BY {order_sql}
LIMIT :offset, :limit
"""),
params
)
shares = [dict(row._mapping) for row in data_result.fetchall()]
return total, shares
# -------------------------- 分享更新与删除 --------------------------
async def update_share(
share_id: int,
update_params: Dict, # 支持更新:title、content、is_public、type、ai_character_id、view_count等
db: AsyncSession = Depends(get_default_db_instance)
) -> None:
"""更新分享信息(作者专用)"""
allowed_fields = [
"title", "content", "is_public", "type",
"ai_character_id", "view_count", "like_count", "comment_count"
]
set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields])
if not set_clause:
return
params = {**update_params, "share_id": share_id}
await db.execute(
text(f"UPDATE shares SET {set_clause} WHERE id = :share_id"),
params
)
async def delete_share(share_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> None:
"""删除分享(级联删除评论和点赞)"""
# 1. 删除点赞关联
await db.execute(
text("DELETE FROM share_likes WHERE share_id = :share_id"),
{"share_id": share_id}
)
# 2. 删除评论
await db.execute(
text("DELETE FROM comments WHERE share_id = :share_id"),
{"share_id": share_id}
)
# 3. 删除分享本身
await db.execute(
text("DELETE FROM shares WHERE id = :share_id"),
{"share_id": share_id}
)
# -------------------------- 评论操作 --------------------------
async def create_comment(
share_id: int,
commenter_id: str,
content: str,
parent_id: Optional[int] = None,
created_at: Optional[datetime] = None,
db: AsyncSession = Depends(get_default_db_instance)
) -> Dict:
"""发表评论(支持回复父评论)"""
created_at = created_at or datetime.now()
result = await db.execute(
text("""
INSERT INTO comments (
share_id, commenter_id, parent_id, content, created_at
)
VALUES (:share_id, :commenter_id, :parent_id, :content, :created_at)
"""),
{
"share_id": share_id,
"commenter_id": commenter_id,
"parent_id": parent_id,
"content": content,
"created_at": created_at
}
)
comment_id = result.lastrowid
# 返回评论详情(含评论者账号)
comm_result = await db.execute(
text("""
SELECT
c.id, c.content, c.parent_id, c.created_at,
u.account AS commenter_account
FROM comments c
JOIN users u ON c.commenter_id = u.id
WHERE c.id = :comment_id
"""),
{"comment_id": comment_id}
)
return dict(comm_result.fetchone()._mapping)
async def get_share_comments(
share_id: int,
page: int = 1,
size: int = 20,
order_by: str = "created_at DESC",
db: AsyncSession = Depends(get_default_db_instance)
) -> Tuple[int, List[Dict]]:
"""分页获取分享的评论(含子评论层级)"""
# 1. 统计总数
total_result = await db.execute(
text("SELECT COUNT(*) AS total FROM comments WHERE share_id = :share_id"),
{"share_id": share_id}
)
total = total_result.scalar()
# 2. 查询分页数据(先查父评论,再关联子评论)
valid_order = ["created_at ASC", "created_at DESC"]
order_sql = order_by if order_by in valid_order else "created_at DESC"
# 第一步:查询父评论(parent_id IS NULL)
parent_result = await db.execute(
text(f"""
SELECT
c.id, c.content, c.created_at,
u.id AS commenter_id, u.account AS commenter_account
FROM comments c
JOIN users u ON c.commenter_id = u.id
WHERE c.share_id = :share_id AND c.parent_id IS NULL
ORDER BY {order_sql}
LIMIT :offset, :limit
"""),
{
"share_id": share_id,
"offset": (page - 1) * size,
"limit": size
}
)
parent_comments = [dict(row._mapping) for row in parent_result.fetchall()]
parent_ids = [comm["id"] for comm in parent_comments]
# 第二步:查询所有子评论(parent_id 在父评论ID列表中)
child_comments = []
if parent_ids:
child_result = await db.execute(
text(f"""
SELECT
c.id, c.content, c.parent_id, c.created_at,
u.id AS commenter_id, u.account AS commenter_account
FROM comments c
JOIN users u ON c.commenter_id = u.id
WHERE c.share_id = :share_id AND c.parent_id IN :parent_ids
ORDER BY {order_sql}
"""),
{
"share_id": share_id,
"parent_ids": tuple(parent_ids)
}
)
child_comments = [dict(row._mapping) for row in child_result.fetchall()]
# 第三步:构建父子评论层级
child_map = {}
for child in child_comments:
parent_id = child["parent_id"]
if parent_id not in child_map:
child_map[parent_id] = []
child_map[parent_id].append(child)
# 给父评论添加子评论列表
for comm in parent_comments:
comm["children"] = child_map.get(comm["id"], [])
return total, parent_comments
async def get_comment_by_id(comment_id: int, db: AsyncSession = Depends(get_default_db_instance)) -> Optional[Dict]:
"""通过ID查询评论(校验父评论是否存在)"""
result = await db.execute(
text("""
SELECT id, share_id, commenter_id, parent_id
FROM comments WHERE id = :comment_id
"""),
{"comment_id": comment_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
# -------------------------- 点赞操作 --------------------------
async def check_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> bool:
"""校验用户是否已点赞该分享"""
result = await db.execute(
text("""
SELECT 1 FROM share_likes
WHERE share_id = :share_id AND user_id = :user_id
"""),
{"share_id": share_id, "user_id": user_id}
)
return result.scalar() is not None
async def add_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None:
"""给分享点赞"""
await db.execute(
text("""
INSERT IGNORE INTO share_likes (share_id, user_id, liked_at)
VALUES (:share_id, :user_id, NOW())
"""),
{"share_id": share_id, "user_id": user_id}
)
async def remove_share_like(share_id: int, user_id: str, db: AsyncSession = Depends(get_default_db_instance)) -> None:
"""取消分享点赞"""
await db.execute(
text("""
DELETE FROM share_likes
WHERE share_id = :share_id AND user_id = :user_id
"""),
{"share_id": share_id, "user_id": user_id}
)
# -------------------------- 搜索记录操作 --------------------------
async def add_search_record(
keyword: str,
user_id: Optional[str] = None,
search_time: Optional[datetime] = None,
db: AsyncSession = Depends(get_default_db_instance)
) -> None:
"""记录用户搜索行为(用于热搜统计)"""
search_time = search_time or datetime.now()
await db.execute(
text("""
INSERT INTO search_records (keyword, user_id, search_time)
VALUES (:keyword, :user_id, :search_time)
"""),
{
"keyword": keyword.strip(),
"user_id": user_id,
"search_time": search_time
}
)
async def get_hot_searches(
date: Optional[date] = None,
limit: int = 10,
db: AsyncSession = Depends(get_default_db_instance)
) -> List[Dict]:
"""获取热搜词TOP(默认今日,按搜索次数排序)"""
date = date or datetime.now().date()
result = await db.execute(
text("""
SELECT
keyword, COUNT(*) AS search_count
FROM search_records
WHERE DATE(search_time) = :date
GROUP BY keyword
ORDER BY search_count DESC
LIMIT :limit
"""),
{"date": date, "limit": limit}
)
return [dict(row._mapping) for row in result.fetchall()]
async def search_shares(
keyword: str,
is_public: bool = True,
author_id: Optional[str] = None,
page: int = 1,
size: int = 10,
db: AsyncSession = Depends(get_default_db_instance)
) -> Tuple[int, List[Dict]]:
"""搜索分享(关键词匹配标题/内容)"""
# 构建模糊查询参数
like_keyword = f"%{keyword}%"
params = {
"keyword": like_keyword,
"is_public": is_public,
"offset": (page - 1) * size,
"limit": size
}
if author_id:
params["author_id"] = author_id
author_clause = "AND s.author_id = :author_id"
else:
author_clause = ""
# 1. 统计总数
total_result = await db.execute(
text(f"""
SELECT COUNT(*) AS total
FROM shares s
WHERE s.is_public = :is_public
AND (s.title LIKE :keyword OR s.content LIKE :keyword)
{author_clause}
"""),
params
)
total = total_result.scalar()
# 2. 查询分页数据
data_result = await db.execute(
text(f"""
SELECT
s.*, u.account AS author_account,
c.name AS ai_char_name
FROM shares s
JOIN users u ON s.author_id = u.id
LEFT JOIN characters c ON s.ai_character_id = c.id
WHERE s.is_public = :is_public
AND (s.title LIKE :keyword OR s.content LIKE :keyword)
{author_clause}
ORDER BY s.created_at DESC
LIMIT :offset, :limit
"""),
params
)
shares = [dict(row._mapping) for row in data_result.fetchall()]
return total, shares
# -------------------------- 管理员统计操作 --------------------------
async def get_user_stats(
start_date: date,
end_date: date,
db: AsyncSession = Depends(get_default_db_instance)
) -> Dict:
"""用户统计(总数、新增数、角色分布)"""
# 1. 总用户数
total_result = await db.execute(text("SELECT COUNT(*) AS total FROM users"))
total = total_result.scalar()
# 2. 时间范围内新增用户数
new_result = await db.execute(
text("""
SELECT COUNT(*) AS new_count
FROM users
WHERE DATE(created_at) BETWEEN :start AND :end
"""),
{"start": start_date, "end": end_date}
)
new_count = new_result.scalar()
# 3. 角色分布
role_result = await db.execute(
text("""
SELECT role, COUNT(*) AS count
FROM users
GROUP BY role
""")
)
role_dist = [dict(row._mapping) for row in role_result.fetchall()]
# 4. 院系分布(前10)
dept_result = await db.execute(
text("""
SELECT
d.name AS dept_name, COUNT(u.id) AS user_count
FROM departments d
LEFT JOIN users u ON d.id = u.department_id
GROUP BY d.id
ORDER BY user_count DESC
LIMIT 10
""")
)
dept_dist = [dict(row._mapping) for row in dept_result.fetchall()]
return {
"total_user": total,
"new_user": new_count,
"role_distribution": role_dist,
"department_distribution": dept_dist
}
async def get_share_stats(
start_date: date,
end_date: date,
db: AsyncSession = Depends(get_default_db_instance)
) -> Dict:
"""分享统计(总数、新增数、类型分布、互动统计)"""
# 1. 总分享数
total_result = await db.execute(text("SELECT COUNT(*) AS total FROM shares"))
total = total_result.scalar()
# 2. 时间范围内新增分享数
new_result = await db.execute(
text("""
SELECT COUNT(*) AS new_count
FROM shares
WHERE DATE(created_at) BETWEEN :start AND :end
"""),
{"start": start_date, "end": end_date}
)
new_count = new_result.scalar()
# 3. 分享类型分布
type_result = await db.execute(
text("""
SELECT type, COUNT(*) AS count
FROM shares
GROUP BY type
""")
)
type_dist = [dict(row._mapping) for row in type_result.fetchall()]
# 4. AI角色关联分布(前10)
ai_result = await db.execute(
text("""
SELECT
c.name AS ai_char_name, COUNT(s.id) AS share_count
FROM characters c
LEFT JOIN shares s ON c.id = s.ai_character_id
WHERE s.ai_character_id IS NOT NULL
GROUP BY c.id
ORDER BY share_count DESC
LIMIT 10
""")
)
ai_dist = [dict(row._mapping) for row in ai_result.fetchall()]
# 5. 总互动数(点赞+评论)
interact_result = await db.execute(
text("""
SELECT
SUM(like_count) AS total_like,
SUM(comment_count) AS total_comment
FROM shares
""")
)
interact = dict(interact_result.fetchone()._mapping)
return {
"total_share": total,
"new_share": new_count,
"type_distribution": type_dist,
"ai_character_distribution": ai_dist,
"total_interaction": interact
}