【研究生工作周报】(第七周)

本文详细介绍了生成对抗网络GAN的基础概念,包括其构成(生成器与判别器)、训练过程、价值函数推导,以及KL散度和极大似然估计的应用。通过实例演示和收敛性分析,展示了GAN如何生成逼真的数据分布。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

学习目标:

  • GAN对抗生成网络
  • 《深度学习》第六章前馈神经网络
  • 《深度学习》第七章深度学习中的正则化

学习内容:

  1. Ian J. Goodfellow《Generative Adversarial Nets》
  2. 李宏毅深度学习课程关于GAN的讲解
  3. GAN算法原理
  4. GAN训练过程
  5. GAN价值函数推导过程

学习时间:

  • 6.19~6.25

学习产出:

  • 优快云 技术博客 1 篇
  • github代码

GAN概述

  生成对抗网络GAN(Generative adversarial nets)是由Goodfellow等人于2014年提出的基于深度学习模型的生成框架,可用于多种生成任务。从名称也不难看出,在GAN中包括了两个部分,分别为”生成”和“对抗”,整两个部分也分别对应了两个网络,即生成网络(Generator)GGG和判别网络(Discriminator)DDD,为描述简单,以图像生成为例:
在这里插入图片描述

  生成网络(Generator)GGG用于生成图片,其输入是一个随机的噪声z\boldsymbol{z}z,通过这个噪声生成图片,记作G(z)G\left ( \boldsymbol{z} \right )G(z)
  判别网络(Discriminator)DDD用于判别一张图片是否是真实的,对应的,其输入是一整图片x\boldsymbol{x}x,输出D(x)D\left ( \boldsymbol{x} \right )D(x)表示的是图片x\boldsymbol{x}x为真实图片的概率。
  在GAN框架的训练过程中,希望生成网络GGG生成的图片尽量真实,能够欺骗过判别网络DDD;而希望判别网络DDD能够把GGG生成的图片从真实图片中区分开。这样的一个过程就构成了一个动态的“博弈”。最终,GAN希望能够使得训练好的生成网络GGG生成的图片能够以假乱真,即对于判别网络DDD来说,无法判断GGG生成的网络是不是真实的。

综上,训练好的生成网络GGG便可以用于生成“以假乱真”的图片。

GAN的框架结构

  GAN的框架是由生成网络GGG和判别网络DDD这两种网络结构组成,通过两种网络的“对抗”过程完成两个网络的训练。由生成网络GGG生成一张“Fake image”,判别网络DDD判断这张图片是否来自真实图片。

在这里插入图片描述

GAN框架的训练过程

  生成网络和判别网络更像是自然界的一对生产者和捕食者,它们在各自独立迭代,进化的同时,也伴随着协同学习,共同进步。用论文里的一句话来描述训练过程就是

The training procedure for G is to maximize the probability of D making a mistake.

  GAN希望的是对于判别网络,其能够正确判定数据是否来自真实的分布,对于生成网络,其能够尽可能使得生成的数据能够“以假乱真”,使得判别网络分辨不了。这样的训练过程是一个动态的“博弈”过程,通过交替训练,最终使得生成网络GGG生成的图片能够“以假乱真”,其具体过程如下图所示:

在这里插入图片描述
  如上图(a)中,黑色的虚线表示的是从真实的分布pxp_{\boldsymbol{x}}px,绿色的实线表示的是需要训练的生成网络的生成的分布pg(G)p_g\left ( G \right )pg(G),蓝色的虚线表示的是判别网络,最下面的横线z\boldsymbol{z}z表示的是从一个先验分布(如图中是一个均匀分布)采样得到的数据点,中间的横线x\boldsymbol{x}x表示的真实分布,两条横线之间的对应关系表示的是生成网络将先验分布映射成一个生成分布pg(G)p_g\left ( G \right )pg(G)。从图(a)到图(d)表示了一个完整的交替训练过程,首先,如图(a)所示,当通过先验分布采样后的数据经过生成网络GGG映射后得到了图上绿色的实线代表的分布,此时判别网络DDD并不能区分数据是否来自真实数据,通过对判别网络的训练,其能够正确地判断生成的数据是否来自真实数据,如图(b)所示;此时更新生成网络GGG,通过对先验分布重新映射到新的生成分布上,如图©中的绿色实线所示。依次交替完成上述步骤,当达到一定迭代的代数后,达到一个平衡状态,此时pg=pdatap_g=p_{data}pg=pdata,判别网络DDD将不能区分图片是否来自真实分布,且D(x)=12D\left ( \boldsymbol{x} \right )=\frac{1}{2}D(x)=21

