LLMs-from-scratch第六章:文本分类微调技术深度解析
你是否曾想过如何让强大的语言模型学会分辨垃圾邮件?是否好奇如何将通用语言模型转变为专业分类工具?本文将带你深入探索第六章的文本分类微调技术,通过实际案例和代码解析,掌握将GPT模型微调为高性能分类器的完整流程。读完本文,你将能够独立完成从数据准备到模型部署的全流程操作,并理解微调背后的关键技术原理。
微调技术概述
文本分类微调是将预训练语言模型(如GPT)适配到特定分类任务的关键技术。与从零开始训练相比,微调具有以下优势:
- 显著降低计算资源需求
- 大幅缩短训练时间
- 利用预训练模型的语言理解能力
- 提高小数据集上的分类性能
本章重点介绍如何将GPT模型微调到垃圾邮件检测任务,主要代码实现位于ch06/01_main-chapter-code/gpt_class_finetune.py。微调过程主要分为四个阶段:数据准备、模型修改、训练优化和评估部署。
数据准备全流程
高质量的数据准备是微调成功的基础。本章采用经典的SMS垃圾邮件数据集,完整处理流程如下:
数据集下载与处理
首先通过download_and_unzip_spam_data函数获取并预处理数据:
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
if data_file_path.exists():
print(f"{data_file_path} already exists. Skipping download and extraction.")
return
# 下载文件
with urllib.request.urlopen(url) as response:
with open(zip_path, "wb") as out_file:
out_file.write(response.read())
# 解压文件
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extracted_path)
# 重命名文件
original_file_path = Path(extracted_path) / "SMSSpamCollection"
os.rename(original_file_path, data_file_path)
print(f"File downloaded and saved as {data_file_path}")
数据平衡与划分
现实世界的数据集往往存在类别不平衡问题,create_balanced_dataset函数通过采样解决这一问题:
def create_balanced_dataset(df):
# 计算"spam"类别的数量
num_spam = df[df["Label"] == "spam"].shape[0]
# 随机采样"ham"类别以匹配"spam"数量
ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
# 合并子集
balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
return balanced_df
数据集划分通过random_split函数实现,按70%/10%/20%比例分为训练集、验证集和测试集:
def random_split(df, train_frac, validation_frac):
# 打乱数据
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
# 计算划分索引
train_end = int(len(df) * train_frac)
validation_end = train_end + int(len(df) * validation_frac)
# 划分数据
train_df = df[:train_end]
validation_df = df[train_end:validation_end]
test_df = df[validation_end:]
return train_df, validation_df, test_df
自定义数据集类
ch06/01_main-chapter-code/gpt_class_finetune.py中的SpamDataset类实现了文本到模型输入的转换:
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file)
# 预编码文本
self.encoded_texts = [
tokenizer.encode(text) for text in self.data["Text"]
]
if max_length is None:
self.max_length = self._longest_encoded_length()
else:
self.max_length = max_length
# 截断过长序列
self.encoded_texts = [
encoded_text[:self.max_length]
for encoded_text in self.encoded_texts
]
# 填充序列至最长长度
self.encoded_texts = [
encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
for encoded_text in self.encoded_texts
]
模型修改与微调策略
预训练模型加载
本章使用GPT-2作为基础模型,通过ch06/01_main-chapter-code/gpt_download.py中的download_and_load_gpt2函数加载预训练权重:
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
分类头修改
为适应分类任务,需要修改GPT模型的输出层。原语言模型的下一个token预测头被替换为二分类头:
# 冻结大部分参数
for param in model.parameters():
param.requires_grad = False
# 添加新的分类头
num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
model.to(device)
# 解冻最后几层以进行微调
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
这种部分解冻的策略平衡了训练效率和模型性能,只更新模型顶部的几层参数,大幅减少计算量。
训练循环实现
完整的训练循环在train_classifier_simple函数中实现,关键步骤包括:
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter):
# 初始化跟踪列表
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
# 主训练循环
for epoch in range(num_epochs):
model.train() # 设置训练模式
for input_batch, target_batch in train_loader:
optimizer.zero_grad() # 重置梯度
loss = calc_loss_batch(input_batch, target_batch, model, device)
loss.backward() # 计算梯度
optimizer.step() # 更新权重
examples_seen += input_batch.shape[0]
global_step += 1
# 定期评估
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
评估指标计算
评估函数calc_accuracy_loader和calc_loss_loader分别计算模型的准确率和损失:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches or num_batches is None:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # 获取最后一个token的输出
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
实验结果与分析
不同模型性能对比
ch06/03_bonus_imdb-classification/README.md中提供了不同模型在IMDb影评分类任务上的性能对比:
| 排名 | 模型 | 测试集准确率 |
|---|---|---|
| 1 | 395M ModernBERT Large | 95.07% |
| 2 | 304M DeBERTa-v3 | 94.69% |
| 3 | 149M ModernBERT Base | 93.79% |
| 4 | 355M RoBERTa | 92.95% |
| 5 | 124M GPT-2 Baseline | 91.88% |
| 6 | 66M DistilBERT | 91.40% |
| 7 | 340M BERT | 90.89% |
| 8 | 逻辑回归基线 | 88.85% |
值得注意的是,124M参数的GPT-2在分类任务上表现优于同等规模的BERT模型,展示了decoder-only架构在某些分类任务上的优势。
微调效率分析
通过对比不同模型的训练时间和性能,可以得出以下关键发现:
- 较小模型(如66M DistilBERT)训练速度快(4.26分钟)但性能略低
- 较大模型(如395M ModernBERT Large)性能最佳但训练时间最长(27.69分钟)
- GPT-2在性能和效率之间取得了良好平衡(9.48分钟达到91.88%准确率)
实战案例:IMDb影评情感分析
ch06/03_bonus_imdb-classification提供了一个更复杂的情感分析案例,使用50k条IMDb影评数据进行情感分类。
数据集准备
通过download_prepare_dataset.py脚本获取并预处理IMDb数据集:
python download_prepare_dataset.py
该脚本会创建train.csv、validation.csv和test.csv三个文件,包含预处理后的影评文本和情感标签。
不同模型训练命令
- GPT-2模型
python train_gpt.py --trainable_layers "all" --num_epochs 1
- BERT模型
python train_bert_hf.py --trainable_layers "all" --num_epochs 1 --model "bert"
- RoBERTa模型
python train_bert_hf.py --trainable_layers "last_block" --num_epochs 1 --model "roberta"
- 逻辑回归基线
python train_sklearn_logreg.py
交互式界面部署
ch06/04_user_interface提供了一个基于Chainlit的交互式界面,让用户可以直观地与微调后的垃圾邮件分类器交互。
界面启动步骤
- 安装依赖
pip install -r requirements-extra.txt
- 运行界面
chainlit run app.py
启动后,你将看到一个类似ChatGPT的界面,可以输入文本并获得分类结果:
THE 0TH POSITION OF THE ORIGINAL IMAGE
界面实现原理
界面核心代码在ch06/04_user_interface/app.py中,使用Chainlit框架实现实时预测:
import chainlit as cl
from transformers import pipeline
# 加载模型和分词器
classifier = pipeline(
"text-classification",
model="path/to/finetuned/model",
return_all_scores=True
)
@cl.on_message
async def main(message: cl.Message):
# 获取预测结果
result = classifier(message.content)
# 格式化响应
response = f"垃圾邮件概率: {result[0][1]['score']:.4f}\n正常邮件概率: {result[0][0]['score']:.4f}"
# 发送响应
await cl.Message(content=response).send()
总结与展望
第六章深入探讨了文本分类微调技术,从数据准备到模型部署的完整流程。通过本章学习,你应该掌握:
- 文本分类任务的数据预处理方法,包括数据平衡、划分和编码
- 预训练语言模型的微调策略,特别是参数冻结和解冻技巧
- 分类模型的评估方法和性能优化
- 实际应用部署的基本流程
微调技术是将通用语言模型适应特定任务的关键方法,通过本章介绍的技术,你可以将预训练模型应用于各种分类问题,如情感分析、垃圾邮件检测、意图识别等。
后续章节将进一步探讨指令微调技术,使模型能够理解和遵循复杂指令,为构建更智能的对话系统奠定基础。要深入实践本章内容,建议从运行ch06/01_main-chapter-code/ch06.ipynb笔记本开始,逐步体验完整的微调流程。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



