[YoloV9][模型优化][知识蒸馏] — 如何实现基于特征的知识蒸馏?

点击下方卡片,关注“小白玩转Python”公众号

所有实现都参考了 yzd-v/FGD github 仓库,该仓库基于 mmdetection 深度学习框架。

介绍和实现工作流程

在这里,我们将从教师模型的颈部特征中提取知识,并以特征损失的形式将知识蒸馏到学生模型中。为了实现这一点,如果按照以下工作流程实施,理解起来会更容易。

  1. 您需要加载教师模型。教师模型当然是在相同数据集上训练过的预训练模型。

  2. 由于它是基于特征的知识蒸馏,因此必须从学生/教师模型的颈部提取特征。在这里,使用 pytorch hook函数提取特征。

  3. 提取的特征必须以损失的形式蒸馏到学生模型中。在 FGD 中实现的特征损失稍作修改以适应 Yolo 格式。

加载教师模型

我们可以像下面所示创建一个知识蒸馏的 Yaml 配置。这允许您更有用地实现知识蒸馏。

# model settings
temp : 0.5
alpha_fgd : 0.002
beta_fgd : 0.001
gamma_fgd : 0.001
lambda_fgd : 0.00001


scale_y : 1.0
scale_x : 1.0
distiller : 
    type : "DetectionDistiller"
    teacher_pretrained : "yolov9e.pt"
    student_model: "yolov9c.yaml"
    distill_cfg :
        - neck_15 : 
            name : "loss_fgd_neck_15"
            type : "KDLOSS"
            student_channels : 256
            teacher_channels : 256


        - neck_18 :
            name : "loss_fgd_neck_18"
            type : "KDLOSS"
            student_channels : 512
            teacher_channels : 512


        - neck_21 :
            name : "loss_fgd_neck_21"
            type : "KDLOSS"
            student_channels : 512
            teacher_channels : 512

从这里,您可以像下面所示加载教师模型。

self.teacher_model = check_model_file_from_stem(self.distill_cfg["distiller"]["teacher_pretrained"])

将教师模型更改为评估模式,冻结梯度,使其不被学习,并设置与学生模型相同的属性。

self.teacher_model.eval()
self.teacher_model = self.teacher_model.to(self.device)
for param in self.teacher_model.parameters():
    param.requires_grad = False


self.teacher_model.nc = self.data["nc"]
self.teacher_model.names = self.data["names"]
self.teacher_model.args = self.args

我们现在可以加载教师模型!那么我们如何从模型中提取特征呢?

从学生和教师模型中提取特征

我们在哪里以及如何提取特征?在论文中,它说从目标检测的颈部提取特征。颈部是特征金字塔网络的一部分,旨在检测多尺度对象,在下图中,红色框中的 P3、P4 和 P5 可以看作是颈部的特征。从 P3 到 P5,包含了更抽象的信息。

34a9945a2c601e7733b87da70e71090f.png

yolov8 模型架构图

Yolov9 也有相同的颈部。如果我们查看 YoloV9-c 的模型配置,如下所示。

# YOLOv9


# parameters
nc: 80  # number of classes


# gelan backbone
backbone:
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]]  # 2
  - [-1, 1, ADown, [256]]  # 3-P3/8
  - [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]]  # 4
  - [-1, 1, ADown, [512]]  # 5-P4/16
  - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]]  # 6
  - [-1, 1, ADown, [512]]  # 7-P5/32
  - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]]  # 8
  - [-1, 1, SPPELAN, [512, 256]]  # 9


head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]]  # 12


  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]]  # 15 (P3/8-small)


  - [-1, 1, ADown, [256]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]]  # 18 (P4/16-medium)


  - [-1, 1, ADown, [512]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]]  # 21 (P5/32-large)


  - [[15, 18, 21], 1, Detect, [nc]]  # DDetect(P3, P4, P5)

我们最感兴趣的颈部层的位置在哪里?如果您查看最后一行的 Detect Layer 配置 ([[15, 18, 21], 1, Detect, [nc]]),您可以看到从第 15、18 和 21 层进行检测,这就是颈部层。

教师模型 YoloV9-e 的颈部是 35、38 和 41。换句话说,我们需要从第 35、38 和 41 层提取特征,并与学生模型的颈部层 15、18 和 21 进行比较,创建损失函数。

提取特征可以通过创建 pytorch 的注册函数来完成。在学生模型的第 15、18 和 21 个位置注册hook的方法如下。

student_15_hook = self.model.model[15].register_forward_hook(student_15_hook_forward)
student_18_hook = self.model.model[18].register_forward_hook(student_18_hook_forward)
student_21_hook = self.model.model[21].register_forward_hook(student_21_hook_forward)

hook函数必须有模块、输入和输出作为输入变量,如下所示。

def student_15_hook_forward(module, input, output):
    global student_15_feature
    student_15_feature = output


def student_18_hook_forward(module, input, output):
    global student_18_feature
    student_18_feature = output


def student_21_hook_forward(module, input, output):
    global student_21_feature
    student_21_feature = output

声明特征 15、18 和 21 为全局变量,并将其用作特征损失函数的输入。

对于教师模型,在相应的层创建hook函数并注册注册hook。

teacher_35_hook = self.teacher_model.model[35].register_forward_hook(teacher_35_hook_forward)
teacher_38_hook = self.teacher_model.model[38].register_forward_hook(teacher_38_hook_forward)
teacher_41_hook = self.teacher_model.model[41].register_forward_hook(teacher_41_hook_forward)
def teacher_35_hook_forward(module, input, output):
    global teacher_35_feature
    teacher_35_feature = output


def teacher_38_hook_forward(module, input, output):
    global teacher_38_feature
    teacher_38_feature = output


def teacher_41_hook_forward(module, input, output):
    global teacher_41_feature
    teacher_41_feature = output

实现特征损失

我在 ultralytics/utils/loss.py 中实现了它。这种实现与这里的代码 (https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py) 大致相同,但输入/输出界面是 mmdetection 框架和 ultralytics 框架。添加了一些修改以消除由差异引起的错误。

class FGDLoss(nn.Module):
    def __init__(self,
                 student_channels,
                 teacher_channels,
                 name,
                 temp=0.5,
                 alpha_fgd=0.001,
                 beta_fgd=0.0005,
                 gamma_fgd=0.0005,
                 lambda_fgd=0.000005,
                 scale_y=1.0,
                 scale_x=1.0
                 ):


        super().__init__()
        self.temp = temp
        self.alpha_fgd = alpha_fgd
        self.beta_fgd = beta_fgd
        self.gamma_fgd = gamma_fgd
        self.lambda_fgd = lambda_fgd
        self.name = name
        self.scale_y = scale_y
        self.scale_x = scale_x


        if student_channels != teacher_channels:
            self.align = nn.Conv2d(student_channels, teacher_channels,
                                   kernel_size=1, stride=1, padding=0, device='cuda', dtype=torch.half)
        else:
            self.align = None


        self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1, device='cuda', dtype=torch.half)
        self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1, device='cuda', dtype=torch.half)


        # GcBlock
        self.channel_add_conv_s = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1, device='cuda', dtype=torch.half),
            nn.LayerNorm([teacher_channels//2, 1, 1], device='cuda', dtype=torch.half),
            nn.ReLU(inplace=True),
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1, device='cuda', dtype=torch.half)
        )
        self.channel_add_conv_t = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1, device='cuda', dtype=torch.half),
            nn.LayerNorm([teacher_channels//2, 1, 1], device='cuda', dtype=torch.half),
            nn.ReLU(inplace=True),
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1, device='cuda', dtype=torch.half)
        )


        self.reset_parameters()


    def forward(self,
                preds_S,
                preds_T,
                batch,
                ):
        batch_num = len(batch['img'])
        gt_bboxes = [] # batched, cx, cy, w, h
        obj_idx = 0
        for i in range(batch_num):
            obj_num = batch["batch_idx"].tolist().count(i)
            xyxy = self.xywh2xyxy(batch["bboxes"])
            gt_bboxes.append(xyxy[obj_idx:obj_idx + obj_num])
            obj_idx += obj_num
        img_metas = batch["resized_shape"] # resized image shape
        # batch: im_file, ori_shape, resized_shape, img, cls, bboxes, batch_idx


        scale_y = self.scale_y
        scale_x = self.scale_x


        assert preds_S.shape[-2:] == preds_T.shape[-2:], 'the output dim of teacher and student differ'


        if self.align is not None:
            preds_S = self.align(preds_S)


        N, C, H, W = preds_S.shape


        S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)
        S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)


        Mask_fg = torch.zeros_like(S_attention_t)
        Mask_bg = torch.ones_like(S_attention_t)
        wmin, wmax, hmin, hmax = [], [], [], []
        for i in range(N):
            new_bboxes = torch.ones_like(gt_bboxes[i])
            new_bboxes[:, 0] = gt_bboxes[i][:, 0] * W / scale_x # xmin
            new_bboxes[:, 2] = gt_bboxes[i][:, 2] * W / scale_x # xmax
            new_bboxes[:, 1] = gt_bboxes[i][:, 1] * H / scale_y # ymin
            new_bboxes[:, 3] = gt_bboxes[i][:, 3] * H / scale_y # ymax


            wmin.append(torch.floor(new_bboxes[:, 0]).int())
            wmax.append(torch.ceil(new_bboxes[:, 2]).int())
            hmin.append(torch.floor(new_bboxes[:, 1]).int())
            hmax.append(torch.ceil(new_bboxes[:, 3]).int())


            area = 1.0/(hmax[i].view(1, -1) + 1 - hmin[i].view(1, -1))/(wmax[i].view(1,-1) + 1 - wmin[i].view(1, -1))
            for j in range(len(gt_bboxes[i])):
                Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
                    torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1],
                                  area[0][j])
            Mask_bg[i] = torch.where(Mask_fg[i]>0, 0, 1)
            if torch.sum(Mask_bg[i]):
                Mask_bg[i] /= torch.sum(Mask_bg[i])
        fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg,
                                             C_attention_s, C_attention_t, S_attention_s, S_attention_t)
        mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
        rela_loss = self.get_rela_loss(preds_S, preds_T)


        fg_loss *= self.alpha_fgd/scale_x/scale_y
        bg_loss *= self.beta_fgd/scale_x/scale_y
        mask_loss *= self.gamma_fgd/scale_x/scale_y
        rela_loss *= self.lambda_fgd/scale_x/scale_y


        return S_attention_t, S_attention_s, C_attention_t, C_attention_s, fg_loss, bg_loss, mask_loss, rela_loss


    def get_attention(self, preds, temp):
        N, C, H, W = preds.shape


        value = torch.abs(preds) # shape [16, 128, 80, 80]
        fea_map = value.mean(axis=1, keepdim=True) # shape [16, 1, 80, 80]
        S_attention = (H*W*F.softmax((fea_map/temp).view(N, -1), dim=1)).view(N, H, W) # shape [16, 80, 80]


        channel_map = value.mean(axis=2, keepdim=False).mean(axis=2, keepdim=False) # shape [16, 128]
        C_attention = C * F.softmax(channel_map/temp, dim=1) # shape [16, 128]


        return S_attention, C_attention


    def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t): # feature loss
        loss_mse = nn.MSELoss(reduction='sum')


        Mask_fg = Mask_fg.unsqueeze(dim=1)
        Mask_bg = Mask_bg.unsqueeze(dim=1)


        C_t = C_t.unsqueeze(dim=-1)
        C_t = C_t.unsqueeze(dim=-1)


        S_t = S_t.unsqueeze(dim=1)


        C_s = C_s.unsqueeze(dim=-1)
        C_s = C_s.unsqueeze(dim=-1)


        S_s = S_s.unsqueeze(dim=1)


        fea_t = torch.mul(preds_T, torch.sqrt(S_t))
        fea_t = torch.mul(fea_t, torch.sqrt(C_t))
        fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
        bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))


        fea_s = torch.mul(preds_S, torch.sqrt(S_s))
        fea_s = torch.mul(fea_s, torch.sqrt(C_s))
        fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))


        bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))


        fg_loss = loss_mse(fg_fea_s / len(Mask_fg), fg_fea_t / len(Mask_fg)) 
        bg_loss = loss_mse(bg_fea_s / len(Mask_bg), bg_fea_t / len(Mask_bg)) 
        return fg_loss, bg_loss


    def get_mask_loss(self, C_s, C_t, S_s, S_t): # Attention mask loss


        mask_loss = torch.sum(torch.abs((C_s - C_t))) / len(C_s) + torch.sum(torch.abs((S_s - S_t))) / len(S_s)


        return mask_loss


    def spatial_pool(self, x, in_type):
        batch, channel, width, height = x.size()
        input_x = x


        input_x = input_x.view(batch, channel, height * width)


        input_x = input_x.unsqueeze(1)


        if in_type == 0:
            context_mask = self.conv_mask_s(x)
        else:
            context_mask = self.conv_mask_t(x)


        context_mask = context_mask.view(batch, 1, height * width)


        context_mask = F.softmax(context_mask, dim=2)


        context_mask = context_mask.unsqueeze(-1)


        context = torch.matmul(input_x, context_mask)


        context = context.view(batch, channel, 1, 1)


        return context


    def get_rela_loss(self, preds_S, preds_T):
        loss_mse = nn.MSELoss(reduction='sum')


        context_s = self.spatial_pool(preds_S, 0)
        context_t = self.spatial_pool(preds_T, 1)


        out_s = preds_S
        out_t = preds_T


        channel_add_s = self.channel_add_conv_s(context_s)
        out_s = out_s + channel_add_s


        channel_add_t = self.channel_add_conv_t(context_t)
        out_t = out_t + channel_add_t
        rela_loss = loss_mse(out_s / (len(out_s)), out_t / (len(out_s))) 
        # rela_loss = loss_mse(out_s, out_t)
        return rela_loss


    def last_zero_init(self, m):
        if isinstance(m, nn.Sequential):
            constant_init(m[-1].weight, val=0)
        else:
            constant_init(m.weight, val=0)


    def reset_parameters(self):
        kaiming_init(self.conv_mask_s.weight, mode='fan_in', nonlinearity='relu')
        kaiming_init(self.conv_mask_t.weight, mode='fan_in', nonlinearity='relu')
        self.conv_mask_s.inited = True
        self.conv_mask_t.inited = True


        self.last_zero_init(self.channel_add_conv_s)
        self.last_zero_init(self.channel_add_conv_t)
     def xywh2xyxy(self, x):
        """
        Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
        top-left corner and (x2, y2) is the bottom-right corner.


        Args:
            x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.


        Returns:
            y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
        """
        assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
        y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)  # faster than clone/copy
        dw = x[..., 2] / 2  # half-width
        dh = x[..., 3] / 2  # half-height
        y[..., 0] = x[..., 0] - dw  # top left x
        y[..., 1] = x[..., 1] - dh  # top left y
        y[..., 2] = x[..., 0] + dw  # bottom right x
        y[..., 3] = x[..., 1] + dh  # bottom right y
        return y

