一、项目介绍
本项目基于THUDM/glm-4-9b-chat模型进行垂域的微调。主要设计医疗领域问答对的lora微调。下面是半精度微调,下一遍是4bit微调,占用资源更小
二、代码实战
1.半精度微调
导入依赖包
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
加载数据
dataset = load_dataset("csv", data_files="./问答.csv", split="train")
dataset = dataset.filter(lambda x: x["answer"] is not None)
print(dataset)
datasets = dataset.train_test_split(test_size=0.1)
print(datasets)
print(datasets['train'][:2])
数据预处理
tokenizer = AutoTokenizer.from_pretrained("./glm-4-9b-chat", trust_remote_code=True)
print(tokenizer)
def process_func(example):
MAX_LENGTH = 256
input_ids, attention_mask, labels = [], [], []
instruction = example["question"].strip() # query
instruction = tokenizer.apply_chat_template([{"role": "user", "content": instruction}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
) # '[gMASK] <sop> <|user|> \nquery <|assistant|>'
response = tokenizer("\n" + example["answer"], add_special_tokens=False) # \n response, 缺少eos token
input_ids = instruction["input_ids"][0].numpy().tolist() + response["input_ids"] + [tokenizer.eos_token_id]
attention_mask = instruction["attention_mask"][0].numpy().tolist() + response["attention_mask"] + [1]
labels = [-100] * len(instruction["input_ids"][0].numpy().tolist()) + response["input_ids"] + [tokenizer.eos_token_id]
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
tokenized_ds = datasets['train'].map(process_func, remove_columns=['id', 'question', 'answer'])
tokenized_ts = datasets['test'].map(process_func, remove_columns=['id', 'question', 'answer'])
print(tokenized_ds)
print(tokenizer.decode(tokenized_ds[1]["input_ids"]))
print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))
创建模型实例
import torch
model = AutoModelForCausalLM.from_pretrained("./glm-4-9b-chat", trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map="auto")
for name, param in model.named_parameters():
print(name)
配置文件
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
config = LoraConfig(target_modules=["query_key_value"], modules_to_save=["post_attention_layernorm"])
print(config)
model = get_peft_model(model, config)
print(config)
for name, parameter in model.named_parameters():
print(name)
print(model.print_trainable_parameters())
print(model)
配置训练参数,创建训练器,进行训练
args = TrainingArguments(
output_dir="./chatbot",
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
gradient_checkpointing=True,
logging_steps=100,
num_train_epochs=10,
learning_rate=1e-4,
remove_unused_columns=False,
save_strategy="epoch"
)
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_ds.select(range(10000)),
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()