生成式对抗网络简介
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。(参考百度百科)
简单来说,GAN通过生成器和判别器的相互博弈,首先训练判别器辨别图片真伪,然后给以生成器真图片的数据不断训练,从而使得生成器可以达到以假乱真的效果,来骗过判别器;通过二者的反复博弈和训练后,生成器最终输出与真实值效果相近的图片。
原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
话不多说我们直接看代码
代码实战
模型整体搭建
首先在构造函数内搭建我们GAN的整体框架,这里看不懂没关系,我会在训练模块详细讲解模型的训练过程。
class GAN():
def __init__(self):
# --------------------------------- #
# 行28,列28,也就是mnist的shape
# --------------------------------- #
self.img_rows = 28
self.img_cols = 28
self.channels = 1
# 28,28,1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# adam优化器
optimizer = Adam(0.0002, 0.5)
# 判别器
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# 生成器
self.generator = self.build_generator()
gan_input = Input(shape=(self.latent_dim,))
img = self.generator(gan_input)
# 在训练generate的时候不训练discriminator
self.discriminator.trainable = False
# 对生成的假图片进行预测
validity = self.discriminator(img)
self.combined = Model(gan_input, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
生成器模型
输入:
输入维度自拟为100长度的随机噪声序列。
生成器采用3个全连接的隐层:
激活函数:α为0.2的LeakyRelu激活函数
批标准化:设置移动均值和方差动量为0.8
输出:
输出维度与图片维度大小保持一致
def build_generator(self):
# --------------------------------- #
# 生成器,输入一串随机数字
# --------------------------------- #
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
# np.prod计算所有元素的乘积,保证生成输出的维度与图片维度一致
model.add(Reshape