import asyncio
import httpx
import json
import os
import uvicorn
import torch
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.security import OAuth2PasswordBearer
from jinja2 import Environment, FileSystemLoader
from dotenv import load_dotenv
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import JWT_token
#import redis
# 导入数据库操作
import database
BASE_DIR = Path(__file__).resolve().parent.parent #主文件夹地址
FRONTEND_DIR = BASE_DIR / "frontend"
app = FastAPI()
load_dotenv(dotenv_path= BASE_DIR / r"backend\.env")
# 挂载静态资源
app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
templates = Jinja2Templates(directory=str(FRONTEND_DIR))
# Jinja2 模板引擎(用于渲染 HTML)
template_env = Environment(loader=FileSystemLoader(str(FRONTEND_DIR)))
JWT_token.create_access_token()
model,tokenizer = None,None
#r = redis.Redis(host='localhost', port=6379, db=0)
""" def create_session(user_id: int):#创建session
session_id = str(uuid4())
session_data = {
"user_id":user_id,
"created_at":datetime.now().isoformat()
}
r.setex(session_id, 3600, json.dumps(session_data)) # 1小时过期
return session_id
def get_session(session_id: str):#获取用户信息
if not session_id:
return None
data = r.get(f"session:{session_id}")
if data:
return json.loads(data)
return None
def destroy_session(session_id: str):# 工具函数:销毁 session(登出用)
if session_id:
r.delete(f"session:{session_id}")
def get_current_user(request: Request):# 中间件风格:获取当前登录用户
session_id = request.cookies.get("session_id")
return get_session(session_id) """
def load_model():#同步函数:实际加载模型
model_name = str(BASE_DIR / "model\deepseek-coder-1.3b-instruct")
print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
m = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
).eval()
return m, tok
async def start_load():#非阻塞地加载模型
global model,tokenizer
loop = asyncio.get_event_loop()
# 使用线程池执行同步加载,避免阻塞事件循环
model, tokenizer = await loop.run_in_executor(None, load_model)
print("✅ Model loaded during startup!")
def shutdown_event():#在应用关闭时清理资源
global model, tokenizer
if model is not None:
del model
if tokenizer is not None:
del tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("👋 Cleaned up model and CUDA cache on shutdown.")
@app.get("/",response_class=HTMLResponse)
async def home():#加载主页
template = template_env.get_template("home.html")
content = template.render()
return HTMLResponse(content=content)
@app.get("/login",response_class=HTMLResponse)
async def login():#加载登录界面
template = template_env.get_template("login.html")
content = template.render()
return HTMLResponse(content=content)
@app.post("/login",response_class=bool)
async def checklogin(request:Request):#返回主页时检查登录信息
data = await request.json()
account = data.get("account")
password = data.get("password")
hash_password = JWT_token.pwd_context.hash([password])#加密
ID,data_password = database.check_users(account,hash_password)
if data_password:
is_correct = JWT_token.pwd_context.verify(password, data_password)
return JSONResponse({"success":is_correct})
""" if success:
create_session(ID) """
return JSONResponse({"success":True})
@app.get("/user1",response_class=HTMLResponse)
async def chat1(request:Request):#进入AI“调教”界面并加载模型
#ID = get_session(get_current_user(request))
if ID:
template = template_env.get_template("myapp1.html")
asyncio.run(start_load())
app.add_event_handler("shutdown",shutdown_event)
content = template.render()
return HTMLResponse(content=content)
context = {
"request": request,
"login-hint": True
}
return template_env.get_template("login.html",context)
@app.post("/user1/chat")
async def deepseek_chat(request:Request):#响应对话请求
data = await request.json()
#ID = get_session(get_current_user(request))
user_message = data.get("message")
# 👇 关键:使用本地模型进行推理
try:
# 确保模型已加载
global model, tokenizer
if model is None or tokenizer is None:
return JSONResponse({"error": "模型尚未加载,请先启动模型"}, status_code=500)
# 构造对话历史(必须使用 chat template)
messages = [
{"role": "user", "content": user_message}
]
# 使用 tokenizer.apply_chat_template 构造输入文本
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True # 让模型知道要开始生成 assistant 回复
)
# Tokenize
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
# 生成参数
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.85,
top_p=0.95,
do_sample=True,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id # 避免 decoder-only 模型 padding 报错
)
# 解码输出(去掉输入部分)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
# 提取 assistant 的回复内容
# 注意:apply_chat_template 已经添加了 <|assistant|> 标记
if "<|assistant|>" in full_response:
reply = full_response.split("<|assistant|>")[-1].strip()
else:
reply = full_response[len(input_text):].strip()
# 清理结尾可能的无关 token
eot_token = "<|EOT|>"
if eot_token in reply:
reply = reply.split(eot_token)[0].strip()
# 保存对话记录
database.save_conversation(ID,0,user_message, reply)
return JSONResponse({"reply": reply})
except Exception as e:
import traceback
error_msg = traceback.format_exc() # 更详细的错误日志
return JSONResponse({"error": f"推理失败: {str(e)}", "detail": error_msg}, status_code=500)
@app.get("/user2", response_class=HTMLResponse)
async def chat2(request:Request): #进入角色扮演界面
#ID = get_session(get_current_user(request))
if ID:
template = template_env.get_template("myapp2.html")
characters = database.get_all_characters()
characters_json = json.dumps([
{"id": c["id"], "name": c["name"], "trait": c["trait"]}
for c in characters
], ensure_ascii=False)
content = template.render(characters=characters, characters_json=characters_json)
return HTMLResponse(content=content)
context = {
"request": request,
"login-hint": True
}
return template_env.get_template("login.html",context)
@app.post("/user2/chat")
async def dashscope_chat(request: Request):#响应对话请求
#ID = get_session(get_current_user(request))
data = await request.json()
character_id = data.get("character_id")
user_message = data.get("message")
# 查询角色信息
conn = database.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT name, trait FROM characters WHERE id = %s", (character_id,))
character = cursor.fetchone()
conn.close()
if not character:
return JSONResponse({"error": "角色不存在"}, status_code=404)
# 构建系统提示词
system_prompt = f"""
你正在扮演 {character['name']}。
人物设定:{character['trait']}
请始终以这个角色的身份、语气和思维方式回答问题。
不要暴露你是 AI 的事实。
""".strip()
# ✅ 使用 DASHSCOPE API 进行推理
try:
API_KEY = os.getenv("DASHSCOPE_API_KEY")
if not API_KEY:
return JSONResponse({"error": "API密钥未配置"}, status_code=500)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
"Accept": "application/json"
}
payload = {
"model": "qwen-plus",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
],
"temperature": 0.85,
"top_p": 0.95,
"max_tokens": 512,
"stream": False
}
async with httpx.AsyncClient() as client:
response = await client.post(
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
json=payload,
headers=headers,
timeout=30.0
)
if response.status_code != 200:
error_detail = response.text
return JSONResponse(
{"error": f"远程API错误 [{response.status_code}]", "detail": error_detail},
status_code=response.status_code
)
result = response.json()
reply = result["choices"][0]["message"]["content"].strip()
# 保存对话记录
database.save_conversation(ID,character_id, user_message, reply)
return JSONResponse({"reply": reply})
except Exception as e:
import traceback
error_msg = traceback.format_exc()
return JSONResponse(
{"error": f"请求失败: {str(e)}", "detail": error_msg},
status_code=500
)
if __name__ == "__main__":#启动应用
uvicorn.run("myapp:app", host="127.0.0.1", port=8000, reload=True)
为我的代码添加JWT系统