今天看到paperweekly上有人分享了一个WGAN-GP的实现,是以MNIST为数据集,代码简洁,结构清晰。我最近也在看GAN的相关内容,就下载下来做个参考。
代码地址:https://github.com/bojone/gan/
对于这个基于tensorflow实现的代码,我对其进行了简单的注释。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import numpy as np
from scipy import misc,ndimage
#读入本地的MNIST数据集,该函数为mnist专用
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)
batch_size = 100 #每个batch的大小
width,height = 28,28 #每张图片包含28*28个像素点
mnist_dim = width*height #用一个数字数组表示一张图,那么这个数组展开成向量的长度就是28*28=784
random_dim = 10 #每张图表示一个数字,从0到9
epochs = 1000000 #共100万轮
def my_init(size): #从[-0.05,0.05]的均匀分布中采样得到维度是size的输出
return tf.random_uniform(size, -