LLMs-from-scratch第六章:文本分类微调技术深度解析

LLMs-from-scratch第六章:文本分类微调技术深度解析

【免费下载链接】LLMs-from-scratch 从零开始逐步指导开发者构建自己的大型语言模型(LLM),旨在提供详细的步骤和原理说明,帮助用户深入理解并实践LLM的开发过程。 【免费下载链接】LLMs-from-scratch 项目地址: https://gitcode.com/GitHub_Trending/ll/LLMs-from-scratch

你是否曾想过如何让强大的语言模型学会分辨垃圾邮件?是否好奇如何将通用语言模型转变为专业分类工具?本文将带你深入探索第六章的文本分类微调技术,通过实际案例和代码解析,掌握将GPT模型微调为高性能分类器的完整流程。读完本文,你将能够独立完成从数据准备到模型部署的全流程操作,并理解微调背后的关键技术原理。

微调技术概述

文本分类微调是将预训练语言模型(如GPT)适配到特定分类任务的关键技术。与从零开始训练相比,微调具有以下优势:

  • 显著降低计算资源需求
  • 大幅缩短训练时间
  • 利用预训练模型的语言理解能力
  • 提高小数据集上的分类性能

本章重点介绍如何将GPT模型微调到垃圾邮件检测任务,主要代码实现位于ch06/01_main-chapter-code/gpt_class_finetune.py。微调过程主要分为四个阶段:数据准备、模型修改、训练优化和评估部署。

数据准备全流程

高质量的数据准备是微调成功的基础。本章采用经典的SMS垃圾邮件数据集,完整处理流程如下:

数据集下载与处理

首先通过download_and_unzip_spam_data函数获取并预处理数据:

def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
    if data_file_path.exists():
        print(f"{data_file_path} already exists. Skipping download and extraction.")
        return
    
    # 下载文件
    with urllib.request.urlopen(url) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())
    
    # 解压文件
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)
    
    # 重命名文件
    original_file_path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)
    print(f"File downloaded and saved as {data_file_path}")

数据平衡与划分

现实世界的数据集往往存在类别不平衡问题,create_balanced_dataset函数通过采样解决这一问题:

def create_balanced_dataset(df):
    # 计算"spam"类别的数量
    num_spam = df[df["Label"] == "spam"].shape[0]
    
    # 随机采样"ham"类别以匹配"spam"数量
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
    
    # 合并子集
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
    
    return balanced_df

数据集划分通过random_split函数实现,按70%/10%/20%比例分为训练集、验证集和测试集:

def random_split(df, train_frac, validation_frac):
    # 打乱数据
    df = df.sample(frac=1, random_state=123).reset_index(drop=True)
    
    # 计算划分索引
    train_end = int(len(df) * train_frac)
    validation_end = train_end + int(len(df) * validation_frac)
    
    # 划分数据
    train_df = df[:train_end]
    validation_df = df[train_end:validation_end]
    test_df = df[validation_end:]
    
    return train_df, validation_df, test_df

自定义数据集类

ch06/01_main-chapter-code/gpt_class_finetune.py中的SpamDataset类实现了文本到模型输入的转换:

class SpamDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        self.data = pd.read_csv(csv_file)
        
        # 预编码文本
        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["Text"]
        ]
        
        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
            # 截断过长序列
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]
        
        # 填充序列至最长长度
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

模型修改与微调策略

预训练模型加载

本章使用GPT-2作为基础模型,通过ch06/01_main-chapter-code/gpt_download.py中的download_and_load_gpt2函数加载预训练权重:

settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)

分类头修改

为适应分类任务,需要修改GPT模型的输出层。原语言模型的下一个token预测头被替换为二分类头:

# 冻结大部分参数
for param in model.parameters():
    param.requires_grad = False

# 添加新的分类头
num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
model.to(device)

# 解冻最后几层以进行微调
for param in model.trf_blocks[-1].parameters():
    param.requires_grad = True

for param in model.final_norm.parameters():
    param.requires_grad = True

这种部分解冻的策略平衡了训练效率和模型性能,只更新模型顶部的几层参数,大幅减少计算量。

训练循环实现

完整的训练循环在train_classifier_simple函数中实现,关键步骤包括:

def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                        eval_freq, eval_iter):
    # 初始化跟踪列表
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    examples_seen, global_step = 0, -1

    # 主训练循环
    for epoch in range(num_epochs):
        model.train()  # 设置训练模式

        for input_batch, target_batch in train_loader:
            optimizer.zero_grad()  # 重置梯度
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()  # 计算梯度
            optimizer.step()  # 更新权重
            examples_seen += input_batch.shape[0]
            global_step += 1

            # 定期评估
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

评估指标计算

评估函数calc_accuracy_loadercalc_loss_loader分别计算模型的准确率和损失:

def calc_accuracy_loader(data_loader, model, device, num_batches=None):
    model.eval()
    correct_predictions, num_examples = 0, 0

    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches or num_batches is None:
            input_batch, target_batch = input_batch.to(device), target_batch.to(device)

            with torch.no_grad():
                logits = model(input_batch)[:, -1, :]  # 获取最后一个token的输出
            predicted_labels = torch.argmax(logits, dim=-1)

            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break
    return correct_predictions / num_examples

实验结果与分析

不同模型性能对比

ch06/03_bonus_imdb-classification/README.md中提供了不同模型在IMDb影评分类任务上的性能对比:

排名模型测试集准确率
1395M ModernBERT Large95.07%
2304M DeBERTa-v394.69%
3149M ModernBERT Base93.79%
4355M RoBERTa92.95%
5124M GPT-2 Baseline91.88%
666M DistilBERT91.40%
7340M BERT90.89%
8逻辑回归基线88.85%

值得注意的是,124M参数的GPT-2在分类任务上表现优于同等规模的BERT模型,展示了decoder-only架构在某些分类任务上的优势。

微调效率分析

通过对比不同模型的训练时间和性能,可以得出以下关键发现:

  • 较小模型(如66M DistilBERT)训练速度快(4.26分钟)但性能略低
  • 较大模型(如395M ModernBERT Large)性能最佳但训练时间最长(27.69分钟)
  • GPT-2在性能和效率之间取得了良好平衡(9.48分钟达到91.88%准确率)

实战案例:IMDb影评情感分析

ch06/03_bonus_imdb-classification提供了一个更复杂的情感分析案例,使用50k条IMDb影评数据进行情感分类。

数据集准备

通过download_prepare_dataset.py脚本获取并预处理IMDb数据集:

python download_prepare_dataset.py

该脚本会创建train.csvvalidation.csvtest.csv三个文件,包含预处理后的影评文本和情感标签。

不同模型训练命令

  1. GPT-2模型
python train_gpt.py --trainable_layers "all" --num_epochs 1
  1. BERT模型
python train_bert_hf.py --trainable_layers "all" --num_epochs 1 --model "bert"
  1. RoBERTa模型
python train_bert_hf.py --trainable_layers "last_block" --num_epochs 1 --model "roberta"
  1. 逻辑回归基线
python train_sklearn_logreg.py

交互式界面部署

ch06/04_user_interface提供了一个基于Chainlit的交互式界面,让用户可以直观地与微调后的垃圾邮件分类器交互。

界面启动步骤

  1. 安装依赖
pip install -r requirements-extra.txt
  1. 运行界面
chainlit run app.py

启动后,你将看到一个类似ChatGPT的界面,可以输入文本并获得分类结果:

THE 0TH POSITION OF THE ORIGINAL IMAGE

界面实现原理

界面核心代码在ch06/04_user_interface/app.py中,使用Chainlit框架实现实时预测:

import chainlit as cl
from transformers import pipeline

# 加载模型和分词器
classifier = pipeline(
    "text-classification",
    model="path/to/finetuned/model",
    return_all_scores=True
)

@cl.on_message
async def main(message: cl.Message):
    # 获取预测结果
    result = classifier(message.content)
    
    # 格式化响应
    response = f"垃圾邮件概率: {result[0][1]['score']:.4f}\n正常邮件概率: {result[0][0]['score']:.4f}"
    
    # 发送响应
    await cl.Message(content=response).send()

总结与展望

第六章深入探讨了文本分类微调技术,从数据准备到模型部署的完整流程。通过本章学习,你应该掌握:

  1. 文本分类任务的数据预处理方法,包括数据平衡、划分和编码
  2. 预训练语言模型的微调策略,特别是参数冻结和解冻技巧
  3. 分类模型的评估方法和性能优化
  4. 实际应用部署的基本流程

微调技术是将通用语言模型适应特定任务的关键方法,通过本章介绍的技术,你可以将预训练模型应用于各种分类问题,如情感分析、垃圾邮件检测、意图识别等。

后续章节将进一步探讨指令微调技术,使模型能够理解和遵循复杂指令,为构建更智能的对话系统奠定基础。要深入实践本章内容,建议从运行ch06/01_main-chapter-code/ch06.ipynb笔记本开始,逐步体验完整的微调流程。

【免费下载链接】LLMs-from-scratch 从零开始逐步指导开发者构建自己的大型语言模型(LLM),旨在提供详细的步骤和原理说明,帮助用户深入理解并实践LLM的开发过程。 【免费下载链接】LLMs-from-scratch 项目地址: https://gitcode.com/GitHub_Trending/ll/LLMs-from-scratch

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值