价值函数

首先对于二分类问题,一般使用交叉熵作为损失函数:

J(θ)=−1m∑i=1m[y(i)log  y^(i)+(1−y(i))log  (1−y^(i))]J\left ( \theta \right )=-\frac{1}{m}\sum_{i=1}^{m}\left [ y^{\left ( i \right )}log\; \hat{y}^{\left ( i \right )}+\left ( 1-y^{\left ( i \right )} \right )log\; \left ( 1-\hat{y}^{\left ( i \right )} \right ) \right ]J(θ)=m1i=1m[y(i)logy^(i)+(1y(i))log(1y^(i))]其中,y(i)y^{\left ( i \right )}y(i) 表示的是真实的样本标签,y^(i)\hat{y}^{\left ( i \right )}y^(i)表示的是模型的预测值。

对于GAN的两部分样本,一个是真实的样本{(x(1),1),(x(2),1),⋯ ,(x(m),1)}\left \{ \left ( \boldsymbol{x}^{\left ( 1 \right )},1 \right ),\left ( \boldsymbol{x}^{\left ( 2 \right )},1 \right ),\cdots ,\left ( \boldsymbol{x}^{\left ( m \right )},1 \right ) \right \}{(x(1),1),(x(2),1),,(x(m),1)},另一部分来自生成模型{(G(z(1)),0),(G(z(2)),0),⋯ ,(G(z(m)),0)}\left \{ \left ( G\left ( \boldsymbol{z}^{\left ( 1 \right )} \right ),0 \right ),\left ( G\left ( \boldsymbol{z}^{\left ( 2 \right )} \right ),0 \right ),\cdots ,\left ( G\left ( \boldsymbol{z}^{\left ( m \right )} \right ),0 \right ) \right \}{(G(z(1)),0),(G(z(2)),0),,(G(z(m)),0)}

将两部分带入交叉熵,可得GAN价值函数V(G,D)V\left ( G,D \right )V(G,D)为:

minG  maxD  V(D,G)=Ex∼pdata(x)[log  D(x)]+Ez∼pz(z)[log  (1−D(G(z)))]\underset{G}{min}\; \underset{D}{max}\; V\left ( D,G \right )=\mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; D\left ( \boldsymbol{x} \right ) \right ]+\mathbb{E}_{\boldsymbol{z}\sim p_{\boldsymbol{z}}\left ( \boldsymbol{z} \right )}\left [ log\; \left ( 1-D\left ( G\left ( \boldsymbol{z} \right ) \right ) \right ) \right ]GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中,Ex∼pdata(x)[log  D(x)]\mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; D\left ( \boldsymbol{x} \right ) \right ]Expdata(x)[logD(x)]表示的是D(x)D\left ( \boldsymbol{x} \right )D(x)的期望,Ez∼pz(z)[log  (1−D(G(z)))]\mathbb{E}_{\boldsymbol{z}\sim p_{\boldsymbol{z}}\left ( \boldsymbol{z} \right )}\left [ log\; \left ( 1-D\left ( G\left ( \boldsymbol{z} \right ) \right ) \right ) \right ]Ezpz(z)[log(1D(G(z)))]表示的是log  (1−D(G(z)))log\; \left ( 1-D\left ( G\left ( \boldsymbol{z} \right ) \right ) \right )log(1D(G(z)))的期望。

假设从真实数据中采样mmm个样本{x(1),x(2),⋯ ,x(m)}\left \{ \boldsymbol{x}^{\left ( 1 \right )},\boldsymbol{x}^{\left ( 2 \right )},\cdots ,\boldsymbol{x}^{\left ( m \right )} \right \}{x(1),x(2),,x(m)},从噪音分布pg(z)p_g\left ( \boldsymbol{z} \right )pg(z)中同样采样mmm个样本,记为{z(1),z(2),⋯ ,z(m)}\left \{ \boldsymbol{z}^{\left ( 1 \right )},\boldsymbol{z}^{\left ( 2 \right )},\cdots ,\boldsymbol{z}^{\left ( m \right )} \right \}{z(1),z(2),,z(m)},此时,上述价值函数可以近似表示为:

