Knowledge Distillation(5)——Deep Mutual Learning

本文介绍了一种名为DeepMutualLearning的知识蒸馏方法,该方法不依赖于教师模型,而是让多个学生模型相互学习,通过引入额外的相对熵损失,提升模型性能。实验表明,这种方法比传统的教师-学生模式更有效。

之前都是对knowledge重新定义,衍生出的knowledge distillation的变体模型。
本篇博客开始,介绍知识蒸馏的第二类方法:改变学习方式,提高student perfomance。

Deep Mutual Learning CVPR2018

概述

本文核心idea是,没有teacher,一系列student之间相互学习。
本质也是学习了另一个网络的输出,但是学习方式遍了,因而我把它归到第二类论文
在这里插入图片描述
motivation:
在这里插入图片描述
至于why work?
在这里插入图片描述
在这里插入图片描述
这种学习方式更稳定,更容易收敛:
在这里插入图片描述
在这里插入图片描述
以及这个分析还不错(似乎也是翻译的论文)

Method

Model

具体方法很简单,定义了两个网络。除了使用GT对各自进行监督,还引入了一个额外的loss,引导他们进行相互间的学习。
在这里插入图片描述
这个额外的loss,利用相对熵,即KL散度来定义:
在这里插入图片描述
这种额外的loss对于两个网络来说是不一样的,非对称的。也可以用下面这个对称的loss替代,而且效果几乎无差别:
在这里插入图片描述

Optimisation

作者的训练步骤
在这里插入图片描述
两个模型一起训练,且互相学习直到收敛。不会像传统的distillation,student学习一个已经训练好的teacher,还只是用来初始化参数!!

Experiments

两个students也是一个大网络,一个小网络。其相互学习的效果比,比teacher-student的模式要好:
在这里插入图片描述
在这里插入图片描述

提供的参考引用中未提及知识蒸馏与小样本学习结合的项目实例相关内容,不过可以从理论角度了解两者结合的可能应用场景及实例思路。 在医学影像诊断领域,医学影像数据标注成本高、数据量有限,属于典型的小样本场景。通过知识蒸馏与小样本学习结合,可以先训练一个大型的教师模型,使用大量模拟生成的医学影像数据或者公共的医学影像数据集进行预训练,让教师模型学习到丰富的医学影像特征和诊断知识。然后将教师模型的知识蒸馏到一个小样本训练的学生模型上,该学生模型可以在少量标注的特定医院或特定疾病的医学影像数据上进行训练。例如,针对罕见病的医学影像诊断,由于病例稀少,数据量小,利用这种结合方法可以提高诊断的准确性和效率。 在工业缺陷检测中,新的产品型号或者新的缺陷类型出现时,往往只有少量的缺陷样本。可以利用知识蒸馏与小样本学习结合的方法。先构建一个基于大量常见产品缺陷样本训练的教师模型,该模型可以学习到各种缺陷的特征模式。之后将教师模型的知识转移到一个在少量新缺陷样本上训练的学生模型中,使得学生模型能够快速准确地检测新出现的缺陷类型。 ```python # 以下是一个简单的伪代码示例,展示知识蒸馏与小样本学习结合的基本流程 import torch import torch.nn as nn import torch.optim as optim # 定义教师模型 class TeacherModel(nn.Module): def __init__(self): super(TeacherModel, self).__init__() # 定义教师模型的层结构 self.fc1 = nn.Linear(10, 20) self.fc2 = nn.Linear(20, 2) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # 定义学生模型 class StudentModel(nn.Module): def __init__(self): super(StudentModel, self).__init__() # 定义学生模型的层结构 self.fc1 = nn.Linear(10, 15) self.fc2 = nn.Linear(15, 2) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # 初始化教师模型和学生模型 teacher_model = TeacherModel() student_model = StudentModel() # 假设这里是小样本数据 small_sample_data = torch.randn(5, 10) small_sample_labels = torch.randint(0, 2, (5,)) # 定义损失函数 criterion = nn.CrossEntropyLoss() distillation_loss = nn.KLDivLoss() # 定义优化器 optimizer = optim.Adam(student_model.parameters(), lr=0.001) # 知识蒸馏训练过程 for epoch in range(100): # 教师模型预测 teacher_output = teacher_model(small_sample_data) # 学生模型预测 student_output = student_model(small_sample_data) # 计算知识蒸馏损失 distill_loss = distillation_loss(torch.log_softmax(student_output, dim=1), torch.softmax(teacher_output, dim=1)) # 计算分类损失 class_loss = criterion(student_output, small_sample_labels) # 总损失 total_loss = class_loss + distill_loss # 反向传播和优化 optimizer.zero_grad() total_loss.backward() optimizer.step() ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值