WGAN代码解读及实验总结

GAN作为图像的另一个新领域,本成为21世纪最好的idea。嘿嘿,最近小试牛刀,下载了个WGAN的代码,这里简单分析下,给大家一个参考。
【提示】
本文预计阅读时间5分钟,带灰色底纹的和加粗的为重要部分哦!

(一)WGAN初识

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

(二)代码分析

2.1 main struct

打开代码后,它的主要结构如下图所示。
在这里插入图片描述
我们先看一下wgan_conv主函数,打开之后首先直接到最底main的位置,如下

if __name__ == '__main__':

	os.environ['CUDA_VISIBLE_DEVICES'] = '0'
	# the dir of pic generated
	sample_folder = 'Samples/mnist_wgan_conv'
	if not os.path.exists(sample_folder):
		os.makedirs(sample_folder)

	# net param
	generator = G_conv_mnist()
	discriminator = D_conv_mnist()
	# data param
	data = mnist()

	# run
	wgan = WGAN(generator, discriminator, data)
	wgan.train(sample_folder)

这里做几点阐述

1、首先创建了一个目录用来存储你的生成图像,程序会每隔一段时间输出一个图像。
2、搞了三个类,一个generater生成器网络,一个是discriminator判别器类,然后是数据类。
3、又声明一个对象WGAN网络,然后调用它的train函数

OK至此,主函数结构阐述清楚。那此时你会想generater咋定义?discriminator咋定义?
好一个一个看。

2.2 generator

generator是生成器网络,其实就是搭了一个上采样的网络,先将噪声输入一维向量,通过全连接到更多的数据,然后把它展开成二维的图像,这里我们先用的灰度,你也可以改成彩色。然后再上采样,随意搞得,反正最后你要上采样到和你的正样本图像维度一致。如下所示:

class G_conv_mnist(object):
	def __init__(self):
		self.name = 'G_conv_mnist'

	def __call__(self, z):
		with tf.variable_scope(self.name) as scope:
			#step 1 全连接层,把z白噪声变为8*15*128图
			g = tcl.fully_connected(z, 8*15*128, activation_fn = tf.nn.relu, normalizer_fn=tcl.batch_norm,
									weights_initializer=tf.random_normal_initializer(0, 0.02))
			g = tf.reshape(g, (-1, 8, 15, 128)) 
			#step 2 反卷积/上采样 到16*30*64图    4代表卷积核大小
			g = tcl.conv2d_transpose(g, 64, 4,stride=2, 
									activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))
			#step 3 反卷积/上采样 到32*60*1的图,此时和真实手写体的数据是一样的图
			g = tcl.conv2d_transpose(g, 1, 4, stride=2, 
										activation_fn=tf.nn.sigmoid, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))
			print(g.shape)
			return g
	@property
	def vars(self):
		return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
	

注意:
这里你会看到一个call函数,它是咋用呢?
一个类下面有个call函数,你就可以生成一个对象后,直接把它当成方法用。例如
class G():
call(x):
print(x)
这样的话你就A = G(),然后再A(1)就打印了1。
其实就是说这个类弄好了,之后可以直接当函数用。

好,然后我们看一下discriminator

2.3 discriminator

和generator干了差不多的事情,他要把X和GX输进去,然后搭建一个卷积网络判别真假。

class D_conv_mnist(object):
	def __init__(self):
		self.name = 'D_conv_mnist'
	def __call__(self, x, reuse=
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值