Knowledge distillation-知识蒸馏

知识蒸馏是一种技术,通过教师网络的软目标引导学生网络的训练,实现复杂模型的知识向轻量级模型的迁移。教师网络通常是高准确率但计算量大的模型,而学生网络则是小巧且资源高效的替代方案。研究表明,在NLP领域,预训练蒸馏(PD)方法优于传统的预训练加微调(PF),特别是在资源有限的情况下,小规模的学生网络也能达到良好的性能。实验结果显示,PD方法的表现最佳,其次是PF,而仅基于有标签数据的基本训练效果最差。

1. 简介

Knowledge distillation-知识蒸馏(暗知识提取)的概念,通过引入与教师网络(teacher network:复杂、但推理性能优越)相关的软目标(soft-target)作为total loss的一部分,以诱导学生网络(student network:精简、低复杂度)的训练,实现知识迁移(knowledge transfer)。

教师网络teacher:高准确率,但模型很大。
学生网络student:模型小,可以在有限资源下使用。

本文参考2019年《WELL-READ STUDENTS LEARN BETTER: ON THE IMPORTANCE OF PRE-TRAINING COMPACT MODELS》。描述了在NLP方面,Distillation有助于提升模型表现,比传统的pre-training+fine-tuning方法好。这样在实际应用中,我们就可以用非常小的student网络取得很好的表现。

各模型介绍:
在这里插入图片描述

我们实验发现PD(Pre-trained Distillation)好于pre-training+fine-tuning (PF)。揭示了Distillation 的好处。PD如下图:
在这里插入图片描述

实验结果如下图:PD最好,其次是PF。Basic training(只是基于有label的数据训练)最差。Distillation(蒸馏)基于Unlabeled transfer data进行训练。

在这里插入图片描述

### FitNets架构及其实现 FitNets 是一种早期的知识蒸馏方法,其核心思想在于不仅传递最终的预测结果(即软标签),还通过中间层特征图的学习来增强模型之间的知识迁移。这种方法特别适用于小型学生网络无法完全匹配大型教师网络复杂度的情况。 #### 架构设计 FitNets 的主要创新点之一是引入了 **hint layers** 来指导学生网络学习教师网络的关键特征表示。具体来说,在训练过程中,除了传统的输出层损失外,还会计算一个额外的中间层损失函数,用于强制学生网络模仿教师网络特定层次的激活模式[^1]。这种机制使得即使学生网络的整体结构较浅或参数较少,也能捕获到教师网络的重要语义信息。 以下是 FitNets 中涉及的主要组件: - **Teacher Network**: 提供高质量的知识源,通常是一个深层复杂的神经网络。 - **Student Network**: 需要优化的目标轻量化模型。 - **Hint Layer**: 教师网络中的某一层被选作 hint layer,该层负责向学生网络提供详细的内部表征信息。 #### 实现细节 为了有效实施上述策略,FitNess采用了两阶段训练流程: 1. **预训练阶段**: 利用标准监督信号单独对 teacher 和 student 进行初始化权重调整; 2. **联合训练阶段**: 同时最小化两个目标——一个是基于 softmax 输出的概率分布差异;另一个则是衡量选定 hint 层之间欧几里得距离的标准 MSE (Mean Squared Error) 损失项。 下面给出一段 Python 伪代码演示如何构建这样的系统框架: ```python import torch.nn as nn import torch.optim as optim class TeacherModel(nn.Module): def __init__(self): super(TeacherModel, self).__init__() # Define a deep network here... class StudentModel(nn.Module): def __init__(self): super(StudentModel, self).__init__() # Define a shallow network here... teacher = TeacherModel() student = StudentModel() criterion_kd = nn.MSELoss() # For intermediate feature matching criterion_ce = nn.CrossEntropyLoss() # Standard classification loss optimizer = optim.Adam(student.parameters(), lr=0.001) for data, target in dataloader: output_teacher, features_teachers = teacher(data) output_student, features_students = student(data) kd_loss = criterion_kd(features_students['hint'], features_teachers['hint']) ce_loss = criterion_ce(output_student, target) total_loss = alpha * kd_loss + beta * ce_loss optimizer.zero_grad() total_loss.backward() optimizer.step() ``` 其中 `alpha` 和 `beta` 控制两种不同类型误差的重要性权衡系数[^2]。 #### 结果评估 实验表明,相比仅依赖于传统交叉熵损失的传统压缩技术而言,加入 fitnets 所提出的辅助约束条件可以显著提升小规模 CNNs 在图像识别任务上的表现效果[^3]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值