摘要
本文提出了跨阶段连接路径(cross-stage connection paths)的概念,这是在知识蒸馏中的首次尝试。实验表明,使用教师网络的低级特征来指导学生网络的深层特征可以显著提高整体性能。基于上述发现,作者提出了一种新颖的框架,称为“知识复习”,使用教师网络的多个层级来指导学生网络的单级学习。为了进一步提高知识复习机制,提出了基于注意力的融合(ABF)模块和分层上下文损失(HCL)函数。
一、引言
1.1 新发现
之前有关KD 的工作集中于转换和损失函数。本文通过教师和学生之间的连接路径入手。如图1(a)-(c)所示,所有以前的方法都只使用相同级别的信息来指导学生。本文研究了以往被忽视的连接路径设计在KD中的重要性,并据此提出了一个新的有效框架。关键的改进是使用教师网络中的低级特征来监督学生的更深层次的特征,从而大大提高了整体性能。
如图(d)我们使用教师的多个层次来监督学生的一个层次。
1.2 知识复习机制
复习机制是使用以前的(较浅的)特征来指导当前特征。这意味着学生必须经常检查以前学过的东西,以更新对“旧知识”的理解和背景。
然而,如何从教师提供的多层次信息中提取出有用的信息,并将其传递给学生,是一个具有挑战性的问题。为了解决这些问题,我们提出了一个残差学习框架,使学习过程稳定和有效。此外,一个新的基于注意力的融合(ABF)模块和分层上下文丢失(HCL)函数的设计,以提高性能。我们提出的框架使学生网络大大提高了学习的有效性。
1.3 主要贡献
- 提出了一种新的知识蒸馏复习机制,利用教师的多级信息指导学生网络的一级进行学习。
- 为了更好地实现复习机制的学习过程,提出了一个残差学习框架。
- 为了进一步提高知识复习机制,提出了一个基于注意力的融合(ABF)模块和分层上下文损失(HCL)函数。
- 在多个CV任务上实现SOTA
二、所提方法
首先形式化知识蒸馏过程和复习机制。在此基础上,提出了一种新的框架,并引入了基于注意力的融合模块和分层上下文损失函数。
2.1 复习机制
给定输入,定义学生网络和教师网络的输出、中间层特征,单层KD损失
多层KD损失
我们的复习机制是使用以前的特征来指导当前的特征。具有复习机制的单层知识蒸馏公式化为:
它与多层KD(4)虽然乍看上去有相似之处,但实际上有着本质的区别。这里,学生网络的特征被固定到FiS,使用教师的前i个级别的特征来指导FiS。Mi,jS简单地由卷积层和最近的插值层组成,来转移学生的第i个特征,以匹配教师的第j个特征的大小。
上式为具有复习机制的多层KD
整个KD损失更改为
在我们提出的复习机制中,只使用教师的浅层特征来监督学生的深层特征。我们发现,相反,带来边际效益,浪费了许多资源。直观的解释是,更深层次和更抽象的特征对于早期学习来说太复杂了。
2.2 残差学习框架
(a)MI,jS简单地由卷积层和最近的插值层组成,来转移学生的第i个特征,以匹配教师的第j个特征的大小。不改变换教师特征Ft。学生网络特征转换为与教师网络特征相同的大小。
(b)显示了将该复习机制直接应用于多层蒸馏(学生网络的当前层不仅与教师网络的当前层有关,还与教师网络之前的层有关),并提取了所有阶段的特征。然而,这种策略并不是最优的,因为阶段之间存在巨大的信息差异。并且该学习过程繁琐并且耗费许多资源。
删除特征转化,简化等式(6)
切换i和j的两个求和的顺序
当j固定时,等式(9)累加教师特征Fjt和学生特征Fjs——Fns之间的距离。
将距离的总和近似为融合特征的距离
U是一个融合特征模块,下图(c)中的红色部分
(c)图表示了这种融合近似
为了更高的效率,可以如图(d)所示以递归方式进一步优化融合的计算
将表示为从
到Fns的特征融合,损失重新写为
2.3 ABF和HCL
- 基于注意力的融合(ABF):首先,将学生网络中来自更高层的特征resize到与低层特征相同的尺寸。接着,将resize后的高层特征与低层特征在通道维度上进行拼接。通过拼接的特征计算出两组注意力图(attention maps),每组注意力图的大小与特征图的高和宽相同。(这些注意力图是通过将拼接特征输入到一个全连接层或卷积层来生成的)
使用计算出的注意力图对拼接的特征进行加权,即每个特征图的每个通道乘以对应的注意力值。最后,将加权后的低层和高层特征图相加,得到最终的融合特征图。
ABF模块的优势在于其能够自适应地聚合来自不同网络层级的特征,因为注意力机制可以根据输入特征动态调整不同区域的重要性。这种方法使得学生网络能够更加灵活地学习教师网络的知识,尤其是在特征融合阶段,能够更好地捕捉到教师网络中的关键信息。
分层上下文损失(HCL):HCL利用空间金字塔池化(Spatial Pyramid Pooling)来提取学生和教师网络特征图中的多尺度上下文信息。金字塔池化是一种特征提取方法,它可以将一个特征图分割成多个不同分辨率的子区域,每个子区域捕获不同层次的上下文信息。通过金字塔池化得到的多尺度特征表示,HCL可以在不同的抽象层次上分别计算损失。这意味着损失函数不仅仅关注整体的特征差异,而是在多个尺度上分别度量学生和教师网络特征的相似度。HCL的目的是在不同的抽象层次上有效地传递教师网络的上下文信息给学生网络。通过分别计算不同层次的损失并累加,HCL鼓励学生网络在各个层次上都学习到教师网络的特征表示。
三. 实验
对比实验略
3.4 更多的分析
跨阶段KD。我们分析了跨阶段知识蒸馏的有效性。在CIFAR-100数据集上,我们使用ResNet 20作为学生网络,ResNet 56作为教师网络。ResNet 20和ResNet 56有四个阶段。我们选择学生的不同阶段和教师的不同阶段做实验。结果总结于表6中。
由表格知,利用教师的低层次信息来监督学生的深层次是有帮助的。相反,教师的更深层次和更抽象的特征对于早期的学生来说太复杂了。这与我们提出的复习机制是一致的,即利用教师的浅层阶段来监督学生的深层阶段。
消融实验。
RM 为复习机制,RLF 为残差学习框架,ABF为注意力融合模块, HCL分层上下文损失。
使用我们提出的复习机制,结果比baseline有所改进,如第二行所示,它使用了图2(B)中所示的结构。当我们用残差学习框架进一步细化结构时,学生会获得更大的收益。基于注意力的融合模块和分层上下文损失功能在单独使用时也提供了很大的改进。当我们将它们聚集在一起时,会得到最好的结果。令人惊讶的是,他们甚至比教师更好。