[Datawhale AI夏令营] 第二日打卡

部署运行你感兴趣的模型镜像

完成任务

将Qwen3-8B 8bit量化模型作为教师模型,根据比赛数据集生成单字段查询问题及答案,并整理为.json文件作为硬标签输入到Qwen3-4B模型,用LoRA方法对模型进行蒸馏。最终在平台测试得分为51分。过程仍有许多需要改进的地方,例如问题的多样性,错误答案的惩罚机制,Agent模式的使用等。

代码

import torch
import pandas as pd
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
    TextStreamer,
    Trainer
)
import requests
import re
import json
from tqdm import tqdm

model_path = "./model/Qwen3-8B"
data_path = "./data/info_table.xlsx"

data = pd.read_excel(data_path)
data = data.fillna('无数据')
data = data.drop(columns=['序号'])

bnb_config = BitsAndBytesConfig(load_in_8bit = True)

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map={"": 0}
)

def call_llm(content: str):
    messages = [
        {
            "role": "user",
            "content": content
        }
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
        enable_thinking=False
    )
    model_inputs = tokenizer([text],return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=32768
    )
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    
    try:
        index = len(output_ids) - output_ids[::-1].index(151668)
    except ValueError:
        index = 0
    
    output_decoder = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
    pattern = re.compile(r'^```json\s*([\s\S]*?)```$', re.IGNORECASE)
    match = pattern.match(output_decoder.strip())
    if match:
        json_str = match.group(1).strip()
        outputs = json.loads(json_str)
        return outputs
    else:
        return output_decoder

def single_require(row: dict):
    question_list = []
    question_list.append(f'{row["车次"]}号车次应该从哪个检票口检票?')
    question_list.append(f'{row["车次"]}号车次应该从哪个站台上车?')
    question_list.append(f'{row["车次"]}号车次的始发站在哪里?')
    question_list.append(f'{row["车次"]}号车次的终点站在哪里?')
    question_list.append(f'乘坐{row["车次"]}号车次应该在哪个候车厅等候?')
    question_list.append(f'{row["车次"]}号车次抵达车站的时间?')
    question_list.append(f'{row["车次"]}号车次离开车站的时间?')
    return question_list

prompt = '''你是列车的乘务员,请你基于给定的列车班次信息回答用户的问题。
# 列车班次信息
{}

# 用户问题列表
{}

'''
output_format = '''# 输出格式
按json格式输出,且只需要输出一个json即可
```json
[{
    "q": "用户问题",
    "a": "问题答案"
},
...
]
```'''
train_data_list = []
for idx,row in tqdm(data.iterrows(), desc='遍历生成答案', total=len(data)):
    row = dict(row)
    row['到点'] = str(row['到点'])
    row['开点'] = str(row['开点'])
    question_list = single_require(row)
    llm_result = call_llm(prompt.format(row, question_list) + output_format)
    train_data_list += llm_result

data_list = []
for data in tqdm(train_data_list, total=len(train_data_list)):
    if isinstance(data, str):
        continue
    data_list.append({'instruction': data['q'], 'output': data['a']})

json.dump(data_list, open('single_row.json', 'w', encoding='utf-8'), ensure_ascii=False)
data_list[:2]

您可能感兴趣的与本文相关的镜像

Qwen3-8B

Qwen3-8B

文本生成
Qwen3

Qwen3 是 Qwen 系列中的最新一代大型语言模型,提供了一整套密集型和专家混合(MoE)模型。基于广泛的训练,Qwen3 在推理、指令执行、代理能力和多语言支持方面取得了突破性进展

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值