文章标题

本文介绍了一个使用生成对抗网络(GAN)生成MNIST手写数字的项目。通过TensorFlow实现,该项目展示了如何训练GAN来创建逼真的手写数字样本。

GAN + MNIST

# -*- coding: utf-8 -*-
# @Time     : 2017/11/2 18:02
# @File     : MNISTGAN.py
# @Author   : Zhiwei Zhong
# @Function :
from __future__ import division
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt

tf.set_random_seed(1)
np.random.seed(1)
BATCH_SIZE = 64
LR_G = 0.001
LR_D = 0.001
MEAN = [5, 5]
SIGMA = [[1, 0], [0, 1]]
TRAIN_STEP = 2000000
PRINT_STEP = 5000
EPOCHS = 30
mnist = input_data.read_data_sets('./mnist', one_hot=True)
print(mnist.train.images.shape)
def fake_data():
    return np.random.normal(0, 1, BATCH_SIZE * 100).reshape(BATCH_SIZE, 100)
    # return np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))

def noise():
    return np.random.normal(0, 1, 5 * 784).reshape(5, 784)

with tf.variable_scope("GEN"):
    G_in = tf.placeholder(tf.float32, [None, 100], name="G_in")
    G_l1 = tf.layers.dense(G_in, 128, name="G_Layer1")
    # leak_relu  系数为0.01
    G_l1 = tf.maximum(0.01*G_l1, G_l1)
    G_l1 = tf.layers.dropout(G_l1, rate=0.2)        # 随机失活
    # G_l2 = tf.layers.dense(G_l1, 128, tf.nn.sigmoid, name="G_Layer2")
    G_out = tf.layers.dense(G_l1, 28*28, tf.nn.sigmoid, name="G_out")

with tf.variable_scope("DISC"):
    D_in = tf.placeholder(tf.float32, [None, 28*28], name="real_data")
    D_l1 = tf.layers.dense(D_in, 128, name="D_Layer1")
    D_l1 = tf.maximum(D_l1, 0.01*D_l1)
    # D_l1 = tf.layers.dropout(D_l1, 0.2)
    # D_l2 = tf.layers.dense(D_l1, 64, tf.nn.tanh, name="D_Layer2")
    D_Real_Out = tf.layers.dense(D_l1, 1, tf.nn.sigmoid, name="D_Out")

    D_l3 = tf.layers.dense(G_out, 128, name="D_Layer1", reuse=True)
    D_l3 = tf.maximum(D_l3, 0.01*D_l3)
    # D_l4 = tf.layers.dense(D_l3, 64, tf.nn.relu, name="D_Layer2", reuse=True)
    D_Fake_Out = tf.layers.dense(D_l3, 1, tf.nn.sigmoid, name="D_Out", reuse=True)

D_Loss = -tf.reduce_mean(tf.log(D_Real_Out) + tf.log(1 - D_Fake_Out))

G_Loss = tf.reduce_mean(tf.log(1 - D_Fake_Out))

train_D = tf.train.AdamOptimizer(LR_D).minimize(
    D_Loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="DISC"))
train_G = tf.train.AdamOptimizer(LR_G).minimize(
    G_Loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))
# f, a = plt.subplots(1, 5, figsize=(5, 2))
# plt.ion()
samples = []
for e in range(EPOCHS):
    for _ in range(mnist.train.num_examples // BATCH_SIZE):         # 向下取整
        b_x, by = mnist.train.next_batch(BATCH_SIZE)
        b_x = b_x.reshape((BATCH_SIZE, 784))
        # data_x = b_x * 2  - 1  # 缩放到-1 1之间
        data_z = fake_data()
        # b_x[: 5] = sess.run(G_out, {G_in: data_z})[: 5]
        data_x = b_x
        sess.run(train_D, feed_dict={G_in: data_z, D_in: data_x})
        sess.run(train_G, feed_dict={G_in: data_z})

    saver.save(sess, "./checkpoints/generator.ckpt")
    DiscFake, DiscReal, DLoss, GLoss = sess.run([D_Fake_Out, D_Real_Out, D_Loss, G_Loss], {G_in: data_z, D_in: data_x})
    print("EPOCH:{}, DISC_FAKE:{},DISC_REAL:{}, D_LOSS:{}, G_LOSS:{}".format(e, np.mean(DiscFake),
                            np.mean(DiscReal), np.mean(DLoss),np.mean(GLoss)))
    data_z = fake_data()
    pic = sess.run(G_out, {G_in: data_z})[: 5]
    """try:
        for i in range(5):
            a[i].clear()
            a[i].imshow(np.reshape(pic[i], (28, 28)), cmap='gray')
            a[i].set_xticks(())
            a[i].set_yticks(())
            plt.draw()
            plt.pause(0.01)
    except:
        pass"""
    samples.append(pic)
import pickle
# 将sample的生成数据记录下来
with open('train_samples.pkl', 'wb') as f:
    pickle.dump(samples, f)

with open('train_samples.pkl', 'rb') as f:
    samples = pickle.load(f)


def view_samples(epoch, samples):
    """
    epoch代表第几次迭代的图像
    samples为我们的采样结果
    """
    fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=1)
    plt.ion()
    for i in range(5):
        # for j in range(5):
            axes[i].imshow(np.reshape(samples[epoch][i], (28, 28)), cmap='gray')
    #plt.draw()
    plt.ioff()
    plt.show()



# 生成新的图片
saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GEN"))
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    sample_noise = np.random.uniform(-1, 1, size=(25, 100))
    gen_samples = sess.run(G_out, {G_in: sample_noise})
view_samples(10, gen_samples)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值