GAN作为图像的另一个新领域,本成为21世纪最好的idea。嘿嘿,最近小试牛刀,下载了个WGAN的代码,这里简单分析下,给大家一个参考。
【提示】
本文预计阅读时间5分钟,带灰色底纹的和加粗的为重要部分哦!
(一)WGAN初识
(二)代码分析
2.1 main struct
打开代码后,它的主要结构如下图所示。
我们先看一下wgan_conv主函数,打开之后首先直接到最底main的位置,如下
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# the dir of pic generated
sample_folder = 'Samples/mnist_wgan_conv'
if not os.path.exists(sample_folder):
os.makedirs(sample_folder)
# net param
generator = G_conv_mnist()
discriminator = D_conv_mnist()
# data param
data = mnist()
# run
wgan = WGAN(generator, discriminator, data)
wgan.train(sample_folder)
这里做几点阐述
1、首先创建了一个目录用来存储你的生成图像,程序会每隔一段时间输出一个图像。
2、搞了三个类,一个generater生成器网络,一个是discriminator判别器类,然后是数据类。
3、又声明一个对象WGAN网络,然后调用它的train函数
OK至此,主函数结构阐述清楚。那此时你会想generater咋定义?discriminator咋定义?
好一个一个看。
2.2 generator
generator是生成器网络,其实就是搭了一个上采样的网络,先将噪声输入一维向量,通过全连接到更多的数据,然后把它展开成二维的图像,这里我们先用的灰度,你也可以改成彩色。然后再上采样,随意搞得,反正最后你要上采样到和你的正样本图像维度一致。如下所示:
class G_conv_mnist(object):
def __init__(self):
self.name = 'G_conv_mnist'
def __call__(self, z):
with tf.variable_scope(self.name) as scope:
#step 1 全连接层,把z白噪声变为8*15*128图
g = tcl.fully_connected(z, 8*15*128, activation_fn = tf.nn.relu, normalizer_fn=tcl.batch_norm,
weights_initializer=tf.random_normal_initializer(0, 0.02))
g = tf.reshape(g, (-1, 8, 15, 128))
#step 2 反卷积/上采样 到16*30*64图 4代表卷积核大小
g = tcl.conv2d_transpose(g, 64, 4,stride=2,
activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))
#step 3 反卷积/上采样 到32*60*1的图,此时和真实手写体的数据是一样的图
g = tcl.conv2d_transpose(g, 1, 4, stride=2,
activation_fn=tf.nn.sigmoid, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))
print(g.shape)
return g
@property
def vars(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
注意:
这里你会看到一个call函数,它是咋用呢?
一个类下面有个call函数,你就可以生成一个对象后,直接把它当成方法用。例如
class G():
call(x):
print(x)
这样的话你就A = G(),然后再A(1)就打印了1。
其实就是说这个类弄好了,之后可以直接当函数用。
好,然后我们看一下discriminator
2.3 discriminator
和generator干了差不多的事情,他要把X和GX输进去,然后搭建一个卷积网络判别真假。
class D_conv_mnist(object):
def __init__(self):
self.name = 'D_conv_mnist'
def __call__(self, x, reuse=