GAN网络调参经验

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.youkuaiyun.com/qq_38262728/article/details/98327082


自从 GAN 提出后,它变得越来越火热,吸引了众多的爱好者前来学习实践。

但是只要你自己去从无到有写出一个 GAN 模型并运行,除非你运气太好,大多数情况下你都会发现自己的GAN并不能很好地 work 。

下面首先对 GAN 进行简要的介绍,然后整理了我自己在 GAN 的设计网络结构、调整参数等方面的经验。

1 什么是GAN?

GAN 是一种生成模型,由知名的学者 Ian Goodfellow 首先提出,并给出了实验结果和理论推导 https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

它以造假币为例对 GAN 的工作原理进行解释,生成器(Generator)就像造假币的人,判别器(Discriminator)就像警察,原始数据的分布类比于真钱,生成的数据分布类比于假钱。

造假币的人不断模仿真钱去造假币,造出来的钱混入真钱一同交给警察去判断。造假币的人的目的是让自己造出来的假币不断逼近于真钱,而警察既需要认出假钱、又不能冤枉真钱。

两者以此方式,不断地对抗提升自己造假和打假的能力,最终理想的结果是使得造假币的人能造出几乎无法辨识的假钱成功迷惑警察。

2 GAN存在的问题

  • 训练不稳定,损失值波动幅度大
  • 判别器收敛迅速,损失值快速降到零
  • 生成器无能为力,损失函数不断增大

3 训练的经验

3.1 不要纠结于损失函数的选择

刚开始你可能会认为损失函数对结果会产生较大的影响,但是实践证明,一般来说其对结果的影响一般并没有你想象的那样大。

因此,对于 GAN 理论入门不久,正在打开实践大门的人,我的建议是选择最简单的损失函数就可以开始实验了。

因为后续还有好多事情值得你去头疼,微调损失函数可以留到最后一步再考虑。

3.2 关于增加模型的容量

当GAN生成的图像不够准确、清晰时,可尝试增加卷积层中的卷积核的大小和数量,特别是初始的卷积层。

卷积核的增大可以增加卷积的视野域,平滑卷积层的学习过程,使得训练不过分快速地收敛。

增加卷积核数(特别是生成器),可以增加网络的参数数量和复杂度,增加网络的学习能力。

但同时也可能存在,增加生成器的模型 capacity 但是对于它快速被判别器打败的事实无济于事的情况,每个人都使用不同的模型和数据,会有不同的情况,需要具体问题具体分析。

3.3 尝试改变标签

如果使用的是真实数据标签为1,生成数据标签为0的分配方法,可将其交换为真实数据标签为0,生成数据标签为1。

这个小技巧会帮助网络在早期快速进行梯度计算,帮助稳定训练过程。

此外,还可使用软标签和带噪声的标签。

所谓软标签指不是使用0和1作为标签,而是使用和0或1接近的小数来标记,这样会减弱梯度的传播速度,稳定训练。

而使用带噪声的标签指对少数的标签进行随机的扰动,这也是一个帮助训练的小技巧。

3.4 尝试使用 batch normalization

我在实践的过程中使用 batch normalization ,发现对结果的提升具有明显的帮助,它在每一层都对数据进行归一化,有利于防止数据发散,进而保护训练的过程与结果的稳定性。

3.5 尝试分次训练

对于一般的 GAN 模型和多分类问题,最好分次训练,一次只训练一个类别,以降低网络训练的难度并提高准确性。

而对于条件 GAN 等,比如可以将类比标签一同作为输入,以类别为先验条件的 GAN ,可适度增大训练的难度。

3.6 最好不要提早结束

有时候我们会看到自己模型的损失函数在几个batch训练过后就停止波动了,但是这个时候先不要为了节省时间而提前停止训练,实践证明这个时候网络很可能仍然在不断地调整结构中。

有时候损失函数也可能突然出现很大的异常波动,这个时候也不要马上提前停止训练,多观察一会儿。

非常建议在训练的过程中,通过保存等方式不断记录当前时刻下的训练结果。通过对结果图像的观察分析来判断训练的过程,损失函数可能会一时蒙蔽双眼,结果应该不会。

因此除非损失马上收敛到接近于0,否则耐心地等待网络训练完再评估结果,调整网络结构和参数。

3.7 关于k的选择

原论文中的 k 指每优化一次生成器的损失函数,优化判别器的损失函数 k 次。

但是在实验中,经常出现判别器迅速打败生成器的情况(即判别器的损失函数快速下降,生成器快速上升)。

于是常规的思路,就是增加生成器的训练次数。没训练一次判别器,训练k次生成器。这样可以增加生成器的学习次数,使得训练在开始时稍稳定。

然而实践证明,如果判别器真的比生成器强太多,这种调节k只是让结果崩溃来的晚一些。或者说只是相当于节省了少训练几次判别器的时间,稍稍提升了结果。

我个人不建议出问题就改k的习惯,还是应该从网络结构本身找问题所在才是治本的关键。

3.8 关于学习率

调整学习率是解决生成器崩溃的一剂良方。

当出现崩溃时,尝试降低学习率,可能会带来意想不到的效果。

3.9 增加噪声

与标签噪声相似,还可在数据中引入一定量的噪声,大多数情况下都能 work 。
在这里插入图片描述

3.10 不要使用性能太好的判别器

WGAN论文提到过,若使用性能过好的判别器可能会使得判别器的损失函数在训练一开始就降到非常低,后续对抗无法继续进行,或者使得训练出来的生成器性能不够好。

3.10 可以尝试最新的multi-scale gradient方法

https://arxiv.org/abs/1903.06048

对于稳定训练帮助很大。

3.11 可以尝试使用TTUR

https://arxiv.org/abs/1706.08500

对于生成器和判别器使用不同的学习率,看似简单的 trick 对结果的提升却有奇效。

3.12 使用Spectral Normalization

https://arxiv.org/abs/1802.05957

对卷积核使用Spectral Normalization,极力安利。

4 正常的损失函数波动情况

目前来看,正常的损失函数应该是:

  • 训练初始,生成器和判别器的损失函数快速波动,但是大致都分别朝着增大或减小的方向。
  • 趋于稳定后,生成器和判别器的损失函数在小的范围内做上下波动,此时模型趋于稳定。

参考

[1] https://arxiv.org/pdf/1406.2661.pdf
[2] https://mp.weixin.qq.com/s?__biz=MzUzNTA1NTQ3NA==&mid=2247486336&idx=1&sn=57c9fe8324a1addd73016c2f9dad4db8&chksm=fa8a169dcdfd9f8b17a02ab37eba61fdb3a1d89f694eaf89e2159275a553efe81c848e9a597c&mpshare=1&scene=1&srcid=&sharer_sharetime=1564758935068&sharer_shareid=f48c6499a7bee75abed9252093ec8062&key=83b29471f317cf4cb4c43b8d6f0f7141528141839d921d9fa05354867868f61243968a92b031ba8d4867003242ab09f1ca621380db5b7bc77bfcab13dc9cc7a0960adac628f5a805694c9fef0468a345&ascene=1&uin=MjYxNDk4MjcwNg%3D%3D&devicetype=Windows+10&version=62060833&lang=zh_CN&pass_ticket=RHjQjboJdAJhysQNM17TfCzpyiuR4K3LIS%2FvyT9wAnt%2BBDxNq0hsDyAO0BNEjE6l
[3] https://towardsdatascience.com/10-lessons-i-learned-training-generative-adversarial-networks-gans-for-a-year-c9071159628

                                </div>
            <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-e44c3c0e64.css" rel="stylesheet">
                </div>

更多《计算机视觉与图形学》知识,可关注下方公众号:
在这里插入图片描述

