一.DCGAN(Deep Convolutional GAN)
依旧有Generator,Discriminator,可使用MNSIT训练生成图片。
和GAN的不同:
1.增加了反卷积,能生成更好的图片,但依旧还是2分类,判断是不是手写数字。
2.采用了BatchNormalization 防止梯度消失和过拟合
代码引用的《Web安全之强化学习与GAN》,位置:
https://github.com/duoergun0729/3book/tree/master/code/keras-dcgan.py
生成器G代码:
def generator_model():
model = Sequential()
model.add(Dense(input_dim=100, units=1024))
model.add(Activation('tanh'))
model.add(Dense(128*7*7))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(Activation('tanh'))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(1, (5, 5), padding='same'))
model.add(Activation('tanh'))
return model
判别器D代码:
def discriminator_model():
model = Sequential()
model.add(
Conv2D(64, (5, 5),
padding='same',
input_shape=(28, 28, 1))
)
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (5, 5)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('tanh'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model
训练图: