Data Augmentation Generative Adversarial Networks
文章目录
摘要
神经网络的有效训练需要很多数据,在低数据情况下,参数是欠定的,学到的网络泛化能力差。
数据增强通过使用存在的数据有效的缓和了这个问题,但是标准的数据增强仅能产生有限的看似合理的用来代替的数据。鉴于有可能产生更广泛的增强数据,作者设计和训练了一种生成模型进行数据增强。
这种模型基于条件生成对抗模型,从源域获取数据,学习获取任何形式的数据并泛化成其他同类的数据。
由于这种生成过程不依赖于这些类别,所以它可能被用于新的未见过的数据种类。
作者展示了使用 D A G A N DAGAN DAGAN来很好地增强标准的 v a n i l l a vanilla vanilla分类器,同样也是展示了 D A G A N DAGAN DAGAN可以可以增强 f e w − s h o t few-shot few−shot学习系统,比如 M a t c h i n g N e t w o r k Matching\space Network Matching Network等。
作者在 O m n i g h t Omnight Omnight展示了这些方法,在 E M N I S T EMNIST EMNIST上展示的时候已经在omniglot和VGG脸部数据上学习了 D A G A N DAGAN DAGAN
实验中能够看到低数据体制实验的精确度。
M a t c h i n g N e t w o r k Matching\space Network Matching Network也有所提升。
1. 介绍
深度神经网络的长足发展离不开非常大的数据集。现实中,我们需要使用有限的数据集来实现目标,在这些情况下深度神经网络显得有些不足了,由于训练集上的过拟合以及测试集上糟糕的泛化能力。
在低数据体制中,很多用于控制过拟合的方法(BN、Dropout等)也显得不足,网络复杂度太高。这些方法不能利用一致的输入不变性,输入不变性可能形成良好的为参数学习提供信息的先验知识。
也可以通过对现有的数据进行多种变换(包括随机平移、旋转和翻转等)。来产生更多的数据。这些方法就利用了我们知道不会用影响类别的信息。这种技术不但对低数据体制来说非常重要,对任何规模的数据集都很重要。及时是在超大数据集,如 I m a g e N e t ImageNet ImageNet上训练的模型也能从这种方法中受益。
典型的数据增强技术使用了非常有限的已知不变性,这些不变性很容易被发现。
本文作者认识到可以通过学习一种来自不同域的条件生成对抗网络,学习到一个具有更大不变性空间的模型,一般称为源域。这种方法可以用于低数据的感兴趣域(目标域)。
作者展示了这样的 D A G A N DAGAN DAGAN能够开启有效的神经网络训练,即使在低数据目标域。因为 D A G A N DAGAN DAGAN不依赖于类本身,它捕捉跨类信息,将数据点移动到其他同类的点。结果是它能被用于新的未见过的种类, 图 2 图2 图2 说明了这一点
作者还展示了学习好的 D A G A N DAGAN DAGAN能被用于大幅度有效地提升 M a t c h i n g N e t w o r k s Matching\space Networks Matching Networks,它通过用从 D A G A N DAGAN DAGAN生成的每个类别的最相关比较点来增强 M a t c h i n g N e t w o r k s Matching\space Networks Matching Networks和相关模型中的数据,从而实现这一点。这与切线距离有关,但涉及到标定流形之间的距离(由DAGAN定义), 见 图 2 见图2 见图2。
作者训练了一个 D A G A N DAGAN DAGAN,并在低数据目标域评价其表现,使用:
- 标准随机梯度神经网络训练
- 特定的 一次性元学习方法
作者 使用了三个数据集: O m n i g l o t Omniglot Omniglot、 E M N I S T EMNIST EMNIST以及更复杂的 V G G VGG VGG面部数据集。
- 从Omniglot训练的DAGAN被用于未见过的Omniglot目标域,也用于EMNIST域来证明在非常不同的源域和目标域之间迁移所带来的提升。
- VGG面部数据集提供了更加有挑战性额测试,
生成的样本质量良好,通过 v a n i l l a vanilla vanilla分类器的分类任务以及 m a t c h i n g n e t w o r k matching\space network matching network上的提升。
本文的主要贡献如下:
- 使用新的生成对抗网络学一下数据增强的表示和处理
- 展示来自单个新数据点的真实数据增强
- 应用 D A G A N DAGAN DAGAN去增强一个低数据体制下的标准分类器,证明所有任务都在泛化性能上有所提升。
- D A G A N DAGAN DAGAN在元学习空间中的应用,展示性能比之前一般的元学习模型性能要好。
- 通过学习一种网络来产生任何给定测试案例中仅仅是最显著增强的样例,来一次性增强 m a t c h i n g n e t w o r k matching\space network matching network
2. 背景
迁移学习和数据集偏移:
一次性学习问题的数据集偏移的一个特殊情况, 图 2 图2 图2以图模型的方式进行了总结。数据集偏移是将协变量偏移的概念推广到多种域间变化的情况。
一次性学习是类分布的极端偏移,两个分布没有任何共享支持:旧类的概率为0,新类的概率从0到非0。但是有一种假设是类条件分布共享一些共性,所以信息能够从源域传输到一次性目标域。
生成对抗网络
生成对抗网络(GANs),特别是深度卷积生成对抗网络(DCGANs)利用分辨真假样本作为目标,GAN方法可以学到复杂的联合密度。
数据增强
在一个模型中将已知的不变性进行编码不是一件简单的事情。将这些不变性编码在数据中,比通过现有的数据进行变形产生额外的数据要见多。
例如手写字符的标签不会因为位置,小幅度旋转或者修剪、强度的改变、尺寸的变化等而改变。
几乎所有数据增强都来自于先验已知的不变性。
少样本学习和元学习
少样本学习有许多种方法
论文名 | 方法 |
---|---|
One-shot learning with ahierarchical nonparametric bayesian model | use a hierarchical Boltzmann machine, |
One-shot generalization in deep generative models. | modern deep learning architectures for one-shot conditional generation |
Towards a neural statistician | hierarchical variational autoencoders |
Generative adversarial residual pairwise networks for one shot learning | a GAN based one-shot generative model |
Siamese neural networks for one-shot image recognition. | the use of Siamese networks |
… | … |
3. 数据增强模型
如果我们知道类标签对于特定变换而言是不会改变的,那我们就可以用这个变换来产生额外的数据。
但是如果我们不知道什么样的变换才是有效且不改变类标签的呢?但是我们有和这个问题相关的其他数据,那么我们就可以从相关的其他数据中学习这种有效变换。间 图 1 图1 图1,这是元学习的一个例子;我们从其他问题中学习如何改进目标问题的学习。这也是本文的主题。
生成对抗网络是一种从与训练集密度匹配的密度生成样本的方法,这些方法通过最小化生成数据与真实数据指甲的分布差异来学习。
生成对抗网络学习的生成模型的形式:
z = N ~ ( 0 , I ) z=\tilde{N}(0,I) z=N~(0,I) (1)
v = f ( z ) v=f(z) v=f(z) (2)
f f f通过神经网络实现, v v v是生成的向量,z是隐含的高斯变量,负责提供所产生的变化。
一个生成对抗网络可以被看成是学习刻画数据流形的变换: z = 0 z=0 z=0给出流形上的一个点,在每个不同的方向上不断改变的 z z z刻画出数据流形。
3.1 数据增强生成对抗模型。
推广的生成对抗网络也能用于刻画数据增强流形。
给定点x,我们可以学习输入x的表征: r = g ( x ) r=g(x) r=g(x),沿用之前的生成模型,组合模型的形式:
r = g ( x ) r=g(x) r=g(x)
z = N ~ ( 0 , I ) z=\tilde{N}(0,I) z=N~(0,I) (3)
x = f ( z , r ) x=f(z,r) x=f(z,r)
现在神经网络 f f f将表征 r r r和随机量 z z z作为输入,给定任何新的 x ∗ x^* x∗,我们可以
- 获得样本点普适性地有意义的表征: r ∗ = g ( x ∗ ) r*=g(x*) r∗=g(x∗),囊括了需要生成的数据的信息
- 生成了额外的增强数据 v 1 ∗ , v 2 ∗ … … v^*_1 , v^*_2…… v1∗,v2∗…… ,补充了原始的 x ∗ x^* x∗,可被用于分类器。先从高斯分布中取样 z z z,然后使用生成网络产生增强样本。
本文所使用的数据增强模型的精确形式如 图 3 图3 图3所示,模型结构的大纲在后详述。
3.2 学习
数据增强模型(3)能在源域上使用对抗方法进行学习。考虑一个由数据 D = x 1 , x 2 . . . x N D D={x_1,x_2...x_{N_D}} D=x1,x2...xND组成的源域,和对应的标签值: t 1 , t 2 . . . t N D {t_1,t_2...t_{N_D}} t1,t2...tND,作者使用改进的 W G A N WGAN WGAN辨别器,其输入:
(a)或者是一些输入样本点 x i x_i xi和来自相同类别的第二个样本点: x j x_j xj,有 t i = t j t_i=t_j ti=tj
(b)或者是一些输入样本点 x i x_i xi和以 x i x_i xi为输入的生成器的输出 x g x_g xg,
辨别器尝试着去辨别从真实数据 ( a ) (a) (a)生成的数据 ( b ) (b) (b)。生成器被训练用于最小化这种辨别能力,这种能力的度量是通过 W a s s e r s t e i n Wasserstein Wasserstein 距离。
要强调一下向辨别器提供原始 x x x的重要性。这是为了想要确保生成器能够生成不同的数据,而这些数据域当前的数据既相关又不同。通过提供当前数据的信息给辨变器,可以防止 G A N GAN GAN将当前数据简单的自动编码。同时我们又不提供类别信息,所以生成器不得不学着去泛化以和所有类的分布保持一致。
4. 结构
4.1 生成器结构:
-
使用的 G A N GAN GAN生成器是 U N e t UNet UNet和 R e s N e t ResNet ResNet的组合,作者将其称为 U R e s N e t UResNet UResNet。
-
U R e s N e t UResNet UResNet生成器由8块,每块有 4 4 4个卷积层(和l e a k y R e L U eaky ReLU eakyReLU激活函数以及 b a t c h r e n o r m a l i z a t i o n batch renormalization batchrenormalization),后面跟着下采样或者上采样层。
-
(1-4块)中的下采样层使用步长为2的卷积后接 l e a k y R e L U leaky ReLU leakyReLU、 b a t c h r e n o r m a l batchrenormal batchrenormal以及 d r o p o u t dropout dropout。
-
上采样层是步长 1 / 2 1/2 1/2复制子,后跟卷积层,leaky ReLU、 b a t c h r e n o r m a l batchrenormal batchrenormal和 d r o p o u t dropout dropout
对于Omniglot和EMNIST实验,所有层都有64个滤波器。而VGG面部实验,编码器的前两个块和解码器的后两个块有64个滤波器,编码器的后两个块和解码器的前两个块有128个滤波器。
- 另外UResNet生成器的每个块都有跳跃连接。和标准的ResNet一样,一个带步长的1\times 1卷积也能在块之间传递信息,绕开了块之间的非线性来帮助梯度流。最后相同尺寸的滤波器之间也加入了跳跃连接(和UNet一样),图如 附 录 A 附录A 附录A图6所示
4.2 辨别器结构
作者使用DenseNet辨别器,使用 l a y e r n o r m a l i z a t i o n layer\space normalization layer normalization代替 b a t c h n o r m a l i z a t i o n batch\space normalization batch normalization,因为后者会破坏 W G A N WGAN WGAN目标函数的假设条件。
5. 数据集
作者在三个数据集上测试了 D A G A N DAGAN DAGAN增强: O m n i g l o t Omniglot Omniglot, E M N I S T EMNIST EMNIST和 V G G VGG VGG面部数据集。所有的数据集被随机分为源域集、验证域集和测试域集。
对于分类网络,每种特征的所有数据更进一步被分为2个训练case,3个验证case和数量随着实验变化的测试case。
分类器的训练是在所有域的所有样本的训练case上完成的,超参数的选择是在是在验证case上完成的。最终尽在目标域的测试case上报告测试性能。
Case分割每次测试运行时随机划分。
对于一次性学习来说, D A G A N DAGAN DAGAN训练是在源域上完成的,元学习是在源域上完成的,在验证域上进行验证,在目标域上展示结果。然后再在目标域上的训练case训练,然后用目标域上的测试case展示结果。
5.1 Omnigloy
分成源域、验证域和目标域。类顺序被打乱,从而源域和目标域包含多样性的样本。
5.2 EMNIST
分成源域、验证域、和目标域。源域包含类 0 − 34 0-34 0−34(随机打乱),验证域包含类35-42,测试域包含类42-47
每一类只选择 100 100 100个样本的子集来保证低数据体制。
5.3 VGG-Face
每类随机选择 100 100 100个样本,排除脏数据后,从而在数据集中获得了完整的 2622 2622 2622个类中的 2396 2396 2396个。前1802 个 个 个用于源域集, 1803 − 2300 1803-2300 1803−2300用于测试域集, 2300 − 2396 2300-2396 2300−2396用于验证域。
6. DAGAN训练和生成
6.1 在源域训练DAGAN
用多种结构训练 D A G A N DAGAN DAGAN,发现越强劲的网络所构成的生成器效果越好, U R e s N e t UResNet UResNet是作者选择的模型。 图 4 图4 图4展示了强劲结构所生成的改进的变化。在 V G G − F a c e VGG-Face VGG−Face上训练生成的实例如 图 5 图5 图5所示。
6.2 Vanilla 分类器
第一个实验测试 D A G A N DAGAN DAGAN对于 v a n i l l a vanilla vanilla分类器的增强效果怎么样。
-
一个 D e n s e N e t DenseNet DenseNet分类器先仅在真实数据上训练(用标准数据增强),每类 5 5 5个、 10 10 10个或15个例子。
-
第二步,分类器还被传入 D A G A N DAGAN DAGAN生成的增强数据。真假标签也被传入网络层,让网络能够学习如何强调真实而非生成数据。这最后一步对于最大限度发挥DAGAN增强至关重要。
在每个训练周期内,为每个真实样本提供不同数量的增强样本 ( 1 − 10 ) (1-10) (1−10)。最好的注释率通过验证域上的性能被选择。来自目标域的保留测试结果在 表 1 表1 表1种给出。每一种情况增强都会提升分类性能。
7. 结论
数据增强是一种应用广泛用来提升低数据体制下模型性能的方法,DAGAN是一种用来自动学习增强数据的灵活模型。但是除此之外,作者展示了 D A G A N DAGAN DAGAN提升了分类器的性能,即便是在标准的数据增强后。数据增强在所有模型和方法间的通用性意味着 D A G A N DAGAN DAGAN能够成为一个低数据条件下有价值的添加剂。