一、上章原理回顾

具体过程:
(1)首先我们要先训练出较大模型既teacher模型。(在图中没有出现)
(2)再对teacher模型进行蒸馏,此时我们已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征x之后,预测出来的结果teacher_preds标签。
(3)此时,求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。
(4)先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。
(5)再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。
(6)求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,我在代码中设置为a=0.3。
(7)最后反向传播继续迭代。
二、代码实现
1、数据集
数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。
class MyDataset(Dataset):
def __init__(self,opt):
self.opt = opt
def MyData(self):
## mnist数据集下载0
mnist = datasets.MNIST(
root='../datasets/', train=True, download=False, transform=transforms.Compose(
[transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
)
dataset_size = len(mnist)
train_size = int(0.8 * dataset_size)
test_size = dataset_size - train_size
train_dataset, test_dataset = random_split(mnist, [train_size, test_size])
train_dataloader = DataLoader(
train_dataset,
batch_size=self.opt.batch_size,
shuffle=True,
)
test_dataloader = DataLoader(
test_dataset,
batch_size=self.opt.batch_size,
shuffle=False, # 在测试集上不需要打乱顺序
)
return train_dataloader,test_dataloader
2、teacher模型和训练实现
(1) 首先是teacher模型构造,经过三次线性层。
import torch.nn as nn
import torch
img_area = 784
class TeacherModel(nn.Module):
def __init__(self,in_channel=1,num_classes=10):
super(TeacherModel,self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(img_area,1200)
self.fc2 = nn.Linear(1200, 1200)
self.fc3 = nn.Linear(1200, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = x.view(-1, img_area)
x = self.fc1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
return x
(2)训练teacher模型
老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。
import torch.nn as nn
import torch
## 创建文件夹
from tqdm import tqdm
from dist.TeacherModel import TeacherModel
weight_path =

本文介绍了知识蒸馏在深度学习中的应用,包括训练大模型(teacher模型)、模型蒸馏过程(利用teacher模型指导student模型学习)、代码实现(使用PyTorch构建MNIST数据集的teacher和student模型,以及蒸馏训练的具体步骤)和模型压缩的效果。
最低0.47元/天 解锁文章
1738

被折叠的 条评论
为什么被折叠?