接下来,为每个颈部层初始化特征损失。为每个层生成损失的原因是计算每个层的注意力图并计算损失。

self.kd_loss_15_fn = FGDLoss(student_channels=self.distill_cfg["distiller"]["distill_cfg"][0]["neck_15"]["student_channels"],
teacher_channels=self.distill_cfg["distiller"]["distill_cfg"][0]["neck_15"]["teacher_channels"],
name=self.distill_cfg["distiller"]["distill_cfg"][0]["neck_15"]["name"],
temp = self.distill_cfg["temp"],
alpha_fgd = self.distill_cfg["alpha_fgd"],
beta_fgd = self.distill_cfg["beta_fgd"],
gamma_fgd = self.distill_cfg["gamma_fgd"],
lambda_fgd = self.distill_cfg["lambda_fgd"],
scale_y = self.distill_cfg["scale_y"],
scale_x = self.distill_cfg["scale_x"],
)


self.kd_loss_18_fn = FGDLoss(student_channels=self.distill_cfg["distiller"]["distill_cfg"][1]["neck_18"]["student_channels"],
teacher_channels=self.distill_cfg["distiller"]["distill_cfg"][1]["neck_18"]["teacher_channels"],
name=self.distill_cfg["distiller"]["distill_cfg"][1]["neck_18"]["name"],
temp = self.distill_cfg["temp"],
alpha_fgd = self.distill_cfg["alpha_fgd"],
beta_fgd = self.distill_cfg["beta_fgd"],
gamma_fgd = self.distill_cfg["gamma_fgd"],
lambda_fgd = self.distill_cfg["lambda_fgd"],
scale_y = self.distill_cfg["scale_y"],
scale_x = self.distill_cfg["scale_x"],
)


