知识蒸馏的概念
可以参照NeurIPS2015的论文“Distilling the Knowledge in a Neural Network”了解知识蒸馏的概念。
知识蒸馏的狭义概念就是从复杂模型中迁移知识来提升简单模型的性能。复杂模型称之为教师模型,简单模型称之为学生模型。最近,笔者重温了知识蒸馏的概念,并在CIFAR100数据集上对知识蒸馏进行了验证和实验。
logits,硬目标,软目标的概念:logits指的是网络最后一层的输出概率,硬目标指的是真值标签的one-hot编码,软目标指的是对logits进行softmax之后的概率。
加入温度系数的软目标,为了让softmax之后的概率分布更加软化,Hinton提出了使用了温度参数对logits进行softmax的软化处理,
T为温度,T越大,概率分布更加平缓。
数据集 CIFAR100,是一个经典的图像分类模型,有100个图像类别
数据集直接采用Pytorch定义的官方数据集进行加载
import torchvision
from torchvision import transforms
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
train_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train=True,
transform=transform_train,
download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train = False,
transform=transform_test,
download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)
分类模型:采用ResNet50作为教师模型,VGG16作为学生模型。
VGG16网络定义代码
"""vgg in pytorch
[1] Karen Simonyan, Andrew Zisserman
Very Deep Convolutional Networks for Large-Scale Image Recognition.
https://arxiv.org/abs/1409.1556v6
"""
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
cfg = {
'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}
class VGG(nn.Module):
def __init__(self, features, num_class=100):
super().__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, num_class)
)
def forward(self, x):
output = self.features(x)
output = output.view(output.size()[0], -1)
output = self.classifier(output)
return output
def make_layers(cfg, batch_norm=False):
layers = []
input_channel = 3
for l in cfg:
if l == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
continue
layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]
if batch_norm:
layers += [nn.BatchNorm2d(l)]
layers += [nn.ReLU(inplace=True)]
input_channel = l
return nn.Sequential(*layers)