import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from skimage.io import imsave
import os
import shutil
mninst=input_data.read_data_sets('data')
img_height=28
img_width=28
img_size=img_height*img_width
output_path='output'
max_epoch=500
h1_size=150
h2_size=300
z_size=100
batch_size=256
x_data=tf.placeholder(tf.float32,[batch_size,img_size])
z_prior=tf.placeholder(tf.float32,[batch_size,z_size])
keep_prob=tf.placeholder(tf.float32)
def build_generator(z_prior):
w1=tf.Variable(tf.truncated_normal([z_size,h1_size],stddev=0.1),name='g_w1')
b1=tf.Variable(tf.zeros([h1_size]),name='g_b1')
h1=tf.nn.relu(tf.nn.xw_plus_b(z_prior,w1,b1))
w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1),name='g_w2')
b2 = tf.Variable(tf.zeros([h2_size]),name='g_b2')
h2 = tf.nn.relu(tf.nn.xw_plus_b(h1, w2, b2))
w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1),name='g_w3')
b3 = tf.Variable(tf.zeros([img_size]),name='g_b3')
x_generate = tf.nn.tanh(tf.nn.xw_plus_b(h2, w3, b3))
g_params=[w1,b1,w2,b2,w3,b3]
return x_generate,g_params
def build_discriminator(x_data,x_generate,keep_prob):
x_in=tf.concat([x_data,x_generate],0)
w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1),name='d_w1')
b1 = tf.Variable(tf.zeros([h2_size]),name='d_b1')
h1 = tf.nn.dropout(tf.nn.relu(tf.nn.xw_plus_b(x_in, w1, b1)),keep_prob)
w2=tf.Variable(tf.truncated_normal([h2_size,h1_size],stddev=0.1),name='d_w2')
b2=tf.Variable(tf.zeros([h1_size]),name='d_b2')
h2=tf.nn.dropout(tf.nn.relu(tf.nn.xw_plus_b(h1, w2, b2)),keep_prob)
w3=tf.Variable(tf.truncated_normal([h1_size,1],stddev=0.1),name='d_w3')
b3=tf.Variable(tf.zeros([1]),name='d_b3')
h3=tf.nn.xw_plus_b(h2, w3, b3)
y_data=tf.nn.sigmoid(tf.slice(h3,[0,0],[batch_size,-1]))
y_generated=tf.nn.sigmoid(tf.slice(h3,[batch_size,0],[-1,-1]))
d_params=[w1,b1,w2,b2,w3,b3]
return y_data,y_generated,d_params
def save_result(gen_val,filename,grid_size=(8,8),grid_pad=5):
val_data=0.5*gen_val.reshape(gen_val.shape[0],img_height,img_width)+0.5
grid_h=img_height*grid_size[0]+grid_pad*(grid_size[0]-1)
grid_w = img_width * grid_size[1] + grid_pad * (grid_size[1] - 1)
img_grid=np.zeros([grid_h,grid_w],np.uint8)
for i,res in enumerate(val_data):
if i>=grid_size[0]*grid_size[1]:
break
img=res*255
img=img.astype(np.uint8)
row=(i//grid_size[0])*(img_height+grid_pad)
col=(i%grid_size[1])*(img_width+grid_pad)
img_grid[row:row+img_height,col:col+img_width]=img
imsave(filename,img_grid)
def train():
x_generated,g_params=build_generator(z_prior)
y_data, y_generated,d_params=build_discriminator(x_data,x_generated,keep_prob)
g_loss=-tf.log(y_generated)
d_loss=-(tf.log(y_data)+tf.log(1-y_generated))
optimizer=tf.train.AdamOptimizer(0.0001)
g_trainer=optimizer.minimize(g_loss,var_list=g_params)
d_trainer=optimizer.minimize(d_loss,var_list=d_params)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver=tf.train.Saver()
ckpt=tf.train.get_checkpoint_state(os.path.join(output_path,'model'))
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
steps=mninst.train.num_examples//batch_size
for i in range(max_epoch):
z_value = np.random.normal(0, 1, size=[batch_size, z_size])
for j in range(steps):
x_value,_=mninst.train.next_batch(batch_size)
x_value=2*x_value-1
d_loss_,_=sess.run([d_loss,d_trainer],feed_dict={x_data:x_value,z_prior:z_value,keep_prob:0.7})
if j%1==0:
g_loss_,_=sess.run([g_loss,g_trainer], feed_dict={x_data:x_value,z_prior:z_value,keep_prob:0.7})
print('Epoch:{}-iter:{},d_loss:{},g_loss:{}'.format(i, j,np.mean(d_loss_),np.mean(g_loss_)))
x_gen_val=sess.run(x_generated,feed_dict={z_prior:z_value})
save_result(x_gen_val,os.path.join(output_path,'sample/{}.jpg'.format(i)))
z_test_value = np.random.normal(0, 1, size=[batch_size, z_size])
x_test_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
save_result(x_test_gen_val, os.path.join(output_path, 'random/{}.jpg'.format(i)))
saver.save(sess,os.path.join(output_path,'model/gan.model'),global_step=i)
train()