一、动机:
(1)在现实中对图像分类难度不一,采用一个固定的框架对图片进行分类时有时不够灵活。
主要思想就是在一个网络中有多个分类出口(创新点),对于简单图像可以直接从前面某个分类出口得到结果,而难分类的网络可能要到网络后面的某一层才能得到可靠的结果.
(2)动态网络的early-exit可以减少计算量,同时latter-exit也不必重复进行浅层backbone的inference,加快速度 。
主要应用领域:图像分类,语义分割
数据集:MNIST, CIFAR 10,CIFAR 100, ImageNet, COCO, PASCAL VOC
二、exit branch是一个分类器,输出每个类别的概率。
(1)数量:2~n个exit
(2)结构:卷积层、全连接层、池化层+softmax
(3)位置:位于每个block之后。
一些文献指出,在backbone中稍后面的放置一个exit并不一定会提高该分支的整体准确性,可能在更前的位置效果更好。
一般来说,exit位于“自然块”之后,例如concatenation layer,residual connection, dense block之后,它们的性能更好
全连接层作为branch,参考:https://github.com/ArchipLab-LinfengZhang/pytorch-scalable-neural-networks
max pool + avg pool +全连接作为branch,参考:https://github.com/yigitcankaya/Shallow-Deep-Networks
conv+conv+avg pool+全连接作为branch,参考:https://github.com/kalviny/MSDNet-PyTorch
三、多级网络的训练方法:
(1)end-to-end, one-stage, 联合训练所有branch的loss,每个branch的loss有一个超参数权重。问题:该方法对branch位置敏感,一个branch的acc可能受其他branch影响
(2)layer-wise, 分级训练, n-stage,第一次训练模型直到第一个branch的部分,第二次冻结之前的权重,训练模型剩余部分直到第二个branch的部分,依次进行。
(3)classifier-wise, n-stage, 首先训练backbone+final exit,然后冻结backbone,单独训练每个branch
(4)在上述三种基础上,加入知识蒸馏,其中许多文献采用自蒸馏方法,采用最后一级branch作为teacher蒸馏前面的branch
(5)可能存在的其他创新
1、End-to-End训练,并加入自蒸馏:参考 https://github.com/ArchipLab-LinfengZhang/pytorch-scalable-neural-networks
作者采用三个损失
(1)分类损失,交叉熵。
(2)自蒸馏损失,用final exit蒸馏前面的exit,交叉熵。
(3)特征损失,最后一层ResBlock与前面的ResBlock输出的特征损失,计算特征误差的平方和。
for epoch in range(args.epoch):
net.train()
sum_loss = 0.0
correct = 0.0
total = 0.0
for i, data in enumerate(trainloader, 0):
length = len(trainloader)
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs, feature_loss = net(inputs) #outputs包含四个branch的输出 [out4, out3, out2, out1], feature_loss是最后的branch输出特征和前面的branch特征的平方差之和
ensemble = sum(outputs[:-1])/len(outputs) #ensemble是所有branch的平均输出
ensemble.detach_()
ensemble.requires_grad = False
# compute loss
loss = torch.FloatTensor([0.]).to(device)
#最后一级branch的loss
loss += criterion(outputs[0], labels)
#最后一级branch作为teacher
teacher_output = outputs[0].detach