一个例子练习知识蒸馏

本文通过一个图像分类的例子练习一下如何做知识蒸馏。

数据集选用CIFAR10.
知识蒸馏是用一个比较大的教师模型的output来训练较小的学生模型,paper可以参考这篇

本文用一个简单的教师模型resnet18,resnet18本来用于imageNet, 现在cifar10的图片太小,需要修改一下防止图片被下采样没信息了,改卷积层的kernel size和maxpool layer. 另外,不要直接调用model.forward, 那样会调用修改之前的模型,还是会报错,因为修改之前的resnet18针对的是imageNet的图片尺寸,不适用于cifar10.

class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.model = torchvision.models.resnet18(pretrained=True)
        # 修改第一层以适应32x32输入
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.maxpool = nn.Identity()
        # 修改最后一层以适应CIFAR-10的10个类别
        self.model.fc = nn.Linear(512, 10)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.model.fc(x)
        return x

先把这个教师模型训练出来。加上early stop机制。

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
            
def train_teacher(train_loader, val_loader, model, epochs=50, device='cuda'):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
    early_stopping = EarlyStopping(patience=5)

    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    best_val_acc = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if i % 100 == 99:
                print(
                    f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f} acc: {100. * correct / total:.2f}%')
                running_loss = 0.0

        train_loss, train_acc = evaluate(model, train_loader, criterion, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f'Epoch {epoch + 1}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_teacher.pth')

        scheduler.step(val_loss)

        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f'Early stopping triggered at epoch {epoch + 1}')
            break

    plot_training_curves(train_losses, train_accs, val_losses, val_accs)

    return train_losses, train_accs, val_losses, val_accs

接下来用这个教师模型去指导一个小型的学生模型。
定义一个只有3个卷积层的学生模型。

class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

定义知识蒸馏的损失函数,损失包含hard label部分和soft label部分,hard label是标注的label, soft label是教师模型的output经过温度的变化。
而且要注意,由于soft target产生的梯度幅度 相当于缩放了 1/T2 ,因此在同时使用hard 和 soft targets时将其乘以 T2.
其中hard loss用的是CE loss, soft loss用的是KL散度。

class DistillationLoss(nn.Module):
    def __init__(self, T=4):
        super(DistillationLoss, self).__init__()
        self.T = T

    def forward(self, student_logits, teacher_logits, labels, alpha=0.5):
        soft_targets = nn.functional.softmax(teacher_logits / self.T, dim=1)
        soft_prob = nn.functional.log_softmax(student_logits / self.T, dim=1)
        #KL散度:Loss = -Σ P(x) * log(Q(x))    [忽略常数项 Σ P(x) * log(P(x))]
        #除以soft_prob.size(0)是为了取平均,使得批大小的改变不会影响损失的规模
        soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size(0)

        hard_loss = nn.functional.cross_entropy(student_logits, labels)
        #用温度T软化概率分布时,梯度会被同时缩小T^2倍
        #为了保持梯度的数量级与硬目标损失相近,需要乘上T^2进行补偿
        loss = alpha * hard_loss + (1 - alpha) * (self.T ** 2) * soft_targets_loss
        return loss

知识蒸馏训练

def main():
    set_seed(225)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./dataset', train=True, download=True, transform=transform_train)
    trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./dataset', train=False, download=True, transform=transform_test)
    testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


    teacher_model = TeacherNet().to(device)
    student_model = StudentNet().to(device)

    teacher_weights = torch.load('best_teacher.pth', map_location=device)
    teacher_model.load_state_dict(teacher_weights)
    print("成功加载教师模型权重!")

    train_student(teacher_model, student_model, trainloader, epochs=15, device=device)

    teacher_acc = evaluate(teacher_model, testloader, device)
    student_acc = evaluate(student_model, testloader, device)

    print(f"教师网络测试准确率: {teacher_acc:.2f}%")
    print(f"知识蒸馏后的学生网络测试准确率: {student_acc:.2f}%")

    # 保存学生模型权重
    torch.save(student_model.state_dict(), 'student_distill.pth')

if __name__ == '__main__':
    main()

你可以通过改变温度T,hard loss和soft loss的占比alpha来看不同的温度T和不同的比例对知识蒸馏的影响。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蓝羽飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值