minGmaxDV(D,G)≈∑i=1m[logD(x(i))]+∑i=1m[log(1−D(G(z(i))))]\underset{G}{min}\underset{D}{max}V(D,G)\approx \sum_{i=1}^{m}{[logD(x^{(i)})]}+ \sum_{i=1}^{m}{[log(1−D(G(z^{(i)})))]}GminDmaxV(D,G)i=1m[logD(x(i))]+i=1m[log(1D(G(z(i))))]

用随机梯度下降法训练过程在论文中如下流程所示:(分两步ascending,descending,加梯度减梯度交替训练生成器判别器)
在这里插入图片描述

KL散度

需要刻画两个分布是否相似,需要用到KL散度(KL divergence)。KL散度是统计学中的一个基本概念,用于衡量两个分布的相似程度,数值越小,表示两种概率分布越接近。
对于离散的概率分布,定义如下:

(P∥Q)=∑iP(i)logP(i)Q(i)\left ( P\parallel Q \right )=\sum_{i}P\left ( i \right )log\frac{P\left ( i \right )}{Q\left ( i \right )}(PQ)=iP(i)logQ(i)P(i)对于连续的概率分布,定义如下:

(P∥Q)=∫−∞+∞p(x)logp(x)q(x)\left ( P\parallel Q \right )=\int_{-\infty }^{+\infty }p\left ( x \right )log\frac{p\left ( x \right )}{q\left ( x \right )}(PQ)=+p(x)logq(x)p(x)

极大似然估计

上述需要求解生成分布pg(x;θ)p_g\left ( \boldsymbol{x};\theta \right )pg(x;θ)中的参数θ\thetaθ,需要用到极大似然估计。根据极大似然估计的方式,由于最终是希望生成的分布pg(x;θ)p_g\left ( \boldsymbol{x};\theta \right )pg(x;θ)与原始的真实分布pdata(x)p_{data}\left ( \boldsymbol{x} \right )pdata(x),首先从真实分布pdata(x)p_{data}\left ( \boldsymbol{x} \right )pdata(x)采样mmm个数据点,记为{x(1),x(2),⋯ ,x(m)}\left \{ \boldsymbol{x}^{\left ( 1 \right )},\boldsymbol{x}^{\left ( 2 \right )},\cdots ,\boldsymbol{x}^{\left ( m \right )} \right \}{x(1),x(2),,x(m)},根据生成的分布,得到似然函数为:

L=∏i=1mpg(x(i);θ)L=\prod_{i=1}^{m}p_g\left ( \boldsymbol{x}^{\left ( i \right )};\theta \right )L=i=1mpg(x(i);θ)

取log后,得到等价的log似然:

L=∑i=1mlog  pg(x(i);θ)L=\sum_{i=1}^{m}log\; p_g\left ( \boldsymbol{x}^{\left ( i \right )};\theta \right )L=i=1mlogpg(x(i);θ)
此时,θ∗\theta ^{\ast }θ为:

θ∗=argmaxθ∑i=1mlogpg(x(i);θ)θ∗=arg\underset{\theta}{max}\sum_{i=1}^{m}{logp_{g}(x^{(i)};θ)}θ=argθmaxi=1mlogpg(x(i);θ)
  ≈argmaxθEx∼pdata[logpg(x;θ)]≈arg\underset{\theta}{max}\mathbb{E}_{\boldsymbol{x}\sim p_{data}}[logp_{g}(x;θ)]argθmaxExpdata[logpg(x;θ)]
  =argmaxθ∫xpdata(x)logpg(x;θ)dx=arg\underset{\theta}{max}∫_{x}p_{data}(x)logp_{g}(x;θ)dx=argθmaxxpdata(x)logpg(x;θ)dx

对上述的公式做一些修改,增加一个与θ\thetaθ无关的项∫xpdata(x)log  pdata(x)dx\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_{data}\left ( \boldsymbol{x} \right )d\boldsymbol{x}xpdata(x)logpdata(x)dx,这样并不改变对θ∗\theta ^{\ast }θ的求解,此时,公式变为:

argmaxθ∫xpdata(x)log  pg(x;θ)dx−∫xpdata(x)log  pdata(x)dx\underset{\theta }{argmax}\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_g\left ( \boldsymbol{x};\theta \right )d\boldsymbol{x}-\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_{data}\left ( \boldsymbol{x} \right )d\boldsymbol{x}θargmaxxpdata(x)logpg(x;θ)dxxpdata(x)logpdata(x)dx
将最大值求解变成最小值为:

