import os
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.utils import get_prompt, extract_from_output
import pdb
model_path = '/data/disk4/home/junkai/ai-project/Datawhale/Qwen2_72b'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
# 读取JSONL文件
data = []
input_file = '/data/disk4/home/junkai/ai-project/Datawhale/logical_reasoning/data/input/test_data_demo.jsonl'
with open(input_file, 'r', encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f if line.strip()]
output_file = '/data/disk4/home/junkai/ai-project/Datawhale/logical_reasoning/data/output/0726_test_qwen2_72b_1.jsonl'
if os.path.exists(output_file):
with open(output_file, 'w', encoding='utf-8') as f:
f.write('')
results = []
batch_size = 1 # 每处理1个problem写入文件一次
for idx, item in enumerate(data):
problem = item["problem"]
questions = item['questions']
item_id = item["id"]
item_results = {"id": item_id, "questions": []}
for question in questions:
question_text = question['question']
options = question["options"]
prompt = get_prompt(problem, question_text, options)
messages = [
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True)
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("================>output:\n", output)
extracted_answer = extract_from_output(output)
print("================>extracted_answer:\n", extracted_answer)
# 保存提取的答案
item_results["questions"].append({"answer": extracted_answer})
results.append(item_results)
# 每处理1个problem写入文件一次,并打印提示
if (idx + 1) % batch_size == 0:
with open(output_file, 'a', encoding='utf-8') as f: # 使用追加模式
for result in results:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
print(f"Processed {idx + 1} problems, results written to {output_file}")
results = [] # 清空结果列表以便存储下一批结果
涉及到的utils.py的代码为:
import re
def get_prompt(problem, question, options):
'''构建一个逻辑推理问题的提示文本'''
# 将选项列表转换为带有字母标签的字符串,每个选项一行,字母标签从 'A' 到 'G'&#x