FixMatch:一致性正则与伪标签方法在SSL中的最佳实践

谷歌研究团队的FixMatch算法简化SSL,结合一致性正则和伪标签,使用弱强数据增强。通过飞桨实现,FixMatch在CIFAR-10上展现SOTA性能,核心在于伪标签筛选和优化策略。

46591dc0c696894160a21d70f9b768d6.gif

0e548e2e80018f65017f5ad7b59acf2b.png

背景

半监督学习(SSL)提供了一种利用无标签数据提高模型性能的有效方法,这一领域最近取得了快速进展,但以往的算法需要借助复杂的损失函数和大量难以调整的超参数。本文介绍了谷歌的研究团队提出的FixMatch[1],这是一种大大简化现有 SSL 方法的算法。FixMatch是SSL的两种方法的组合:一致性正则和伪标签。

ab9e3af2fb8f6721c0feb0fbb3bf78dc.png

如图所示为FixMatch的流程图。FixMatch的新颖之处在于,对于无标签的样本:

  • FixMatch首先对弱增强的无标签图像预测伪标签,对于给定的图像,只有当模型产生高于阈值的预测时,才会保留作为伪标签;

  • 再对同一图像的强增强版本预测出分类概率,通过交叉熵损失衡量强弱二者的预测的一致性。

我们首先回顾一下利用了一致性正则和伪标签方法的经典SSL算法,然后再详细描述FixMatch算法,并穿插关键部分的飞桨代码实现(采用飞桨最新稳定版本)。

一致性正则

一致性正则是许多SSL算法的重要组成部分。一致性正则的思想是——即使在无标签的样本被注入噪声之后,分类器也应该为其输出相同的类分布概率。即强制一个无标签的样本  应该被分类为与自身的增强  相同的分类[2]。

伪标签

指使用模型本身为无标签数据获取标签的方法。具体而言,将模型输出的softmax概率分布视为软伪标签;或将经过argmax或者one_hot得到的预测视为硬伪标签。利用这些伪标签作为监督损失进一步训练模型。

SSL算法的比较

下表提到了与生成伪标签(Artificial label)相关的SSL算法。其中列出了用于伪标签的数据增强、模型的预测以及应用于伪标签的后处理。

a0847c21c9416700d2a46f707656e9a8.png

FixMatch的核心是一致性正则和伪标签方法的简单组合,无标签模型预测与UDA一样采用RandAugment[3]进行强增强,详细实现见AI Studio项目。

3da9723d8628f10454dc252640195e0c.png

FixMatch

FixMatch的损失函数

FixMatch的损失函数由两个交叉熵损失项组成:一个是应用于有标签数据的全监督损失,另一个是用于无标签数据的一致性正则损失。

76e01823e918967d86d273dd24c3da05.png表示batch size为  的有标签样本,其中  是训练样本,  是相应的one-hot标签。

0bb6099dc02e33c3fd187151059f4402.png是batch size为  μ 的无标签样本,其中 μ 是决定60d1a1a1d3f3edcf29596e09ffb5eaab.pngf7cd90dafe11bc9a100c9fdddcabbeff.png的数量关系的超参数。令  表示输入样本  时模型输出的类概率分布。将两个概率分布  和  之间的交叉熵表示为  。强增强表示为  ,弱增强表示为 α 。

  • 对于有标签样本,FixMatch均采用弱增强,其损失函数为:

  • </
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值