一、什么是模型蒸馏and为什么要模型蒸馏
模型蒸馏简单来说,就是将一个大模型(比如BERT)的权重,通过一定规则,压缩到小模型(比如RoBERTa)的权重中。
蒸馏其实特别形象,就像把一杯饱和盐水蒸馏成纯净水一样,质量减少了效果却没有多大变化。
在大模型(LLM)领域,模型的发展趋势是参数量不断增大。因为模型的性能和模型的规模是正相关的。
从2017年到2024年,LLM的参数数量经历了指数级增长:
2017年,Transformer模型的参数量为0.05亿。
2018年,GPT模型的参数量为0.11亿,BERT模型为0.34亿。
2020年,GPT-2的参数量达到15亿,MegatronLM为83亿。
2021年,GPT-3的参数量达到1750亿,T-NLG为170亿。
2022年,MT-NLG的参数量达到5300亿。
2023年,GPT-4的参数量突破万亿。
2024年,最大且最常用的模型参数量约为700亿。
而模型蒸馏的目标就是减少模型的参数量,从而减少模型的计算量和内存占用,从而提高模型的性能。
因此模型蒸馏就是一种模型压缩的方法。
二、模型蒸馏的原理
模型蒸馏的原理是知识蒸馏。
首先介绍几个概念:
2.1 学生模型和教师模型
所谓的学生模型,就是模型蒸馏的目标模型。一般比被蒸馏的模型也就是教师模型要少一些参数。
教师模型通常是一个大型的、复杂的并且经过充分训练的模型,它在目标任务上表现非常出色,在绝大多数的情况下,教师模型决定了学生模型的性能上限,并且学生模型的性能不会超过教师模型。所以教师模型一般会是一个性能出色的模型。
教师模型就像是一位知识渊博的专家,它对知识的掌握全面而深入,其丰富的参数和复杂的结构使其能够学习到数据中深层次的信息。
而学生模型则相对小巧简单,参数数量较少,结构也更为精简 。学生模型的目标是通过学习教师模型的知识,尽可能地逼近教师模型的性能。尽管学生模型自身的容量有限,但通过模型蒸馏这一过程,它能够从教师模型那里获取关键知识,从而在保持较小规模的同时,实现较好的任务表现,就像学生在老师的教导下不断成长,逐渐掌握解决问题的能力。
2.2 软标签和硬标签
在机器学习和深度学习中,硬标签(Hard Labels)和软标签(Soft Labels) 是两种不同的标签表示方式。
硬标签(Hard Labels)在模型蒸馏中,硬标签是指 真实数据的标注,通常是 独热编码(One-Hot Encoding) 的形式。硬标签表示每个样本的真实类别,用于 监督学习。
还是举个例子:
假设有一个三分类问题,某个样本的真实类别是第2类,那么它的硬标签为:
[0, 1, 0]
特点 是离散性和明确性:
离散性:硬标签是离散的,每个样本只有一个类别被标记为1,其余类别为0。
明确性:硬标签直接反映了数据的真实类别,没有概率信息。
软标签(Soft Labels) 在模型蒸馏中,软标签是指教师模型的输出概率分布。教师模型对每个样本进行预测,输出每个类别的概率值,这些概率值构成了软标签。
还是举个例子:
假设教师模型对某个样本 的 预测Logits(原始输出) 为 [2.0, 1.0, 0.1],温度参数 T=2
我们首先要将Logits(原始输出) 转化为概率分布,即公式:
Softmax ( z i ) = e z i ∑ j e z j \text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}} Softmax(zi)=∑jezjezi
但我们要考虑到温度参数T,这是一个超参数,即是可以人为给定的一个数值,它控制了概率分布的平滑程度。所以我们采用下面的公式:
Softmax ( z i T ) = e z i / T ∑ j e z j / T \text{Softmax}\left(\frac{z_i}{T}\right) = \frac{e^{z_i/T}}{\sum_j e^{z_j/T}} Softmax(Tzi)=∑jezj/Tezi/T
其中温度T有以下特点:
高温度(T>1):使 Softmax 输出更加平滑,类别之间的差异变小。
低温度(T=1):得到标准的 Softmax 输出。
极低温度(T→0):使 Softmax 输出接近独热编码。
那么软标签计算如下:
p 1 = exp ( 2.0 / 2 ) exp ( 2.0 / 2 ) + exp ( 1.0 / 2 ) + exp ( 0.1 / 2 ) ≈ 0.65 p_1 = \frac{\exp(2.0/2)}{\exp(2.0/2) + \exp(1.0/2) + \exp(0.1/2)} \approx 0.65 p1=exp(2.0/2)+exp(1.0/2)+exp(0.1/2)exp(2.0/2)≈0.65
p 2 = exp ( 1.0 / 2 ) exp ( 2.0 / 2 ) + exp ( 1.0 / 2 ) + exp ( 0.1 / 2 ) ≈ 0.25 p_2 = \frac{\exp(1.0/2)}{\exp(2.0/2) + \exp(1.0/2) + \exp(0.1/2)} \approx 0.25 p2=exp(2.0/2)+exp(1.0/2)+exp(0.1/2)exp(1.0/2)≈0.25
p 3 = exp ( 0.1 / 2 ) exp ( 2.0 / 2 ) + exp ( 1.0 / 2 ) + exp ( 0.1 / 2 ) ≈ 0.10 p_3 = \frac{\exp(0.1/2)}{\exp(2.0/2) + \exp(1.0/2) + \exp(0.1/2)} \approx 0.10 p3=exp(2.0/2)+exp(1.0/2)+exp(0.1/2)exp(0.1/2)≈0.10
所以最终的软标签为:
[0.65, 0.25, 0.10]
一句话总结就是:Logits 是模型在最后一层的原始输出,表示模型对每个类别的预测置信度。它们通常需要经过 Softmax 函数处理,以转换为概率分布,用于分类决策和损失计算。在模型蒸馏中,Logits 通过温度参数调整,生成软标签,帮助学生模型学习教师模型的隐含知识。
2.3 蒸馏损失函数
2.3.1 软标签蒸馏损失函数(Soft Label Loss)
软标签损失使用Kullback-Leibler 散度损失(KL Divergence Loss)来计算。KL 散度衡量的是两个概率分布之间的差异。在模型蒸馏中,它用于衡量学生模型预测的概率分布与教师模型输出的概率分布之间的差异。公式如下:
Soft Loss = ∑ i = 1 C ( log ( q i p i ) ⋅ q i − q i + p i ) \text{Soft Loss} = \sum_{i=1}^{C} \left( \log\left(\frac{q_i}{p_i}\right) \cdot q_i - q_i + p_i \right) Soft Loss=i=1∑C(log(piqi)⋅qi−qi+pi)
其中:
- ( C ) 是类别总数。
- ( p_i ) 是学生模型预测样本属于第 ( i ) 类的概率。
- ( q_i ) 是教师模型预测样本属于第 ( i ) 类的概率(软标签)。
为了简化计算和避免数值不稳定,还有为了进行温度调整(为了避免学生模型过于关注教师模型的软标签而忽略真实标签),KL 散度损失通常使用以下等价形式:
Soft Loss = 1 T 2 ∑ i = 1 C q i log ( q i p i ) \text{Soft Loss} = \frac{1}{T^2} \sum_{i=1}^{C} q_i \log \left(\frac{q_i}{p_i}\right) Soft Loss=T21i=1∑Cqilog(piqi)
其中,T 是温度参数,通常 T>1。
2.3.2 硬标签蒸馏损失函数(Hard Label Loss)
硬标签损失使用**交叉熵损失(Cross Entropy Loss)**来计算。交叉熵损失用于衡量两个概率分布之间的差异。在模型蒸馏中,它用于衡量学生模型预测的概率分布与真实标签之间的差异。公式如下:
Hard Loss = − ∑ i = 1 C y i log ( p i ) \text{Hard Loss} = -\sum_{i=1}^{C} y_i \log(p_i) Hard Loss=−i=1∑Cyilog(pi)
- ( C ) 是类别总数。
- ( y_i ) 是真实标签的独热编码,如果样本属于第 ( i ) 类,则 ( y_i = 1 ),否则 ( y_i = 0 )。
- ( p_i ) 是模型预测样本属于第 ( i ) 类的概率。
2.3.3 总损失函数
学生模型的总损失函数是硬标签损失和软标签损失的加权和
Total Loss = α ⋅ Soft Loss + ( 1 − α ) ⋅ Hard Loss \text{Total Loss} = \alpha \cdot \text{Soft Loss} + (1 - \alpha) \cdot \text{Hard Loss} Total Loss=α⋅Soft Loss+(1−α)⋅Hard Loss
2.3.4 为什么使用软标签
软标签在模型蒸馏中的主要作用是传递教师模型的隐含知识,帮助学生模型学习到更复杂的决策逻辑。具体来说:
类别间关系:软标签反映了类别之间的相似性。例如,如果教师模型预测某个样本属于类别A的概率为0.7,属于类别B的概率为0.2,这意味着类别A和类别B之间有一定的相似性。
平滑输出:通过调整温度参数 T,软标签可以变得更加平滑,帮助学生模型学习到更广泛的特征。
总结 :
硬标签是真实数据的标注,用于监督学习。
软标签是教师模型的输出概率分布,用于传递教师模型的隐含知识。
学生模型通过结合硬标签损失和软标签损失进行训练,既能学习到正确的分类规则,又能继承教师模型的复杂决策逻辑。
三、模型蒸馏的具体流程
具体的任务
假设我们需要对金融文本进行分类,判断其是否涉及风险。我们将使用模型蒸馏技术将一个复杂的教师模型的知识迁移到一个轻量级的学生模型中。
3.1 选择合适的教师模型
教师模型选择:使用 finBERT 作为教师模型,它是一个基于BERT架构的预训练模型,专门用于金融文本处理,性能强大
3.2 设计学生模型
学生模型架构:选择一个轻量级的网络,例如3层LSTM+Attention,适合在资源受限的设备上运行。
初始化:学生模型的参数可以随机初始化,也可以通过预训练模型进行初始化。
3.3 使用教师模型生成软标签(教师模型的输出)
注意硬标签是真实类别,是在训练集中的,不是由教师模型生成的,教师模型只输出软标签,即概率分布。
教师模型对训练数据进行预测,输出每个类别的概率分布(软标签)。这些软标签不仅包含正确答案,还反映了类别之间的相似性
教师模型的输入 可以是教师模型的训练集,或者是其他的输入数据,如金融文本。
3.4 学生模型的输入
学生模型的输入就是教师模型的输入,学生模型和教师模型的输入相同。
3.5 学生模型的训练策略
学生模型的训练目标是使其输出尽可能接近教师模型的软标签,同时结合真实标签进行学习。
我们采用软硬标签混合损失函数,即软标签损失和硬标签损失的加权组合。具体来说,使用70%的软标签损失和30%的真实标签损失。
Total Loss = 0.7 ⋅ Soft Loss + ( 1 − 0.7 ) ⋅ Hard Loss \text{Total Loss} = \ 0.7 \cdot \text{Soft Loss} + (1 - \ 0.7 ) \cdot \text{Hard Loss} Total Loss= 0.7⋅Soft Loss+(1− 0.7)⋅Hard Loss
温度调节:使用温度参数(T=3)对软标签进行平滑处理,使学生模型能够学习到更多的类别间关系。
渐进式温度调节:在训练过程中,温度从5逐渐降低到1,帮助学生模型更好地学习教师模型的知识
3.6 进行训练
输入是和教师模型一样的输入,根据教师模型产生的软标签和真实标签计算损失,并使用反向传播进行梯度下降。
1.输入数据同时传递给教师模型和学生模型。
2.计算学生模型输出与教师模型软标签之间的KL散度损失,以及学生模型输出与真实标签之间的交叉熵损失。
3.通过反向传播更新学生模型的参数
3.7 模型评估和优化
评估指标:使用准确率、F1分数等指标评估学生模型的性能。
优化策略: 如果学生模型的性能不理想,可以调整蒸馏参数(如温度、α系数)或对模型进行微调。
四、思考:模型蒸馏和将教师模型的训练数据直接训练学生模型有什么区别
模型蒸馏:
学生模型能够通过蒸馏获得更高的性能,尤其是在资源受限的环境中(如移动设备或边缘计算)。这是因为蒸馏过程允许学生模型继承教师模型的部分知识。
蒸馏过程还可以通过调整温度参数和损失函数的组合,进一步优化学生模型的学习效果
直接训练学生模型:
直接训练学生模型通常会导致性能较低,因为学生模型缺乏教师模型的复杂结构和深层次知识。
在有限的数据和计算资源下,直接训练学生模型可能无法达到教师模型的性能水平
模型蒸馏通过软标签的知识传递,使学生模型能够学习到教师模型的深层次知识,从而在资源受限的环境中获得更高的性能和更好的泛化能力。而直接训练学生模型则更依赖于数据和计算资源,适合在资源充足的情况下使用。
如果您觉得有用的话,可以为我点个赞和关注,我会持续更新有用简单免费的教程!