知识蒸馏:教师学生模型,通过软目标实现从教师到学生模型的知识传递

单独训练教师模型
教师模型采用三个全连接层,隐藏层的神经元为1200–>2400–>1200,为防止过拟合,加入Dropout
模型训练采用交叉熵损失,Adam优化器,学习率为1e-4,训练批次epoch=10,batch-size=64
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
class TeacherModel(nn.Module):
def __init__(self, in_channel=1, num_classes=10):
super(TeacherModel, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784, 1200)
self.fc2 = nn.Linear(1200, 2400)
self.fc3 = nn.Linear(2400, 1200)
self.fc4 = nn.Linear(1200, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc4(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = torchvision.datasets.MNIST(
root = "./dataset/",
train=True,
transform=transforms.ToTensor(),
download=True
)
test_dataset = torchvision.datasets.MNIST(
root = "./dataset/",
train = False,
transform=transforms.ToTensor(),
download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
if __name__ == "__main__":
"""
从头训练教师模型
"""
model = TeacherModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 10
for epoch in range(epochs):
model.train()
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
prediction = model(data)
loss = criterion(prediction, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
prediction = model(x)
prediction = prediction.max(1).indices
num_correct += (prediction == y).sum()
num_samples += prediction.size(0)
acc = (num_correct/num_samples).item()
model.train()
print("Epoch:{}\t Accuracy:{:.4f}".format(epoch, acc))
torch.save(model.state_dict(), './weights/teacher/teacher_{}.pth'.format(acc))
"""
教师模型
Epoch:8 Accuracy:0.9831
"""

经过10个epoch训练,教师模型的精度为 Accuracy:0.9831
单独训练学生模型
学生模型采用3层全连接层,隐藏层的神经元为20–>20–>20,不需要Dropout
学生模型的训练设置与教师模型完全保持一致
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
class StudentModel(nn.Module):
def __init__(self, in_channel=1, num_classes=10):
super(StudentModel, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784, 20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, num_classes)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = torchvision.datasets.MNIST(
root = "./dataset/",
train=True,
transform=transforms.ToTensor(),
download=True
)
test_dataset = torchvision.datasets.MNIST(
root = "./dataset/",
train = False,
transform=transforms.ToTensor(),
download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
if __name__ == "__main__":
"""
从头训练学生模型
"""
model = StudentModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 10
for epoch in range(epochs):
model.train()
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
prediction = model(data)
loss = criterion(prediction, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
prediction = model(x)
prediction = prediction.max(1).indices
num_correct += (prediction == y).sum()
num_samples += prediction.size(0)
acc = (num_correct/num_samples).item()
model.train()
print("Epoch:{}\t Accuracy:{:.4f}".format(epoch, acc))
torch.save(model.state_dict(), './weights/student/student_{}.pth'.format(acc))
"""
学生模型
Epoch:9 Accuracy:0.9224
"""

经过10个epoch训练,教师模型的精度为 Accuracy:0.9224
知识蒸馏训练
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
from teacher import TeacherModel
from student import StudentModel
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
train_dataset = torchvision.datasets.MNIST(
root = "./dataset/",
train=True,
transform=transforms.ToTensor(),
download=True
)
test_dataset = torchvision.datasets.MNIST(
root = "./dataset/",
train = False,
transform=transforms.ToTensor(),
download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
if __name__ == "__main__":
"""
学生模型蒸馏训练
"""
student_model = StudentModel().to(device)
teacher_model = TeacherModel().to(device).eval()
teacher_model.load_state_dict(torch.load("./weights/teacher/teacher_0.9830999970436096.pth"))
student_model.train()
Temp = 4
alpha = 0.8
hard_loss = nn.CrossEntropyLoss()
soft_loss = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
epochs = 20
for epoch in range(epochs):
student_model.train()
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
with torch.no_grad():
teacher_predictions = teacher_model(data)
teacher_predictions = teacher_predictions.detach()
student_predictions = student_model(data)
student_loss = hard_loss(student_predictions, targets)
distillation_loss = soft_loss(
F.log_softmax(student_predictions / Temp, dim=1),
F.softmax(teacher_predictions / Temp, dim=1)
)
loss = (1-alpha) * Temp * Temp * distillation_loss + alpha * student_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
student_model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
prediction = student_model(x)
prediction = prediction.max(1).indices
num_correct += (prediction == y).sum()
num_samples += prediction.size(0)
acc = (num_correct/num_samples).item()
student_model.train()
print("Epoch:{}\t Accuracy:{:.4f}".format(epoch, acc))
torch.save(student_model.state_dict(), './weights/knowledge_distillation/student_{}.pth'.format(acc))
"""
Temp = 4, alpha = 0.3
学生训练: Epoch:9 Accuracy:0.9224
教师训练: Epoch:8 Accuracy:0.9831
学生蒸馏训练 Epoch:19 Accuracy:0.9288
Temp = 4, alpha = 0.5
学生训练: Epoch:9 Accuracy:0.9224
教师训练: Epoch:8 Accuracy:0.9831
学生蒸馏训练 Epoch:19 Accuracy:0.9308
Temp = 4, alpha = 0.8
学生训练: Epoch:9 Accuracy:0.9224
教师训练: Epoch:8 Accuracy:0.9831
学生蒸馏训练 Epoch:19 Accuracy:0.9293
"""
学生蒸馏训练,参照另一篇博客的四个注意事项,笔者选择温度系数T=4,损失权重alpha为0.3,0.5,0.8分别进行实验,得到实验结果如下
模型 | 温度参数T | 损失权重alpha | 分类精度 |
---|
教师模型 | – | – | 98.31% |
学生模型 | – | – | 92.24% |
蒸馏学生模型 | 4 | 0.3 | 92.88% |
蒸馏学生模型 | 4 | 0.5 | 93.08% |
蒸馏学生模型 | 4 | 0.8 | 92.93% |
实验总结:通过表格对比,证明了知识蒸馏的有效性。此外,损失权重alpha取值的不同,也影响着蒸馏学生模型的分类精度,说明了温度系数,损失权重,都对蒸馏这个学习过程有着重要的影响。