为什么在GAN中使用KL散度会出现模式坍缩的问题?

本文深入探讨了生成对抗网络(GAN)中模式坍缩问题的原因,解析了逆向KL散度对模型分布的影响,以及其如何导致生成样本多样性的降低。

李宏毅老师GAN的课程中曾经对于这个问题有一个直观的解释,大体上可以参考以下的简单理解

https://blog.youkuaiyun.com/Forlogen/article/details/88921247

那么具体在数学上如何理解这个问题的产生呢?在Goodfellow提出的GAN的原始论文中指出,当生成器和判别器的能力足够强的时候,一定可以收敛到纳什均衡点,此时得到的最优的判别器 D ∗ D^* D
D ∗ = p r ( x ) p r ( x ) + p θ ( x ) D^*=\frac{p_{r}(x)}{p_{r}(x)+p_{\theta}(x)} D=pr(x)+pθ(x)pr(x)
而当判别器固定时,生成器的目标函数为:
max ⁡ θ ( E z ∼ p ( z ) [ log ⁡ D ( G ( z , θ ) , ϕ ) ] ) \max _{\theta}\left(\mathbb{E}_{\mathrm{z} \sim p(\mathrm{z})}[\log D(G(\mathbf{z}, \theta), \phi)]\right) θmax(Ezp(z)[logD(G(z,θ),ϕ)])
那么将 D ∗ D^* D的表达式代入上式有
L ′ ( G ∣ D ⋆ ) = E x ∼ p θ ( x ) [ log ⁡ D ⋆ ( x ) ] = E x ∼ p θ ( x ) [ log ⁡ p r ( x ) p r ( x ) + p θ ( x ) ⋅ p θ ( x ) p θ ( x ) ] = − E x ∼ p θ ( x ) [ log ⁡ p θ ( x ) p r ( x ) ] + E x ∼ p θ ( x ) [ log ⁡ p θ ( x ) p r ( x ) + p θ ( x ) ] = − D K L ( p θ ∥ p r ) + E x ∼ p θ ( x ) [ log ⁡ ( 1 − D ⋆ ( x ) ) ] = − D K L ( p θ ∥ p r ) + 2 D J S ( p r ∥ p θ ) − 2 log ⁡ 2 − E x ∼ p r ( x ) [ log ⁡ D ⋆ ( x ) ] \begin{array}{l}{\mathcal{L}^{\prime}\left(G | D^{\star}\right)=\mathbb{E}_{\mathbf{x} \sim p_{\theta}(\mathbf{x})}\left[\log D^{\star}(\mathbf{x})\right]} \\ {\quad=\mathbb{E}_{\mathbf{x} \sim p_{\theta}(\mathbf{x})}\left[\log \frac{p_{r}(\mathbf{x})}{p_{r}(\mathbf{x})+p_{\theta}(\mathbf{x})} \cdot \frac{p_{\theta}(\mathbf{x})}{p_{\theta}(\mathbf{x})}\right]} \\ {=-\mathbb{E}_{\mathbf{x} \sim p_{\theta}(\mathbf{x})}\left[\log \frac{p_{\theta}(\mathbf{x})}{p_{r}(\mathbf{x})}\right]+\mathbb{E}_{\mathbf{x} \sim p_{\theta}(\mathbf{x})}\left[\log \frac{p_{\theta}(\mathbf{x})}{p_{r}(\mathbf{x})+p_{\theta}(\mathbf{x})}\right]} \\ \begin{array}{l}{=-D_{\mathrm{KL}}\left(p_{\theta} \| p_{r}\right)+\mathbb{E}_{\mathrm{x} \sim p_{\theta}(\mathrm{x})}\left[\log \left(1-D^{\star}(\mathrm{x})\right)\right]} \\ {=-D_{\mathrm{KL}}\left(p_{\theta} \| p_{r}\right)+2 D_{\mathrm{JS}}\left(p_{r} \| p_{\theta}\right)-2 \log 2-\mathbb{E}_{\mathrm{x} \sim p_{r}(\mathrm{x})}\left[\log D^{\star}(\mathrm{x})\right]}\end{array} \end{array} L(GD)=Expθ(x)[logD(x)]=Expθ(x)[logpr(x)+pθ(x)pr(x)pθ(x)pθ(x)]=Expθ(x)[logpr(x)pθ(x)]+Expθ(x)[logpr(x)+pθ(x)pθ(x)]=DKL(pθpr)+Expθ(x)[log(1D(x))]=DKL(pθpr)+2DJS(prpθ)2log2Expr(x)[logD(x)]
其中后两项为常数,只有前两项与生成器有关,所以可以将生成器的目标函数表示为
max ⁡ θ L ′ ( G ∣ D ⋆ ) = min ⁡ θ D K L ( p θ ∥ p r ) − 2 D J S ( p r ∥ p θ ) \max _{\theta} \mathcal{L}^{\prime}\left(G | D^{\star}\right)=\min _{\theta} D_{\mathrm{KL}}\left(p_{\theta} \| p_{r}\right)-2 D_{\mathrm{JS}}\left(p_{r} \| p_{\theta}\right) θmaxL(GD)=θminDKL(pθpr)2DJS(prpθ)
那么根据原论文中的介绍,其中第二项的 2 D J S ( p r ∥ p θ ) 2 D_{\mathrm{JS}}\left(p_{r} \| p_{\theta}\right) 2DJS(prpθ)为有界函数,因此 L \mathcal{L} L主要受逆向KL散度 D K L ( p θ ∥ p r ) D_{\mathrm{KL}}\left(p_{\theta} \| p_{r}\right) DKL(pθpr)的影响,因此为了更快的收敛,生成器会倾向于生成一些已经骗过判别器的样本,这样的话,样本的多样性就会下降,从而造成了**模式坍缩(model collapse)**的问题。


为什么当 L \mathcal{L} L主要受逆向KL散度 D K L ( p θ ∥ p r ) D_{\mathrm{KL}}\left(p_{\theta} \| p_{r}\right) DKL(pθpr)的影响时就会出现模式坍缩呢?这要从逆向KL散度的性质来看。我们知道KL散度是非对称的,印象按照表达式中分布顺序的不同,可以分为正向KL散度和逆向KL散度,表达如下所示
D K L ( p r ∥ p θ ) = ∫ p r ( x ) log ⁡ p r ( x ) p θ ( x ) d x D K L ( p θ ∥ p r ) = ∫ p θ ( x ) log ⁡ p θ ( x ) p r ( x ) d x \begin{aligned} D_{\mathrm{KL}}\left(p_{r} \| p_{\theta}\right) &=\int p_{r}(\mathrm{x}) \log \frac{p_{r}(\mathrm{x})}{p_{\theta}(\mathrm{x})} d \mathrm{x} \\ D_{\mathrm{KL}}\left(p_{\theta} \| p_{r}\right) &=\int p_{\theta}(\mathrm{x}) \log \frac{p_{\theta}(\mathrm{x})}{p_{r}(\mathrm{x})} d \mathrm{x} \end{aligned} DKL(prpθ)DKL(pθpr)=pr(x)logpθ(x)pr(x)dx=pθ(x)logpr(x)pθ(x)dx
在前向KL散度中:

  • p x → 0 p_{x} \rightarrow 0 px0 p θ > 0 p_{\theta}>0 pθ>0时,积分号内整体趋近于0,所以 p θ ( x ) p_{\theta}(x) pθ(x)的取值对于KL散度的计算不起任何作用
  • p x > 0 p_{x} > 0 px>0 p θ → 0 p_{\theta} \rightarrow0 pθ0时,积分号内整体趋近于 ∞ \infty ,所以KL散度会变得非常大

从中我们可以看出,前向KL散度会鼓励模型分布 p θ ( x ) p_{\theta}(x) pθ(x) 尽可能的覆盖所有真实分布 p r ( x ) > 0 p_{r}(x) >0 pr(x)>0的点,而不用会回避 p r ( x ) ≈ 0 p_{r}(x) \approx 0 pr(x)0的点。换言之,如果我们使用前向KL散度来衡量两个分布,生成器会更加倾向于生成多样性更好的样本。

而在逆向的KL散度中

  • p x → 0 p_{x} \rightarrow 0 px0 p θ > 0 p_{\theta}>0 pθ>0时,积分号内整体趋近于 ∞ \infty ,即当 p θ ( x ) p_{\theta}(x) pθ(x)趋近于0时,而 p θ ( x ) p_{\theta}(x) pθ(x)有一定的密度时,逆向KL散度的值就会变得非常大
  • p θ → 0 p_{\theta} \rightarrow0 pθ0时,无论 p r ( x ) p_{r}(x) pr(x)如何取值,积分号内的整体都将趋近于0

从中我们可以看出,逆向KL散度会鼓励模型分布 p θ ( x ) p_{\theta}(x) pθ(x) 尽可能避开所有真实分布 p r ( x ) ≈ 0 p_{r}(x) \approx 0 pr(x)0 的点,而不需要考虑是否覆盖所有 p r ( x ) > 0 p_{r}(x) > 0 pr(x)>0的点。这样的话,只要真实分布不趋近于0,模型就会尽量生成类似于已接近真实分布的新样本,而不是去生成所有覆盖真实分布的样本。

形象的表示如下


在这里插入图片描述

因此,如果使用逆向KL散度作为衡量手段的话,就会出现模式坍缩的问题。为了解决这个问题,不少人针对两个分布之间的衡量提出了很多GAN的变体,例如WGAN、Improved-WAN、LSGAN等等。

邱锡鹏《神经网络与深度学习》13.3.5.1节

<think>嗯,用户问的是Wasserstein距离相比传统GAN中的KL有哪些优势。我需要先回忆一下GAN的基本知识。传统的GAN使用的是JS(Jensen-Shannon divergence),而WGAN用的是Wasserstein距离。那为什么会有这种改变呢? 首先,传统GAN问题可能在于梯消失。当生成分布和真实分布没有重叠或者重叠很少时,JS会饱和,导致梯接近于零,这样生成器无法有效更新。而Wasserstein距离即使在分布不重叠的情况下也能提供有效的梯,这应该是一个主要优势。 其次,Wasserstein距离衡量的是两个分布之间的“距离”,而不是单纯的相似KL和JS都是概率分布的差异量,但Wasserstein考虑了分布之间的几何结构,比如在移动概率质量时的成本,这可能使得训练更稳定。 另外,训练过程中的指标问题。传统GAN的判别器输出经过sigmoid,得到的是概率,而Wasserstein的判别器(或者说Critic)输出的是一个实数,这可以更直接地反映分布之间的距离,从而提供更有意义的训练信号。 还有,Wasserstein距离的平滑性可能有助于模式覆盖。传统GAN容易导致模式崩溃,生成器可能只生成有限的几种样本,而Wasserstein距离可能鼓励生成分布覆盖更多的真实分布模式,因为它在计算时考虑的是最优运输成本,而不是单纯的差异。 需要确认这些点是否正确。例如,JS问题确实在原始论文中被指出,当两个分布没有交集时,JS是常数,导致梯为零。而Wasserstein在这种情况下仍然能给出有意义的梯。因此,Wasserstein GAN在训练稳定性和避免模式崩溃上有优势。 然后,在解释的时候要分点说明优势,比如解决梯消失、提供更平滑的梯、评估指标更合理等。还要提到WGAN的改进,比如权重裁剪或者梯惩罚来满足Lipschitz约束。 可能还需要比较KL和Wasserstein的不同。KL不对称,而且当生成分布有零而真实分布非零时,KL会趋向无穷大,导致训练不稳定。而Wasserstein距离没有这样的问题,因为它基于最优运输理论,考虑的是分布间的“移动”成本,而不是严格的概率匹配。 最后,总结一下Wasserstein距离的优势,并确保回答结构清晰,分点说明,然后给出相关的问题,比如WGAN的实现、Lipschitz约束的处理方法等。</think><< Wasserstein距离(Earth-Mover距离)相比传统GAN使用KL(Kullback-Leibler divergence)具有以下核心优势: 1. **解决梯消失问题** 传统GAN在训练时可能出现梯消失,尤其是在生成分布与真实分布无重叠时。KL和JS在此情况下会饱和(值为常数),导致梯接近于零。而Wasserstein距离即使分布不重叠时仍能提供有意义的梯,确保训练稳定。 2. **更合理的距离量** Wasserstein距离直接衡量两个分布之间的"最小运输成本",能反映分布之间的几何关系。而KL不对称且对分布微小变化不敏感(例如当生成分布覆盖真实分布但存在偏移时,KL可能极大)。 3. **缓解模式崩溃(Mode Collapse)** 传统GAN容易因KL/JS的缺陷导致生成器仅覆盖部分真实数据模式。Wasserstein距离通过平滑的梯引导生成器更全面地覆盖真实分布。 4. **训练过程可解释性** Wasserstein距离的值与生成质量呈负相关(值越小生成质量越高),可直接作为训练指标。而传统GAN的判别器输出概率无法直接反映生成质量。 **代码示例(WGAN-GP实现片段)** ```python # WGAN-GP的Critic(判别器)需满足Lipschitz约束,使用惩罚 def gradient_penalty(critic, real, fake, device): batch_size = real.size(0) epsilon = torch.rand(batch_size, 1, 1, 1).to(device) interpolates = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True) critic_interpolates = critic(interpolates) gradients = torch.autograd.grad( outputs=critic_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(critic_interpolates), create_graph=True, retain_graph=True )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值