代码地址:https://github.com/gongpx20069/CycleGAN-TensorFlow
这是Van Huy巨佬的代码,做一个学习巨佬CycleGAN代码的小笔记。CycleGAN的一个巨大的优点就是不需要X和Y两个域(相互转化的两个域)有一一对应的关系。
不得不说神经网络这种东西很消耗GPU内存,显卡内存决定了网络层数、输入图片大小这些很重要的东西。
总体代码笔记
整体来看,大佬的代码有如下的文件:
- sample:一些已经训练好的图片示例;
- bulid_data.py:用于将data目录下的trainA和trainB转化为data/tfrecords中的.tfrecords文件,方便网络读取;
- discriminator.py:定义了判别器的类;
- download_dataset.sh:下载数据集,查看具体代码可以看出它是在项目中新建了data文件,下载斑马马,苹果橘子的数据集的压缩包,并且解压;解压后data中会有trainA和trainB(必要),以及testA和testB的文件(非必要)。
- export_graph.py:将保存的模型(checkpoints)发布为.pb这样的模型文件,比如apple2orange.pb等;
- generator.py:定义了生成器的类;
- inference.py:用于使用模型文件(.pb,以及发布的模型文件)来测试将图片X变为图片Y,或将图片Y变为图片X;
- model.py:CycleGAN的具体模型参数,这里引用了(generator.py)生成器类和(discriminator.py)判别器类来分别实体化G生成器,以及F生成器,以及D(X)判别器,以及D(Y)判别器;
- ops.py:即operations,指tensorflow中的具体操作,比如可视化、具体的神经网络某一层;
- reader.py:读入tfrecords文件的类
- train.py:规定了训练批次,每批图片数量,图片大小(256*256)等,可以从checkpoints继续训练;
- utils.py:定义了两种函数,一种是将图片从像素点[0, 255]转化为[-1, 1],另一种刚好相反,将[-1, 1]转化为[0, 255];并且使用tf.map_fn函数来批处理这两种函数;
具体模型搭建笔记
由于是在看作者到底如和搭建成的CycleGAN,我们的思路就应当随着train的过程慢慢深入。
目前来看,作者模型的搭建依赖关系是:
ops(神经网络某一层)
->generator(生成器类别)|discriminator(判别器类)
->model(CycleGAN具体模型)
->train(训练的批次等参数)
因此我们的学习顺序也是同一个方向:从ops.py到train.py
1.0 ops.py
ops.py定义了如下几个函数:
1. def c7s1_k(input, k, reuse=False, norm=‘instance’, activation=‘relu’, is_training=True, name=‘c7s1_k’)
函数的作用:
首先为输入图片左右都填充3条边,再用一个773的过滤器,步长为1,将结果先通过normal再用激活函数(tanh或者relu)输出,输出结果的深度为k。
输入参数的解释为:
- input:输入是一个4D-Tensor,即一批图像;
- k:输出的深度,也是过滤器最后一个参数;
- reuse:tf.variable_scope函数中的一个参数,一般来讲reuse=tf.AUTO_REUSE;或者在该命名域再次被使用时为resue = True;
- norm:可以选择"instance"或者"batch",分别代表instance_normal和batch_normal;
- activation:该卷积层在输出时的激活函数"relu"或者"tanh";
- is_training: