本文来源公众号“Coggle数据科学”,仅用于学术分享,侵权删,干货满满。
原文链接:Kaggle 知识点:知识蒸馏的三种方法
本文介绍了知识蒸馏(Knowledge Distillation)
技术,这是一种将大型、计算成本高昂的模型的知识转移到小型模型上的方法,从而在不损失有效性的情况下实现在计算能力较低的硬件上部署,使得评估过程更快、更高效。
https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
在本教程中,我们将进行一系列实验,专注于通过使用更强大的网络作为教师来提高轻量级神经网络的准确性。通过本教程,你将学习到:
-
如何修改模型类以提取隐藏层表示,并将其用于进一步计算。
-
如何修改PyTorch中的常规训练循环,以包括额外的损失函数,例如在分类上的交叉熵之外。
-
如何通过使用更复杂的模型作为教师来提高轻量级模型的性能。
步骤 1:读取环境
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
步骤 2:加载数据集
CIFAR-10 是一个非常流行的图像数据集,它包含了十个类别,包含 60,000 张 32x32 像素的彩色图像,分为 50,000 张训练图像和 10,000 张测试图像,每个类别有 6,000 张图像。
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset,