f-GAN学习笔记

本文介绍了f-GAN,一种通过F-Divergence改进GAN的变种,强调了其在测量真实与生成分布差距上的优势。文章详细阐述了f-Divergence的定义、性质和在GAN优化中的作用,特别提到了如何解决Mode Collapse和Mode Dropping问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

f-GAN作为GAN的变种之一,它在真实数据的分布与生成数据的分布之间DIvergence(差距)的测量方面做出了改进,即使用F-Divergence来代替,其中F的意为函数function,它可以是KL(进而构成KL散度)、JS(进而构成JS散度)、W(进而构成Wasserstein散度)等等。其通式如下:
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x D_f(P||Q)=\int_xq(x)f(\frac{p(x)}{q(x)})dx Df(PQ)=xq(x)f(q(x)p(x))dx
并满足约束条件:

  • f为凸函数;
  • f(1)=0恒成立。

f-Divergence有如下性质:

  • if p(x) = q(x) for all x, then D f ( P ∣ ∣ Q ) D_f(P||Q) Df(PQ)=0(最小值).
  • D f ( P ∣ ∣ Q ) D_f(P||Q) Df(PQ)始终 ≥ \geq 0.

当f取不同的函数时, D f ( P ∣ ∣ Q ) D_f(P||Q) Df(PQ) 得到实例化,其中几个特例如下:

  • 当f(x) = xlogx时,得KL Divergence = ∫ x p ( x ) l o g p ( x ) q ( x ) d x \int_xp(x)log\frac{p(x)}{q(x)}dx xp(x)logq(x)p(x)dx
  • 当f(x) = -logx时,得Reverse KL Divergence = ∫ x q ( x ) l o g q ( x ) p ( x ) d x \int_xq(x)log\frac{q(x)}{p(x)}dx xq(x)logp(x)q(x)dx
  • 当f(x) = ( x − 1 ) 2 (x-1)^{2} (x1)2时,得ChiSquare Divergence = ∫ x l o g ( p ( x ) − q ( x ) ) 2 q ( x ) d x \int_xlog\frac{(p(x) - q(x))^{2}}{q(x)}dx xlogq(x)(p(x)q(x))2dx

由共轭函数推及GAN的优化目标

f ∗ ( t ) = m a x x ∈ d o m ( f ) x t − f ( x ) f^{*}(t)=\mathop{max}\limits_{x\in dom(f)}{xt-f(x)} f(t)=xdom(f)maxxtf(x), 换成用x做未知量的函数,即
f ∗ ( x ) = m a x t ∈ d o m ( f ) x t − f ( t ) f^{*}(x)=\mathop{max}\limits_{t\in dom(f)}{xt-f(t)} f(x)=tdom(f)maxxtf(t)
p ( x ) q ( x ) \frac{p(x)}{q(x)} q(x)p(x)代替f(x)中的x,可得
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x D_f(P||Q)=\int\limits_{x}q(x)f(\frac{p(x)}{q(x)})dx Df(PQ)=xq(x)f(q(x)p(x))dx
= ∫ x q ( x ) ( m a x t ∈ d o m ( f ) p ( x ) q ( x ) t − f ∗ ( t ) ) d x =\int\limits_{x}{q(x)(\mathop{max}\limits_{t\in dom(f)}{\frac{p(x)}{q(x)}t-f^{*}(t)})}dx =xq(x)(tdom(f)maxq(x)p(x)tf(t))dx
再用D(x)代替t,得到:
D f ( P ∣ ∣ Q ) ≈ m a x D ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x D_f(P||Q)\approx\mathop{max}\limits_{D}{\int\limits_{x}p(x)D(x)dx - \int\limits_{x}q(x)f^{*}(D(x))dx} Df(PQ)Dmaxxp(x)D(x)dxxq(x)f(D(x))dx

D是一个函数,它的输入是x,输出是t. 而优化的过程即找到 D f ( P ∣ ∣ Q ) D_f(P||Q) Df(PQ),找到的D使得上式越大,对应得到的t就越准确,越能逼近真实的 D f ( P ∣ ∣ Q ) D_f(P||Q) Df(PQ).

在上式中引入期望E,得到:
D f ( P ∣ ∣ Q ) = m a x D ∫ x E x ∼ P D ( x ) d x − ∫ x E x ∼ Q f ∗ ( D ( x ) ) d x D_f(P||Q)=\mathop{max}\limits_{D}{\int\limits_{x}E_{x\sim P}D(x)dx - \int\limits_{x}E_{x\sim Q}f^{*}(D(x))dx} Df(PQ)=DmaxxExPD(x)dxxExQf(D(x))dx

用GAN中的真实数据分布 P d a t a P_{data} Pdata代替分布P,生成数据分布 P g e n P_{gen} Pgen代替Q, 得到如下:
D f ( P ∣ ∣ Q ) = m a x D ∫ x E x ∼ P d a t a D ( x ) d x − ∫ x E x ∼ P g e n f ∗ ( D ( x ) ) d x D_f(P||Q)=\mathop{max}\limits_{D}{\int\limits_{x}E_{x\sim P_{data}}D(x)dx - \int\limits_{x}E_{x\sim P_{gen}}f^{*}(D(x))dx} Df(PQ)=DmaxxExPdataD(x)dxxExPgenf(D(x))dx,
再在前加入一个求最小约束,即得到Generator的目标函数:
G ∗ = a r g m i n G D f ( P d a t a ∣ ∣ P G ) G^{*}=\mathop {argmin}\limits_{G}D_f(P_{data}||P_G) G=GargminDf(PdataPG)
= a r g m i n G m a x D ∫ x E x ∼ P d a t a D ( x ) d x − ∫ x E x ∼ P g e n f ∗ ( D ( x ) ) d x =arg\mathop {min}\limits_{G}\mathop{max}\limits_{D}{\int\limits_{x}E_{x\sim P_{data}}D(x)dx - \int\limits_{x}E_{x\sim P_{gen}}f^{*}(D(x))dx} =argGminDmaxxExPdataD(x)dxxExPgenf(D(x))dx
= a r g m i n G m a x D V ( G , D ) =arg\mathop {min}\limits_{G}\mathop{max}\limits_{D}V(G, D) =argGminDmaxV(G,D)

我们想让判别器将来自真实分布的数据判断为正类,将生成得到的数据判断为负类,因此需要极大化上式;同时,对于生成器来说,要想让自己生成的数据尽可能骗过判别器,即判别器给生成的数据打高分(接近正类),因此需极小化上式的第二项。其中, f ∗ f^* f为待定函数,当f取何种divergence,后者就计算什么。

为什么要引入f-GAN?

  • 可以有效解决Mode Collapse(指生成的内容会局限于真实空间的某一个形态,无法捕捉到各种模式的信息,变得less diverse).
  • 可以有效解决Mode Dropping(训练过程中,生成信息的某一维度发生改变,其他维度均保持不变的).
### StyleGAN2-ADA 模型介绍 StyleGAN2-ADA 是由 NVIDIA 提供的一个改进版本的生成对抗网络(GAN),该模型特别之处在于加入了自适应判别器增强(Adaptive Discriminator Augmentation, ADA)。这种机制能够有效防止过拟合现象,在数据集较小的情况下也能保持良好的泛化能力[^1]。 此开源项目存在两个主要实现平台——TensorFlow 和 PyTorch。其中,PyTorch 实现由于其灵活性以及更快速度的优势受到了广泛关注。通过采用一系列优化措施和技术手段,如Top-K训练策略等,使得StyleGAN2-ADA可以在多种应用场景下发挥出色的表现,包括但不限于低数据量场景下的图像生成、数据集扩展等方面的工作[^2]。 对于开发者而言,该项目具备较高的易用性和广泛的适用范围。除了提供详细的文档指导外,官方还准备了Colab笔记本形式的教学材料,允许用户直接利用云端GPU资源来进行实验操作;同时支持与原有TensorFlow版预训练权重之间的无缝对接,极大地方便了研究人员开展对比研究和迁移学习任务[^3]。 ### 使用方法概述 为了帮助初次接触StyleGAN2-ADA的新手更好地理解和应用这一强大工具,下面给出了一些基本的操作指南: #### 准备工作 确保安装好Python开发环境,并按照官方说明完成依赖库的配置。如果打算使用Google Colab,则可以跳过这一步骤,因为大部分必要的设置已经在模板文件里完成了。 #### 加载并运行现有模型 可以通过加载已有的预训练参数来启动一个新的会话: ```python import dnnlib from training import networks_stylegan2_ada_pytorch as networks G = networks.Generator().eval() # 创建生成器实例 D = networks.Discriminator().eval() # 创建判别器实例 Gs = dnnlib.tflib.load_network_pkl('path_to_pretrained_model.pkl')['G_ema'] # 载入带有EMA更新后的生成器状态字典 ``` #### 自定义训练流程 当拥有自己的数据集时,可以根据需求调整超参数设定,编写适合具体项目的训练脚本。这里展示了一个简单的损失函数计算过程作为例子: ```python def compute_loss(G, D, batch_real_images, labels=None): # 获取一批随机噪声向量z作为输入特征 z = torch.randn([batch_size, G.z_dim]).cuda() # 利用这些特征生成假图片fake_imgs fake_imgs = G(z=z, c=labels).detach() # 将真假样本分别送入判别器得到预测得分logits real_logits = D(batch_real_images, labels) fake_logits = D(fake_imgs, labels) # 根据交叉熵准则构建最终的目标函数L_adv loss_G = -(F.softplus(-real_logits)).mean() + (F.softplus(fake_logits)).mean() return loss_G ``` 以上代码片段展示了如何在一个批次内处理真实样例及其对应的标签信息,进而求得相应的对抗性损失值$[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值