基于 Transformer RoBERTa的情感分类任务实践总结之六——知识蒸馏

核心技术概述

实现了一个高效的知识蒸馏(Knowledge Distillation)训练流程,并结合了超参数优化大规模训练优化工具。旨在高效地训练出性能优异但更轻量级的模型。

1. 知识蒸馏 (Knowledge Distillation)

这是代码的核心目标。它旨在将一个大型、高性能的教师模型./best_model)的知识有效迁移到一个更小、更高效的学生模型distilbert-base-uncased)中,从而在保持较高性能的同时,大幅减少模型大小和推理时间。
best_model是reberta微调后的模型。

具体实现了以下蒸馏策略:

  • 软标签蒸馏 (Soft Label Distillation):学生模型学习教师模型输出的平滑概率分布(通过 temperature 参数控制平滑程度)。这通过计算学生模型和教师模型软化后的 logits 之间的 KL 散度实现。
  • 硬标签损失 (Hard Label Loss):学生模型依然直接从真实标签中学习,这通过标准的 交叉熵损失 实现。
  • 注意力蒸馏 (Attention Distillation):学生模型尝试模仿教师模型的内部注意力机制。它通过计算学生模型和教师模型对应层(通过 get_layer_mapping 映射)的注意力权重之间的 MSE 损失来完成。
  • 加权损失组合:最终的蒸馏损失是上述三种损失的加权和,通过 alpha 参数调整硬标签和软标签损失的比例,通过 attention_weight 参数调整注意力蒸馏损失的贡献。

2. 超参数优化 (Hyperparameter Optimization)

为了找到最佳的蒸馏效果,使用了强大的 Optuna 库来自动化超参数搜索过程:

  • 智能搜索策略:Optuna 能够有效地探索 train_micro_batch_size_per_gpugradient_accumulation_stepsalphatemperaturelearning_ratenum_epochs 等多个超参数的最佳组合。
  • 基于验证指标的优化objective 函数在每次试验中训练并验证模型,将验证集上的 F1 分数作为优化目标(direction="maximize")。
  • 智能剪枝 (Pruning):通过 optuna.pruners.MedianPruner,Optuna 会根据中间结果判断哪些试验不太可能达到好的性能,并提前终止这些试验。这大大减少了不必要的计算时间,提高了调参效率。
  • 最终模型训练与保存:在 Optuna 找到最优超参数组合后,代码会使用这组参数进行一次最终训练,并在测试集上进行评估,最后将最佳的蒸馏模型保存到本地。

3. 大规模训练优化工具

为了确保模型训练的高效性和可扩展性,代码集成了 Hugging Face 的 Accelerate 库:

  • 分布式训练支持Accelerate 提供了一个抽象层,使得训练代码无需大幅改动,即可在单设备、多 GPU 或分布式环境中运行。
  • 混合精度训练 (Mixed Precision Training):通过 mixed_precision="fp16",模型可以在训练中使用半精度浮点数(FP16),这能显著减少显存占用并加速计算。
  • 梯度累积 (Gradient Accumulation)Accelerateaccelerator.accumulate 上下文管理器实现了梯度累积,允许使用较小的物理批次大小来模拟更大的逻辑批次,从而在有限的 GPU 内存下处理更多数据。
  • 兼容 DeepSpeed:通过 accelerate config 命令指定 DeepSpeed 配置文件,Accelerate 就会在底层利用 DeepSpeed 的高级优化功能(如 ZeRO 优化器),进一步降低显存消耗并提升训练效率。
    deepspeed 配置:
    ds_config.json
{
   
   
  "zero_optimization": {
   
   
    "stage": 2,
    "offload_optimizer": {
   
   
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
   
   
      "device": "cpu",
      "pin_memory": true
    }
  },
  "train_micro_batch_size_per_gpu": 8,
  "gradient_accumulation_steps": 2,
  "gradient_clipping": 1.0,
  "fp16": {
   
   
    "enabled": true
  }
}


4. 数据处理与模型管理

此外,代码还利用了其他实用的工具和技术:

  • Hugging Face Transformers:用于加载和管理预训练的 Transformer 模型 (AutoTokenizer, AutoModelForSequenceClassification)。特别是 output_attentions=Trueoutput_hidden_states=True 对于获取蒸馏所需的中间输出至关重要。
  • Hugging Face Datasets:高效加载和预处理数据集(例如 IMDB 评论数据集)。
  • PyTorch 优化器与学习率调度器:使用 torch.optim.AdamW 作为优化器,并结合 get_cosine_schedule_with_warmupget_linear_schedule_with_warmup 来动态调整学习率,帮助模型更好地收敛。
  • tqdm:提供了美观的进度条,方便监控训练进度。
  • sklearn.metrics:用于计算模型性能指标,如准确率和 F1 分数。

代码:

#!/usr/bin/env python
# coding: utf-8

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    set_seed
)
from datasets import load_dataset
from accelerate import Accelerator
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# --- 引入 Optuna 库 ---
import optuna
from optuna.pruners import MedianPruner # 用于提前停止不佳的试验

# --- 层映射工具 ---
def get_layer_mapping(teacher_layers, student_layers):
    """
    根据教师模型和学生模型的层数,创建注意力层映射关系。
    """
    mapping = {
   
   }
    for s_idx in range(student_layers):
        # 计算学生层 s_idx 对应的教师层 t_idx
        # 确保映射在教师模型层数范围内,并尽量均匀分布
        t_idx = int(s_idx * (teacher_layers - 1) / (student_layers - 1))
        mapping[t_idx] = s_idx
    return mapping

# --- 计算评估指标 ---
def compute_metrics(preds, labels):
    """
    计算预测结果的准确率和 F1 分数。
    """
    preds = preds.argmax(axis=-1)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average="weighted")
    return {
   
   "accuracy": acc, "f1": f1}

# --- 蒸馏损失计算 ---
def compute_distillation_loss(
    student_logits, teacher_logits,
    student_outputs, teacher_outputs,
    labels,
    layer_mapping,
    alpha=0.3, temperature=3.0,
    attention_weight=0.5,
    use_attention_distillation=True
):
    """
    计算知识蒸馏的总损失,包括硬标签损失、软标签损失和注意力蒸馏损失。

    Args:
        student_logits: 学生模型的 logits。
        teacher_logits: 教师模型的 logits。
        student_outputs: 学生模型的完整输出(包含 attention)。
        teacher_outputs: 教师模型的完整输出(包含 attention)。
        labels: 真实标签。
        layer_mapping: 教师层到学生层的映射字典。
        alpha: 硬标签损失的权重。
        temperature: 软标签蒸馏的温度参数。
        attention_weight: 注意力蒸馏损失的权重。
        use_attention_distillation: 是否使用注意力蒸馏。

    Returns:
        total_loss: 总蒸馏损失。
        loss_info: 包含各项损失的字典。
    """
    dtype = student_logits.dtype

    # 1. 硬标签损失 (Hard Label Loss)
    # 学生模型预测与真实标签之间的交叉熵损失
    hard_loss = F.cross_entropy(student_logits, labels)

    # 2. 软标签损失 (Soft Label Loss / KL 散度)
    # 学生模型 logits 经过温度 T 软化后的分布与教师模型 logits 经过温度 T 软化后的分布之间的 KL 散度
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).to(dtype)
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
    soft_loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)

    # 3. 注意力蒸馏损失 (Attention Distillation Loss)
    att_loss 
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值