参考教程:
https://arxiv.org/pdf/2012.12877.pdf
https://github.com/facebookresearch/deit
文章目录
概述
在之前的章节中提到过,VIT模型训练的一个问题是对数据的要求比较高,因为基于transformer的模型相对于基于卷积的模型,更加flexible。卷积的模型有着预设好的感受野,而transformer的模型需要自己去学习哪部分更加重要,因此训练上也更困难。
在这种情况下,想独自训练一个效果比较好的transformer模型是很困难的,你很难准备大几百万的数据集用于训练。这也给论文复现带来了难度,你看别人的模型效果好,你想去学习,但是没有资源训练出相当的模型。
DEIT提出了一种基于token的蒸馏方法,使用和训练卷积网络差不多的时间,只用imagenet作为训练集,就实现了非常不错的效果。
总的来说,DEIT做出了以下贡献(这一段直接翻译的论文原文):
- 证明了不包含卷积层的网络在只是用ImageNet数据的情况下也能取得很有竞争力的表现。
- 提出了一种基于token的蒸馏方法,并且这个方法的效果明显超过了普通的蒸馏方法。
- 有趣的是,基于transformer的模型以convnet为老师时表现的比以transformer为老师时要好。
- 他们的基于imagenet预训练的模型应用于其它下游任务时效果也很不错。
Knowledge Distillation
在这里补充一点知识蒸馏相关的内容。
知识蒸馏简单来说呢,就是把我们想要训练的模型当作“学生”模型,在向我们的hard label,也就是ground truth的结果靠近的同时,也让它向一个“老师”模型(一般是一个效果更好的、体量更大的模型)输出的soft label靠近。
比较简单的方法就是直接让学生模型的输出logits去拟合老师模型的输出logits,复杂一点的会增加层与层之间的拟合。
下面的代码就来自一个比较早期的repohttps://github.com/haitongli/knowledge-distillation-pytorch/tree/master
可以看到KD_loss明显有两部分组成。
T = params.temperature
KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + F.cross_entropy(outputs, labels) * (1. - alpha)
第一个部分就是我们的软目标损失,使用KLD散度计算输出的logits与老师模型输出的logits的差距,T在这里是一个温度系数,T越大得到的概率分布就越平滑。第二个部分就是我们的硬目标损失,也就是输出与label的交叉熵损失。
DEIT
base model: VIT
首先来重新介绍一下DEIT方法中使用的模型框架,其实也就是复习了一遍VIT。
transformer block
DEIT的工作是在VIT模型的基础上完成的。使用固定大小的RGB图像作为输入,这个图像被拆解成N个大小为16*16的小patch,N的大小一般是14*14。也就是说默认图像的大小是224*224。
每个patch都会被处理成一个指定维度的token。在之前的章节中我们介绍过这里有两个常用做法,再次复述一下。
第一种做法是使用reshape之后,使用全连接层完成维度的变化。
self.proj = Rearrange('b c (h p) (w p ) -> b (h w) (p1 p2 c)', p = patch_size)
self.linear = nn.Linear(patch_size * patch_size * in_c, embed_dim)
第二种做法是直接使用卷积。
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size = patch_size, stride=patch_size)
目前来说第二种方法是更常用的。
然后再给得到的embedding加上一个class_token和一个position_embeddings。就构成了一个完整的输入。
class token
VIT中模仿BERT的做法,在得到的patch embedding上concat了一个可训练的class token。这个class token也会贯穿整个网络,并且最终用于分类。它相当于起到了串联所有patch_embedding的作用,它包含的也是一个整体的信息。
也就是说在整个过程中,transformer一共使用了N+1个token,但是只有第一个class token被用来进行结果的预测。
position embedding
已知transformer中最重要的结构就是MSA,在MSA中会根据你的输入计算三个vector,分别是Query, Key, Value。并使用Q和K的内积计算attention。
我们直接看一下源码,可以看到这个qkv是通过全连接得到的,它完成的是从embed_dim到embed_dim的映射,这个过程是和embed的数量无关的。
self.qkv = nn.Linear(emb_size, emb_size*3)
所以一个在low-resolution的图像上训练的模型,也是很容易用在high-resolution的图像上的。只要使用一样的patch_size就可以。
这时候聪明的你可能会发现一个问题,patch_size大小一样,在high-resolution图像上得到的patch的数量肯定比low-resolution要多呀。那么position_embedding是会受到影响的,position_embedding的大小是和我们的数量以及embed_size都有关系的。
self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))
原VIT论文中的做法是这样的
We therefore perform 2D interpolation of the pre-trained position embeddings, according to their location in the original image.
Distillation through attention
作者在论文中对蒸馏的部分进行了比较详细的介绍。
soft distillation
软蒸馏就是上面介绍的,用学生模型的logits向老师模型的logits学习,两者的差距使用KL散度来衡量。
hard distillation
硬蒸馏是将老师模型预测的结果也作为真实的标签,让你的学生模型也去学习这个标签。
L g l o b a l h a r d D i s t i l l = ( 1 − ϵ ) × 1 2 L C E ( ψ ( Z s ) , y ) + ϵ × 1 2 L C E ( ψ ( Z s ) , y t ) L^{hardDistill}_{global} = (1-\epsilon)\times\frac{1}{2}L_{CE}(\psi(Z_s),y) + \epsilon\times\frac{1}{2}L_{CE}(\psi(Z_s),yt) L

DEIT通过引入基于token的蒸馏方法,展示了在仅使用ImageNet数据集和类似训练卷积网络的时间下,纯Transformer模型也能达到优秀效果。文章详细阐述了知识蒸馏的概念,DEIT模型的基础——VisionTransformer(VIT)的结构,以及TransformerBlock、classtoken和positionembedding。此外,还介绍了DistilledVisionTransformer的实现,包括distillationtoken和相应的损失函数。DEIT证明了Transformer模型可以有效地从convnet老师那里学习,并在下游任务中表现出色。
最低0.47元/天 解锁文章
369