self.kd_loss_21_fn = FGDLoss(student_channels=self.distill_cfg["distiller"]["distill_cfg"][2]["neck_21"]["student_channels"],
teacher_channels=self.distill_cfg["distiller"]["distill_cfg"][2]["neck_21"]["teacher_channels"],
name=self.distill_cfg["distiller"]["distill_cfg"][2]["neck_21"]["name"],
temp = self.distill_cfg["temp"],
alpha_fgd = self.distill_cfg["alpha_fgd"],
beta_fgd = self.distill_cfg["beta_fgd"],
gamma_fgd = self.distill_cfg["gamma_fgd"],
lambda_fgd = self.distill_cfg["lambda_fgd"],
scale_y = self.distill_cfg["scale_y"],
scale_x = self.distill_cfg["scale_x"],
)

然后,通过将学生的特征、教师的特征和输入数据输入到每个损失中,计算教师和学生的时空/通道注意力图、前景、背景损失、掩码损失和相对损失。

self.kd_feat_15_spatial_t, self.kd_feat_15_spatial_s, self.kd_feat_15_channel_t, self.kd_feat_15_channel_s,self.kd_loss_15_fg, self.kd_loss_15_bg, self.kd_loss_15_ma, self.kd_loss_15_re = self.kd_loss_15_fn(student_15_feature, teacher_15_feature, batch)
self.kd_feat_18_spatial_t, self.kd_feat_18_spatial_s, self.kd_feat_18_channel_t, self.kd_feat_18_channel_s,self.kd_loss_18_fg, self.kd_loss_18_bg, self.kd_loss_18_ma, self.kd_loss_18_re = self.kd_loss_18_fn(student_18_feature, teacher_18_feature, batch)
self.kd_feat_21_spatial_t, self.kd_feat_21_spatial_s, self.kd_feat_21_channel_t, self.kd_feat_21_channel_s,self.kd_loss_21_fg, self.kd_loss_21_bg, self.kd_loss_21_ma, self.kd_loss_21_re = self.kd_loss_21_fn(student_21_feature, teacher_21_feature, batch)

然后,计算最终损失并进行反向处理。

self.kd_loss_fg = self.kd_loss_15_fg + self.kd_loss_18_fg + self.kd_loss_21_fg
self.kd_loss_bg = self.kd_loss_15_bg + self.kd_loss_18_bg + self.kd_loss_21_bg
self.kd_loss_ma = self.kd_loss_15_ma + self.kd_loss_18_ma + self.kd_loss_21_ma
self.kd_loss_re = self.kd_loss_15_re + self.kd_loss_18_re + self.kd_loss_21_re
self.loss_items_kd = [self.kd_loss_fg, self.kd_loss_bg, self.kd_loss_ma, self.kd_loss_re]
self.tloss_kd = ([(x*i + y) / (i + 1) for x, y in zip(self.tloss_kd, self.loss_items_kd)] if self.tloss_kd is not None else self.loss_items_kd)
self.loss_kd = sum(self.loss_items_kd) * len(batch["im_file"])
self.loss_kd_total = self.loss + self.loss_kd
self.scaler.scale(self.loss_kd_total).backward()