<think>好的,我需要回答用户关于如何使用生成对抗网络GAN)及其应用场景和实现方式的问题。首先,我应该回顾GAN的基本原理,确保自己理解正确。GAN由生成器和判别器组成,两者通过对抗训练提升性能。生成器负责生成数据,判别器负责判断数据是真实的还是生成的。 接下来,用户询问使用方法,我需要分步骤说明。首先,明确任务类型,比如图像生成、风格迁移等。然后选择适合的GAN变体,例如DCGAN、CycleGAN等。接着准备数据集,可能需要图像数据预处理,如归一化、调整尺寸。构建网络结构是关键部分,生成器和判别器的设计需要详细说明,比如使用卷积层、激活函数等。定义损失函数和优化器,对抗训练的过程需要解释清楚,包括交替训练生成器和判别器。最后是评估生成效果,可能需要使用FID、IS等指标,或者人工检查。 应用场景部分,根据引用资料,用户提到了图像生成、数据增强、图像编辑、视频生成等。我需要将这些场景与具体例子结合,比如生成人脸、艺术风格转换等。同时,引用对应的参考文献,确保信息准确。 实现方式方面,可以建议使用TensorFlow或PyTorch框架,并给出简单的代码示例。例如,使用PyTorch搭建生成器和判别器的基本结构。此外,还需要提到训练过程中的注意事项,如平衡生成器和判别器的训练,防止模式崩溃等。 用户提供的引用中,引用3提到了风光场景生成的MATLAB代码,但用户主要使用Python和TensorFlow,所以可能需要调整示例代码的语言。不过,保持示例代码的简洁和易懂更重要,使用Python更合适。 最后,生成相关问题,需要围绕GAN的扩展应用、挑战、改进方法等方面提问,帮助用户进一步探索。 现在需要确保所有数学公式正确使用LaTeX,行内公式用$...$,独立公式用$$...$$。例如,生成器的损失函数可以写成$\mathcal{L}_G = -\mathbb{E}[\log D(G(z))]$。同时,引用标记要正确添加在相关段落末尾。 检查是否有遗漏的应用场景或实现细节,确保回答结构清晰,步骤明确,引用恰当。避免使用复杂术语,保持语言简洁易懂,符合用户需要。</think>### 如何使用生成对抗网络GAN)及其常见应用场景与实现方式 #### **1. GAN的基本原理** 生成对抗网络GAN)由生成器(Generator)和判别器(Discriminator)两部分组成,两者通过对抗训练实现优化: - **生成器**:接收随机噪声$z$,生成与真实数据分布接近的假数据$G(z)$。 - **判别器**:判断输入数据是来自真实数据集还是生成器,输出概率值$D(x)$。 两者的目标函数为: $$\min_G \max_D \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$$ #### **2. GAN的使用方法** **步骤1:明确任务类型** - 图像生成(如人脸、风景) - 图像到图像的转换(如风格迁移) - 数据增强(生成补充样本) **步骤2:选择GAN变体** - **DCGAN**:深度卷积GAN,适合图像生成[^1]。 - **CycleGAN**:用于无配对数据的图像转换(如马变斑马)。 - **WGAN**:改进训练稳定性,使用Wasserstein距离替代原始损失函数。 **步骤3:准备数据集** - 数据需归一化(如像素值缩放到$[-1, 1]$)。 - 图像需统一尺寸(如$64 \times 64$)。 **步骤4:构建网络结构** - **生成器**:通常包含反卷积层(Deconvolution)和激活函数(如ReLU)。 - **判别器**:使用卷积层和LeakyReLU,输出为二元分类概率。 **步骤5:定义损失函数与优化器** - 生成器损失:$\mathcal{L}_G = -\mathbb{E}[\log D(G(z))]$ - 判别器损失:$\mathcal{L}_D = -\mathbb{E}[\log D(x)] - \mathbb{E}[\log(1 - D(G(z)))]$ - 优化器常用Adam,学习率需调参(如$0.0002$)。 **步骤6:对抗训练** - 交替训练判别器和生成器,保持两者能力平衡。 - 每轮训练判别器$k$次(如$k=1$),再训练生成器1次。 **步骤7:评估生成效果** - 定量指标:Fréchet Inception Distance(FID)[^2]。 - 定性评估:人工观察生成样本的真实性。 #### **3. 常见应用场景** 1. **图像生成** - 生成逼真的人脸、艺术作品(如DeepArt)[^1][^4]。 - 风光场景生成(如新能源场景模拟)[^3]。 2. **数据增强** - 医学图像扩充,提升小样本模型的泛化能力。 3. **图像编辑与风格转换** - 照片转油画风格(CycleGAN)[^2]。 - 图像修复(如填补缺失区域)。 4. **视频生成与预测** - 生成未来帧或动态场景(如VideoGAN)。 #### **4. 实现示例(PyTorch框架)** ```python import torch import torch.nn as nn # 生成器定义 class Generator(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.Linear(100, 256), nn.ReLU(), nn.Linear(256, 784), nn.Tanh() # 输出归一化到[-1,1] ) def forward(self, z): return self.main(z) # 判别器定义 class Discriminator(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, x): return self.main(x) ``` #### **5. 训练注意事项** - **模式崩溃**:生成器仅生成单一模式样本,可通过多样化损失函数(如WGAN-GP)缓解。 - **训练稳定性**:使用梯度惩罚(Gradient Penalty)或调整学习率衰减策略。 --- ###
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值