在这个系列中将系统的构建GAN及其相关的一些变种模型,来了解GAN的基本原理。本片为此系列的第一篇,实现起来很简单,所以不要期待有很好的效果出来。
第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN,使用MINIST数据集,包含0-9的60000张手写数字图像,如图:
原理
首先简单讲一下GAN的工作原理,如下为前向传播的过程:
GAN网络有两个模型,分别是生成器generator和判别器discriminator。generator的作用是生成图片的,也就是我们想要的结果,通过输入随机噪声来生成图片;而discriminator是判断输入的图片是真实数据还是生成的假数据,输入生成的假数据或真实数据,输出真与假的概率值。
而反向传播过程其实是分开的,即generator和discriminator是分别进行梯度更新的。且交替进行训练的,一个模型训练,另一个模型就要保持不变,保持两个模型的能力要相当才能一起进步,否则如果判别器的性能要比生成器要好的话就很容易陷入模式崩溃mdoel collapse或梯度消失等。
下图为discriminator的反向传播的过程:
discriminator的工作是为了将生成的假数据判别为0,将真实的数据判别为1,即公正判别非黑即白,所以loss的计算为:
下图为generator的反向传播的过程:
而generator的工作是为了将生成的假数据让discriminator判别为1,即骗过discriminator颠倒黑白,所以loss的计算为:
代码
下面开始直接上代码,我在网上学习别人代码的习惯是先把所有代码跑起来再来仔细看每个代码模块,我在这也就先放上所有代码再分析各个模块。
model.py:
from torch import nn
import torch
def get_generator_block(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True),
)
class Generator(nn.Module):
def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
super(Generator, self).__init__()
self.gen = nn.Sequential(
get_generator_block(z_dim, hidden_dim),
get_generator_block(hidden_dim, hidden_dim * 2),
get_generator_block(hidden_dim * 2, hidden_dim * 4),
get_generator_block(hidden_dim * 4, hidden_dim * 8),
nn.Linear(hidden_dim * 8, im_dim),
nn.Sigmoid()
)
def forward(self, noise):
return self.gen(noise)
def get_gen(self)