2024 山东大学软件学院创新项目实训(五)

模型调用

(1)数据库建设

根据前端展现数据需要,我们共设计了七个表结构:

backadmin:管理员信息表

question:随机问题表

records:历史对话表

user_feedback:用户反馈信息表

user_info:用户信息表

user_login:用户登录情况表

user_record:侧边栏展示历史记录表

(2)模型调用API

为在服务器上访问本地数据库,需进行端口映射,内网穿透,从而实现外网访问。

使用路由侠进行内网映射:

在服务器端编写调用模型的方法:

首先设置工具类:

class DBUtils(object):
    # 初始化连接对象和游标对象
    _db_conn = None
    _db_cursor = None
 
    """
        数据库工具类 初始化方法
    """
 
    def __init__(self, host, user, password, db, port=内网映射端口号, charset='utf8mb4'):
        try:
            self._db_conn = pymysql.connect(host=host, user=user, password=password, port=port, db=db, charset=charset)
            self._db_cursor = self._db_conn.cursor()
        except Exception as e:
            logging.error(e)

连接数据库方法:

def get_db_conn():
    """
    获取数据库连接
    :return: db_conn 数据库连接对象
    """
    return DBUtils(host='xxxx.e3.luyouxia.net', user='root', password='本地数据库密码', db='model')

模型对话接口:

from fastapi import FastAPI, Request
import uvicorn, json, datetime
import torch
from flask_cors import CORS
from fastapi.middleware.cors import CORSMiddleware
from dbutils import DBUtils
import uuid
from transformers import AutoTokenizer, AutoModel, AutoConfig
from datetime import datetime, timedelta
import os


DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE

def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device(CUDA_DEVICE):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def get_db_conn():
    """
    获取数据库连接
    :return: db_conn 数据库连接对象
    """
    return DBUtils(host='hxyw.e3.luyouxia.net', user='root', password='123456', db='model')

def msg(status, data='未加载到数据'):
    """
    :param status: 状态码 200成功,201未找到数据
    :param data: 响应数据
    :return: 字典 如{'status': 201, 'data': ‘未加载到数据’}
    """
    return {'status': status, 'data': data}

@app.post("/api/home")  
async def create_item(request: Request):
    db_conn = get_db_conn()
    # print(await request.body()) 
    global model, tokenizer
    json_post_raw = await request.json()
    print(json_post_raw)
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    # print("json:",json_post)
    prompt = json_post_list.get('message')
    type = json_post_list.get('type')
    history = json_post_list.get('history')
    time = json_post_list.get('time')
    user_id = json_post_list.get('user_id')
    q_time=json_post_list.get('time')
    status=json_post_list.get('status')
    max_length = 1024
    top_p = 0.7
    temperature = 0.95
    response, history = model.chat(tokenizer,
                                   prompt,
                                   history=history,
                                   max_length=max_length if max_length else 2048,
                                   top_p=top_p if top_p else 0.7,
                                   temperature=temperature if temperature else 0.95)
    question=prompt
    answer=response
    now = datetime.now()
    a_time = now.strftime("%Y-%m-%d %H:%M:%S")
    record_id = str(uuid.uuid4())
        # 插入 SQL 语句
    sql_str = "INSERT INTO records (record_id, user_id, question, answer,q_time,a_time,type) VALUES (%s, %s, %s ,%s,%s,%s,%s)"
        # 执行插入操作
    if(status==1):
        result = db_conn.insert(sql_str, (record_id, user_id, question, answer, q_time, a_time, type))
    log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
    # print(log)
    answer = {
        "message": answer,
        "record_id": record_id,
        "user_id": user_id,
    }
    torch_gc()
    return msg(200, answer)

if __name__ == '__main__':
    pre_seq_len = 300
    checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-3/checkpoint-3000"

    tokenizer = AutoTokenizer.from_pretrained("./THUDM/chatglm2-6b", trust_remote_code=True)
    config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)
    model = AutoModel.from_pretrained("THUDM/chatglm2-6b", config=config, device_map="auto", trust_remote_code=True)
    prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
    new_prefix_state_dict = {}
    for k, v in prefix_state_dict.items():
        if k.startswith("transformer.prefix_encoder."):
            new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
    model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    model.eval()
    uvicorn.run(app, host='0.0.0.0', port=8082, workers=1)

 在使用时,需要先建立SSH连接,才能建立服务器与本地的通信服务:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值