使用pytorch构建一个初级的无监督的GAN网络模型

在这个系列中将系统的构建GAN及其相关的一些变种模型,来了解GAN的基本原理。本片为此系列的第一篇,实现起来很简单,所以不要期待有很好的效果出来。

第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN,使用MINIST数据集,包含0-9的60000张手写数字图像,如图:
在这里插入图片描述

原理

首先简单讲一下GAN的工作原理,如下为前向传播的过程:
在这里插入图片描述
GAN网络有两个模型,分别是生成器generator和判别器discriminator。generator的作用是生成图片的,也就是我们想要的结果,通过输入随机噪声来生成图片;而discriminator是判断输入的图片是真实数据还是生成的假数据,输入生成的假数据或真实数据,输出真与假的概率值。

而反向传播过程其实是分开的,即generator和discriminator是分别进行梯度更新的。且交替进行训练的,一个模型训练,另一个模型就要保持不变,保持两个模型的能力要相当才能一起进步,否则如果判别器的性能要比生成器要好的话就很容易陷入模式崩溃mdoel collapse或梯度消失等。
下图为discriminator的反向传播的过程:
在这里插入图片描述
discriminator的工作是为了将生成的假数据判别为0,将真实的数据判别为1,即公正判别非黑即白,所以loss的计算为:
在这里插入图片描述

下图为generator的反向传播的过程:
在这里插入图片描述
而generator的工作是为了将生成的假数据让discriminator判别为1,即骗过discriminator颠倒黑白,所以loss的计算为:
在这里插入图片描述

代码

下面开始直接上代码,我在网上学习别人代码的习惯是先把所有代码跑起来再来仔细看每个代码模块,我在这也就先放上所有代码再分析各个模块。
model.py:

from torch import nn
import torch

def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True),
    )

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
    def forward(self, noise):
        return self.gen(noise)
    def get_gen(self)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值