条件GAN就是在GAN的基础上加入了一个条件y,在生成器和判别器中加入条件参与训练,这样训练出来的模型可以根据设置的条件生成想到的图,一般条件可以为label。CGAN的论文为:《Conditional Generative Adversarial Nets》。CGAN的结构图如下:
CGAN的实现只需要在GAN的基础上稍作修改即可,代码如下:
#coding=utf-8
import pickle
import tensorflow as tf
import numpy as np
import matplotlib.gridspec as gridspec
import os
import shutil
from scipy.misc import imsave
# 定义一个mnist数据集的类
class mnistReader():
def __init__(self,mnistPath,onehot=True):
self.mnistPath=mnistPath
self.onehot=onehot
self.batch_index=0
print ('read:',self.mnistPath)
fo = open(self.mnistPath, 'rb')
self.train_set,self.valid_set,self.test_set = pickle.load(fo,encoding='bytes')
fo.close()
self.data_label_train=list(zip(self.train_set[0],self.train_set[1]))
np.random.shuffle(self.data_label_train)
# 获取下一个训练集的batch
def next_train_batch(self,batch_size=100):
if self.batch_index < int(len(self.data_label_train)/batch_size):
# print ("batch_index:",self.batch_index )
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
return self._decode(datum,self.onehot)
else:
self.batch_index=0
np.random.shuffle(self.data_label_train)
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
return self._decode(datum,self.onehot)
# 获取测试集的数据
def test_data(self):
tdata,tlabel=self.test_set
data_label_test=list(zip(tdata,tlabel))
return self._decode(data_label_test,self.onehot)
# 把一个batch的训练数据转换为可以放入模型训练的数据