损失函数:从Softmax到AMSoftmax

不同于Softmax,AMSoftmax属于Metric Learning——缩小类内距增大类间距的策略。下图2形象的解释了Softmax 和 AMSoftmax的区别,Softmax能做到的只能是划分类别间的界线——绿色虚线,而AMSoftmax可以缩小类内距增大类间距,将类的区间缩小到Target region范围,同时又会产生margin大小的类间距。

SoftMaxLoss—>A-SoftMaxLoss—>AM-SoftMaxLoss—>Arc-SoftMaxLoss




参考资料:
从Softmax到AMSoftmax(附可视化代码和实现代码)
AM-Softmax

### A-SoftmaxAM-Softmax 和 AAM-Softmax 的区别与应用场景 #### 背景介绍 在深度学习领域,尤其是人脸识别任务中,设计有效的损失函数对于提升模型性能至关重要。A-SoftmaxAM-Softmax 和 AAM-Softmax 都是在传统 SoftMax Loss 基础上改进而来的高级损失函数,旨在增强类间区分性和类内紧凑性。 --- #### 1. **A-Softmax** A-Softmax(Angular Softmax)的核心思想是通过对角度施加约束来增大类别之间的间隔[^2]。具体来说,它通过引入一个超参数 \( m \),强制要求预测的角度 \( \theta_{y_i} \) 至少比其他类别的最小角度大 \( m\pi/rad \)[^4]。这种机制使得特征分布在高维空间中的分布更加集中于特定区域,从而提高分类准确性。 其主要特点如下: - 使用球面约束使特征向量规范化到单位长度。 - 提升了类间的判别能力,但由于优化目标复杂化可能导致收敛速度变慢。 实现方式通常涉及对原始 logits 进行变换操作: ```python import torch import torch.nn as nn class ASoftmaxLoss(nn.Module): def __init__(self, num_classes, embedding_size, margin=4): super(ASoftmaxLoss, self).__init__() self.num_classes = num_classes self.embedding_size = embedding_size self.margin = margin def forward(self, x, labels): cos_theta = F.linear(F.normalize(x), F.normalize(self.weight)) phi_theta = (-1)**k * cos(m*theta) - 2*k one_hot = torch.zeros_like(cos_theta) one_hot.scatter_(1, labels.view(-1, 1).long(), 1) output = (one_hot * phi_theta) + ((1.0 - one_hot) * cos_theta) logit = output * s loss = F.cross_entropy(logit, labels) return loss ``` --- #### 2. **AM-Softmax** AM-Softmax(Additive Margin Softmax)进一步简化了 A-Softmax 中复杂的三角函数计算,转而采用一种线性的加法形式来增加边界[^3]。它的核心在于直接修改 logits 计算公式,在正样本方向减去一个固定值 \( m \),而在负样本方向保持不变。这样既保留了 A-Softmax 的优势又降低了训练难度。 关键特性包括: - 更容易实现且稳定; - 减少了因多次幂运算带来的数值不稳定风险。 以下是其实现代码片段: ```python class AMSoftmaxLoss(nn.Module): def __init__(self, num_classes, embedding_size, margin=0.35, scale=30): super(AMSoftmaxLoss, self).__init__() self.num_classes = num_classes self.scale = scale self.margin = margin self.weight = nn.Parameter(torch.Tensor(num_classes, embedding_size)) def forward(self, embeddings, labels): cosine = F.linear(F.normalize(embeddings), F.normalize(self.weight)) one_hot = torch.zeros_like(cosine) one_hot.scatter_(1, labels.view(-1, 1), 1) adjusted_cosine = self.scale * (cosine - one_hot * self.margin) loss = F.cross_entropy(adjusted_cosine, labels) return loss ``` --- #### 3. **AAM-Softmax** AAM-Softmax(Additive Angular Margin Softmax),也称为 ArcFace,是对 AM-Softmax 的又一次升级版本。不同于简单地在线性维度上加入偏移项,该方法提出了基于角度的加法规则——即给定真实标签对应的余弦相似度减少一定弧度作为惩罚因子。这种方法能够更好地捕捉数据内在几何结构并促进更优解的学习过程。 技术亮点总结为以下几点: - 明确考虑到了欧几里得距离之外的角度关系; - 可以有效缓解由于尺度缩放引起的信息丢失问题。 下面是 Python 实现示例: ```python class AAMSoftmaxLoss(nn.Module): def __init__(self, num_classes, embedding_size, margin=0.5, scale=64): super(AAMSoftmaxLoss, self).__init__() self.num_classes = num_classes self.scale = scale self.margin = margin self.weight = nn.Parameter(torch.FloatTensor(num_classes, embedding_size)) def forward(self, embeddings, labels): cosine = F.linear(F.normalize(embeddings), F.normalize(self.weight)) theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7)) target_logits = torch.cos(theta[range(len(labels)), labels] + self.margin) final_target_logits = torch.cos(target_logits) onehot = torch.zeros_like(cosine) onehot.scatter_(1, labels.view(-1, 1), 1) adjusted_cosine = cosine + onehot * (final_target_logits - cosine[range(len(labels)), labels].unsqueeze(1)) scaled_logit = self.scale * adjusted_cosine loss = F.cross_entropy(scaled_logit, labels) return loss ``` --- ### 总结对比表 | 特性/算法 | A-Softmax | AM-Softmax | AAM-Softmax | |------------------|------------------------------------|-----------------------------------|----------------------------------| | 边界类型 | 多次幂 | 加法 | 角度加法 | | 数学表达式 | \( (\cos(\theta))^m \) | \( \cos(\theta)-m \) | \( \cos(\theta+m) \) | | 收敛效率 | 较低 | 较高 | 较高 | | 类间分离程度 | 很好 | 好 | 极佳 | --- ### 应用场景建议 - 如果追求极致精度并且可以接受较长训练时间,则推荐选用 **AAM-Softmax** 或者 **A-Softmax**。 - 对于资源受限或者希望快速迭代实验的情况下,可以选择相对简单的 **AM-Softmax** 来获得不错的平衡表现。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值