知识蒸馏之手写体识别

知识蒸馏

本文主要是根据该网站视频(https://www.bilibili.com/video/BV1s7411h7K2?t=906)进行总结,如有理解误差,望批评指点

1. 首次提出

首次提出:https://arxiv.org/pdf/1503.02531.pdf
作者的动机是想找到一个方法,把多个模型的知识提炼给单个模型。
虽然现在很多分类模型都采用交叉熵衡量预测值与真实值,然而真实值采用的one-hot向量所能提共的信息没有概率分布多。
原理:概率分布比onehot更能提供信息-暗知识。
loss=0.7KL散度(softmax_t(老师输出))+0.3交叉熵(oneHot)
两个分布的loss可以用KL散度。

2. 简单介绍

知识蒸馏,可以将一个网络的知识转移到另一个网络,两个网络可以是同构或者异构。做法是先训练一个teacher网络,然后使用这个teacher网络的输出和数据的真实标签去训练student网络。知识蒸馏,可以用来将网络从大网络转化成一个小网络,并保留接近于大网络的性能;也可以将多个网络的学到的知识转移到一个网络中,使得单个网络的性能接近emsemble的结果。

例如:水果蔬菜分类
在这里插入图片描述

3.pytorch中的一些损失函数

在这里插入图片描述

log_softmax_p = torch.log(torch.softmax(p))
loss1= F.nll_loss(log_softmax_p ,target)
loss2 = F.cross_entropy(p,target))
loss1 == loss2

在这里插入图片描述

4. 核心

在这里插入图片描述

当没有老师网络时候,仅仅将图像经过student网络,softmax之后,输出概率分布值q,将q与真实值求loss就是称为Hard loss,因为这个p是真实值的one-hot向量,我们希望q和p越接近越好。
当有老师的帮助下的时候,loss来自student和teacher 网络。且,老师输出的q‘要经过蒸馏之后(让它更加平滑)得到q’‘再与q求loss,总loss,是它们之间的和。

5. 理论分析

在这里插入图片描述
通过引入软目标:soft-target作为total loss的一部分,诱导学生网络以精简,低复杂度的训练,实现知识迁移。
其中total loss的设计采用的是 交叉熵损失和kl散度(软目标,学生预测),如上图,如果软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献。这对训练初期是很有必要的。但是训练后期需要减少软目标的比重。

图中没加平滑时的数字识别是9,对应的概率中9的概率很小,加了平滑之后,9的比重相对而言大了。
在这里插入图片描述

6. 实例:知识蒸馏实现手写体识别

6.1 Teacher model

简单的卷积网络

class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet,self).__init__()
        self.conv1 = nn.Conv2d(1,32,3,1)
        self.conv2 = nn.Conv2d(32,64,3,1)
        self.dropout1 = nn.Dropout2d(0.3)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216,128)
        self.fc2 = nn.Linear(128,10)
        
    #此处没有softmax
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)
        x = self.dropout1(x)
        x = torch.flatten(x,1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        output = self.fc2(x)
        return output

6.2 Student model

更简单的线形模型

class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet,self).__init__()
        self.fc1 = nn.Linear(28*28,128)
        self.fc2 = nn.Linear(128,64)
        self.fc3 = nn.Linear(64,10)
        
    def forward(self,x):
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = F.relu(self.fc3(x))
        return output
    

6.3 Teacher 教 Student

KD的Loss

def distillation(y,labels,teacher_scores,temp,alpha):
    return nn.KLDivLoss()(F.log_softmax(y/temp, dim=1),
                          F.softmax(teacher_scores/temp,dim=1))*(temp*temp*2.0*alpha)+F.cross_entropy(y,labels)*(1. - alpha)

老师教学生,前提是教师网络已经是训练好的,然后去教学生。

def train_student_kd(model,device,train_loader,optimizer,epoch):
    model.train()
    teacher_model.eval()
    trained_samples = 0
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target = data.to(device),target.to(device)
        optimizer.zero_grad()
        output = model(data)
        teacher_output = teacher_model(data)
        teacher_output = teacher_output.detach()
        loss = distillation(output,target,teacher_output,temp=5.0,alpha=0.7)
        loss.backward()
        optimizer.step()
        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader)*50)
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
#         print("\r Train epoch %d:%d/%d"%(epoch,trained_samples,len(train_loader.dataset)))
        
def test_student_kd(model,device,test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data,target in test_loader:
            data,target = data.to(device),target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item() 
            pred = output.argmax(dim=1,keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print("\r Test average loss:{:.4f},accuracy:{}/{}({:.0f}%)".format(test_loss,correct,len(test_loader.dataset),100.*correct/len(test_loader.dataset)))
    return test_loss, correct/len(test_loader.dataset)

    
def student_kd_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/MNIST',train=True,download=False,
                                                             transform=transforms.Compose([
                                                                 transforms.ToTensor(),
                                                                 transforms.Normalize((0.1307,),(0.3081,))
                                                             ])),
                                              batch_size=batch_size,shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/MNIST',train=False,download=False,
                                                             transform=transforms.Compose([
                                                                 transforms.ToTensor(),
                                                                 transforms.Normalize((0.1307,),(0.3081,))
                                                             ])),
                                              batch_size=batch_size,shuffle=True)
    
    model = StudentNet().to(device)
    teacher_model = TeacherNet()
    teacher_model.load_state_dict(torch.load('teacher.pth')) #加载老师网络参数
    
    optimizer = torch.optim.Adadelta(model.parameters())
    student_history = []
    for epoch in range(1,epochs+1):
        train_student_kd(model,device,train_loader,optimizer,epoch)
        loss , acc = test_student_kd(model,device,test_loader)
        student_history.append((loss,acc))
    torch.save(model.state_dict(),'student.pth')
    return model,student_history

在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值