知识蒸馏实战代码教学二(代码实战部分)

本文介绍了知识蒸馏在深度学习中的应用,包括训练大模型(teacher模型)、模型蒸馏过程(利用teacher模型指导student模型学习)、代码实现(使用PyTorch构建MNIST数据集的teacher和student模型,以及蒸馏训练的具体步骤)和模型压缩的效果。

一、上章原理回顾

具体过程:

        (1)首先我们要先训练出较大模型既teacher模型。(在图中没有出现)

        (2)再对teacher模型进行蒸馏,此时我们已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征x之后,预测出来的结果teacher_preds标签。

        (3)此时,求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。

        (4)先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。

        (5)再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。

        (6)求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,我在代码中设置为a=0.3。

        (7)最后反向传播继续迭代。

二、代码实现

1、数据集

        数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。

class MyDataset(Dataset):
    def __init__(self,opt):
        self.opt = opt

    def MyData(self):
        ## mnist数据集下载0
        mnist = datasets.MNIST(
            root='../datasets/', train=True, download=False, transform=transforms.Compose(
                [transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )

        dataset_size = len(mnist)
        train_size = int(0.8 * dataset_size)
        test_size = dataset_size - train_size

        train_dataset, test_dataset = random_split(mnist, [train_size, test_size])

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.opt.batch_size,
            shuffle=True,
        )

        test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.opt.batch_size,
            shuffle=False,  # 在测试集上不需要打乱顺序
        )
        return train_dataloader,test_dataloader

2、teacher模型和训练实现

       (1) 首先是teacher模型构造,经过三次线性层。

import torch.nn as nn
import torch

img_area = 784

class TeacherModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(TeacherModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

        (2)训练teacher模型

        老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。

import torch.nn as nn
import torch


## 创建文件夹
from tqdm import tqdm

from dist.TeacherModel import TeacherModel

weight_path = 
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值