Pytorch框架实现WGAN-GP

文章介绍了WGAN-GP的背景,作为解决原始GAN和WGAN训练不稳定的改进方案。WGAN-GP通过梯度惩罚而非权重约束来稳定训练,适用于多种GAN架构,包括图像和语言模型。文章提供了代码实现和数据集下载链接,展示了在解决梯度消失和爆炸问题上的优势。
部署运行你感兴趣的模型镜像

目录

1.WGAN-GP产生背景

2.WGAN-GP的主要成就

3.权重限制的困难(Difficulties with weight constraints)

(1)WGAN-GP算法流程

(2)梯度消失(梯度弥散)和梯度爆炸

(3)梯度惩罚(gradient penalty) 

4.WGAN-GP网络结构

5.数据集下载

6.WGAN-GP代码实现

7.mainWindow窗口显示生成器生成的图片

8.模型下载 


GAN原理及Pytorch框架实现GAN(比较容易理解)

Pytorch框架实现DCGAN(比较容易理解)

CycleGAN的基本原理以及Pytorch框架实现

WGAN基本原理及Pytorch实现WGAN

提示:本文不会对WGAN-GP其中的公式进行推导,主要是讲解WGAN-GP算法的实现。

1.WGAN-GP产生背景

        自从生成对抗网络(Generative Adversarial Networks 简称GAN)被提出来之后,各种各样在原始GAN的基础上提出了很多的创新的GAN,但是原始的GAN网络训练不稳定,并且对超参数非常的敏感,还会导致模式崩塌的现象。因此,DCGAN对原始的GAN进行了改进,提出了一些新的方法和技巧来让GAN训练更加的稳定,但是这些技巧不是从最根本的方法去解决问题,所以WGAN诞生了。

        WGAN(Wasserstein GAN)提出了一种新的方法,可以说是从本质的方向上去解决GAN训练不稳定的问题,并且WGAN在这方面取得了很大的进展,但是有时仍然产生只有较差的样本或者无法收敛。该问题主要是由于在WGAN中使用权重裁剪时,对于critic模型来说,可能会导致不可预料的行为,因此提出了WGAN-GP,继续对其进行改进。

        WGAN-GP提出了一种剪辑的替代方案权重:惩罚critic相对于其输入的梯度数。WGAN-GP所提出的方法比标准WGAN执行得更好,并且能够在几乎没有超参数调整的情况下稳定地训练各种GAN架构,包括101层ResNets和具有连续生成器的语言模型,包括在CIFAR-10和LSUN卧室上实现了高质量的生成器。

提示:在论文的附录中给出很多关于WGAN-GP和其他GAN进行对比的效果,更加说明了WGAN-GP在训练过程中更加的稳定,效果更好。 

 

 

 2.WGAN-GP的主要成就

  • 在玩具数据集上,我们展示了critic的权重削减是如何导致不可预料的行为的。

  • 提出了梯度惩罚(WGAN-GP),它不会遇到同样的问题,并且效果更加的稳定。

  • 展示了各种GAN架构的稳定训练和性能改进超权重剪裁、高质量图像生成和字符级GAN语言没有任何离散采样的模型。

3.权重限制的困难(Difficulties with weight constraints)

  • WGAN中的权重裁剪会导致优化困难,即使优化成功,最终的critic也可能具有无法控制的价值表面。
  • 实验使用了特定形式的权重约束(每个权重大小的硬剪裁),并且也尝试了其他权重约束(L2范数剪裁、权重归一化)以及软约束(L1和L2权重衰减),发现它们表现出类似的问题。
  • 在某种程度上,这些问题可以通过critic中的批处理规范化来缓解。然而,即使使用批量标(Batch Normalization)准化,我们也观察到WGAN的critic往往无法达成一致

         左图:使用(顶部)权重剪裁和(底部)梯度惩罚,训练WGAN的critic值曲面在玩具数据集上进行运算。经过权重裁剪训练的critic无法捕捉到数据分布的更高时刻。生成器固定在真实数据处加上高斯噪声处

        右图:在Swiss Roll数据集上训练期间,深度WGAN critic的梯度范数在使用权重剪裁时爆炸或消失,但在使用梯度惩罚。权重剪裁(顶部)将权重推向两个值(剪裁范围的极端),这与渐变惩罚(底部)不同。可以看到使用权重裁剪(weight clipping[-c,c])时,对去那种的初始化出现了两个极端。然而使用梯度惩罚的话,权重在整个的区间都是呈现一种连续的。

