GAN + MNIST
# -*- coding: utf-8 -*-
# @Time : 2017/11/2 18:02
# @File : MNISTGAN.py
# @Author : Zhiwei Zhong
# @Function :
from __future__ import division
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
tf.set_random_seed(1)
np.random.seed(1)
BATCH_SIZE = 64
LR_G = 0.001
LR_D = 0.001
MEAN = [5, 5]
SIGMA = [[1, 0], [0, 1]]
TRAIN_STEP = 2000000
PRINT_STEP = 5000
EPOCHS = 30
mnist = input_data.read_data_sets('./mnist', one_hot=True)
print(mnist.train.images.shape)
def fake_data():
return np.random.normal(0, 1, BATCH_SIZE * 100).reshape(BATCH_SIZE, 100)
# return np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
def noise():
return np.random.normal(0, 1, 5 * 784).reshape(5, 784)
with tf.variable_scope("GEN"):
G_in = tf.placeholder(tf.float32, [None, 100], name="G_in")
G_l1 = tf.layers.dense(G_in, 128, name="G_Layer1")
# leak_relu 系数为0.01
G_l1 = tf.maximum(0.01*G_l1, G_l1)
G_l1 = tf.layers.dropout(G_l1, rate=0.2) # 随机失活
# G_l2 = tf.layers.dense(G_l1, 128, tf.nn.sigmoid, name="G_Layer2")
G_out = tf.layers.dense(G_l1, 28*28, tf.nn.sigmoid, name="G_out")
with tf.variable_scope("DISC"):
D_in = tf.placeholder(tf.float32, [None, 28*28], name="real_data")
D_l1 = tf.layers.dense(D_in, 128, name="D_Layer1")
D_l1 = tf.maximum(D_l1, 0.01*D_l1)
# D_l1 = tf.layers.dropout(D_l1, 0.2)
# D_l2 = tf.layers.dense(D_l1, 64, tf.nn.tanh, name="D_Layer2")
D_Real_Out = tf.layers.dense(D_l1, 1, tf.nn.sigmoid, name="D_Out")
D_l3 = tf.layers.dense(G_out, 128, name="D_Layer1", reuse=True)
D_l3 = tf.maximum(D_l3, 0.01*D_l3)
# D_l4 = tf.layers.dense(D_l3, 64, tf.nn.relu, name="D_Layer2", reuse=True)
D_Fake_Out = tf.layers.dense(D_l3, 1, tf.nn.sigmoid, name="D_Out", reuse=True)
D_Loss = -tf.reduce_mean(tf.log(D_Real_Out) + tf.log(1 - D_Fake_Out))
G_Loss = tf.reduce_mean(tf.log(1 - D_Fake_Out))
train_D = tf.train.AdamOptimizer(LR_D).minimize(
D_Loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="DISC"))
train_G = tf.train.AdamOptimizer(LR_G).minimize(
G_Loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))
# f, a = plt.subplots(1, 5, figsize=(5, 2))
# plt.ion()
samples = []
for e in range(EPOCHS):
for _ in range(mnist.train.num_examples // BATCH_SIZE): # 向下取整
b_x, by = mnist.train.next_batch(BATCH_SIZE)
b_x = b_x.reshape((BATCH_SIZE, 784))
# data_x = b_x * 2 - 1 # 缩放到-1 1之间
data_z = fake_data()
# b_x[: 5] = sess.run(G_out, {G_in: data_z})[: 5]
data_x = b_x
sess.run(train_D, feed_dict={G_in: data_z, D_in: data_x})
sess.run(train_G, feed_dict={G_in: data_z})
saver.save(sess, "./checkpoints/generator.ckpt")
DiscFake, DiscReal, DLoss, GLoss = sess.run([D_Fake_Out, D_Real_Out, D_Loss, G_Loss], {G_in: data_z, D_in: data_x})
print("EPOCH:{}, DISC_FAKE:{},DISC_REAL:{}, D_LOSS:{}, G_LOSS:{}".format(e, np.mean(DiscFake),
np.mean(DiscReal), np.mean(DLoss),np.mean(GLoss)))
data_z = fake_data()
pic = sess.run(G_out, {G_in: data_z})[: 5]
"""try:
for i in range(5):
a[i].clear()
a[i].imshow(np.reshape(pic[i], (28, 28)), cmap='gray')
a[i].set_xticks(())
a[i].set_yticks(())
plt.draw()
plt.pause(0.01)
except:
pass"""
samples.append(pic)
import pickle
# 将sample的生成数据记录下来
with open('train_samples.pkl', 'wb') as f:
pickle.dump(samples, f)
with open('train_samples.pkl', 'rb') as f:
samples = pickle.load(f)
def view_samples(epoch, samples):
"""
epoch代表第几次迭代的图像
samples为我们的采样结果
"""
fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=1)
plt.ion()
for i in range(5):
# for j in range(5):
axes[i].imshow(np.reshape(samples[epoch][i], (28, 28)), cmap='gray')
#plt.draw()
plt.ioff()
plt.show()
# 生成新的图片
saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
sample_noise = np.random.uniform(-1, 1, size=(25, 100))
gen_samples = sess.run(G_out, {G_in: sample_noise})
view_samples(10, gen_samples)