1.生成对抗网络训练心得
作者:阿阿阿阿毛
https://www.jianshu.com/p/aab68eb0f7ed2.GAN — Ways to improve GAN performance
作者:Jonathan Hui
https://towardsdatascience.com/gan-ways-to-improve-gan-performance-acf37f9f59b
一、权重
a. 调节Generator loss中GAN loss的权重
G loss和Gan loss在一个尺度上或者G loss比Gan loss大一个尺度。但是千万不能让Gan loss占主导地位, 这样整个网络权重会被带偏。
二、训练次数
b. 调节Generator和Discrimnator的训练次数比
一般来说,Discrimnator要训练的比Genenrator多。比如训练五次Discrimnator,再训练一次Genenrator(WGAN论文 是这么干的)。
三、学习率
c. 调节learning rate
这个学习速率不能过大。一般要比Genenrator的速率小一点。
四、优化器
d. Optimizer的选择不能用基于动量法的
如Adam和momentum。可使用RMSProp或者SGD。
五、结构
e. Discrimnator的结构可以改变
如果用WGAN,判别器的最后一层需要去掉sigmoid。但是用原始的GAN,需要用sigmoid,因为其loss function里面需要取log,所以值必须在[0,1]。这里用的是邓炜的critic模型当作判别器。之前twitter的论文里面的判别器即使去掉了sigmoid也不好训练。
六、 loss曲线
f. Generator loss的误差曲线走向
因为Generator的loss定义为: G_loss = -tf.reduce_mean(D_fake) Generator_loss = gen_loss + lamda*G_loss 其中gen_loss为Generator的loss,G_loss为Discrimnator的loss,目标是使Generator_loss不断变小。所以理想的Generator loss的误差曲线应该是不断往0靠的下降的抛物线。
g. Discrimnator loss的误差曲线走向
因为Discrimnator的loss定义为: D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)这个是一个和Generator抗衡的loss。目标就是使判别器分不清哪个是生成器的输出哪个是真实的label。所以理想的Discrimnator loss的误差曲线应该是最终在0附近振荡,即傻傻分不清。换言之,就是判别器有50%的概率判断你是真的,50%概率判断你是假的。

本文总结了Wasserstein GAN(WGAN)训练的关键技巧,包括调节Generator和Discriminator的权重、训练次数比例、学习率、优化器选择、Discriminator结构调整以及Generator和Discriminator的loss曲线分析。通过这些方法,可以有效改善GAN的性能和稳定性。
744

被折叠的 条评论
为什么被折叠?



