代码的摘要说明
一、整体功能概述
这段 Python 代码主要实现了基于 Hugging Face Transformers 库对预训练语言模型(具体为 TAIDE-LX-7B-Chat 模型)进行微调(Fine-tuning)的功能,使其能更好地应用于生成唐诗相关内容的任务。整个流程涵盖了数据加载与预处理、模型配置、模型训练以及训练后模型的测试与结果保存等环节。二、代码各部分详细说明
导入模块部分
导入了多个常用的 Python 库,如 os、sys、argparse 等基础库,以及和深度学习、自然语言处理相关的库,像 torch、transformers、peft 等,还导入了用于忽略警告信息、文本颜色控制等功能的相关模块,为后续代码实现做准备。
基础配置与参数设置部分
模型相关路径:定义了模型名称、模型所在目录、工作目录、输出目录以及检查点目录等路径相关的变量,确保后续文件操作(如加载模型、保存训练结果等)能找到正确的位置。
训练参数设置:设定了训练的总轮数(num_epoch)、学习率(LEARNING_RATE),还包括一系列影响模型训练和推理的超参数,例如用于训练的数据集数量(num_train_data)、文本截断长度(CUTOFF_LEN)、LORA 相关参数(LORA_R、LORA_ALPHA、LORA_DROPOUT)、验证集大小(VAL_SET_SIZE)、微批次大小(MICRO_BATCH_SIZE)、梯度累积步数(GRADIENT_ACCUMULATION_STEPS)等,这些参数将用于控制模型训练过程和微调效果。
分布式训练配置:根据环境变量 WORLD_SIZE 判断是否进行分布式训练,若需要则配置相应的设备映射(device_map)。
数据加载与预处理部分
创建目录:首先确保输出目录和检查点目录存在,若不存在则创建它们,用于后续保存训练结果和模型检查点。
加载数据集:从指定的 json 文件路径(dataset_dir)读取原始数据集,先将其加载为 Python 的字典数据结构(通过 json.load),然后从中截取指定数量(由 num_train_data 决定)的数据保存为临时数据集文件(tmp_dataset_path),再通过 load_dataset 函数以 json 格式重新加载这个临时数据集用于后续处理。
数据处理函数定义:定义了 generate_training_data 函数,用于将包含指令、输入和输出文本的数据点转换为适合模型训练的格式,具体是通过 tokenizer 将文本转换为令牌(token)形式,并处理好输入令牌、对应标签以及注意力掩码等信息,构建训练样本。
划分数据集(可选):根据 VAL_SET_SIZE 参数的值决定是否划分训练集和验证集,如果 VAL_SET_SIZE 大于 0,则从原始训练集中划分出验证集,对训练集和验证集的数据都应用 generate_training_data 函数进行预处理,使其格式符合模型训练要求;若 VAL_SET_SIZE 为 0,则整个数据集都作为训练集进行处理。
模型相关配置部分
创建模型和令牌器:从指定的模型路径(model_path)加载预训练的语言模型,并根据相关配置(如量化配置 BitsAndBytesConfig 用于以特定量化方式加载模型以节省内存等)进行初始化;同时加载对应的分词器(tokenizer),并设置其填充令牌(pad_token)与结束令牌(eos_token)相同,方便后续文本处理。
模型准备与微调配置:将加载的模型通过 prepare_model_for_int8_training 函数进行处理,使其适合 INT8 训练方式,接着使用 LoraConfig 配置 LORA(低秩适配,一种参数高效微调方法)相关参数,并通过 get_peft_model 函数应用 LORA 配置到模型上,以便后续进行参数高效的微调训练。
生成配置定义:定义了 GenerationConfig 对象,用于指定模型在生成文本时的一些解码参数,例如是否采样(do_sample)、温度参数(temperature)、束搜索的束数量(num_beams)等,这些参数会影响模型生成唐诗内容时的质量和多样性等特性。
模型训练部分
使用 transformers.Trainer 类来实例化训练器对象(trainer),传入模型、训练数据集、验证数据集(可为 None)以及训练相关的各种参数(如批次大小、学习率、训练轮数、梯度累积步数、日志记录和模型保存相关的配置等),同时指定了数据整理器(data_collator)用于整理训练数据的格式。在训练前还禁用了模型的缓存功能(model.config.use_cache = False),然后通过 try-except 语句块尝试执行模型训练过程,若训练出现异常则打印出错误信息。
模型保存部分
同样使用 try-except 语句块尝试将训练后的模型保存到指定的检查点目录(ckpt_dir)中,若保存过程出现异常则打印相应的错误提示。
模型测试与结果保存部分
从指定的测试数据文件(test_data_path)中加载测试数据集(通过 json.load),然后循环遍历每条测试数据,调用 evaluate 函数(该函数基于给定的指令、生成配置和输入等信息,利用已训练好的模型生成相应回复内容)获取模型针对测试数据的回复,将测试数据的输入以及模型生成的回复拼接后写入到指定的结果文件(output_path)中保存起来,同时也会打印出来方便查看,用于评估模型在测试集上的表现。
原文链接:https://blog.youkuaiyun.com/chenchihwen/article/details/144000079
整体代码参照这篇文章
导入必要的库
import os
import sys
import argparse
import json
import warnings
import logging
- 导入了多个 Python 标准库,
os
用于操作系统相关的操作(如文件路径处理、目录创建等),sys
用于处理 Python 运行时环境相关的配置和参数,argparse
常用于命令行参数解析(但此代码中未体现其使用),json
用于处理 JSON 格式的数据读写,warnings
用于控制警告信息的显示,logging
用于记录日志信息。
# 忽略警告信息
warnings.filterwarnings("ignore")
- 配置
warnings
模块,使其忽略所有警告信息,避免在程序运行过程中输出大量警告干扰正常的输出和分析。
import torch
import torch.nn as nn
import bitsandbytes as bnb
from datasets import load_dataset, load_from_disk
import transformers
from peft import PeftModel
from colorama import Fore, Back, Style
- 导入了
torch
(深度学习常用的张量计算库)及其nn
模块(用于构建神经网络),bitsandbytes
(可能用于模型量化等相关功能),load_dataset
和load_from_disk
函数用于加载数据集,transformers
库(用于自然语言处理相关的预训练模型、工具等),PeftModel
(可能是某种特定的模型扩展相关类),colorama
(用于在终端输出添加颜色等样式,不过此处代码后续未看到其使用)。
# 以下是根据Hugging Face Transformers库的相关功能进行的设置
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import GenerationConfig
from peft import (
prepare_model_for_int8_training,
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_kbit_training
)
- 从
transformers
库中导入了多个用于创建模型、配置模型以及处理特定训练相关设置的类和函数,比如AutoTokenizer
用于自动加载合适的文本令牌化器,AutoConfig
用于自动加载模型配置,AutoModelForCausalLM
用于加载适用于因果语言建模的预训练模型,BitsAndBytesConfig
用于配置模型量化相关的参数;从peft
模块导入的函数主要用于对模型进行特定的微调(如int8
训练准备、配置LORA
相关参数、获取基于LORA
配置后的模型等)。
模型和路径相关的参数设置
# 模型名称
model_name = "TAIDE-LX-7B-Chat"
# 模型所在目录
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TAIDE-LX-7B-Chat")
- 定义了模型的名称以及模型所在的本地文件系统路径,通过
os.path.join
结合当前脚本文件所在目录(os.path.dirname(os.path.abspath(__file__))
获取)来确定完整的模型路径。
# 设置工作目录
work_dir = os.path.dirname(os.path.abspath(__file__))
# 输出目录
output_dir = os.path.join(work_dir, "output")
# 检查点目录
ckpt_dir = os.path.join(work_dir, "checkpoint")
- 确定了工作目录(即当前脚本所在目录),并基于此设置了后续用于存储输出结果的目录和保存模型检查点的目录路径。
# 训练的总Epoch数
num_epoch = 1
# 学习率
LEARNING_RATE = 3e-4
- 定义了模型训练的关键超参数,总训练轮数为
1
轮,学习率设置为3e-4
,学习率决定了模型在训练过程中参数更新的步长大小。
# 配置模型和训练参数
cache_dir = os.path.join(work_dir, "cache")
from_ckpt = False
ckpt_name = None
- 设置了模型缓存目录路径,以及两个与是否从检查点加载模型、具体检查点名称相关的变量,这里表示不从检查点加载(
from_ckpt
为False
)且没有指定具体检查点名称(ckpt_name
为None
)。
# 以下是补充的超参数设置
# 用于训练的数据集数量
num_train_data = 1040 # 可设置的最大值为5000,数据量越多,模型性能可能越好,但训练时间也会增加
# 以下参数影响模型的训练和推理
CUTOFF_LEN = 256 # 文本截断的最大长度
LORA_R = 8 # LORA的R值
LORA_ALPHA = 16 # LORA的Alph