MixMatch:半监督学习的整体方法
Abstract
半监督学习已被证明是一个强大的范例,利用未标记的数据,以减轻对大型标记数据集的依赖。在这项工作中,我们统一了目前半监督学习的主流方法,提出了一种新的算法MixMatch,该算法为数据增强的未标记样本猜测低熵标签,并使用MixUp混合标记和未标记的数据。MixMatch通过大量的数据集和标记的数据量获得最先进的结果。例如,在具有250个标签的CIFAR-10上,我们将错误率降低了4倍(从38%降低到11%),在STL-10上降低了2倍。我们还演示了MixMatch如何帮助在差异隐私中实现更精确的隐私权衡。最后,我们进行了一项消融研究来梳理MixMatch的哪些成分对其成功最重要。我们发布了实验中使用的所有代码: https://github.com/google-research/mixmatch。
1 Introduction
最近在训练大型深层神经网络方面取得的成功很大程度上要归功于大型标记数据集的存在。然而,对于许多学习任务来说,收集标记数据是昂贵的,因为它必然涉及专家知识。这也许是最好的说明,在医疗任务中,测量需要昂贵的机器和标签是一个耗时的分析,从多个人类专家的成果。此外,数据标签可能包含私人信息。相比之下,在许多任务中,获取未标记的数据要容易或便宜得多。
半监督学习[6](SSL)通过允许模型利用未标记的数据,在很大程度上缓解了对标记数据的需求。最近的许多半监督学习方法都增加了一个损失项,该损失项是在未标记数据上计算的,并鼓励模型更好地推广到不可见数据。在最近的工作中,这个损失项分为三类(在第2节中进一步讨论):熵最小化[18,28]——它鼓励模型对未标记的数据输出有信心的预测;一致性正则化,它鼓励模型在输入受到扰动时产生相同的输出分布;以及泛型正则化,使模型具有良好的泛化能力,避免了训练数据的过度拟合。
在本文中,我们介绍了MixMatch,一个SSL算法,它引入了一个单一的损失,优雅地统一了这些主要方法的半监督学习。与以前的方法不同,MixMatch一次以所有属性为目标,我们发现这有以下好处:
- 实验表明,MixMatch在所有标准图像基准上都获得了最先进的结果(第4.2节),并将CIFAR-10上的错误率降低了4倍;
- 我们在一项消融研究中进一步表明,MixMatch大于其各部分之和;
- 我们在第4.3节中证明了MixMatch对于差异化的隐私学习是有用的,使学生在PATE框架[36]中获得新的最新成果,同时加强隐私保障和准确性。
简而言之,MixMatch为未标记的数据引入了一个统一的损失项,它无缝地减少了熵,同时保持了一致性,并与传统的正则化技术保持兼容。
2 Related Work
为了设置MixMatch的阶段,我们首先介绍现有的SSL方法。我们主要关注那些目前最先进的技术,以及MixMatch的基础;有很多关于SSL技术的文献我们不在这里讨论(例如,“转换”模型[14,22,21],基于图的方法[49,4,29],生成建模[3,27,41,9,17,23,38,34
,42]等)。更全面的概述见[49,6]。下面,我们将引用一个通用模型
p
m
o
d
e
l
(
y
∣
x
;
θ
)
p_{model}(y | x;θ)
pmodel(y∣x;θ),对于参数为
θ
θ
θ的输入
x
x
x,它在类标签
y
y
y上产生一个分布。
2.1 Consistency Regularization(一致性正则化)
在有监督学习中,一种常见的正则化技术是数据增强,它应用假设的输入变换,使类的语义不受影响。例如,在图像分类中,通常对输入图像进行弹性变形或添加噪声,这可以显著地改变图像的像素内容而不改变其标签[7,43,10]。粗略地说,这可以通过生成近乎无限的新的、修改过的数据流来人为地扩大训练集的大小。一致性正则化将数据增强应用到半监督学习中,它利用了这样一种思想,即对于一个未标记的例子,即使在它被增强之后,分类器也应该输出相同的类分布。更正式地说,一致性正则化强制一个未标记的示例x应该被分类为与 A u g m e n t ( x ) Augment(x) Augment(x)相同的类, A u g m e n t ( x ) Augment(x) Augment(x)是其自身的一种扩充。
在最简单的情况下,对于未标记的点
x
x
x,先前的工作[25,40]添加了损失项:
请注意,
A
u
g
m
e
n
t
(
x
)
Augment(x)
Augment(x)是一个随机变换,因此等式(1)中的两项不相同。”Mean Teacher“[44]使用模型参数值的指数移动平均值,将公式(1)中的一个术语替换为模型的输出。这提供了一个更稳定的目标,并被发现经验显着改善结果。这些方法的一个缺点是它们使用特定领域的数据扩充策略。”虚拟对抗训练“[31](VAT)通过计算附加扰动来解决这一问题,该扰动应用于最大程度改变输出类分布的输入。MixMatch通过使用图像的标准数据增强(随机水平翻转和裁剪)。
2.2 Entropy Minimization(熵最小化)
在许多半监督学习方法中,一个共同的基本假设是分类器的决策边界不应通过边缘数据分布的高密度区域。一种方法是要求分类器对未标记的数据输出低熵预测。这是在 [18] 中明确完成的,使用损失项最小化未标记数据 x x x的 p m o d e l ( y ∣ x ; θ ) p_{model}(y | x; θ) pmodel(y∣x;θ) 的熵。这种熵最小化的形式与[31]中的VAT相结合,以获得更强的结果。”伪标签“[28]通过从未标记数据的高置信度预测中构造硬(1-hot)标签并将其用作标准交叉熵损失中的训练目标,隐式地实现熵最小化。MixMatch还通过对未标记数据的目标分布使用“锐化”函数隐式实现熵最小化,如第3.2节所述。
2.3 Traditional Regularization(传统正则化)
正则化是指对模型施加约束以使其更难记忆训练数据的一般方法,因此有望使其更好地泛化到看不见的数据 [19]。 我们使用权重衰减来惩罚模型参数的 L2 范数 [30, 46]。 我们还在 MixMatch 中使用 MixUp [47] 来鼓励“在”样本之间的凸行为。 我们将 MixUp 用作正则化器(应用于标记数据点)和半监督学习方法(应用于未标记数据点)。 MixUp 之前已经应用于半监督学习; 特别是,[45] 的并发工作使用了 MixMatch 中使用的方法的一个子集。 我们澄清了消融研究中的差异(第 4.2.3 节)。
3 MixMatch
在本节中,我们将介绍我们提出的半监督学习方法MixMatch。MixMatch是一种“整体”方法,它结合了第2节中讨论的SSL主流范例的思想和组件。给定一批带有一个热目标的标记样本
X
X
X(代表 L 个可能的标签之一)和一个相同大小的批次
U
U
U 的未标记样本,MixMatch 生成一批经过处理的增强标记样本
X
′
X'
X′和一批具有 “猜测”标签的增强未标记样本
U
′
U'
U′。然后用
U
′
U'
U′和
X
′
X'
X′分别计算有标记和无标记的损失项。更正式地说,半监督学习的组合损失L定义为:
其中
H
(
p
,
q
)
H(p,q)
H(p,q)是分布
p
p
p和
q
q
q之间的交叉熵,
T
,
K
,
α
,
T,K,α,
T,K,α, 以及
λ
u
λ_u
λu是下面描述的超参数。算法1中提供了完全混合匹配算法,图1中给出了标签猜测过程的示意图。接下来,我们描述了混合匹配的各个部分。
3.1 Data Augmentation(数据扩充)
正如许多SSL方法中的典型情况一样,我们对标记和未标记的数据都使用数据扩充。对于一批标记数据 X X X中的每个 x b x_b xb,我们生成一个转换版本 x b ^ \hat{x_b} xb^= A u g m e n t ( x b ) Augment(x_b) Augment(xb) (算法1,第3行)。对于一批未标记数据 U U U中的每个 u b u_b ub,我们生成 K K K个扩充 u b , k ^ \hat{u_{b,k}} ub,k^= A u g m e n t ( u b ) Augment(u_b) Augment(ub), k ∈ ( 1 , . . . , K ) k∈ (1, . . . , K) k∈(1,...,K) (算法1,第5行)。我们使用这些单独的扩充来为每个 u b u_b ub生成一个“猜测标签” q b q_b qb,这个过程我们将在下面的小节中描述。
3.2 Label Guessing(标签猜测)
对于
U
U
U中的每个未标记样本,MixMatch使用模型的预测生成样本标签的“猜测”。这种猜测后来用于无监督损失术语。为了做到这一点,我们在算法1第7行计算了
u
b
u_b
ub的所有
K
K
K个扩充的模型预测类分布的平均值。
在一致性正则化方法中,使用数据增强来获得未标记样本的人工目标是常见的[25,40,44]。
Sharpening. 在生成标签猜测的过程中,我们执行了另一个步骤,其灵感来自于半监督学习中熵最小化的成功(在第2.2节中讨论)。给出了在扩充上的平均预测
q
b
ˉ
\bar{q_b}
qbˉ、 我们使用锐化函数来降低标签分布的熵。在实践中,对于锐化函数,我们使用调整这个分类分布的“温度”的常用方法[16],这被定义为操作
其中
p
p
p是一些输入分类分布(特别是在 MixMatch 中,p 是在扩充上的平均类别预测
q
b
ˉ
\bar{q_b}
qbˉ,如算法 1,第 8 行所示),
T
T
T 是超参数。 作为
T
→
0
T → 0
T→0,
S
h
a
r
p
e
n
(
p
,
T
)
Sharpen(p, T)
Sharpen(p,T) 的输出将接近Dirac(“one-hot”)分布。 由于我们稍后将使用
q
b
=
S
h
a
r
p
e
n
(
q
b
ˉ
,
T
)
q_b = Sharpen(\bar{q_b}, T)
qb=Sharpen(qbˉ,T) 作为模型预测
u
b
u_b
ub扩充的目标,因此降低温度会鼓励模型产生低熵预测。
3.3 MixUp
我们使用MixUp进行半监督学习,与过去的SSL工作不同,我们将标签样本和未标签样本与标签猜测(如第3.2节所述生成)混合。为了与我们单独的损失项相兼容,我们定义了一个稍微修改的MixUp版本。对于一对具有相应标签概率
(
x
1
,
p
1
)
,
(
x
2
,
p
2
)
(x_1,p_1),(x_2,p_2)
(x1,p1),(x2,p2)的两个例子,我们通过
其中
α
α
α 是超参数。Vanilla MixUp省略了公式(9)(即,它设定
λ
′
=
λ
λ'= λ
λ′=λ). 假设标记和未标记的样本在同一批中串联在一起,我们需要保持批的顺序以适当地计算单个损失分量。这可通过等式(9)实现,其确保
x
′
x'
x′更接近
x
1
x_1
x1而不是
x
2
x_2
x2。为了应用混合,我们首先收集所有带标签的扩充标记样本和所有带猜测标签的未标记样本
(算法 1,第 10-11 行)。 然后,我们组合这些集合并将结果打乱以形成 W W W,它将作为 MixUp 的数据源(算法 1,第 12 行)。对于 X ^ \hat{X} X^中的每个第 i 个样本-标签对,我们计算 MixUp( X i ^ \hat{X_i} Xi^, W i W_i Wi) 并将结果添加到集合 X ′ X' X′(算法 1,第 13 行)。我们计算 U i ′ = M i x U p ( U i ^ , W i + ∣ X ^ ∣ ) U'_i = MixUp(\hat{U_i}, W_{i+| \hat{X}|}) Ui′=MixUp(Ui^,Wi+∣X^∣) i ∈ ( 1 , . . . , ∣ U ^ ∣ ) i∈ (1, . . . , |\hat{U} |) i∈(1,...,∣U^∣),故意使用 X ′ X' X′构造中未使用的 W W W的剩余部分(算法1,第14行)。总而言之,MixMatch 将 X X X 转换为 X ′ X' X′,这是一组应用了数据扩充和 MixUp(可能与未标记的样本混合)的标记样本。 类似地, U U U 被转换为 U ′ U' U′,这是每个未标记样本的多个扩充的集合,带有相应的标签猜测。
3.4 Loss Function(损失函数)
给定我们处理过的批次 X ′ X' X′ 和 U ′ U' U′ ,我们使用等式中所示的标准半监督损失。 (3) 至 (5)。 等式 (5) 将来自 X ′ X' X′的标签和模型预测之间的典型交叉熵损失与来自 U ′ U' U′ 的预测和猜测标签的平方 L 2 L_2 L2 损失相结合。 我们在方程中使用这个 L 2 L_2 L2 损失。 (4)(多类 Brier 分数 [5]),因为与交叉熵不同,它是有界的,对错误预测不太敏感。 出于这个原因,它经常被用作 SSL 中未标记的数据丢失 [25, 44] 以及预测不确定性的度量 [26]。 我们不通过计算猜测的标签来传播梯度,这是标准的 [25, 44, 31, 35]
3.5 Hyperparameters(超参数)
由于 MixMatch 结合了多种利用未标记数据的机制,它引入了各种超参数——特别是锐化温度 T T T、未标记扩充的数量 K K K、MixUp 中 Beta 的 α α α 参数和无监督损失权重 λ u λ_u λu。 在实践中,具有许多超参数的半监督学习方法可能会出现问题,因为小验证集很难进行交叉验证 [35, 39, 35]。 然而,我们在实践中发现 MixMatch 的大部分超参数都可以固定,不需要在每个实验或每个数据集的基础上进行调整。 具体来说,对于所有实验,我们设置 T = 0.5 T = 0.5 T=0.5 和 K = 2 K = 2 K=2。此外,我们仅在每个数据集的基础上更改 α α α 和 λ u λ_u λu; 我们发现 α α α = 0.75 和 λ u λ_u λu = 100 是很好的调整起点。 在所有实验中,我们在训练的前 16,000 步中将 λ u λ_u λu 线性增加到其最大值,这是常见的做法 [44]。
4 Experiments
我们在标准SSL基准上测试MixMatch的有效性(第4.2节)。我们的消融研究梳理了MixMatch每个成分的贡献(第4.2.3节)。作为一个附加应用,我们在第4.3节中考虑了隐私保护学习。
4.1 Implementation details(实施细则)
除非另有说明,在所有实验中,我们都使用[35]中的“宽ResNet-28”模型。我们对模型和训练过程的实现与[35]中的非常接近(包括使用5000个样本来选择超参数),除了以下区别:首先,我们没有衰减学习率,而是使用衰减率为0.999的参数指数移动平均值来评估模型。其次,我们对Wide ResNet-28模型在每次更新时应用0.0004的权重衰减。最后,我们对每 2 16 2^{16} 216个训练样本进行检查,并报告最后20个检查点的中间错误率。例如,通过平均检查点[2]或选择具有最低验证错误的检查点,这以潜在的准确性成本简化了分析。
4.2 Semi-Supervised Learning (半监督学习)
首先,我们评估了MixMatch在四个标准基准数据集上的有效性:CIFAR-10和CIFAR-100[24]、SVHN[32]和STL-10[8]。在前三个数据集上评估半监督学习的标准实践是将大多数数据集视为未标记数据,并使用一小部分作为标记数据。STL-10是一个专门为SSL设计的数据集,有5000个标记图像和100000个未标记图像,这些图像的分布与标记数据略有不同。
4.2.1 Baseline Methods
作为基线,我们考虑[35]中考虑的四种方法(Π-模型[25,40]、平均教师[44]、虚拟对抗训练[31]和伪标签[28]),这些都在第2节中进行了描述。我们还使用混合[47]作为基线。MixUp被设计为监督学习的正则化器,因此我们将其应用于SSL,将其同时应用于扩充标记样本和扩充未标记样本及其相应的预测。根据混合的标准用法,我们在混合产生的猜测标签和模型预测之间使用了交叉熵损失。正如[35]所倡导的,我们在相同的代码库中重新实现了这些方法中的每一种,并将它们应用于相同的模型(如第4.1节所述),以确保公平的比较。我们重新调整了每个基线方法的超参数,与[35]中的方法相比,这通常会导致边缘精度的提高,从而为测试MixMatch提供了更具竞争力的实验设置。
4.2.2 Results
CIFAR-10 对于CIFAR-10,我们使用从250到4000的不同数量的标记样本来评估每种方法的准确性(这是标准实践)。结果如图2所示,对于CIFAR-10,我们使用 λ u λ_u λu=75。我们为每个标记点创建了5个分割,每个具有不同的随机种子。每个模型在每次分割时进行训练,错误率通过分割的平均值和方差来报告。我们发现MixMatch比其他方法有显著的优势,例如在4000个标签上达到6.24%的错误率。作为参考,在同一模型上,对5万个样本进行全监督训练,错误率为4.17%。此外,MixMatch仅使用250个标签就获得了11.08%的错误率。相比之下,在250个标签时,次优方法(VAT[31])的错误率为36.03,超过4.5× 考虑到4.17%是我们的模型在完全监督学习下获得的误差极限,比MixMatch更高。此外,在4000个标签上,次优的方法(Mean Teacher[44])获得了10.36%的错误率,这表明MixMatch可以在只有1/16个标签的情况下获得类似的性能。我们认为,最有趣的比较是与很少的标记数据点进行比较,因为它揭示了方法的样本效率,这是SSL的核心。
CIFAR-10 and CIFAR-100 with a larger model 一些先前的工作[44,2]也考虑了使用更大的2600万参数模型。我们的基本模型,如[35]所用,只有150万个参数,这使得与这些结果的比较很混乱。为了与这些结果进行更合理的比较,我们测量了增加基础ResNet模型宽度的效果,并评估了MixMatch在28层宽ResNet模型上的性能,该模型每层有135个过滤器,产生2600万个参数。我们还评估了带有10000个标签的CIFAR-100上这个更大模型的MixMatch,以与[2]的相应结果进行比较。结果如表1所示。一般来说,MixMatch匹配或优于[2]中的最佳结果,尽管我们注意到,由于[44,2]中的模型也使用了更复杂的“shake-shake”正则化[15],因此比较仍然存在问题。对于这个模型,我们使用了0.0008的权重衰减。我们曾经对于CIFAR-10使用 λ u λ_u λu=75和对于CIFAR-100, λ u λ_u λu=150。
SVHN and SVHN+Extra 与CIFAR-10一样,我们使用250到4000个不同的标签数来评估SVHN上每个SSL方法的性能。作为标准实践,我们首先考虑将73257样本训练集分为标记和未标记数据的设置。结果如图3所示。
我们使用
λ
u
λ_u
λu=250。这里再次对模型进行了评估,对每个标记点的数量进行5次分割,每个标记点具有不同的随机种子。我们发现MixMatch的性能在所有数量的标记数据中都是相对稳定的(并且比所有其他方法都好)。令人惊讶的是,经过额外的调整后,我们能够从普通教师那里获得非常好的表现[44],尽管它的错误率始终略高于MixMatch。
注意,SVHN有两个训练集:train和extra。在完全监督学习中,两个集合被连接起来形成完整的训练集(604388个样本)。在SSL中,由于历史原因,多余的设置被放在一边,只使用了train(73257个样本)。我们认为,利用train和extra来处理未标记数据更有趣,因为它显示出比标记样本更高的未标记样本比率。我们在表3中报告了SVHN和SVHN+Extra的错误率。对于我们使用的SVHN+额外 α = 0.25 , λ u = 250 α = 0.25, λ_u=250 α=0.25,λu=250,由于可用数据量较大,重量衰减较低,为0.000002。我们发现,在两个训练集上,MixMatch几乎立即匹配同一训练集上的完全监督性能–例如,MixMatch在SVHN+Extra上仅使用250个标签时,错误率为2.22%,而完全监督性能为1.71%。有趣的是,在SVHN+Extra-MixMatch上,在没有额外(2.59%的误差)的情况下,在SVHN上,每个标记的数据量都优于完全监督训练。为了强调这一点的重要性,请考虑以下场景:您有来自SVHN的73257个示例,其中250个示例已标记,并且可以选择:您可以获得8个× 更多未标记数据并使用MixMatch或获取293× 更多的标记数据和使用完全监督学习。我们的结果表明,获得额外的未标记数据和使用MixMatch更有效,这可能比获得293更便宜× 更多标签。
STL-10 STL-10包含5000个训练样本,旨在与10个预定义折叠一起使用(我们仅使用前5个),每个1000个示例。然而,之前的一些工作对所有5000个样本进行了培训。因此,我们比较了两种实验设置。MixMatch拥有1000个样本,超过了1000个样本的最新技术,也超过了使用所有5000个标记示例的最新技术。请注意,表2中没有一条基线使用相同的实验装置(即模型),因此很难直接比较结果;但是,由于MixMatch以2的因子获得最小的误差,因此我们认为这是对我们的方法的信任投票。我们使用
λ
u
λ_u
λu=50。
4.2.3 Ablation Study(消融研究)
由于MixMatch结合了各种半监督学习机制,因此它与文献中现有的方法有很多共同点。因此,我们研究移除或添加组件的效果,以便进一步深入了解MixMatch的性能。具体来说,我们衡量
- 使用K个扩充的平均类分布或使用单个扩充的类分布(即设置K=1)
- 消除温度锐化(即设置T=1)
- 在产生猜测标签时使用模型参数的指数移动平均(EMA),正如Mean Teacher所做的那样[44]
- 只在有标签的样本和无标签的样本之间进行MixUp,而不在有标签和无标签的样本之间进行
- 使用插值一致性训练[45],这可以看作是本消融研究的一个特例,其中只使用未标记的混合,不应用锐化,使用EMA参数进行标签猜测
对250和4000个标签的CIFAR-10进行消融;结果如表4所示。我们发现每种成分对MixMatch的性能都有贡献,其中250标签设置的差异最大。尽管Mean Teacher’s SVHN(图3)的有效性,我们发现使用类似的EMA参数值略微有损害MixMatch的性能。
4.3 Privacy-Preserving Learning and Generalization(隐私保护学习与推广)
隐私学习可以让我们衡量我们的方法的概括能力。实际上,保护训练数据的隐私等同于证明模型没有过度拟合:一个学习算法如果添加、修改,或者,移除任何训练样本都保证不会导致所学习的模型参数出现统计上的显著差异[13]。因此,在实践中,利用不同隐私进行学习是一种正规化的形式[33]。每个训练数据访问都构成了潜在的隐私泄露,编码为输入及其标签对。因此,从私有训练数据进行深度学习的方法,如DP-SGD[1]和PATE[36],在计算模型参数的更新时,受益于访问尽可能少的标记私有训练点。半监督学习自然适合这种环境。
我们使用PATE框架进行隐私学习。学生是在半监督的方式下从公共的未标记数据中训练出来的,其中一部分是由一组教师标记的,这些教师可以访问私有的标记训练数据。学生要求达到固定准确度的标签越少,它提供的隐私保障就越强。教师使用嘈杂的投票机制来回应学生的标签查询,当他们无法达成足够强烈的共识时,他们可能会选择不提供标签。因此,如果MixMatch改进了PATE的性能,那么它还可以从每个类的几个典型示例中说明MixMatch改进的泛化。
我们比较了MixMatch与SVHN上的VAT[31]基线的准确性和隐私权衡。VAT达到了之前最先进的91.6%的测试准确率,隐私损失为ε = 4.96 [37]。因为MixMatch在标记点较少的情况下表现良好,所以它能够达到95.21± 测试8的准确率为0.17%,隐私损失小得多ε = 0.97。因为 e ε e^ε eε 是用来衡量隐私度的,改进幅度约为 e 4 ≈ 55 × e^4≈ 55× e4≈55×, 一个显著的进步。失去隐私ε 低于1意味着一个更强大的隐私保障。请注意,在“私人训练”设置中,学生模型总共只使用10000个样本。
5 Conclusion
我们介绍了MixMatch,这是一种半监督学习方法,它结合了当前SSL主流模式的思想和组件。通过对半监督和隐私保护学习的大量实验,我们发现MixMatch在我们研究的所有环境中都比其他方法表现出显著的性能改进,错误率通常减少两倍或更多。在未来的工作中,我们有兴趣将半监督学习文献中的其他想法整合到混合方法中,并继续探索哪些组件可以产生有效的算法。另外,大多数关于半监督学习算法的现代工作是在图像基准上进行评估的;我们有兴趣探索MixMatch在其他领域的有效性。