GAN生成mnist图片

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from skimage.io import imsave
import os
import shutil
# mnist 数据集读取进来
mninst=input_data.read_data_sets('data')

# 建立网络结构参数
img_height=28
img_width=28
img_size=img_height*img_width

output_path='output' # fake数据保存路径
max_epoch=500   # 迭代次数

# 神经网络参数
h1_size=150 # 第一个隐层
h2_size=300 # 第二个隐层
z_size=100  # 噪声
batch_size=256

x_data=tf.placeholder(tf.float32,[batch_size,img_size])
z_prior=tf.placeholder(tf.float32,[batch_size,z_size])
keep_prob=tf.placeholder(tf.float32)

# 构建生成器,bp网络
# z(batch,z_size)->w1(z_size,h1_size)->w2(h1_size,h2_size)->w3(h2_size,img_size)
#=》[batch,img_size]
def build_generator(z_prior):
    w1=tf.Variable(tf.truncated_normal([z_size,h1_size],stddev=0.1),name='g_w1')
    b1=tf.Variable(tf.zeros([h1_size]),name='g_b1')
    h1=tf.nn.relu(tf.nn.xw_plus_b(z_prior,w1,b1))

    w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1),name='g_w2')
    b2 = tf.Variable(tf.zeros([h2_size]),name='g_b2')
    h2 = tf.nn.relu(tf.nn.xw_plus_b(h1, w2, b2))

    w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1),name='g_w3')
    b3 = tf.Variable(tf.zeros([img_size]),name='g_b3')
    x_generate = tf.nn.tanh(tf.nn.xw_plus_b(h2, w3, b3))

    g_params=[w1,b1,w2,b2,w3,b3]
    return x_generate,g_params

# 构建判别器,bp网络
# [batch,img_size]->w1(img_size,h2_size)->w2(h2_size,h1_size)->w3(h1_size,1)
# =>[batch,1]
def build_discriminator(x_data,x_generate,keep_prob):
    x_in=tf.concat([x_data,x_generate],0)
    w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1),name='d_w1')
    b1 = tf.Variable(tf.zeros([h2_size]),name='d_b1')
    h1 = tf.nn.dropout(tf.nn.relu(tf.nn.xw_plus_b(x_in, w1, b1)),keep_prob)

    w2=tf.Variable(tf.truncated_normal([h2_size,h1_size],stddev=0.1),name='d_w2')
    b2=tf.Variable(tf.zeros([h1_size]),name='d_b2')
    h2=tf.nn.dropout(tf.nn.relu(tf.nn.xw_plus_b(h1, w2, b2)),keep_prob)

    w3=tf.Variable(tf.truncated_normal([h1_size,1],stddev=0.1),name='d_w3')
    b3=tf.Variable(tf.zeros([1]),name='d_b3')
    h3=tf.nn.xw_plus_b(h2, w3, b3)

    y_data=tf.nn.sigmoid(tf.slice(h3,[0,0],[batch_size,-1]))
    y_generated=tf.nn.sigmoid(tf.slice(h3,[batch_size,0],[-1,-1]))

    d_params=[w1,b1,w2,b2,w3,b3]
    return y_data,y_generated,d_params

# 8*8=>1个图
# grid_pad 每一个格子间距上下为5
# gen_val.shape:[batch,img_size]
def save_result(gen_val,filename,grid_size=(8,8),grid_pad=5):
    val_data=0.5*gen_val.reshape(gen_val.shape[0],img_height,img_width)+0.5
    # 实际框的高度
    grid_h=img_height*grid_size[0]+grid_pad*(grid_size[0]-1)
    grid_w = img_width * grid_size[1] + grid_pad * (grid_size[1] - 1)
    img_grid=np.zeros([grid_h,grid_w],np.uint8)
    for i,res in enumerate(val_data):
        if i>=grid_size[0]*grid_size[1]:
            break
        img=res*255
        img=img.astype(np.uint8)
        row=(i//grid_size[0])*(img_height+grid_pad)
        col=(i%grid_size[1])*(img_width+grid_pad)
        img_grid[row:row+img_height,col:col+img_width]=img
    imsave(filename,img_grid)


def train():
    # 调用生成模型
    x_generated,g_params=build_generator(z_prior)
    # 调用判别模型
    y_data, y_generated,d_params=build_discriminator(x_data,x_generated,keep_prob)

    # 构建生成器损失
    g_loss=-tf.log(y_generated)
    # 构建判别器损失
    d_loss=-(tf.log(y_data)+tf.log(1-y_generated))

    # 构建生成器优化函数,传入生成器变量参数
    optimizer=tf.train.AdamOptimizer(0.0001)
    g_trainer=optimizer.minimize(g_loss,var_list=g_params)
    # 构建判别器优化函数,传入判别器变量参数
    d_trainer=optimizer.minimize(d_loss,var_list=d_params)

    # 迭代训练,训练两次的D,一次的G
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # 持久化模型
        saver=tf.train.Saver()
        # 断点续训
        ckpt=tf.train.get_checkpoint_state(os.path.join(output_path,'model'))
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess,ckpt.model_checkpoint_path)

        steps=mninst.train.num_examples//batch_size
        for i in range(max_epoch):
            z_value = np.random.normal(0, 1, size=[batch_size, z_size])
            for j in range(steps):

                x_value,_=mninst.train.next_batch(batch_size)
                # 设置数据范围[-1,1]
                x_value=2*x_value-1

                # 执行判别器的优化
                d_loss_,_=sess.run([d_loss,d_trainer],feed_dict={x_data:x_value,z_prior:z_value,keep_prob:0.7})
                if j%1==0:
                    g_loss_,_=sess.run([g_loss,g_trainer], feed_dict={x_data:x_value,z_prior:z_value,keep_prob:0.7})
                    print('Epoch:{}-iter:{},d_loss:{},g_loss:{}'.format(i, j,np.mean(d_loss_),np.mean(g_loss_)))
            # 保存生成器生成的数据为图片,存在本地磁盘中,作为人为判断中止训练的依据
            x_gen_val=sess.run(x_generated,feed_dict={z_prior:z_value})
            save_result(x_gen_val,os.path.join(output_path,'sample/{}.jpg'.format(i)))
            # 再去生成一个随机的正态分布的噪声,相当于测试数据
            z_test_value = np.random.normal(0, 1, size=[batch_size, z_size])
            x_test_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
            save_result(x_test_gen_val, os.path.join(output_path, 'random/{}.jpg'.format(i)))
            saver.save(sess,os.path.join(output_path,'model/gan.model'),global_step=i)


train()
#print(mninst.train.next_batch(1))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值