核心技术概述
实现了一个高效的知识蒸馏(Knowledge Distillation)训练流程,并结合了超参数优化和大规模训练优化工具。旨在高效地训练出性能优异但更轻量级的模型。
1. 知识蒸馏 (Knowledge Distillation)
这是代码的核心目标。它旨在将一个大型、高性能的教师模型( ./best_model)的知识有效迁移到一个更小、更高效的学生模型(distilbert-base-uncased)中,从而在保持较高性能的同时,大幅减少模型大小和推理时间。
best_model是reberta微调后的模型。
具体实现了以下蒸馏策略:
- 软标签蒸馏 (Soft Label Distillation):学生模型学习教师模型输出的平滑概率分布(通过
temperature参数控制平滑程度)。这通过计算学生模型和教师模型软化后的 logits 之间的 KL 散度实现。 - 硬标签损失 (Hard Label Loss):学生模型依然直接从真实标签中学习,这通过标准的 交叉熵损失 实现。
- 注意力蒸馏 (Attention Distillation):学生模型尝试模仿教师模型的内部注意力机制。它通过计算学生模型和教师模型对应层(通过
get_layer_mapping映射)的注意力权重之间的 MSE 损失来完成。 - 加权损失组合:最终的蒸馏损失是上述三种损失的加权和,通过
alpha参数调整硬标签和软标签损失的比例,通过attention_weight参数调整注意力蒸馏损失的贡献。
2. 超参数优化 (Hyperparameter Optimization)
为了找到最佳的蒸馏效果,使用了强大的 Optuna 库来自动化超参数搜索过程:
- 智能搜索策略:Optuna 能够有效地探索
train_micro_batch_size_per_gpu、gradient_accumulation_steps、alpha、temperature、learning_rate、num_epochs等多个超参数的最佳组合。 - 基于验证指标的优化:
objective函数在每次试验中训练并验证模型,将验证集上的 F1 分数作为优化目标(direction="maximize")。 - 智能剪枝 (Pruning):通过
optuna.pruners.MedianPruner,Optuna 会根据中间结果判断哪些试验不太可能达到好的性能,并提前终止这些试验。这大大减少了不必要的计算时间,提高了调参效率。 - 最终模型训练与保存:在 Optuna 找到最优超参数组合后,代码会使用这组参数进行一次最终训练,并在测试集上进行评估,最后将最佳的蒸馏模型保存到本地。
3. 大规模训练优化工具
为了确保模型训练的高效性和可扩展性,代码集成了 Hugging Face 的 Accelerate 库:
- 分布式训练支持:
Accelerate提供了一个抽象层,使得训练代码无需大幅改动,即可在单设备、多 GPU 或分布式环境中运行。 - 混合精度训练 (Mixed Precision Training):通过
mixed_precision="fp16",模型可以在训练中使用半精度浮点数(FP16),这能显著减少显存占用并加速计算。 - 梯度累积 (Gradient Accumulation):
Accelerate的accelerator.accumulate上下文管理器实现了梯度累积,允许使用较小的物理批次大小来模拟更大的逻辑批次,从而在有限的 GPU 内存下处理更多数据。 - 兼容 DeepSpeed:通过
accelerate config命令指定 DeepSpeed 配置文件,Accelerate就会在底层利用 DeepSpeed 的高级优化功能(如 ZeRO 优化器),进一步降低显存消耗并提升训练效率。
deepspeed 配置:
ds_config.json
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
}
},
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 2,
"gradient_clipping": 1.0,
"fp16": {
"enabled": true
}
}
4. 数据处理与模型管理
此外,代码还利用了其他实用的工具和技术:
- Hugging Face Transformers:用于加载和管理预训练的 Transformer 模型 (
AutoTokenizer,AutoModelForSequenceClassification)。特别是output_attentions=True和output_hidden_states=True对于获取蒸馏所需的中间输出至关重要。 - Hugging Face Datasets:高效加载和预处理数据集(例如 IMDB 评论数据集)。
- PyTorch 优化器与学习率调度器:使用
torch.optim.AdamW作为优化器,并结合get_cosine_schedule_with_warmup或get_linear_schedule_with_warmup来动态调整学习率,帮助模型更好地收敛。 tqdm:提供了美观的进度条,方便监控训练进度。sklearn.metrics:用于计算模型性能指标,如准确率和 F1 分数。
代码:
#!/usr/bin/env python
# coding: utf-8
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
set_seed
)
from datasets import load_dataset
from accelerate import Accelerator
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
# --- 引入 Optuna 库 ---
import optuna
from optuna.pruners import MedianPruner # 用于提前停止不佳的试验
# --- 层映射工具 ---
def get_layer_mapping(teacher_layers, student_layers):
"""
根据教师模型和学生模型的层数,创建注意力层映射关系。
"""
mapping = {
}
for s_idx in range(student_layers):
# 计算学生层 s_idx 对应的教师层 t_idx
# 确保映射在教师模型层数范围内,并尽量均匀分布
t_idx = int(s_idx * (teacher_layers - 1) / (student_layers - 1))
mapping[t_idx] = s_idx
return mapping
# --- 计算评估指标 ---
def compute_metrics(preds, labels):
"""
计算预测结果的准确率和 F1 分数。
"""
preds = preds.argmax(axis=-1)
acc = accuracy_score(labels, preds)
f1 = f1_score(labels, preds, average="weighted")
return {
"accuracy": acc, "f1": f1}
# --- 蒸馏损失计算 ---
def compute_distillation_loss(
student_logits, teacher_logits,
student_outputs, teacher_outputs,
labels,
layer_mapping,
alpha=0.3, temperature=3.0,
attention_weight=0.5,
use_attention_distillation=True
):
"""
计算知识蒸馏的总损失,包括硬标签损失、软标签损失和注意力蒸馏损失。
Args:
student_logits: 学生模型的 logits。
teacher_logits: 教师模型的 logits。
student_outputs: 学生模型的完整输出(包含 attention)。
teacher_outputs: 教师模型的完整输出(包含 attention)。
labels: 真实标签。
layer_mapping: 教师层到学生层的映射字典。
alpha: 硬标签损失的权重。
temperature: 软标签蒸馏的温度参数。
attention_weight: 注意力蒸馏损失的权重。
use_attention_distillation: 是否使用注意力蒸馏。
Returns:
total_loss: 总蒸馏损失。
loss_info: 包含各项损失的字典。
"""
dtype = student_logits.dtype
# 1. 硬标签损失 (Hard Label Loss)
# 学生模型预测与真实标签之间的交叉熵损失
hard_loss = F.cross_entropy(student_logits, labels)
# 2. 软标签损失 (Soft Label Loss / KL 散度)
# 学生模型 logits 经过温度 T 软化后的分布与教师模型 logits 经过温度 T 软化后的分布之间的 KL 散度
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).to(dtype)
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
soft_loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
# 3. 注意力蒸馏损失 (Attention Distillation Loss)
att_loss

最低0.47元/天 解锁文章
1025

被折叠的 条评论
为什么被折叠?