argminθ∫xpdata(x)log  pdata(x)dx−∫xpdata(x)log  pg(x;θ)dx\underset{\theta }{argmin}\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_{data}\left ( \boldsymbol{x} \right )d\boldsymbol{x}-\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_g\left ( \boldsymbol{x};\theta \right )d\boldsymbol{x}θargminxpdata(x)logpdata(x)dxxpdata(x)logpg(x;θ)dx

通过积分公式的合并,得到:

argminθ∫xpdata(x)log  pdata(x)log  pg(x;θ)dx\underset{\theta }{argmin}\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\;\frac{p_{data}\left ( \boldsymbol{x} \right )}{log\; p_g\left ( \boldsymbol{x};\theta \right )} d\boldsymbol{x}θargminxpdata(x)loglogpg(x;θ)pdata(x)dx

由KL散度可知,上述可以表示为:

argminθ  KL(pdata(x)∥pg(x;θ))\underset{\theta }{argmin}\; KL\left ( p_{data}\left ( \boldsymbol{x} \right )\parallel p_g\left ( \boldsymbol{x};\theta \right ) \right )θargminKL(pdata(x)pg(x;θ))
由此可以看出最小化KL散度等价于最大化似然函数。


收敛性分析

当生成网络G确定后,价值函数可以表示为:

maxD  V(D)=∫xpdata(x)log(D(x))dx+∫zpz(z)log(1−D(G(z)))dz\underset{D}{max}\;V(D)=∫_{x}p_{data}(x)log(D(x))dx+∫_{z}p_{z}(z)log(1−D(G(z)))dzDmaxV(D)=xpdata(x)log(D(x))dx+zpz(z)log(1D(G(z)))dz
      =∫xpdata(x)log(D(x))+pg(x)log(1−D(x))dx=∫_{x}p_{data}(x)log(D(x))+p_{g}(x)log(1−D(x))dx=xpdata(x)log(D(x))+pg(x)log(1D(x))dx

由于上述的积分与D无关,上述可以简化成求解:

maxD[pdata(x)log  (D(x))+pg(x)log  (1−D(x))]\underset{D}{max}\left [ p_{data}\left ( \boldsymbol{x} \right )log\; \left ( D\left ( \boldsymbol{x} \right ) \right )+p_{g}\left ( \boldsymbol{x} \right )log\; \left ( 1-D\left ( \boldsymbol{x} \right ) \right )\right ]Dmax[pdata(x)log(D(x))+pg(x)log(1D(x))]

求导数并令其为0,便可以得到最大的D:

D∗=pdata(x)pdata(x)+pg(x)D^{\ast }=\frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )}D=pdata(x)+pg(x)pdata(x)

D(x)∈[0,1]D\left ( \boldsymbol{x} \right )\in \left [ 0,1 \right ]D(x)[0,1],将其带入到价值函数中,可得

V(D∗,G)=Ex∼pdata(x)[log  pdata(x)pdata(x)+pg(x)]+Ex∼pg(x)[log  (1−pdata(x)pdata(x)+pg(x))]V\left ( D^{\ast },G \right )=\mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; \frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ]+\mathbb{E}_{\boldsymbol{x}\sim p_g\left ( \boldsymbol{x} \right )}\left [ log\; \left ( 1-\frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ) \right ]V(D,G)=Expdata(x)[logpdata(x)+pg(x)pdata(x)]+Expg(x)[log(1pdata(x)+pg(x)pdata(x))]

对上式简化,可得:
V(D∗,G)=Ex∼pdata(x)[log  pdata(x)pdata(x)+pg(x)]+Ex∼pg(x)[log  (1−pdata(x)pdata(x)+pg(x))]V\left ( D^{\ast },G \right )=\mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; \frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ]+\mathbb{E}_{\boldsymbol{x}\sim p_g\left ( \boldsymbol{x} \right )}\left [ log\; \left ( 1-\frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ) \right ]V(D,G)=Expdata(x)[logpdata(x)+pg(x)pdata(x)]+Expg(x)[log(1pdata(x)+pg(x)pdata(x))]

      =∫xpdata(x)logpdata(x)pdata(x)+pg(x)dx+∫xpg(x)logpg(x)pdata(x)+pg(x)dx=∫_{x}p_{data}(x)log\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}dx+∫_{x}p_{g}(x)log\frac{p_{g}(x)}{p_{data}(x)+p_{g}(x)}dx=xpdata(x)logpdata(x)+pg(x)pdata(x)dx+xpg(x)logpdata(x)+pg(x)pg(x)dx


