center loss 论文学习

本文深入解析了CenterLoss的工作原理及其在深度学习中的应用。通过对比传统网络框架,详细介绍了CenterLoss如何使同一类别的输出结果更加集中,从而提高分类准确性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

center loss框架

从网络的的框架来看,center loss的主要工作是下图中的“Discriminative Features”。
fig1

普通的网络框架,在反向传播的过程中,根据类别标签,会将不同的类别划分开。如“Separable Features”所示,一开始两种颜色是混杂的,通过改变网络参数,让不同颜色能被分类器分开,就达到了目的。而这个过程中,只对不同类有要求,同一类没有进行约束。
center loss则是让类内的输出结果更加集中。

为了展示实际的效果,作者在mnist上进行了测试,下图是softmax分类器前面增加的一层的参数,其维度为2,这样就可以进行可视化的显示。

F=WX

X是上一层的输出,维度为800(根据论文计算得到),F为施加center loss的全连接层的输出,维度为2。那么权重参数F为{800,2}的矩阵。
fig2
在没有采用center loss时,不同类别的输出图像是一种花瓣,其特点是同一类的方差较大。可以找到分界线将不同类别区分开,虽然花瓣外尖端与其他类间距很大,花瓣中心的区分很小,很容易造成错误,如橘色区域,红线表示分类线。
这里写图片描述

如何让同一类颜色更集中呢?文中采用了center loss:
centerloss
很简单,每个将输出点与这类中心点的距离累加作为损失。
回想方差公式:
v
是不是很类似?降低center loss其实也可以看作是降低同类的方差。

实现

推荐EncodeTS/TensorFlow_Center_Loss的代码,使用TensorFlow实现,且有详细的中文注释。

center loss流程大致为:

  1. 初始化权重中心centers,形状为[num_classes, len_features],中心值为0
  2. 在一次iteration中,获取mini-batch中每一个样本对应的中心值,centers_batch,形状为[batch_size, feature_length](使用tf.gather技巧)
  3. 计算loss,特征与中心features - centers_batch的l2范数
  4. 根据论文公式(3)(4)更新权重中心:
    在一个mini-batch中,某一类j出现了n次,分解来看:
    1. 属于该类的第i个样本与中心距离cjxi
      • 同理算出这个类出现的n次样本的距离,并汇总求和
      • 除以n+1
      • loss
        center loss

### Center Loss的定义 Center Loss是一种用于深度学习模型的损失函数,旨在提升特征表示的判别能力。它的核心思想是通过最小化每个类别内部样本之间的距离,使得同一类别的特征更加集中,从而提高分类的准确性。具体而言,Center Loss通过计算每个样本特征与其对应类别中心之间的欧氏距离,并将这些距离累加作为损失值。公式可以表示为: $$ L_{center} = \frac{1}{2} \sum_{i=1}^{m} \| x_i - c_{y_i} \|^2 $$ 其中,$ x_i $ 是第 $ i $ 个样本的特征向量,$ c_{y_i} $ 是该样本所属类别的中心点,$ m $ 是样本数量[^3]。 ### Center Loss的原理 在深度学习中,Center Loss通常与Softmax Loss结合使用。Softmax Loss负责将不同类别的特征分开,而Center Loss则负责压缩同一类别的特征,使其更加紧凑。这种双重优化机制有助于模型学习到更具判别性的特征表示。 具体来说,在训练过程中,每个类别的中心点会不断更新,以反映当前所有样本的特征分布。通过最小化样本特征与类中心之间的距离,Center Loss促使同一类别的样本在特征空间中聚集在一起,从而减少类内差异,增加类间差异[^2]。 ### Center Loss的应用 Center Loss在人脸识别等领域表现出色。例如,在ECCV2016的一篇论文中,研究人员利用Center Loss辅助Softmax Loss进行人脸训练,成功提高了特征的判别能力。通过这种方式,模型能够在特征空间中更好地分离不同个体,同时压缩同一类别的特征,从而提升识别准确率[^2]。 此外,Center Loss还可以应用于其他需要高精度特征表示的任务,如图像检索和物体检测。在这些任务中,通过减少类内差异,Center Loss能够帮助模型更有效地找到目标对象[^3]。 ### 示例代码 以下是一个简单的PyTorch实现示例,展示了如何计算Center Loss: ```python import torch import torch.nn as nn class CenterLoss(nn.Module): def __init__(self, num_classes, feat_dim, device): super(CenterLoss, self).__init__() self.centers = nn.Parameter(torch.randn(num_classes, feat_dim).to(device)) def forward(self, x, labels): batch_size = x.size(0) feat_dim = x.size(1) centers = self.centers[labels] loss = (x - centers).pow(2).sum() / 2 / batch_size return loss ``` ### 总结 Center Loss通过最小化样本特征与类中心之间的距离,有效地提升了深度学习模型的特征表示能力。它在人脸识别等任务中表现出色,能够显著提高模型的准确性和鲁棒性。结合Softmax LossCenter Loss为模型提供了一种强大的优化机制,使其在特征空间中更好地分离不同类别[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值