DTKD核心代码解析

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from ._base import Distiller


import numpy as np



class DTKD(Distiller):
    def __init__(self, student, teacher, cfg):
        super(DTKD, self).__init__(student, teacher)
        self.ce_loss_weight = cfg.DTKD.CE_WEIGHT  # 交叉熵损失的权重
        self.alpha = cfg.DTKD.ALPHA  # 自定义蒸馏权重
        self.beta = cfg.DTKD.BETA # 传统KD损失权重
        self.warmup = cfg.DTKD.WARMUP  # 用于逐步引入蒸馏损失
        self.temperature = cfg.DTKD.T  #  温度参数,用于蒸馏损失计算中的平滑处理

    #  前向传播过程
    def forward_train(self, image, target, **kwargs):
        #  接受Image、target及其他参数
        #  获取学生模型的输出 logits_student
        logits_student, _ = self.student(image)
        #  使用 torch.no_grad() 计算教师模型的输出 logits_teacher,以确保不计算梯度
        with torch.no_grad():
            logits_teacher, _ = self.teacher(image)

        # DTKD Loss
        reference_temp = self.temperature
        #  计算学生和教师模型的最大输出值
        logits_student_max, _ = logits_student.max(dim=1, keepdim=True)
        logits_teacher_max, _ = logits_teacher.max(dim=1, keepdim=True)
        #  使用这些最大值计算自定义温度 logits_student_temp 和 logits_teacher_temp
        logits_student_temp = 2 * logits_student_max / (logits_teacher_max + logits_student_max) * reference_temp 
        logits_teacher_temp = 2 * logits_teacher_max / (logits_teacher_max + logits_student_max) * reference_temp

        # 使用 nn.KLDivLoss(reduction='none') 计算学生和教师模型输出之间的 Kullback-Leibler 散度
        # 因为已经得到了自定义温度了
        ourskd = nn.KLDivLoss(reduction='none')(
            F.log_softmax(logits_student / logits_student_temp, dim=1) , # 学生
            F.softmax(logits_teacher / logits_teacher_temp, dim=1)       # 老师
        )
        #  对学生和教师的输出进行自定义温度下的 softmax 和 log_softmax 处理
        #  最终的损失乘以自定义温度的乘积,并求均值。
        loss_ourskd = (ourskd.sum(1, keepdim=True) * logits_teacher_temp * logits_student_temp).mean()
        
        # Vanilla KD Loss
        #  传统的KD计算
        vanilla_temp = self.temperature
        kd = nn.KLDivLoss(reduction='none')(
            F.log_softmax(logits_student / vanilla_temp, dim=1) , # 学生
            F.softmax(logits_teacher / vanilla_temp, dim=1)       # 老师
        ) 
        loss_kd = (kd.sum(1, keepdim=True) * vanilla_temp ** 2).mean() 
         
        # CrossEntropy Loss
        #  计算学生模型输出和真实标签之间的标准交叉熵损失
        loss_ce = nn.CrossEntropyLoss()(logits_student, target)
        #  总损失计算
        loss_dtkd = min(kwargs["epoch"] / self.warmup, 1.0) * (self.alpha * loss_ourskd + self.beta * loss_kd) + self.ce_loss_weight * loss_ce
        losses_dict = {
            "loss_dtkd": loss_dtkd,
        }
        #  方法返回学生模型的输出 logits_student 和包含总损失的字典 losses_dict。
        return logits_student, losses_dict

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值