Qwen2.5代码简易写法

import os
import random
import json
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM
import logging
import datetime
import argparse
import torch


def chat_with_llm(args):
    model_path = {
        "Qwen2": "../modelscope/Qwen2.5-0.5B/",
    }

    tokenizer = AutoTokenizer.from_pretrained(model_path[args.model_name])
    # model = AutoModelForCausalLM.from_pretrained(model_path[args.model_name])
    model = AutoModelForCausalLM.from_pretrained(
        model_path[args.model_name],
        torch_dtype=torch.float16,
    )

    # 如果你有GPU可用,可以将模型转移到GPU上
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    def generate_answer(question, max_length=120, num_return_sequences=1):
        # 对问题进行编码
        inputs = tokenizer.encode(question, return_tensors='pt').to(device)

        # 如果 pad_token_id 和 eos_token_id 是不同的,请分别设置
        pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
        # 生成回答
        outputs = model.generate(
            inputs,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            pad_token_id=pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True
        )

        # 解码生成的回答
        answers = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs.sequences]

        return answers

    # 示例提示
    # prompt = "Q: What is the capital of France?\nA: "
    # 输入文本
    prompt = "问题:XXX"

    # 生成回答
    answers = generate_answer(prompt)

    # 打印生成的回答
    for i, answer in enumerate(answers):
        logging.info(f"Answer {i + 1}: {answer}")

    print("#######end########")


def llm_main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache_path", default='./cached_prompts', type=str)
    parser.add_argument("--output_dir", default='./output', type=str)
    parser.add_argument("--max_length", default=4096, type=int)
    parser.add_argument("--batch_size", default=1, type=int)
    parser.add_argument("--prompt_size", default=1, type=int)
    parser.add_argument("--rule_size", default=10, type=int)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--model_name", default='Qwen2', type=str,
                        choices=['ChatGPT', 'GPT3', 'GPT2', 'GPT4', 'Qwen2', 'vicuna-13b', 'Llama'])
    parser.add_argument("--add_rule", default='none', type=str, choices=['none', 'self', 'new', 'retri', 'all'])
    parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
    args = parser.parse_args()
    args.device = None
    if not args.disable_cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
    else:
        args.device = torch.device('cpu')

    # sys.stdout = open(os.path.join(date_string+"_log.txt"), 'w')
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s', \
        datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO
    )
    logging.info("Parameters:{}\n".format(args))
    random.seed(args.seed)

    chat_with_llm(args)


if __name__ == '__main__':
    llm_main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值