# backend/database.py
from typing import Optional, List, Dict, Any
import json
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import DateTime, Index, text, select, update, delete, Column, func
from sqlalchemy.sql.expression import true
from sqlmodel import SQLModel, Field, Column as SQLColumn
import urllib.parse
# 推荐写法
def utc_now():
return datetime.now(timezone.utc)
# ==================== 数据库配置 ====================
DATABASE_URL = "mysql+asyncmy://root:123456@localhost/ai_roleplay?charset=utf8mb4"
engine = create_async_engine(
DATABASE_URL,
echo=False, # 调试时设为 True
pool_pre_ping=True,
max_overflow=10,
pool_size=5
)
AsyncSessionLocal = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False
)
# ==================== 数据模型定义 ====================
class User(SQLModel, table=True):
__tablename__ = "users"
id: int = Field(default=None, primary_key=True)
account: str = Field(index=True, unique=True, min_length=3, max_length=50)
password_hash: str = Field(min_length=60, max_length=128)
nickname: Optional[str] = Field(default=None, max_length=50)
avatar_url: str = Field(default="/static/default_avatar.png", max_length=255)
email: Optional[str] = Field(default=None, unique=True, max_length=100)
role: str = Field(default="user", sa_column=Column("role",
SQLColumn(sqlalchemy_type=SQLModel.get_sa_type("enum"),
type_=None), nullable=False))
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now,
sa_column=SQLColumn(
"updated_at",
DateTime,
default=func.current_timestamp(),
onupdate=func.current_timestamp())
)
class Department(SQLModel, table=True):
__tablename__ = "departments"
id: int = Field(default=None, primary_key=True)
name: str = Field(unique=True, max_length=100)
description: Optional[str] = Field(default=None)
created_at: datetime = Field(default_factory=utc_now)
class UserDepartmentLink(SQLModel, table=True):
__tablename__ = "user_departments"
user_id: int = Field(foreign_key="users.id", primary_key=True)
dept_id: int = Field(foreign_key="departments.id", primary_key=True)
joined_at: datetime = Field(default_factory=utc_now)
class Character(SQLModel, table=True):
__tablename__ = "characters"
id: int = Field(default=None, primary_key=True)
name: str = Field(max_length=100, nullable=False)
trait: str = Field(nullable=False) # prompt 描述
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now,
sa_column=SQLColumn(
"updated_at",
DateTime,
default=func.current_timestamp(),
onupdate=func.current_timestamp())
)
class UserProfile(SQLModel, table=True):
__tablename__ = "user_profiles"
user_id: int = Field(foreign_key="users.id", primary_key=True)
personality: Optional[str] = Field(default=None)
role_setting: Optional[str] = Field(default=None)
max_history: int = Field(default=4)
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now,
sa_column=SQLColumn(
"updated_at",
DateTime,
default=func.current_timestamp(),
onupdate=func.current_timestamp())
)
class Post(SQLModel, table=True):
__tablename__ = "posts"
id: int = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id")
title: str = Field(max_length=200)
content: str = Field(sa_column=Column("content", SQLColumn(sqlalchemy_type=SQLModel.get_sa_type("longtext"))))
summary: Optional[str] = Field(default=None, max_length=500)
cover_image: Optional[str] = Field(default=None, max_length=255)
visibility: str = Field(default="public", regex="^(public|private|friends)$")
view_count: int = Field(default=0)
like_count: int = Field(default=0)
comment_count: int = Field(default=0)
tags: Optional[str] = Field(default=None) # JSON string, e.g., '["AI", "编程"]'
is_pinned: bool = Field(default=False)
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now,
sa_column=SQLColumn(
"updated_at",
DateTime,
default=func.current_timestamp(),
onupdate=func.current_timestamp())
)
# 索引
__table_args__ = (
Index("idx_created_at", "created_at"),
Index("idx_like_count", "like_count"),
Index("idx_view_count", "view_count"),
)
class Comment(SQLModel, table=True):
__tablename__ = "comments"
id: int = Field(default=None, primary_key=True)
post_id: int = Field(foreign_key="posts.id")
user_id: int = Field(foreign_key="users.id")
parent_id: Optional[int] = Field(default=None, foreign_key="comments.id")
content: str = Field()
like_count: int = Field(default=0)
created_at: datetime = Field(default_factory=utc_now)
__table_args__ = (
Index("idx_post_id", "post_id"),
Index("idx_parent_id", "parent_id"),
)
class Like(SQLModel, table=True):
__tablename__ = "likes"
user_id: int = Field(primary_key=True, foreign_key="users.id")
target_type: str = Field(primary_key=True, regex="^(post|comment)$")
target_id: int = Field(primary_key=True)
created_at: datetime = Field(default_factory=utc_now)
__table_args__ = (
Index("idx_target", "target_type", "target_id"),
)
class SearchTrend(SQLModel, table=True):
__tablename__ = "search_trends"
keyword: str = Field(primary_key=True, max_length=100)
hit_count: int = Field(default=1)
last_searched_at: datetime = Field(default_factory=utc_now,
sa_column=SQLColumn(
"updated_at",
DateTime,
default=func.current_timestamp(),
onupdate=func.current_timestamp())
)
is_hot: bool = Field(default=False)
class ChatSession(SQLModel, table=True):
__tablename__ = "chat_sessions"
id: int = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id")
session_name: str = Field(default="未命名对话", max_length=100)
character_id: Optional[int] = Field(default=None, foreign_key="characters.id")
chat_type: str = Field(regex="^(ai_room|multi_user_chat|roleplay)$")
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now,
sa_column=SQLColumn(
"updated_at",
DateTime,
default=func.current_timestamp(),
onupdate=func.current_timestamp())
)
__table_args__ = (
Index("idx_user_id", "user_id"),
Index("idx_updated_at", "updated_at"),
)
class Message(SQLModel, table=True):
__tablename__ = "messages"
id: int = Field(default=None, primary_key=True)
session_id: int = Field(foreign_key="chat_sessions.id")
sender_type: str = Field(regex="^(user|ai|system)$")
sender_id: Optional[int] = Field(default=None, foreign_key="users.id")
content: str = Field()
timestamp: datetime = Field(default_factory=utc_now)
__table_args__ = (
Index("idx_session_id", "session_id"),
Index("idx_timestamp", "timestamp"),
)
# ==================== 工具函数:上下文管理器 ====================
@asynccontextmanager
async def get_db():
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# ==================== 用户相关操作 ====================
async def check_users(account: str, password_hash: str) -> tuple[int, str]:
async with get_db() as db:
result = await db.execute(
select(User).where(User.account == account)
)
user = result.scalar_one_or_none()
if user:
return user.id, user.password_hash
# 创建新用户
new_user = User(account=account, password_hash=password_hash)
db.add(new_user)
await db.flush() # 获取 ID
user_id = new_user.id
# 创建默认 profile
profile = UserProfile(user_id=user_id, role_setting="你正在扮演一位聪明、幽默又略带毒舌的程序员助手。")
db.add(profile)
# 创建专属聊天会话(可选)
session = ChatSession(user_id=user_id, session_name="我的第一个 AI 对话", chat_type="roleplay")
db.add(session)
await db.commit()
return user_id, password_hash
async def get_user_profile(user_id: int) -> Optional[Dict[str, Any]]:
async with get_db() as db:
result = await db.execute(
select(UserProfile).where(UserProfile.user_id == user_id)
)
row = result.scalar_one_or_none()
return row.model_dump() if row else None
async def create_or_update_user_profile(user_id: int, personality: str, role_setting: str) -> bool:
async with get_db() as db:
result = await db.execute(
select(UserProfile).where(UserProfile.user_id == user_id)
)
profile = result.scalar_one_or_none()
if profile:
profile.personality = personality.strip()
profile.role_setting = role_setting.strip()
else:
profile = UserProfile(
user_id=user_id,
personality=personality.strip(),
role_setting=role_setting.strip()
)
db.add(profile)
await db.commit()
return True
# ==================== 角色相关 ====================
async def get_all_characters() -> List[Dict[str, Any]]:
async with get_db() as db:
result = await db.execute(select(Character))
rows = result.scalars().all()
return [r.model_dump() for r in rows]
async def get_character_by_id(character_id: int) -> Optional[Dict[str, Any]]:
async with get_db() as db:
result = await db.execute(
select(Character).where(Character.id == character_id)
)
row = result.scalar_one_or_none()
return row.model_dump() if row else None
# ==================== 聊天历史与消息 ====================
async def load_history_from_db(user_id: int, max_count: int = 4) -> List[Dict[str, str]]:
async with get_db() as db:
# 获取最新会话
session_result = await db.execute(
select(ChatSession.id)
.where(ChatSession.user_id == user_id)
.order_by(ChatSession.updated_at.desc())
.limit(1)
)
session_id = session_result.scalar()
if not session_id:
return []
# 获取该会话中最近的消息
result = await db.execute(
select(Message.sender_type, Message.content)
.where(Message.session_id == session_id)
.order_by(Message.timestamp.desc())
.limit(max_count)
)
rows = result.all()
history = [
{"role": "ai" if r[0] == "ai" else r[0], "content": r[1]}
for r in reversed(rows)
]
return history
async def save_message_to_db(user_id: int, role: str, content: str):
async with get_db() as db:
# 获取或创建当前会话(示例:取最新会话)
session_result = await db.execute(
select(ChatSession.id)
.where(ChatSession.user_id == user_id)
.order_by(ChatSession.updated_at.desc())
.limit(1)
)
session_id = session_result.scalar()
if not session_id:
new_session = ChatSession(user_id=user_id, chat_type="roleplay")
db.add(new_session)
await db.flush()
session_id = new_session.id
sender_type = "system" if role == "system" else ("ai" if role == "ai" else "user")
sender_id = user_id if sender_type == "user" else None
message = Message(
session_id=session_id,
sender_type=sender_type,
sender_id=sender_id,
content=content
)
db.add(message)
await db.commit()
# ==================== 博客相关 ====================
async def create_post(user_id: int, title: str, content: str, tags: List[str] = None):
async with get_db() as db:
summary = content[:500].strip() + ("..." if len(content) > 500 else "")
post = Post(
user_id=user_id,
title=title,
content=content,
summary=summary,
tags=json.dumps(tags, ensure_ascii=False) if tags else None
)
db.add(post)
await db.commit()
await db.refresh(post)
return post.id
async def get_posts(page: int = 1, page_size: int = 10, order_by: str = "latest"):
async with get_db() as db:
offset = (page - 1) * page_size
stmt = select(Post, User.nickname, User.avatar_url)\
.join(User, User.id == Post.user_id)\
.where(Post.visibility == "public")\
.offset(offset).limit(page_size)
if order_by == "hot":
stmt = stmt.order_by(Post.like_count.desc(), Post.created_at.desc())
else:
stmt = stmt.order_by(Post.created_at.desc())
result = await db.execute(stmt)
rows = result.all()
posts = []
for post, nickname, avatar in rows:
data = post.model_dump()
data["author_nickname"] = nickname
data["author_avatar"] = avatar
posts.append(data)
return posts
# ==================== 初始化数据库 ====================
async def create_tables():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
print("✅ 所有数据库表已创建完成!")
# 可选:关闭连接池
async def close_db():
await engine.dispose()
检查一下修改正确与否,不需要优化
最新发布