编者按:更高性能的深度神经网络往往伴随着愈加庞大的参数量,而大量的计算需求使其难以部署在移动端。为此,精巧的网络结构设计(如MobileNet、ShuffleNet)、模型压缩策略(剪枝、二值化等)及其他优化方法应运而生。
Hinton等人在2015年提出的模型蒸馏算法,利用预训练好的大网络当作教师来向小网络传递知识,从而提高小网络性能。而模型蒸馏算法需要有提前预训练好的大网络,且仅可对小网络进行单向的知识传递。古人云“三人行必有我师焉”,本文作者提出了一种“深度互学习Deep Mutual Learning”策略,使得小网络之间能够互相学习共同进步。
1.研究动机
近几年来,深度神经网络在计算机视觉、语音识别、语言翻译等领域中取得了令人瞩目的成果,为了完成更加复杂的信息处理任务,网络在设计上不断增加深度或宽度,使得模型参数量越来越大,如经典的VGG、Inception、ResNet系列网络。尽管更深或更宽的神经网络取得了更好的性能,大量计算需求使得它们难以部署在资源条件有限的环境中,如手机、平板、车载等移动端应用。这促使研究者们采用各种各样的方法去探索更高效的模型,如更精巧的网络结构设计MobileNet和ShuffleNet,还有网络压缩、剪枝、二值化,以及比较有趣的模型蒸馏等。
模型蒸馏算法由Hinton等人在2015年提出,利用一个预训练好的大网络当作教师来提供小网络额外的知识即平滑后的概率估计,实验表明小网络通过模仿大网络估计的类别概率,优化过程变得更容易,且表现出与大网络相近甚至更好的性能。然而模型蒸馏算法需要有提前预训练好的大网络,且大网络在学习过程中保持固定,仅对小网络进行单向的知识传递,难以从小网络的学习状态中得到反馈信息来对训练过程进行优化调整。
我们尝试探索一种能够学习到更强大小网络的训练机制—深度互学习,即采用多个网络同时进行训练,每个网络在训练过程中不仅接受来自真值标记的监督,还参考同伴网络的学习经验来进一步提升泛化能力。在整个过程中,两个网络之间不断分享学习经验,实现互相学习共同进步。
2.算法描述
图1 深度互学习算法框架
具体来说,每个网络在学习过程中有两个损失函数,一个是传统的监督损失函数,采用交叉熵损失来度量网络预测的目标类别与真实标签之间的差异,另一个是网络间的交互损失函数,采用KL散度来度量两个网络预测概率分布之间的差异。公式表示为
采用这两种损失函数,不仅可以使得网络学习到如何区分不同的类别,还能够使其参考另一个网络的概率估计来提升自身泛化能力。
接下来我们给出网络的优化策略。对于单块GPU,我们采用交替迭代的方式依次更新两个网络,当有多块GPU时,我们可以采用分布式训练,每次迭代时两个网络同时计算概率估计差异并更新模型参数。实验发现分布式训练可以获得更好的性能。目前关于分布式训练为何能比串行训练获得更好的性能还未有比较好的理论解释,一些研究者认为在分布式训练中每个worker对附近参数空间的探索实际上提高了模型在连续梯度下降方面的统计性能。
我们提出的互学习算法也很容易扩展到多网络学习和半监督学习场景中。当有K个网络时,深度互学习学习每个网络时将其余K-1个网络分别作为教师来提供学习经验。另外一种策略是将其余K-1个网络融合后得到一个教师来提供学习经验 。在半监督互学习场景中,我们对有标签的数据计算监督损失和交互损失,而针对无标签数据我们仅计算交互损失来帮助网络从训练数据中挖掘更多有用