目录
摘要
本周学习了生成式对抗网络,了解了生成器和判别器的基础原理,同时学习了 GAN 的训练过程和算法思想。最后了解了 Wasserstein Distance,学习了 WGAN 的一点基础。
Abstract
This week, I learned about Generative Adversarial Networks and understood the basic principles of the generator and discriminator. I also studied the training process and the conceptual ideas behind GAN algorithms. Finally, I was introduced to the concept of Wasserstein Distance and gained some foundational knowledge about WGAN.
1.Generative Adversarial Network
1.1.Generator
之前我们学习到呢 network 是这样子的:输入一个 x x x 就会输出一个固定的 y y y。我们在这个原则上添加一个简单的分布,比如高斯分布。每次输入一个 x x x 的时候会从分布中采样一个值与 x x x 一起丢进 network 中,使得每次生成的 y y y 值都不固定,并且输出出来的值都是一个复杂的分布,我们把这种 network 称为 generator。
生成器(Generator):生成器的目标是生成尽可能接近真实数据的假数据。它通常是一个神经网络,通过学习真实数据的分布特征,生成新的数据样本。生成器的输入是随机噪声(通常来自高斯分布),输出是生成的数据样本。
那究竟为什么我们要使得 y y y 的输入不固定呢?
考虑这样一个例子:如上图这是一个小精灵游戏的视频,我们可以把这个视频的前几帧丢到一个 network 里面,然后预测出来它的下一帧。
在某个训练资料里面,小精灵到了一个特定的转角之后可能学习到的是向左转,但是在另一个训练资料里面,小精灵学习到的可能是向右转。又因为一个训练集里面可能同时包含着两个训练资料,这就导致了机器会出现“两面讨好”的现象,小精灵到了某个转角之后下一帧会同时向左向右分裂。
因此,当我们对网络加入一个分布的采样之后,对于一个固定的训练资料就能输出不一样的结果。
也就是说,一个固定的输入加上分布 z z z,可以获得更加多样性的输出。
特别是对于那些需要 “创造力” 的任务来说, 我们就非常需要 generative model 。其中,最著名的一个 generative model 就是 Generative Adversarial Network。
Generative Adversarial Network(GAN,生成对抗网络)是一种深度学习模型,它由两个主要部分组成:生成器(Generator)和判别器(Discriminator)。GANs的设计目的是为了让生成器学会模仿真实数据的分布,从而能够生成与训练数据相似的新样本。这两个组件通过“对抗”的方式共同进化,最终达到一种平衡状态,在这种状态下生成的数据难以被区分真假。
以上图为例,假设先不考虑输入 x x x ,输入到 generator 中的应该是从一个正态分布中采样的一个低维向量,通过 generator 之后将会产生一个代表不同图片的高维向量。
1.2.Discriminator
在 GAN 中还需要一个组件 discriminator (判别器),这个判别器也是一个神经网络,我们可以将生成器产生的图片输入判别器中生成一个数值(scalar),这个数值越大说明越接近真实值,越小说明越远离真实值。
1.3.Basic Idea of GAN
在生成二次元图像的例子中,第一代的生成器生成图像之后会交给第一代的判别器进行判别,判别器对比真实图像和生成图像的差距对生成图像进行鉴别。在训练过程中,生成器和判别器同时更新它们的参数并进行迭代进化。
生成器试图最大化判别器错误分类其生成样本的概率,即让判别器将其生成的数据识别为真实数据。相反,判别器试图最小化错误分类的概率,即准确地区分真实数据和生成数据。这种训练过程可以看作是一种“零和游戏”,其中生成器和判别器互相对抗,直到生成器能够生成几乎不可辨别的虚假数据为止。
1.4.Algorithm
1.4.1.Step 1:Fix generator G, and update discriminator D
第一步,我们应该先固定生成器 G 的参数,然后更新判别器 D 的参数。
例如上图,我们将随机采样的向量丢到生成器里产生四张图片,然后再从数据集里采样四张图片,最后训练判别器来区分真实图片和生成出来的图片。这个任务既可以看做是分类问题也可以看做是回归问题,即可以用两种方法来玩法。
1.4.2.Step 2:Fix discriminator D, and update generator G
第二步,我们应该固定判别器 D 的参数,然后更新生成器 G 的参数。
我们在这个阶段的任务是让生成器去“骗过”判别器。我们将一个采样出来的向量通过这个 GAN 网络之后会得到一个分数,此时我们通过调整生成器的参数(不调整判别器的参数)来使得我们最后得出来的分数越大越好。这个阶段的训练和之前学习的没什么不同,可以采用梯度下降法来求解。
最后就是反复进行这个过程:固定住 G 的参数,更新 D 的参数,然后再固定住 D 的参数,更新 G 的参数,一直迭代运行下去。
2.Theory behind GAN
我们将一个简单的分布 z z z 传入生成器后会产生一个新的分布 P G P_G PG 。假设真实的数据也是一个分布,令其为 P d a t a P_{data} Pdata。我们的任务就是让 P G P_G PG 和 P d a t a P_{data} Pdata 越接近越好。此时可以定义一个新的函数 G ∗ = a r g m i n D i v ( P G , P d a t a ) G^*=arg\ min\ Div(P_G,P_{data}) G∗=arg min Div(PG