import os
import random
import json
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM
import logging
import datetime
import argparse
import torch
def chat_with_llm(args):
model_path = {
"Qwen2": "../modelscope/Qwen2.5-0.5B/",
}
tokenizer = AutoTokenizer.from_pretrained(model_path[args.model_name])
# model = AutoModelForCausalLM.from_pretrained(model_path[args.model_name])
model = AutoModelForCausalLM.from_pretrained(
model_path[args.model_name],
torch_dtype=torch.float16,
)
# 如果你有GPU可用,可以将模型转移到GPU上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def generate_answer(question, max_length=120, num_return_sequences=1):
# 对问题进行编码
inputs = tokenizer.encode(question, return_tensors='pt').to(device)
# 如果 pad_token_id 和 eos_token_id 是不同的,请分别设置
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
# 生成回答
outputs = model.generate(
inputs,
max_length=max_length,
num_return_sequences=num_return_sequences,
pad_token_id=pad_token_id,
eos_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True
)
# 解码生成的回答
answers = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs.sequences]
return answers
# 示例提示
# prompt = "Q: What is the capital of France?\nA: "
# 输入文本
prompt = "问题:XXX"
# 生成回答
answers = generate_answer(prompt)
# 打印生成的回答
for i, answer in enumerate(answers):
logging.info(f"Answer {i + 1}: {answer}")
print("#######end########")
def llm_main():
parser = argparse.ArgumentParser()
parser.add_argument("--cache_path", default='./cached_prompts', type=str)
parser.add_argument("--output_dir", default='./output', type=str)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--prompt_size", default=1, type=int)
parser.add_argument("--rule_size", default=10, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--model_name", default='Qwen2', type=str,
choices=['ChatGPT', 'GPT3', 'GPT2', 'GPT4', 'Qwen2', 'vicuna-13b', 'Llama'])
parser.add_argument("--add_rule", default='none', type=str, choices=['none', 'self', 'new', 'retri', 'all'])
parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
args = parser.parse_args()
args.device = None
if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device('cuda')
else:
args.device = torch.device('cpu')
# sys.stdout = open(os.path.join(date_string+"_log.txt"), 'w')
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', \
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO
)
logging.info("Parameters:{}\n".format(args))
random.seed(args.seed)
chat_with_llm(args)
if __name__ == '__main__':
llm_main()
Qwen2.5代码简易写法
于 2024-12-17 15:34:54 首次发布