研究背景与意义
现如今,模型的能力变得越来越强!将其应用于智能设备具有很强的应用场景。但是由于模型的能力主要是由参数量的累计使得模型能力变强,智能设备又不具有强大的计算能力。因此,若能找到一种技术,将模型的参数进行压缩,并且能使得压缩后的模型的性能与原始模型差距不大,这样的技术将具有很大的应用场景。
知识蒸馏就是这样一种模型的压缩技术,通过让学生模型(一个参数量小的模型)去学习教师模型(高性能的模型,大参数量的模型)的中间层表示或者最终的输出分布信息,使得学生模型能具有与教师模型相似的性能。
本文使用的蒸馏方法就是蒸馏模型每一层的信息。
研究现状
当前,使用层粒度的知识蒸馏蒸馏方法效果不好,由于教师模型与学生模型之间的能力差距过大(参数量差距过大),教师模型有太多丰富的信息,学生模型无法完全学习得到,导致学生会出现欠拟合。
本文使用的方法-TED
注意:本文提到得教师模型与学生模型都是预训练过的。
作者在对教师模型与学生模型的中间层都加上一个神经网络,来拟合目标任务,用以找到目标任务的模式信息,将教师模型中的丰富信息过滤。然后将教师模型与学生模型结合来训练,使用教师模型来微调学生模型。具体而言,可以分为两个阶段
stage1
蓝色是预训练好的教师的中间层,黄色是预训练好的学生的中间层。左图虚线包裹的是教师的过滤器,一个神经网络。右图是学生模型的过滤器。在stage1,教师模型与学生模型的中间层参数都是固定的,通过学习目标任务来修改教师模型与学生模型过滤器的参数,使得能够针对具体得任务过滤掉一些不重要得信息。
stage2
固定教师模型的过滤器参数,将学生模型的中间层参数变为可训练。使用教师模型过滤器学习到的隐藏层的状态来更新学生模型的参数。
第一个损失函数:交叉熵损失,ui是教师得logits值,zi是学生的
第二个损失函数:交叉熵损失
第三各损失函数:MSE,教师的过滤器的隐藏层表示与学生的隐藏层表示之间的mse
论文中有详细解释
最小化这样一个损失,第一项损失表示的是最小化学生输出与标签的损失,第二项损失是最小化学生最终输出的向量与教师的损失(这个损失参见知识蒸馏得开篇之作),第三项损失是最小化教师中间层表示与学生中间层表示的损失。
实验结果
KD:最原始的那个蒸馏的算法,其学习目标相对于本文,没有第三项。
LWD:相对于此方法的区别就是教师没有一个过滤器,学生模型直接来拟合到教师模型的中间层表示。其学习目标相对于本文来说,将第三项中的教师模型的表示换为不进行过滤器变换后的表示。
在这个生成式任务中,本文作者提出的方法是可以用的,但是并不在所有指标都是SOTA
在GLUE benchmark中,本文作者提出的方法有着较好的性能。
本文的作者还做了不同的网络结构的过滤器的影响,如上图。