本文通过一个图像分类的例子练习一下如何做知识蒸馏。
数据集选用CIFAR10.
知识蒸馏是用一个比较大的教师模型的output来训练较小的学生模型,paper可以参考这篇。
本文用一个简单的教师模型resnet18,resnet18本来用于imageNet, 现在cifar10的图片太小,需要修改一下防止图片被下采样没信息了,改卷积层的kernel size和maxpool layer. 另外,不要直接调用model.forward, 那样会调用修改之前的模型,还是会报错,因为修改之前的resnet18针对的是imageNet的图片尺寸,不适用于cifar10.
class TeacherNet(nn.Module):
def __init__(self):
super(TeacherNet, self).__init__()
self.model = torchvision.models.resnet18(pretrained=True)
# 修改第一层以适应32x32输入
self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.model.maxpool = nn.Identity()
# 修改最后一层以适应CIFAR-10的10个类别
self.model.fc = nn.Linear(512, 10)
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
x = torch.flatten(x, 1)
x = self.model.fc(x)
return x
先把这个教师模型训练出来。加上early stop机制。
class EarlyStopping:
def __init__(self, patience=7, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
基于CIFAR10的知识蒸馏图像分类实践

最低0.47元/天 解锁文章
1089

被折叠的 条评论
为什么被折叠?



