模型调用
(1)数据库建设
根据前端展现数据需要,我们共设计了七个表结构:
backadmin:管理员信息表
question:随机问题表
records:历史对话表
user_feedback:用户反馈信息表
user_info:用户信息表
user_login:用户登录情况表
user_record:侧边栏展示历史记录表
(2)模型调用API
为在服务器上访问本地数据库,需进行端口映射,内网穿透,从而实现外网访问。
使用路由侠进行内网映射:
在服务器端编写调用模型的方法:
首先设置工具类:
class DBUtils(object):
# 初始化连接对象和游标对象
_db_conn = None
_db_cursor = None
"""
数据库工具类 初始化方法
"""
def __init__(self, host, user, password, db, port=内网映射端口号, charset='utf8mb4'):
try:
self._db_conn = pymysql.connect(host=host, user=user, password=password, port=port, db=db, charset=charset)
self._db_cursor = self._db_conn.cursor()
except Exception as e:
logging.error(e)
连接数据库方法:
def get_db_conn():
"""
获取数据库连接
:return: db_conn 数据库连接对象
"""
return DBUtils(host='xxxx.e3.luyouxia.net', user='root', password='本地数据库密码', db='model')
模型对话接口:
from fastapi import FastAPI, Request
import uvicorn, json, datetime
import torch
from flask_cors import CORS
from fastapi.middleware.cors import CORSMiddleware
from dbutils import DBUtils
import uuid
from transformers import AutoTokenizer, AutoModel, AutoConfig
from datetime import datetime, timedelta
import os
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_db_conn():
"""
获取数据库连接
:return: db_conn 数据库连接对象
"""
return DBUtils(host='hxyw.e3.luyouxia.net', user='root', password='123456', db='model')
def msg(status, data='未加载到数据'):
"""
:param status: 状态码 200成功,201未找到数据
:param data: 响应数据
:return: 字典 如{'status': 201, 'data': ‘未加载到数据’}
"""
return {'status': status, 'data': data}
@app.post("/api/home")
async def create_item(request: Request):
db_conn = get_db_conn()
# print(await request.body())
global model, tokenizer
json_post_raw = await request.json()
print(json_post_raw)
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
# print("json:",json_post)
prompt = json_post_list.get('message')
type = json_post_list.get('type')
history = json_post_list.get('history')
time = json_post_list.get('time')
user_id = json_post_list.get('user_id')
q_time=json_post_list.get('time')
status=json_post_list.get('status')
max_length = 1024
top_p = 0.7
temperature = 0.95
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
question=prompt
answer=response
now = datetime.now()
a_time = now.strftime("%Y-%m-%d %H:%M:%S")
record_id = str(uuid.uuid4())
# 插入 SQL 语句
sql_str = "INSERT INTO records (record_id, user_id, question, answer,q_time,a_time,type) VALUES (%s, %s, %s ,%s,%s,%s,%s)"
# 执行插入操作
if(status==1):
result = db_conn.insert(sql_str, (record_id, user_id, question, answer, q_time, a_time, type))
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
# print(log)
answer = {
"message": answer,
"record_id": record_id,
"user_id": user_id,
}
torch_gc()
return msg(200, answer)
if __name__ == '__main__':
pre_seq_len = 300
checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-3/checkpoint-3000"
tokenizer = AutoTokenizer.from_pretrained("./THUDM/chatglm2-6b", trust_remote_code=True)
config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", config=config, device_map="auto", trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8082, workers=1)
在使用时,需要先建立SSH连接,才能建立服务器与本地的通信服务: