import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from ._base import Distiller
import numpy as np
class DTKD(Distiller):
def __init__(self, student, teacher, cfg):
super(DTKD, self).__init__(student, teacher)
self.ce_loss_weight = cfg.DTKD.CE_WEIGHT # 交叉熵损失的权重
self.alpha = cfg.DTKD.ALPHA # 自定义蒸馏权重
self.beta = cfg.DTKD.BETA # 传统KD损失权重
self.warmup = cfg.DTKD.WARMUP # 用于逐步引入蒸馏损失
self.temperature = cfg.DTKD.T # 温度参数,用于蒸馏损失计算中的平滑处理
# 前向传播过程
def forward_train(self, image, target, **kwargs):
# 接受Image、target及其他参数
# 获取学生模型的输出 logits_student
logits_student, _ = self.student(image)
# 使用 torch.no_grad() 计算教师模型的输出 logits_teacher,以确保不计算梯度
with torch.no_grad():
logits_teacher, _ = self.teacher(image)
# DTKD Loss
reference_temp = self.temperature
# 计算学生和教师模型的最大输出值
logits_student_max, _ = logits_student.max(dim=1, keepdim=True)
logits_teacher_max, _ = logits_teacher.max(dim=1, keepdim=True)
# 使用这些最大值计算自定义温度 logits_student_temp 和 logits_teacher_temp
logits_student_temp = 2 * logits_student_max / (logits_teacher_max + logits_student_max) * reference_temp
logits_teacher_temp = 2 * logits_teacher_max / (logits_teacher_max + logits_student_max) * reference_temp
# 使用 nn.KLDivLoss(reduction='none') 计算学生和教师模型输出之间的 Kullback-Leibler 散度
# 因为已经得到了自定义温度了
ourskd = nn.KLDivLoss(reduction='none')(
F.log_softmax(logits_student / logits_student_temp, dim=1) , # 学生
F.softmax(logits_teacher / logits_teacher_temp, dim=1) # 老师
)
# 对学生和教师的输出进行自定义温度下的 softmax 和 log_softmax 处理
# 最终的损失乘以自定义温度的乘积,并求均值。
loss_ourskd = (ourskd.sum(1, keepdim=True) * logits_teacher_temp * logits_student_temp).mean()
# Vanilla KD Loss
# 传统的KD计算
vanilla_temp = self.temperature
kd = nn.KLDivLoss(reduction='none')(
F.log_softmax(logits_student / vanilla_temp, dim=1) , # 学生
F.softmax(logits_teacher / vanilla_temp, dim=1) # 老师
)
loss_kd = (kd.sum(1, keepdim=True) * vanilla_temp ** 2).mean()
# CrossEntropy Loss
# 计算学生模型输出和真实标签之间的标准交叉熵损失
loss_ce = nn.CrossEntropyLoss()(logits_student, target)
# 总损失计算
loss_dtkd = min(kwargs["epoch"] / self.warmup, 1.0) * (self.alpha * loss_ourskd + self.beta * loss_kd) + self.ce_loss_weight * loss_ce
losses_dict = {
"loss_dtkd": loss_dtkd,
}
# 方法返回学生模型的输出 logits_student 和包含总损失的字典 losses_dict。
return logits_student, losses_dict