查看结果

特征损失的目标是增加学生和教师的时空/通道注意力图之间的相似性。从图像中也可以确认,随着学习epoch的增加,相似性增加,损失也减少。

COCO 数据集,教师 (YoloV9-e),学生 (YoloV9-c)

46f768fd3e2a6a563587cd32d492212b.png

(第 2 个epoch)每个颈部层的空间注意力图(上)/每个颈部层的通道注意力图(下)

a02249c5ec8509f4cc856671315fc82e.png

(第 52 个epoch)每个颈部层的空间注意力图(上)/每个颈部层的通道注意力图(下)

b4bdd498cc04cdb742a856adce6c90d3.png

(第 116 个epoch)每个颈部层的空间注意力图(上)/每个颈部层的通道注意力图(下)

确认随着epoch数的增加,不仅空间注意力图的相似性增加,通道注意力图的相似性也增加。然而,mAP50:95 的性能并没有超过现有的 YoloV9-c 模型。

结论

尽管学生和教师的颈部之间的特征相似性增加了(保真度高),但这个结果并没有提高学生的泛化性能。

·  END  ·

🌟 想要变身计算机视觉小能手?快来「小白玩转Python」公众号!

回复Python视觉实战项目,解锁31个超有趣的视觉项目大礼包!🎁

6efe2890b811d6dab3d876407548b930.png

本文仅供学习交流使用,如有侵权请联系作者删除

### YOLOv10知识蒸馏的方法和技术 #### 背景介绍 YOLO系列模型因其高效性和准确性而在目标检测领域广泛应用。随着版本迭代,YOLOv10引入了一系列优化措施来提高性能并减少计算资源消耗。其中一个重要方向就是利用知识蒸馏技术,在保持高精度的同时降低模型复杂度。 #### 方法概述 对于YOLOv10的知识蒸馏实现,主要集中在两个方面:特征图层面的迁移学习以及分类分数和边界框预测的学习。具体来说: - **教师-学生框架构建** 教师网络通常选用较大的预训练模型(如YOLOv10),而学生则可以是一个更轻量级的小型化变体。这种设计允许从小规模数据集中提取更多有用的信息[^1]。 - **损失函数调整** 结合传统的交叉熵损失项外加额外的距离度量机制用于衡量师生间输出差异。例如,采用均方误差(MSE)或KL散度(KLD),使得学生的软标签尽可能接近于由老师产生的概率分布[^2]. ```python import torch.nn.functional as F def kd_loss_fn(student_output, teacher_output, temperature=4): """Compute the KD loss between student and teacher outputs.""" soft_student = F.log_softmax(student_output / temperature, dim=-1) soft_teacher = F.softmax(teacher_output / temperature, dim=-1) return F.kl_div( soft_student, soft_teacher.detach(), reduction='batchmean' ) * (temperature**2) ``` - **多尺度特征融合** 鉴于不同层次的感受野大小各异,因此有必要让低层到高层逐步继承来自老师的表征能力。此过程可通过跳跃连接或者注意力模块完成,确保细粒度语义信息得以保留. - **自适应权重分配** 不同阶段的重要性可能有所区别,故应动态调节各部分所占比例以达到最佳效果。比如初期侧重模仿整体结构;后期聚焦局部细节修正等. #### 实现技巧 实际操作过程中还需注意以下几点建议: - 数据增强策略需一致; - 初始化参数时可借鉴已有的良好实践方案; - 训练超参的选择至关重要,特别是批次尺寸(batch size), 学习率(lr)等因素会影响最终收敛情况。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值