初识DCGAN
问:什么是DCGAN?
答:DC意为deep convolution,它把卷积神经网络应用在对抗生成网络中。
问:DCGAN相对于GAN做了哪些改变?
答:有以下几点:
(1)池化层pooling被卷积层convolution代替,网络结构中没有池化层。
具体而言,在生成模型中,允许卷积层代替池化层完成空间上采样的学习;
在判别模型中,允许卷积层代替池化层完成空间下采样的学习;
(2)在生成模型和判别模型中使用batchnorm。解决的问题是1)初始化差的问题;2)梯度消失、弥散等问题;3)防止生成模型把所有样本收敛于同一点;
(3)相比CNN移除了全连接层;
(4)使用激活函数不同,生成模型中出输出层使用tanh外,其他全部采用 Relu;判别模型全部采用Leaky ReLU。
DCGAN网络结构
1 G网络
100z代表一个100维的噪音向量,先通过一个简单的全连接层reshape成4X4X1024的特征图形式,再通过四层CONV层实现反卷积,最终输出一个64X64X3的图片。
2 D网络
代码讲解
1 数据及代码
数据:人脸数据 提取码:c8u3
卡通图像 提取码:6u6m
DCGAN代码提取码:umsx
2 代码讲解
(1)创建结构
main.py文件
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width, //输入输出数据的大小
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size, //一次迭代用到图像的数量
c_dim=FLAGS.c_dim, //通道数,黑白为1,彩色为3
dataset_name=FLAGS.dataset, //数据集名字
input_fname_pattern=FLAGS.input_fname_pattern,
is_crop=FLAGS.is_crop, //是否进行crop
checkpoint_dir=FLAGS.checkpoint_dir, //存储模型参数的路径
sample_dir=FLAGS.sample_dir)
model.py文件
class DCGAN(object):
def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
batch_size=64, sample_num = 64, output_hei