Supervised Contrastive Learning
Motivation
- 交叉熵损失是监督学习中应用最广泛的损失函数,度量两个分布(标签分布和经验回归分布)之间的KL散度,但是也存在对于有噪声的标签缺乏鲁棒性、可能存在差裕度(允许有余地的余度)导致泛化性能下降的问题。而大多数替代方案还不能很好地用于像ImageNet这样的大规模数据集。
- 在对比学习中,核心思想是拉近某一个锚点与其正样本之间的距离,拉远锚点与该锚点其他负样本之间的距离,通常来说,一个锚点只有一个正样本,其他全视为负样本。
Contribution
-
提出了一个新的扩展对比损失函数,允许每个锚点有多个正对。因此,将对比学习适应于完全监督的setting。
-
作者通过梯度计算的角度说明了文中提出的loss可以更好地关注于 hard positives and negatives,从而获得更好的效果。
Method
-
表征学习框架
总的来说,有监督对比学习框架的结构类似于表征学习框架,由如下几个部分组成:
-
数据增强模块
A ( ⋅ ) A(\cdot) A(⋅)的作用是将输入图像转换为随机增强的图像 x ˉ \bar x xˉ,对每张图像都生成两张增强的子图像,代表原始数据的不同视图。数据增强分为两个阶段:第一阶段是对数据进行随机裁剪,然后将其调整为原分辨率大小;第二阶段使用了三种不同的增强方法,具体包括:(1)自动增强,(2)随机增强,(3)Sim增强(按照顺序进行随机颜色失真和高斯模糊,并可能在序列最后进行额外的稀疏图像扭曲操作)。
-
编码器网络
编码器网络 E ( ⋅ ) E(\cdot) E(⋅)的作用是将增强后的图像 x ˉ \bar x xˉ映射到表征空间,每对子图像输入到同一个编码器中得到一对表征向量,本文用的是ResNet50和ResNet200,最后使用池化层得到一个2048维的表征向量。表征层使用单位超球面进行正则化。
-
投影网络
投影网络 P ( ⋅ ) P (\cdot) P(⋅)的作用是将表征向量映射成一个最终向量 z z z进行loss的计算,本文用的是只有一个隐藏层的多层感知器,输出维度为128。同样使用单位超球面进行正则化。在训练完成后,这个网络会被一个单一线性层取代。
-
-
对比损失
本文的数据是带有标签的,采用mini batch的方法获取数据,首先从数据中随机采样 N N N个样本对,记为 { x k , y k } k = 1 , 2 , … , N {\left\{x_k , y_k\right\}}_{k = 1,2,\dots,N} { xk,yk}k=1,2,…,N , y k y_k yk是 x k x_k xk的标签,之后进行数据增强获得 2 N 2N 2N个数据样本 { x ˉ k , y ˉ k } k = 1 , 2 , … , 2 N {\left\{\bar x_k , \bar y_k\right\}}_{k = 1,2,\dots,2N} { xˉk,yˉk}k=1,2,…,2N,其中, x ˉ