import asyncio
from datetime import timedelta
import httpx
import json
import os
import torch
import uvicorn
from fastapi import Depends, FastAPI, Request, Response
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from jinja2 import Environment, FileSystemLoader
from dotenv import load_dotenv
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from jwt_handler import (
get_current_user_id,
verify_password,
get_password_hash,
create_access_token
)
import database
BASE_DIR = Path(__file__).resolve().parent.parent # 项目根目录
FRONTEND_DIR = BASE_DIR / "frontend"
app = FastAPI()
load_dotenv(dotenv_path=BASE_DIR / "backend" / ".env")
# 挂载静态资源
app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
templates = Jinja2Templates(directory=str(FRONTEND_DIR))
template_env = Environment(loader=FileSystemLoader(str(FRONTEND_DIR)))
model, tokenizer = None, None
def load_model():
model_name = str(BASE_DIR / "model/deepseek-coder-1.3b-instruct")
print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
m = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
).eval()
return m, tok
async def start_load():
global model, tokenizer
loop = asyncio.get_event_loop()
model, tokenizer = await loop.run_in_executor(None, load_model)
print("✅ Model loaded during startup!")
@app.get("/", response_class=HTMLResponse)
async def home():
template = template_env.get_template("home.html")
content = template.render()
return HTMLResponse(content=content)
@app.get("/login", response_class=HTMLResponse)
async def login_page():
template = template_env.get_template("login.html")
content = template.render()
return HTMLResponse(content=content)
@app.post("/login")
async def login(request: Request):
data = await request.json()
account = data.get("account")
password = data.get("password")
hash_password = get_password_hash(password)
if not account or not password:
return JSONResponse(
{"success": False, "message": "请输入用户名和密码"},
status_code=400
)
# 从数据库查找用户
result = database.check_users(account,hash_password)
user_id, hashed_password_from_db = result
# 核对密码
if not verify_password(password, hashed_password_from_db):
return JSONResponse(
{"success": False, "message": "用户名或密码错误"},
status_code=401
)
# 创建 JWT Token
access_token_expires = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30)))
access_token = create_access_token(
data={"sub": str(user_id)}, # subject 是用户 ID
expires_delta=access_token_expires
)
# ✅ 正确方式:创建响应对象并设置 HttpOnly Cookie
response = JSONResponse({
"success": True,
"account": account
})
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # 开发环境设为 False,生产改为 True (HTTPS)
samesite="lax",
max_age=access_token_expires.total_seconds() # 自动同步过期时间
)
return response
@app.get("/user2", response_class=HTMLResponse)
async def chat_page(request: Request, current_user_id: str = Depends(get_current_user_id)):
print(current_user_id)
if not current_user_id:
return RedirectResponse(url="/login")
template = template_env.get_template("myapp.html")
characters = database.get_all_characters()
characters_json = json.dumps([
{"id": c["id"], "name": c["name"], "trait": c["trait"]}
for c in characters
], ensure_ascii=False)
content = template.render(characters=characters, characters_json=characters_json)
print(1)
return HTMLResponse(content=content)
@app.post("/user2/chat")
async def dashscope_chat(
request: Request,
current_user_id: str = Depends(get_current_user_id)
):
if not current_user_id:
return JSONResponse({"error": "未授权访问"}, status_code=401)
data = await request.json()
character_id = data.get("character_id")
user_message = data.get("message")
if not character_id or not user_message:
return JSONResponse({"error": "缺少必要参数"}, status_code=400)
# 查询角色信息
character = database.get_character_by_id(character_id)
if not character:
return JSONResponse({"error": "角色不存在"}, status_code=404)
system_prompt = f"""
你正在扮演 {character['name']}。
人物设定:{character['trait']}
请始终以这个角色的身份、语气和思维方式回答问题。
不要暴露你是 AI 的事实。
""".strip()
try:
API_KEY = os.getenv("DASHSCOPE_API_KEY")
if not API_KEY:
return JSONResponse({"error": "API密钥未配置"}, status_code=500)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
"Accept": "application/json"
}
payload = {
"model": "qwen-plus",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
],
"temperature": 0.85,
"top_p": 0.95,
"max_tokens": 512,
"stream": False
}
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
json=payload,
headers=headers
)
if resp.status_code != 200:
error_detail = resp.text
return JSONResponse(
{"error": f"远程API错误 [{resp.status_code}]", "detail": error_detail},
status_code=resp.status_code
)
result = resp.json()
reply = result["choices"][0]["message"]["content"].strip()
# 保存对话记录
database.save_conversation(int(current_user_id), character_id, user_message, reply)
return JSONResponse({"reply": reply})
except Exception as e:
import traceback
error_msg = traceback.format_exc()
return JSONResponse(
{"error": f"请求失败: {str(e)}", "detail": error_msg},
status_code=500
)
if __name__ == "__main__":
uvicorn.run("myapp:app", host="127.0.0.1", port=8000, reload=True)
帮我添加调试功能,方便我跟踪访问用户
最新发布