使用SFTTrainer进行监督式微调:HuggingFace SmolLM2实践指南
smol-course A course on aligning smol models. 项目地址: https://gitcode.com/gh_mirrors/smo/smol-course
引言
在自然语言处理领域,监督式微调(Supervised Fine-Tuning, SFT)是将预训练语言模型适配到特定任务的重要技术。本文将详细介绍如何使用HuggingFace生态中的SFTTrainer
工具对SmolLM2-135M
模型进行监督式微调。
环境准备
开始前需要确保已安装必要的Python库:
pip install transformers datasets trl huggingface_hub
这些库分别提供:
transformers
:HuggingFace的核心模型库datasets
:数据集加载和处理工具trl
:包含SFTTrainer等训练工具huggingface_hub
:模型和数据集仓库交互工具
模型加载与初始化
首先加载基础模型和对应的分词器:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model_name = "HuggingFaceTB/SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
SmolLM2-135M
是一个1.35亿参数的小型语言模型,适合在资源有限的环境中进行微调实验。
对话模板设置
现代对话模型通常需要特定的对话格式模板:
from trl import setup_chat_format
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
这一步骤确保模型能够正确处理对话格式的输入,其中每条消息包含role
(角色)和content
(内容)字段。
基础模型测试
微调前先测试基础模型的生成能力:
prompt = "写一首关于编程的俳句"
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100)
print("微调前结果:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
这个测试可以帮助我们对比微调前后的模型表现差异。
数据集准备
监督式微调需要准备格式化的训练数据。HuggingFace提供了多种数据集选择:
- 入门级:
HuggingFaceTB/smoltalk
- 日常对话数据集 - 进阶级:
bigcode/the-stack-smol
- 代码生成数据集 - 自定义级:选择与您实际应用场景相关的数据集
from datasets import load_dataset
# 示例:加载日常对话数据集
ds = load_dataset(path="HuggingFaceTB/smoltalk", name="everyday-conversations")
数据集需要转换为包含role
和content
字段的对话格式,TRL会自动处理这种格式的数据。
SFTTrainer配置
SFTTrainer
是专门为监督式微调设计的训练器,配置参数包括:
from trl import SFTConfig, SFTTrainer
sft_config = SFTConfig(
output_dir="./sft_output",
max_steps=1000, # 训练步数
per_device_train_batch_size=4, # 批大小
learning_rate=5e-5, # 学习率
logging_steps=10, # 日志记录频率
save_steps=100, # 模型保存频率
evaluation_strategy="steps", # 评估策略
eval_steps=50, # 评估频率
use_mps_device=(True if device == "mps" else False),
hub_model_id="SmolLM2-FT-MyDataset", # 模型保存名称
)
训练过程
初始化并启动训练:
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=ds["train"],
tokenizer=tokenizer,
eval_dataset=ds["test"],
)
trainer.train()
trainer.save_model("./SmolLM2-FT-MyDataset")
训练过程中会定期评估模型表现,并保存检查点。
微调后模型测试
训练完成后,测试模型在新数据上的表现:
prompt = "写一首关于编程的俳句"
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100)
print("微调后结果:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
对比微调前后的输出,可以直观评估微调效果。
进阶建议
-
数据集选择:根据应用场景选择合适的数据集
- 对话系统:选择对话数据集
- 代码生成:选择代码数据集
- 特定领域:收集领域相关数据
-
参数调优:尝试不同的学习率、批大小和训练步数组合
-
评估指标:设计定量评估指标,而不仅依赖生成样例
-
模型部署:考虑将微调后的模型部署为API服务
总结
本文详细介绍了使用SFTTrainer
对SmolLM2-135M
进行监督式微调的全过程。通过这种方法,开发者可以将通用语言模型适配到特定任务或领域,显著提升模型在目标场景下的表现。监督式微调是构建定制化NLP应用的重要技术,值得深入学习和实践。
smol-course A course on aligning smol models. 项目地址: https://gitcode.com/gh_mirrors/smo/smol-course
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考