通过对分子和分母分别除2,可得:

V(D∗,G)=∫xpdata(x)log12pdata(x)pdata(x)+pg(x)2dx+∫xpg(x)log12pg(x)pdata(x)+pg(x)2dxV\left ( D^{\ast },G \right )=∫_{x}p_{data}(x)log\frac{\frac{1}{2}p_{data}(x)}{\frac{p_{data}(x)+p_{g}(x)}{2}}dx+∫_{x}p_{g}(x)log\frac{\frac{1}{2}p_{g}(x)}{\frac{p_{data}(x)+p_{g}(x)}{2}}dxV(D,G)=xpdata(x)log2pdata(x)+pg(x)21pdata(x)dx+xpg(x)log2pdata(x)+pg(x)21pg(x)dx
      =−2log2+KL(pdata(x)∥pdata(x)+pg(x)2)+KL(pg(x)∥pdata(x)+pg(x)2)=−2log2+KL(p_{data}(x)∥\frac{p_{data}(x)+p_{g}(x)}{2})+KL(p_{g}(x)∥\frac{p_{data}(x)+p_{g}(x)}{2})=2log2+KL(pdata(x)2pdata(x)+pg(x))+KL(pg(x)2pdata(x)+pg(x))


这里引入另一个散度:JS散度(Jensen-Shannon Divergence)

JSD(P∥Q)=12[KL(P∥M)+KL(Q∥M)]JSD\left ( P\parallel Q \right )=\frac{1}{2}\left [ KL\left ( P\parallel M \right ) + KL\left ( Q\parallel M \right )\right ]JSD(PQ)=21[KL(PM)+KL(QM)]

其中,M=P+Q2M=\frac{P+Q}{2}M=2P+Q 。因此V(D∗,G)V\left ( D^{\ast },G \right )V(D,G)可以表示为:

V(D∗,G)=−log  4+2JSD(pdata(x)∥pg(x))V\left ( D^{\ast },G \right )=-log\; 4+2JSD\left ( p_{data}\left ( \boldsymbol{x} \right )\parallel p_g\left ( \boldsymbol{x} \right ) \right )V(D,G)=log4+2JSD(pdata(x)pg(x))

已知JS散度是一个非负值,且值域为[0,1]\left [ 0,1 \right ][0,1],当两个分布相同时取0,不同时取1。对于V(D∗,G)V\left ( D^{\ast },G \right )V(D,G)的最小值为当JS=0JS=0JS=0时,即最小值是−log  4-log\;4log4。此时pdata(x)=pg(x)p_{data}\left ( \boldsymbol{x} \right )=p_g\left ( \boldsymbol{x} \right )pdata(x)=pg(x),求得的生成网络GGG生成的数据分布与真实的数据分布差异性最小,即GAN所要求的目标:pg=pdatap_g=p_{data}pg=pdata

训练结果

初始结果:
在这里插入图片描述

训练20个Epoch:
在这里插入图片描述
生成器性能不断优化:
在这里插入图片描述

总结

  本周重点是研究Ian Goodfellow在2014年发表的论文《Generative Adversarial Nets》中提出的GAN网络,结合李宏毅老师的深度学习课程相关部分针对对抗生成网络进行系统的了解学习。
  GAN的生成器和判别器就像自然界的一对物种,生成器为了骗过判别器不断优化,判别器也通过不断训练努力提升辨别能力,最终我们将会得到一个非常优秀的生成器。
  在训练过程推导价值函数时,接触到了散度这个新概念,结合之前学过的极大似然估计,发现这里的最小化KL散度等价于最大化似然函数。JSD散度相比较于KL散度,具备了对称性,且JS 散度的取值为 0 到 log2。若两个分布完全没有交集,那么 JS 散度取最大值 log2;若两个分布完全一样,那么 JS 散度取最小值 0,当且仅当 P=Q,JSD散度取最小值0,求得的G会使得真实的数据分布和生成的数据分布差异性最小,这样自然可以生成一个和原分布尽可能接近的分布,同时也摆脱了计算极大似然估计,所以GAN的本质是通过改变训练的过程来避免繁琐的计算。所以GAN在建模数据分布的优势不言而喻,同时它的局限性也很明显:难训练不稳定,生成器判别器需要很好的同步,还有别的论文中提到的多次训练梯度消失问题,模式缺失(Mode Collapse)问题(即只生成简单重复样本点)。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值