# database.py
import re
from typing import Optional, Dict, List, Tuple, AsyncGenerator
from datetime import date, datetime, timedelta, timezone
from fastapi import logger
from fastapi.concurrency import asynccontextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy import text
import os
from dotenv import load_dotenv
load_dotenv()
# -------------------------- 数据库核心配置 --------------------------
DATABASE_URL = os.getenv(
"DATABASE_URL",
"mysql+asyncmy://root:123456@localhost/ai_roleplay?charset=utf8mb4"
)
engine = create_async_engine(
DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
pool_recycle=3600
)
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False
)
# -------------------------- 通用工具函数 --------------------------
def get_conversation_table_name(user_id: str) -> str:
safe_id = "".join(c for c in str(user_id) if c.isalnum() or c == "_")
return f"conversations_{safe_id}"
def is_valid_table_name(table_name: str) -> bool:
return re.match(r'^conversations_[a-zA-Z0-9_]+$', table_name) is not None
@asynccontextmanager
async def get_default_db() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSessionLocal() as db:
try:
yield db
await db.commit()
except Exception as e:
await db.rollback()
raise e
finally:
await db.close()
# -------------------------- 用户基础操作 --------------------------
async def get_user_by_account(account: str) -> Optional[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT id, account, password, role, department_id, created_at
FROM users
WHERE account = :account
"""),
{"account": account}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_user_by_id(user_id: str) -> Optional[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT id, account, role, department_id, created_at
FROM users
WHERE id = :user_id
"""),
{"user_id": user_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_user_detail(user_id: str) -> Optional[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT u.id, u.account, u.role, u.department_id, u.created_at, d.name AS dept_name
FROM users u
LEFT JOIN departments d ON u.department_id = d.id
WHERE u.id = :user_id
"""),
{"user_id": user_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
# -------------------------- 用户创建与更新 --------------------------
async def create_user(
account: str,
password: str,
role: str = "user",
department_id: Optional[int] = None,
) -> Dict:
async with get_default_db() as db:
# 插入用户
result = await db.execute(
text("""
INSERT INTO users (account, password, role, department_id, created_at)
VALUES (:account, :password, :role, :dept_id, NOW())
"""),
{
"account": account,
"password": password,
"role": role,
"dept_id": department_id
}
)
user_id = result.lastrowid
# 创建专属对话表
table_name = get_conversation_table_name(user_id)
if not is_valid_table_name(table_name):
raise ValueError(f"Invalid user ID: {user_id}")
await db.execute(text(f"""
CREATE TABLE IF NOT EXISTS `{table_name}` (
id INT AUTO_INCREMENT PRIMARY KEY,
character_id INT NOT NULL,
user_message TEXT NOT NULL,
ai_message TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (character_id) REFERENCES characters(id) ON DELETE CASCADE
) ENGINE=InnoDB CHARSET=utf8mb4;
"""))
return await get_user_by_id(str(user_id))
async def update_user(
user_id: str,
update_params: Dict,
) -> None:
async with get_default_db() as db:
if not update_params:
return
allowed_fields = ["password", "role", "department_id"]
set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields])
if not set_clause:
return
params = {**update_params, "user_id": user_id}
await db.execute(
text(f"UPDATE users SET {set_clause} WHERE id = :user_id"),
params
)
# -------------------------- 用户列表与统计 --------------------------
async def get_users_list(
page: int = 1,
size: int = 10,
role: Optional[str] = None,
dept_id: Optional[int] = None,
) -> Tuple[int, List[Dict]]:
async with get_default_db() as db:
where_clause = []
params = {"offset": (page - 1) * size, "limit": size}
if role:
where_clause.append("role = :role")
params["role"] = role
if dept_id is not None:
where_clause.append("department_id = :dept_id")
params["dept_id"] = dept_id
where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else ""
total_result = await db.execute(
text(f"SELECT COUNT(*) AS total FROM users {where_sql}"),
params
)
total = total_result.scalar()
data_result = await db.execute(
text(f"""
SELECT id, account, role, department_id, created_at
FROM users {where_sql}
ORDER BY created_at DESC
LIMIT :offset, :limit
"""),
params
)
users = [dict(row._mapping) for row in data_result.fetchall()]
return total, users
async def get_user_count_by_dept(dept_id: int) -> int:
async with get_default_db() as db:
result = await db.execute(
text("SELECT COUNT(*) FROM users WHERE department_id = :dept_id"),
{"dept_id": dept_id}
)
return result.scalar()
# -------------------------- 登录校验 --------------------------
async def check_users(account: str, password: str) -> Optional[Tuple[str, str]]:
async with get_default_db() as db:
result = await db.execute(
text("SELECT id, password FROM users WHERE account = :account"),
{"account": account}
)
row = result.fetchone()
return (str(row.id), row.password) if row else None
# -------------------------- 院系操作 --------------------------
async def create_department(name: str, description: Optional[str] = None) -> Dict:
async with get_default_db() as db:
result = await db.execute(
text("""
INSERT INTO departments (name, description, created_at)
VALUES (:name, :desc, NOW())
"""),
{"name": name, "desc": description}
)
dept_id = result.lastrowid
return await get_department_by_id(dept_id)
async def get_department_by_id(dept_id: int) -> Optional[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("SELECT id, name, description, created_at FROM departments WHERE id = :dept_id"),
{"dept_id": dept_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_department_by_name(name: str) -> Optional[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("SELECT id, name FROM departments WHERE name = :name"),
{"name": name}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_departments_with_user_count() -> List[Dict]:
async with get_default_db() as db:
result = await db.execute(text("""
SELECT
d.id, d.name, d.description, d.created_at,
COUNT(u.id) AS user_count
FROM departments d
LEFT JOIN users u ON d.id = u.department_id
GROUP BY d.id
ORDER BY d.created_at DESC
"""))
return [dict(row._mapping) for row in result.fetchall()]
async def update_department(dept_id: int, update_params: Dict) -> None:
async with get_default_db() as db:
allowed_fields = ["name", "description"]
set_clause = ", ".join([f"{k} = :{k}" for k in update_params if k in allowed_fields])
if not set_clause:
return
params = {**update_params, "dept_id": dept_id}
await db.execute(
text(f"UPDATE departments SET {set_clause} WHERE id = :dept_id"),
params
)
async def delete_department(dept_id: int) -> None:
async with get_default_db() as db:
await db.execute(
text("DELETE FROM departments WHERE id = :dept_id"),
{"dept_id": dept_id}
)
# -------------------------- 聊天室操作 --------------------------
async def create_room(
name: str,
type: str,
creator_id: str,
dept_id: Optional[int] = None,
ai_character_id: Optional[int] = None,
description: Optional[str] = None,
) -> Dict:
async with get_default_db() as db:
result = await db.execute(
text("""
INSERT INTO rooms (
name, type, dept_id, ai_character_id, description, creator_id, created_at
)
VALUES (:name, :type, :dept_id, :ai_char_id, :desc, :creator_id, NOW())
"""),
{
"name": name,
"type": type,
"dept_id": dept_id,
"ai_char_id": ai_character_id,
"desc": description,
"creator_id": creator_id
}
)
room_id = result.lastrowid
return await get_room_by_id(room_id)
async def get_room_by_id(room_id: int) -> Optional[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT
r.*, u.account AS creator_account,
c.name AS ai_char_name
FROM rooms r
JOIN users u ON r.creator_id = u.id
LEFT JOIN characters c ON r.ai_character_id = c.id
WHERE r.id = :room_id
"""),
{"room_id": room_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_rooms(
type: Optional[str] = None,
dept_id: Optional[int] = None,
page: int = 1,
size: int = 10,
) -> Tuple[int, List[Dict]]:
async with get_default_db() as db:
where_clause = []
params = {"offset": (page - 1) * size, "limit": size}
if type:
where_clause.append("r.type = :type")
params["type"] = type
if dept_id is not None:
where_clause.append("r.dept_id = :dept_id")
params["dept_id"] = dept_id
where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else ""
total_result = await db.execute(
text(f"""
SELECT COUNT(*) AS total
FROM rooms r {where_sql}
"""),
params
)
total = total_result.scalar()
data_result = await db.execute(
text(f"""
SELECT
r.id, r.name, r.type, r.description,
r.creator_id, u.account AS creator_account,
r.created_at, COUNT(rm.user_id) AS member_count
FROM rooms r
JOIN users u ON r.creator_id = u.id
LEFT JOIN room_members rm ON r.id = rm.room_id
{where_sql}
GROUP BY r.id
ORDER BY r.created_at DESC
LIMIT :offset, :limit
"""),
params
)
rooms = [dict(row._mapping) for row in data_result.fetchall()]
return total, rooms
async def get_room_messages(
room_id: int,
page: int = 1,
size: int = 20,
order_by: str = "sent_at DESC"
) -> Tuple[int, List[Dict]]:
"""
分页获取聊天室历史消息(支持排序)
Args:
room_id: 聊天室ID
page: 页码(从1开始)
size: 每页数量
order_by: 排序方式(防止SQL注入,仅允许白名单)
Returns:
(总条数, 消息列表)
"""
async with get_default_db() as db:
# 1. 校验参数
if page < 1:
page = 1
if size < 1 or size > 100:
size = 20 # 防止过大分页
# 2. 白名单校验排序字段(防SQL注入)
valid_order = ["sent_at ASC", "sent_at DESC"]
order_sql = order_by.strip() if order_by in valid_order else "sent_at DESC"
# 3. 统计总数
total_result = await db.execute(
text("SELECT COUNT(*) AS total FROM room_messages WHERE room_id = :room_id"),
{"room_id": room_id}
)
total = total_result.scalar()
# 4. 查询分页数据(JOIN 用户表获取账号名)
offset = (page - 1) * size
data_result = await db.execute(
text(f"""
SELECT
rm.id,
rm.content,
rm.sent_at,
u.id AS sender_id,
u.account AS sender_account
FROM room_messages rm
JOIN users u ON rm.sender_id = u.id
WHERE rm.room_id = :room_id
ORDER BY {order_sql}
LIMIT :offset, :limit
"""),
{
"room_id": room_id,
"offset": offset,
"limit": size
}
)
messages = [dict(row._mapping) for row in data_result.fetchall()]
return total, messages
async def create_room_message(
room_id: int,
sender_id: str,
content: str,
sent_at: Optional[datetime] = None
) -> Dict:
"""发送聊天室消息并返回详情"""
async with get_default_db() as db:
sent_at = sent_at or datetime.now(timezone.utc)
# 插入消息
result = await db.execute(
text("""
INSERT INTO room_messages (room_id, sender_id, content, sent_at)
VALUES (:room_id, :sender_id, :content, :sent_at)
"""),
{
"room_id": room_id,
"sender_id": sender_id,
"content": content.strip(),
"sent_at": sent_at
}
)
msg_id = result.lastrowid
# 返回消息详情(含用户名)
msg_result = await db.execute(
text("""
SELECT
rm.id, rm.content, rm.sent_at,
u.id AS sender_id, u.account AS sender_account
FROM room_messages rm
JOIN users u ON rm.sender_id = u.id
WHERE rm.id = :msg_id
"""),
{"msg_id": msg_id}
)
return dict(msg_result.fetchone()._mapping)
async def check_room_member(room_id: int, user_id: str) -> bool:
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT 1 FROM room_members
WHERE room_id = :room_id AND user_id = :user_id
"""),
{"room_id": room_id, "user_id": user_id}
)
return result.scalar() is not None
async def add_room_member(room_id: int, user_id: str) -> None:
async with get_default_db() as db:
await db.execute(
text("""
INSERT IGNORE INTO room_members (room_id, user_id, joined_at)
VALUES (:room_id, :user_id, NOW())
"""),
{"room_id": room_id, "user_id": user_id}
)
async def remove_room_member(room_id: int, user_id: str) -> None:
async with get_default_db() as db:
await db.execute(
text("""
DELETE FROM room_members
WHERE room_id = :room_id AND user_id = :user_id
"""),
{"room_id": room_id, "user_id": user_id}
)
# -------------------------- 分享操作 --------------------------
async def create_share(
title: str,
content: str,
author_id: str,
is_public: bool = True,
type: str = "public",
ai_character_id: Optional[int] = None,
created_at: Optional[datetime] = None,
) -> Dict:
async with get_default_db() as db:
created_at = created_at or datetime.now(timezone.utc)
result = await db.execute(
text("""
INSERT INTO shares (
title, content, author_id, is_public, type,
ai_character_id, view_count, like_count, comment_count, created_at
)
VALUES (
:title, :content, :author_id, :is_public, :type,
:ai_char_id, 0, 0, 0, :created_at
)
"""),
{
"title": title,
"content": content,
"author_id": author_id,
"is_public": is_public,
"type": type,
"ai_char_id": ai_character_id,
"created_at": created_at
}
)
share_id = result.lastrowid
return await get_share_by_id(share_id)
async def get_character_by_id(char_id: int) -> Optional[Dict]:
"""根据ID获取AI角色信息"""
async with get_default_db() as db:
result = await db.execute(
text("SELECT id, name FROM characters WHERE id = :char_id"),
{"char_id": char_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def update_share(share_id: int, update_data: Dict) -> None:
"""更新分享字段(仅允许白名单字段)"""
async with get_default_db() as db:
allowed_fields = ["title", "content", "is_public", "ai_character_id",
"view_count", "like_count", "comment_count"]
set_clause = ", ".join([f"{k} = :{k}" for k in update_data if k in allowed_fields])
if not set_clause:
return
params = {**update_data, "share_id": share_id}
await db.execute(
text(f"UPDATE shares SET {set_clause} WHERE id = :share_id"),
params
)
async def delete_share(share_id: int) -> None:
"""删除分享及其关联数据(级联删除)"""
async with get_default_db() as db:
# 删除点赞记录
await db.execute(
text("DELETE FROM share_likes WHERE share_id = :share_id"),
{"share_id": share_id}
)
# 删除评论
await db.execute(
text("DELETE FROM comments WHERE share_id = :share_id"),
{"share_id": share_id}
)
# 删除分享本身
await db.execute(
text("DELETE FROM shares WHERE id = :share_id"),
{"share_id": share_id}
)
async def get_share_by_id(share_id: int) -> Optional[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT
s.*, u.account AS author_account, u.department_id,
c.name AS ai_char_name
FROM shares s
JOIN users u ON s.author_id = u.id
LEFT JOIN characters c ON s.ai_character_id = c.id
WHERE s.id = :share_id
"""),
{"share_id": share_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_shares(
is_public: Optional[bool] = None,
author_id: Optional[str] = None,
type: Optional[str] = None,
order_by: str = "created_at DESC",
page: int = 1,
size: int = 10,
) -> Tuple[int, List[Dict]]:
async with get_default_db() as db:
where_clause = []
params = {"offset": (page - 1) * size, "limit": size}
if is_public is not None:
where_clause.append("s.is_public = :is_public")
params["is_public"] = is_public
if author_id:
where_clause.append("s.author_id = :author_id")
params["author_id"] = author_id
if type:
where_clause.append("s.type = :type")
params["type"] = type
where_sql = "WHERE " + " AND ".join(where_clause) if where_clause else ""
valid_order = ["created_at DESC", "created_at ASC", "like_count DESC", "view_count DESC"]
order_sql = order_by if order_by in valid_order else "created_at DESC"
total_result = await db.execute(
text(f"SELECT COUNT(*) AS total FROM shares s {where_sql}"),
params
)
total = total_result.scalar()
data_result = await db.execute(
text(f"""
SELECT
s.*, u.account AS author_account,
c.name AS ai_char_name
FROM shares s
JOIN users u ON s.author_id = u.id
LEFT JOIN characters c ON s.ai_character_id = c.id
{where_sql}
ORDER BY {order_sql}
LIMIT :offset, :limit
"""),
params
)
shares = [dict(row._mapping) for row in data_result.fetchall()]
return total, shares
# -------------------------- 搜索记录操作(重点修复)--------------------------
async def add_search_record(
keyword: str,
user_id: Optional[str] = None,
search_time: Optional[datetime] = None,
) -> None:
async with get_default_db() as db:
search_time = search_time or datetime.now(timezone.utc)
await db.execute(
text("""
INSERT INTO search_records (keyword, user_id, search_time)
VALUES (:keyword, :user_id, :search_time)
"""),
{
"keyword": keyword.strip(),
"user_id": user_id,
"search_time": search_time
}
)
async def get_hot_searches(date: Optional[date] = None, limit: int = 10) -> List[Dict]:
async with get_default_db() as db:
date = date or datetime.now().date()
result = await db.execute(
text("""
SELECT keyword, COUNT(*) AS search_count
FROM search_records
WHERE DATE(search_time) = :date
GROUP BY keyword
ORDER BY search_count DESC
LIMIT :limit
"""),
{"date": date, "limit": limit}
)
return [dict(row._mapping) for row in result.fetchall()]
async def get_hot_search_keywords(
start_time: datetime,
limit: int = 10,
) -> List[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT keyword, COUNT(*) AS search_count
FROM search_records
WHERE search_time >= :start_time
GROUP BY keyword
ORDER BY search_count DESC
LIMIT :limit
"""),
{"start_time": start_time, "limit": limit}
)
return [dict(row._mapping) for row in result.fetchall()]
async def get_user_search_history(
user_id: str,
limit: int = 10,
days: int = 30,
) -> List[Dict]:
async with get_default_db() as db:
cutoff_time = datetime.now(timezone.utc) - timedelta(days=days)
result = await db.execute(
text("""
SELECT id, keyword, search_time
FROM search_records
WHERE user_id = :user_id AND search_time >= :cutoff_time
ORDER BY search_time DESC
LIMIT :limit
"""),
{"user_id": user_id, "cutoff_time": cutoff_time, "limit": limit}
)
return [dict(row._mapping) for row in result.fetchall()]
async def get_search_history_by_id(history_id: int) -> Optional[Dict]:
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT id, user_id, keyword, search_time
FROM search_records
WHERE id = :history_id
"""),
{"history_id": history_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def delete_search_history(history_id: int) -> None:
async with get_default_db() as db:
await db.execute(
text("DELETE FROM search_records WHERE id = :history_id"),
{"history_id": history_id}
)
async def clear_user_search_history(user_id: str) -> None:
async with get_default_db() as db:
await db.execute(
text("DELETE FROM search_records WHERE user_id = :user_id"),
{"user_id": user_id}
)
# -------------------------- 搜索功能 --------------------------
async def search_shares(
keyword: str,
is_public: bool = True,
author_id: Optional[str] = None,
page: int = 1,
size: int = 10,
) -> Tuple[int, List[Dict]]:
async with get_default_db() as db:
like_keyword = f"%{keyword}%"
params = {
"keyword": like_keyword,
"is_public": is_public,
"offset": (page - 1) * size,
"limit": size
}
author_clause = "AND s.author_id = :author_id" if author_id else ""
if author_id:
params["author_id"] = author_id
total_result = await db.execute(
text(f"""
SELECT COUNT(*) AS total
FROM shares s
WHERE s.is_public = :is_public
AND (s.title LIKE :keyword OR s.content LIKE :keyword)
{author_clause}
"""),
params
)
total = total_result.scalar()
data_result = await db.execute(
text(f"""
SELECT
s.*, u.account AS author_account,
c.name AS ai_char_name
FROM shares s
JOIN users u ON s.author_id = u.id
LEFT JOIN characters c ON s.ai_character_id = c.id
WHERE s.is_public = :is_public
AND (s.title LIKE :keyword OR s.content LIKE :keyword)
{author_clause}
ORDER BY s.created_at DESC
LIMIT :offset, :limit
"""),
params
)
shares = [dict(row._mapping) for row in data_result.fetchall()]
return total, shares
async def search_rooms(
keyword: str,
user_id: str,
department_id: int,
is_admin: bool = False,
page: int = 1,
size: int = 10,
) -> Tuple[int, List[Dict]]:
async with get_default_db() as db:
like_keyword = f"%{keyword}%"
params = {
"keyword": like_keyword,
"user_id": user_id,
"dept_id": department_id,
"offset": (page - 1) * size,
"limit": size
}
where_clauses = []
if not is_admin:
where_clauses.append("(r.type = 'public' OR r.dept_id = :dept_id)")
where_clauses.append("""
EXISTS (
SELECT 1 FROM room_members rm
WHERE rm.room_id = r.id AND rm.user_id = :user_id
)
""")
where_sql = " AND ".join(where_clauses)
if where_sql:
where_sql = "AND " + where_sql
total_result = await db.execute(
text(f"""
SELECT COUNT(*) AS total
FROM rooms r
WHERE (r.name LIKE :keyword OR r.description LIKE :keyword)
{where_sql}
"""),
params
)
total = total_result.scalar()
data_result = await db.execute(
text(f"""
SELECT
r.id, r.name, r.type, r.description, r.creator_id, u.account AS creator_account,
r.created_at, COUNT(rm.user_id) AS member_count
FROM rooms r
JOIN users u ON r.creator_id = u.id
LEFT JOIN room_members rm ON r.id = rm.room_id
WHERE (r.name LIKE :keyword OR r.description LIKE :keyword)
{where_sql}
GROUP BY r.id
ORDER BY r.created_at DESC
LIMIT :offset, :limit
"""),
params
)
rooms = [dict(row._mapping) for row in data_result.fetchall()]
return total, rooms
async def search_users(
keyword: str,
current_user_id: str,
is_admin: bool = False,
page: int = 1,
size: int = 10,
) -> Tuple[int, List[Dict]]:
async with get_default_db() as db:
like_keyword = f"%{keyword}%"
params = {
"keyword": like_keyword,
"current_user_id": current_user_id,
"offset": (page - 1) * size,
"limit": size
}
select_fields = "u.id, u.account, u.role, u.department_id, u.created_at, d.name AS dept_name"
from_join = "FROM users u LEFT JOIN departments d ON u.department_id = d.id"
where_clause = "(u.account LIKE :keyword)"
if not is_admin:
where_clause += " AND u.id != :current_user_id"
total_result = await db.execute(
text(f"SELECT COUNT(*) AS total {from_join} WHERE {where_clause}"),
params
)
total = total_result.scalar()
data_result = await db.execute(
text(f"""
SELECT {select_fields} {from_join} WHERE {where_clause}
ORDER BY u.created_at DESC LIMIT :offset, :limit
"""),
params
)
users = [dict(row._mapping) for row in data_result.fetchall()]
return total, users
# -------------------------- 推荐功能 --------------------------
async def recommend_shares_by_keywords(
keywords: List[str],
limit: int = 5,
exclude_user_id: Optional[str] = None,
department_id: Optional[int] = None,
) -> List[Dict]:
async with get_default_db() as db:
if not keywords:
return []
# 清洗关键词
safe_keywords = [
re.sub(r"[^a-zA-Z0-9\u4e00-\u9fa5]", "", kw)[:20]
for kw in keywords if kw
]
if not safe_keywords:
return []
conditions = " OR ".join([f"s.title LIKE :k{i} OR s.content LIKE :k{i}" for i in range(len(safe_keywords))])
params = {f"k{i}": f"%{kw}%" for i, kw in enumerate(safe_keywords)}
params["limit"] = limit
if exclude_user_id:
params["exclude_user_id"] = exclude_user_id
filters = " AND s.author_id != :exclude_user_id" if exclude_user_id else ""
order_by = "CASE"
for i, kw in enumerate(safe_keywords):
order_by += f" WHEN s.title LIKE '%{kw}%' THEN {i}"
order_by += f" WHEN s.content LIKE '%{kw}%' THEN {i + len(safe_keywords)}"
order_by += " ELSE 99 END"
result = await db.execute(text(f"""
SELECT s.*, u.account AS author_account, c.name AS ai_char_name
FROM shares s
JOIN users u ON s.author_id = u.id
LEFT JOIN characters c ON s.ai_character_id = c.id
WHERE ({conditions}) AND s.is_public = TRUE {filters}
ORDER BY {order_by}, s.like_count DESC
LIMIT :limit
"""), params)
return [dict(row._mapping) for row in result.fetchall()]
async def get_dept_hot_shares(
dept_id: int,
limit: int = 5,
exclude_user_id: Optional[str] = None,
) -> List[Dict]:
async with get_default_db() as db:
params = {"dept_id": dept_id, "limit": limit}
exclude_clause = " AND s.author_id != :exclude_user_id" if exclude_user_id else ""
if exclude_user_id:
params["exclude_user_id"] = exclude_user_id
result = await db.execute(text(f"""
SELECT s.*, u.account AS author_account
FROM shares s
JOIN users u ON s.author_id = u.id
WHERE u.department_id = :dept_id AND s.is_public = TRUE {exclude_clause}
ORDER BY s.like_count DESC, s.view_count DESC
LIMIT :limit
"""), params)
return [dict(row._mapping) for row in result.fetchall()]
async def get_global_hot_shares(
limit: int = 5,
exclude_user_id: Optional[str] = None,
) -> List[Dict]:
async with get_default_db() as db:
params = {"limit": limit}
exclude_clause = " AND s.author_id != :exclude_user_id" if exclude_user_id else ""
if exclude_user_id:
params["exclude_user_id"] = exclude_user_id
result = await db.execute(text(f"""
SELECT s.*, u.account AS author_account
FROM shares s
JOIN users u ON s.author_id = u.id
WHERE s.is_public = TRUE {exclude_clause}
ORDER BY s.like_count DESC, s.view_count DESC
LIMIT :limit
"""), params)
return [dict(row._mapping) for row in result.fetchall()]
# -------------------------- 管理员统计 --------------------------
async def get_user_stats(start_date: date, end_date: date) -> Dict:
async with get_default_db() as db:
total_result = await db.execute(text("SELECT COUNT(*) AS total FROM users"))
total = total_result.scalar()
new_result = await db.execute(
text("""
SELECT COUNT(*) AS new_count
FROM users
WHERE DATE(created_at) BETWEEN :start AND :end
"""),
{"start": start_date, "end": end_date}
)
new_count = new_result.scalar()
role_result = await db.execute(text("SELECT role, COUNT(*) AS count FROM users GROUP BY role"))
role_dist = [dict(row._mapping) for row in role_result.fetchall()]
dept_result = await db.execute(text("""
SELECT d.name AS dept_name, COUNT(u.id) AS user_count
FROM departments d
LEFT JOIN users u ON d.id = u.department_id
GROUP BY d.id
ORDER BY user_count DESC
LIMIT 10
"""))
dept_dist = [dict(row._mapping) for row in dept_result.fetchall()]
return {
"total_user": total,
"new_user": new_count,
"role_distribution": role_dist,
"department_distribution": dept_dist
}
async def get_share_stats(start_date: date, end_date: date) -> Dict:
async with get_default_db() as db:
total_result = await db.execute(text("SELECT COUNT(*) AS total FROM shares"))
total = total_result.scalar()
new_result = await db.execute(
text("""
SELECT COUNT(*) AS new_count
FROM shares
WHERE DATE(created_at) BETWEEN :start AND :end
"""),
{"start": start_date, "end": end_date}
)
new_count = new_result.scalar()
type_result = await db.execute(text("SELECT type, COUNT(*) AS count FROM shares GROUP BY type"))
type_dist = [dict(row._mapping) for row in type_result.fetchall()]
ai_result = await db.execute(text("""
SELECT c.name AS ai_char_name, COUNT(s.id) AS share_count
FROM characters c
LEFT JOIN shares s ON c.id = s.ai_character_id
WHERE s.ai_character_id IS NOT NULL
GROUP BY c.id
ORDER BY share_count DESC
LIMIT 10
"""))
ai_dist = [dict(row._mapping) for row in ai_result.fetchall()]
interact_result = await db.execute(
text("SELECT SUM(like_count) AS total_like, SUM(comment_count) AS total_comment FROM shares")
)
interact = dict(interact_result.fetchone()._mapping)
return {
"total_share": total,
"new_share": new_count,
"type_distribution": type_dist,
"ai_character_distribution": ai_dist,
"total_interaction": interact
}
# -------------------------- 评论功能 --------------------------
async def create_comment(
share_id: int,
commenter_id: str,
content: str,
parent_id: Optional[int] = None,
created_at: Optional[datetime] = None
) -> Dict:
"""创建评论(支持回复)"""
async with get_default_db() as db:
created_at = created_at or datetime.now(timezone.utc)
result = await db.execute(
text("""
INSERT INTO comments (share_id, commenter_id, content, parent_id, created_at)
VALUES (:share_id, :commenter_id, :content, :parent_id, :created_at)
"""),
{
"share_id": share_id,
"commenter_id": commenter_id,
"content": content.strip(),
"parent_id": parent_id,
"created_at": created_at
}
)
comment_id = result.lastrowid
# 返回评论详情
result = await db.execute(
text("""
SELECT c.*, u.account AS commenter_account
FROM comments c
JOIN users u ON c.commenter_id = u.id
WHERE c.id = :comment_id
"""),
{"comment_id": comment_id}
)
return dict(result.fetchone()._mapping)
async def get_comment_by_id(comment_id: int) -> Optional[Dict]:
"""根据ID获取评论"""
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT * FROM comments WHERE id = :comment_id
"""),
{"comment_id": comment_id}
)
row = result.fetchone()
return dict(row._mapping) if row else None
async def get_share_comments(
share_id: int,
page: int = 1,
size: int = 20,
order_by: str = "created_at ASC"
) -> Tuple[int, List[Dict]]:
"""分页获取某分享下的评论(含子评论)"""
async with get_default_db() as db:
valid_order = ["created_at ASC", "created_at DESC"]
order_sql = order_by if order_by in valid_order else "created_at ASC"
offset = (page - 1) * size
# 统计总数
total_result = await db.execute(
text("SELECT COUNT(*) AS total FROM comments WHERE share_id = :share_id"),
{"share_id": share_id}
)
total = total_result.scalar()
# 查询分页评论 + 用户名
data_result = await db.execute(
text(f"""
SELECT
c.id, c.content, c.created_at, c.parent_id,
u.id AS commenter_id, u.account AS commenter_account
FROM comments c
JOIN users u ON c.commenter_id = u.id
WHERE c.share_id = :share_id
ORDER BY {order_sql}
LIMIT :offset, :limit
"""),
{"share_id": share_id, "offset": offset, "limit": size}
)
comments = [dict(row._mapping) for row in data_result.fetchall()]
return total, comments
# -------------------------- 点赞功能 --------------------------
async def check_share_like(share_id: int, user_id: str) -> bool:
"""检查用户是否已点赞该分享"""
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT 1 FROM share_likes
WHERE share_id = :share_id AND user_id = :user_id
"""),
{"share_id": share_id, "user_id": user_id}
)
return result.scalar() is not None
async def add_share_like(share_id: int, user_id: str) -> None:
"""用户点赞分享"""
async with get_default_db() as db:
await db.execute(
text("""
INSERT IGNORE INTO share_likes (share_id, user_id, created_at)
VALUES (:share_id, :user_id, NOW())
"""),
{"share_id": share_id, "user_id": user_id}
)
async def remove_share_like(share_id: int, user_id: str) -> None:
"""用户取消点赞"""
async with get_default_db() as db:
await db.execute(
text("""
DELETE FROM share_likes
WHERE share_id = :share_id AND user_id = :user_id
"""),
{"share_id": share_id, "user_id": user_id}
)
# -------------------------- AI聊天对话记录操作 --------------------------
async def save_conversation(
user_id: int,
character_id: int,
user_message: str,
ai_reply: str,
timestamp: Optional[datetime] = None
) -> None:
"""
保存用户与AI角色的单条对话记录(插入到用户专属表)
Args:
user_id: 用户ID(整数)
character_id: AI角色ID
user_message: 用户输入
ai_reply: AI回复
timestamp: 时间戳(默认当前时间)
"""
async with get_default_db() as db:
# 1. 获取用户专属表名
table_name = get_conversation_table_name(user_id)
if not is_valid_table_name(table_name):
raise ValueError(f"Invalid user ID: {user_id}")
# 2. 确保该表存在(防御性检查)
try:
await db.execute(text(f"""
CREATE TABLE IF NOT EXISTS `{table_name}` (
id INT AUTO_INCREMENT PRIMARY KEY,
character_id INT NOT NULL,
user_message TEXT NOT NULL,
ai_message TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (character_id) REFERENCES characters(id) ON DELETE CASCADE
) ENGINE=InnoDB CHARSET=utf8mb4;
"""))
except Exception as e:
logger.error(f"Failed to ensure table {table_name} exists: {e}")
raise
# 3. 插入对话记录
try:
await db.execute(
text(f"""
INSERT INTO `{table_name}`
(character_id, user_message, ai_message, timestamp)
VALUES (:char_id, :user_msg, :ai_msg, :ts)
"""),
{
"char_id": character_id,
"user_msg": user_message,
"ai_msg": ai_reply,
"ts": timestamp or datetime.now(timezone.utc)
}
)
except Exception as e:
logger.error(f"Failed to insert into {table_name}: {e}")
raise
async def get_all_characters() -> List[Dict]:
"""
获取所有AI角色列表
Returns:
List[Dict]: 包含 id, name, trait, avatar_url 的字典列表
"""
async with get_default_db() as db:
result = await db.execute(
text("""
SELECT
id,
name,
trait,
avatar_url,
created_at
FROM characters
ORDER BY created_at DESC
""")
)
rows = result.fetchall()
return [dict(row._mapping) for row in rows]
你新给的数据库代码和原来的代码比是不是少了很多函数,把他们补全,重新给我一份简化后的代码
最新发布