一、项目介绍
本项目基于THUDM/glm-4-9b-chat模型进行垂域的微调。主要设计医疗领域问答对的lora微调。主要是4bit微调,占用资源更小
二、代码实战
1.4bit微调
导入依赖包
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
加载数据
dataset = load_dataset("csv", data_files="./问答.csv", split="train")
dataset = dataset.filter(lambda x: x["answer"] is not None)
print(dataset)
datasets = dataset.train_test_split(test_size=0.1)
print(datasets)
print(datasets['train'][:2])
数据预处理
tokenizer = AutoTokenizer.from_pretrained("./glm-4-9b-chat", trust_remote_code=True)
print(tokenizer)
def process_func(example):
MAX_LENGTH = 768
input_ids, attention_mask, labels = [], [], []
instruction = example["question"].strip() # query
instruction = tokenizer.apply_chat_template([{"role": "user", "content": instruction}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
) # '[gMASK] <sop> <|user|> \nquery <|assistant|>'
response = tokenizer("\n" + example["answer"], add_special_tokens=False) # \n response, 缺少eos token
input_ids = instruction["input_ids"][0].numpy().tolist() + response["input_ids"] + [tokenizer.eos_token_id]
attention_mask = instruction["attention_mask"][0].numpy().tolist() + response["attention_mask"] + [1]
labels = [-100] * len(instruction["input_ids"][0].numpy().tolist()) + response["input_ids"] + [tokenizer.eos_token_id]
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
tokenized_ds = datasets['train'].map(process_func, remove_columns=['id', 'question', 'answer'])
tokenized_ts = datasets['test'].map(process_func, remove_columns=['id', 'question', 'answer'])
print(tokenized_ds)
print(tokenizer.decode(tokenized_ds[1]["input_ids"]))
print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))
创建模型和配置文件
import torch
model = AutoModelForCausalLM.from_pretrained("./glm-4-9b-chat", trust_remote_code=True, low_cpu_mem_usage=True, device_map="auto",
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16)
# for name, param in model.named_parameters():
# print(name)
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
config = LoraConfig(target_modules=["query_key_value"], modules_to_save=["post_attention_layernorm"])
print(config)
model = get_peft_model(model, config)
print(model.print_trainable_parameters())
配置训练参数,模型训练
args = TrainingArguments(
output_dir="./chatbot",
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
logging_steps=100,
num_train_epochs=10,
learning_rate=1e-4,
remove_unused_columns=False,
save_strategy="epoch"
)
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_ds.select(range(10000)),
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()
查看训练结果
from safetensors import safe_open
with safe_open("./chatbot/checkpoint-1875/adapter_model.safetensors", framework="pt") as f:
for key in f.keys():
if ".0.post_attention_layernorm" in key:
print(key)
print(f.get_tensor(key))
model.eval()
print(model.chat(tokenizer, "治疗胃溃疡用什么药物最好。", history=[])[0])