(1)WGAN-GP算法流程

 

 具体代码实现如下:

    for batch_idx,(data,_) in enumerate(dataLoader):
        data = data.to(device)
        cur_batch_size = data.shape[0]

        #Train: Critic : max[critic(real)] - E[critic(fake)]
        loss_critic = 0
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(size = (cur_batch_size,Z_DIM,1,1),device=device)
            fake_img = gen(noise)
            #使用reshape主要是将最后的维度从[1,1,1,1]=>[1]
            critic_real = critic(data).reshape(-1)
            critic_fake = critic(fake_img).reshape(-1)

            gp = gradient_penality(critic,data,fake_img,device=device)

            loss_critic = -(torch.mean(critic_real)- torch.mean(critic_fake)) + LAMBDA_GP*gp
            opt_critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        #将维度从[1,1,1,1]=>[1]
        gen_fake = critic(fake_img).reshape(-1)
        #max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        loss_gen = -torch.mean(gen_fake)
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

(2)梯度消失(梯度弥散)和梯度爆炸

        由于权重约束和成本函数之间的相互作用,WGAN优化过程很困难,这导致梯度消失或梯度爆炸,而没有仔细调整剪裁阈值c。

        如上图所示:在Swiss Roll toy数据集上训练WGAN,在[0.1,0.01,0.001]中改变剪裁阈值c,并绘制关于连续激活层的临界损失梯度的范数。生成器和critic都是12层神经网络(多层感知机 - ReLU MLPs),没有批处理规范化(Batch Normalization)。

        由图可观察到,对于这些值中的每一个,随着在网络中不断进行训练,梯度要么呈指数增长,要么呈指数衰减。而WGAN-GP的方法产生了更稳定的梯度,既不会消失也不会爆炸,从而允许训练更复杂的网络。

(3)梯度惩罚(gradient penalty) 

        WGAN+GP:在WGAN的基础上加上梯度惩罚。公式的前一半是WGAN的critic的原始优化目标公式,后一半是梯度惩罚(GP)。


def gradient_penality(critic,real,fake,device='cpu'):
    """
    :param critic: 判别器模型
    :param real: 真实样本
    :param fake: 生成的样本
    :param device: 设备CUP or GPU
    :return:
    """
    BATCH_SIZE,C,H,W = real.shape
    alpha = torch.randn(size=(BATCH_SIZE,1,1,1)).repeat(1,C,H,W).to(device)
    interpolated_images=real*alpha + fake*(1-alpha)

    #计算判别器输出
    mixed_scores = critic(interpolated_images)
    #求导
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]
    gradient = gradient.view(gradient.shape[0],-1)
    gradient_norm = gradient.norm(2,dim = 1)
    gradient_penality = torch.mean((gradient_norm - 1)**2)
    return gradient_penality

        理解为什么权重裁剪在WGAN critic中存在问题,建议读者去看原论文的附录证明。

4.WGAN-GP网络结构

Pytorch框架实现DCGAN(比较容易理解)

5.数据集下载

链接:https://pan.baidu.com/s/1i_VU3aQpLkCx4Z5fhDVKHA 
提取码:79y3

6.WGAN-GP代码实现

提示:代码放在了Github上,本文的代码是参考下面这位博主写的,但是自己其中只是做了一下修改,并且其中加了一个mainWindows界面代码,方便后面训练的模型进行图像风格的转换。

参考博主的代码:https://b23.tv/QUc0CNb

本文的代码下载:https://github.com/KeepTryingTo/Pytorch-GAN

7.mainWindow窗口显示生成器生成的图片

提示:这里编写了一个显示生成器显示图片的程序(mainWindow.py),加载之前训练之后保存的生成器模型,之后可使用该模型进行随机生成图片,如下:

(1)运行mainWindow.py 初始界面如下

 

 点击随机生成图片:

 

 

8.模型下载 

链接:https://pan.baidu.com/s/1NyQzaxgX3B-pUAuwJWwCgQ 
提取码:39p5

您可能感兴趣的与本文相关的镜像

PyTorch 2.9

PyTorch 2.9

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值