蒸馏法文章选读——Correlation Congruence for Knowledge Distillation

本文介绍了知识蒸馏的相关性一致性(CCDK)算法。该算法创新点在于不仅学习单个样本向量差异,还学习样本间相关性。目标损失包含交叉熵、KL散度和新的Lcc损失。采用非线形RBF核,介绍了两种采样策略。实验表明小网络学习大训练集困难,CCDK能让类内更紧密,学习更多结构信息。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

7, Correlation Congruence for Knowledge Distillation

https://arxiv.org/abs/1904.01802

1),创新点:原始的蒸馏法只是用学生网络的某个向量去拟合教师网络的该向量,无论是kl散度还是欧式距离,只是向量之间的映射;但是由于教师网络和学生网络本来的差异性,所以不应该仅仅学习教师网络和学生网络单个样本向量间差异,还应该学习这两个样本间的相关性,instance congruence (学术网络和教师网络预测值的kl散度,KL divergence on predictions of teacher and student) and correlation congruence (教师网络和学生网络的相关性之间的欧式距离 euclidean distance on correlation of teacher and student).

这里的f是网络的embedded feature space的一个点,即一个图像映射到embedded feature space的一个点,

又有一个映射关系如下,其中C是相关矩阵,

Cij表示的是xi和xj在embedding space的相关性

上面的fi函数可以用各种correlation关系来度量,引入下面的correlation congruence函数

2),目标损失:所以文章最终的CCDK目标损失为,第一项为真实标签的交叉熵损失,第二项为原始蒸馏法的kl散度损失,第三项为本文新的到的Lcc损失

整个算法的流程如下所示:

3),最后一步就是选择一个合适的F->C的映射

最终选择的是一个非线形RBF核,因为非线形核更加灵活 可以补货向量之间的非线形关系

对右边的公式进行泰勒级数展开

4),minibatch的采样策略

CUR类一致性随机采样:每个minibatch包含6个类别,每个类别取8张图,总共有48张

SUR超类一致性随机采样:使用教师模型对所有数据分k类,然后每个minibatch包含6个超级类别,在每个超级类别中取8张图,总共有48张

 

5) 实验结果分析

综合上面两个分别在reid和人脸识别的应用来看,用小网络去学习一个大的/有难度的训练集是很困难的,单纯的one-hot向量很难帮助网络学习到好的泛化能力。有两点结论:

 

1,小网络使用one-hot去直接学习大训练集(ms1m)/难的训练集(MSMT17),很困难。

学生:小网络

教师:大网络

2,小网络主要差在泛化能力上,表4其实在干扰项只有百、千量级的时候差别不明显,到十万量级的时候差别就非常大了,

batch取为40,所以k=1,意味着一个minibatch只有一个人的40张照片,当k=20,一个人只有两张;但是k=1时,由于一个人的图片数可能少于40,所以一个minibatch是很可能出现两个人的图片;综合来看,单个人图片越少越难训练,因为k-20结果是最差的,而且差距很大。

k=4优于k=1说明适当多的个体数是好的。

类内的差异如上图所示,cckd的类内整体颜色要比kd的颜色要深,所以一定程度上让类内更紧密。为什么会这样?再回到公式7,相比传统的蒸馏法,多了一项Lcc,再看公式6,学生网络特征间的距离还需要和教师网络特征间的距离更接近,假设教师网络类内学习的较好,那么cckd相比kd也会有更好的类内的差距。(注意这里严重依赖于教师网络类内学习的效果)

传统kd只需要教师网络的单个向量和学生网络的单个向量接近,cckd还需要学习教师网络的向量间和学生网络的向量间距离更接近,即类内的结构信息,比如教师网络认为某人带眼镜和不带眼镜相比这个人长发和短发差异更大,那学生网络也需要学习到该信息,并且这个差异是通过RBF的p阶泰勒展开式得到的,所以能学习到更多的向量本身每个维度代表的信息。

 

### HDMNet: Hierarchical Dense Correlation Distillation for Few-Shot Segmentation HDMNet 是一种针对少样本分割任务提出的创新框架,旨在通过分层密集关联蒸馏来提升模型的表现。该方特别适用于图像分割领域,在仅有少量标记样本的情况下仍能保持较高的准确性。 #### 方概述 HDMNet 的核心在于引入了一种新颖的分层机制,能够有效地捕捉不同尺度下的特征表示,并利用教师-学生网络结构进行知识迁移。具体来说: 1. **多尺度特征提取**:为了更好地适应目标对象的变化形态,HDMNet 设计了一个可以处理多种分辨率输入的支持向量机(SVM),从而获得更加鲁棒和支持性的特征描述子[^1]。 2. **跨层注意力模块**:借鉴于层次注意原型网络(HAPN)[^3]的设计理念,HDMNet 中加入了跨层注意力机制(Cross-Level Attention Module, CLA),它允许低级别到高级别的逐层交互学习,增强了局部细节与全局上下文之间的联系。 3. **稠密对应关系建模**:不同于传统的一对一匹配方式,HDMNet 提出了稠密对应的策略(Dense Correspondence Modeling, DCM),即在整个支持集中寻找最相似区域并建立一对一或多对多的关系映射表,以此指导查询图片中相应位置像素标签预测过程。 4. **渐进式蒸馏损失函数**:考虑到直接优化整个网络可能会遇到梯度消失等问题,因此采用了自底向上逐步精细化调整的方式定义了渐进式的蒸馏损失(Progressive Distillation Loss, PDL)。这不仅有助于稳定训练过程,而且促进了从简单模式到复杂场景的有效过渡。 ```python import torch.nn as nn class HDMDistiller(nn.Module): def __init__(self, student_net, teacher_net): super().__init__() self.student = student_net self.teacher = teacher_net def forward(self, support_set, query_image): # Extract features from both networks s_features = self.student(support_set) t_features = self.teacher(support_set) # Compute dense correspondence between layers corr_matrix = compute_correspondences(s_features, t_features) # Apply cross-level attention mechanism attended_sfeat = apply_attention(corr_matrix, s_features) # Predict segmentation mask using distilled knowledge pred_mask = predict_segmentation(attended_sfeat, query_image) return pred_mask def progressive_distill_loss(student_output, teacher_output): loss = 0. for level in range(len(student_output)): l2_diff = F.mse_loss(student_output[level], teacher_output[level]) kl_divergence = F.kl_div( F.log_softmax(student_output[level]/T), F.softmax(teacher_output[level]/T)) loss += alpha * l2_diff + beta * kl_divergence return loss ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Eva_Hua

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值