https://github.com/HyeongseokSon1/SRFeat
- config.py
#在深度学习中往往利用 easydict 建立一个全局的变量
from easydict import EasyDict as edict
import json
config = edict()
config.TRAIN = edict()
## Adam
config.TRAIN.batch_size = 9
config.TRAIN.lr_init = 1e-4
config.TRAIN.lr_decay = 0.1
config.TRAIN.beta1 = 0.9
# various log location
config.TRAIN.checkpoint = 'checkpoint/'
config.TRAIN.save_valid_results = 'result_valid'
config.TRAIN.summary_g = 'summary_init'
config.TRAIN.summary_adv = 'summary'
## train set location
config.TRAIN.hr_img_path = 'SRdataset/DIV_train_cropped/GT/' #need to change
config.TRAIN.lr_img_path = 'DIV_train_cropped/LR_bicubic/' #need to change
config.VALID = edict()
## validation set location
config.VALID.hr_img_path = 'SRdataset/valid/GT/' #need to change
config.VALID.lr_img_path = 'SRdataset/valid/LR_bicubic/' #need to change
config.TEST = edict()
config.TEST.checkpoint = 'models/'
config.TEST.input_path = 'SRdataset/test/LR_bicubic/' #need to change
config.TEST.save_path = 'result_test'
def log_config(filename, cfg):
with open(filename, 'w') as f:
f.write("================================================\n")
f.write(json.dumps(cfg, indent=4))
f.write("\n================================================\n")
- main_gan_eval.py
def modcrop(imgs, modulo):#将图片剪切成可以整除modulo的尺寸
tmpsz = imgs.shape
sz = tmpsz[0:2]
h = sz[0] - sz[0]%modulo #减去余数
w = sz[1] - sz[1]%modulo #减去余数
imgs = imgs[0:h+1, 0:w+1,:]
return imgs
def read_all_imgs(img_list, path='', n_threads=32):
""" 通过指定路径和每个图像文件的名称返回数组中的所有图像 """
imgs = []
for idx in range(0, len(img_list), n_threads):
b_imgs_list = img_list[idx : idx + n_threads]
b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=path)
# print(b_imgs.shape)
imgs.extend(b_imgs)
print('read %d from %s' % (len(imgs), path))
return imgs
def evaluate():
## 创建文件夹来保存结果图像
save_dir = config.TEST.save_path
tl.files.exists_or_mkdir(save_dir)
checkpoint_dir = config.TEST.checkpoint
im_path_lr = config.TEST.input_path
###====================== 预先加载数据 ===========================###
valid_lr_img_list = sorted(tl.files.load_file_list(path=im_path_lr, regx='.*.*', printable=False))#排序
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
tl.layers.initialize_global_variables(sess)
t_image = tf.placeholder('float32', [None, None, None, 3], name='input_image')
net_g = SRGAN_g(t_image, is_train=False, reuse=False)
###========================== RESTORE G =============================###
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/SRFeat_full.npz', network=net_g)#已有的模型,读取
for imid in range(len(valid_lr_img_list)):
valid_lr_img = get_imgs_fn(valid_lr_img_list[imid],im_path_lr)
print(valid_lr_img.shape)
valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1]归一化
###======================= EVALUATION =============================###
start_time = time.time()# 时间
out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})# 执行模型;输出图像
print("took: %4.4fs" % (time.time() - start_time))
print("LR size: %s / generated HR size: %s" % (valid_lr_img.shape, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3)
print("[*] save images")
tl.vis.save_image(out[0], save_dir+'/' + valid_lr_img_list[imid]) #保存图片
- main_gan_init.py
ef train():
## create folders to save result images and trained model
checkpoint_dir = config.TRAIN.checkpoint
tl.files.exists_or_mkdir(checkpoint_dir)
###====================== 预先加载数据 ===========================###
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.bmp', printable=False))
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.bmp', printable=False))
###========================== 定义模型 ============================###
## train inference
t_image = tf.placeholder('float32', [batch_size, patch_size_l, patch_size_l, 3], name='t_image_input')
t_target_image = tf.placeholder('float32', [batch_size, patch_size_h, patch_size_h, 3], name='t_target_image')
net_g= SRGAN_g(t_image, is_train=True, reuse=False)
## test inference
t_sample_image = tf.placeholder('float32', [5, 56, 56, 3], name='t_sample_image')
net_g_test = SRGAN_g(t_sample_image, is_train=False, reuse=True)
# ###========================== 定义训练 OPS ==========================###
mse_loss = tl.cost.mean_squared_error(net_g.outputs , t_target_image, is_mean=True)#用MSE进行初始训练
tf.summary.scalar('mse_loss', mse_loss)
merged = tf.summary.merge_all()
g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
with tf.variable_scope('learning_rate'):
lr_v = tf.Variable(lr_init, trainable=False)
## Pretrain
g_optim_init_= tf.train.AdamOptimizer(lr_v, beta1=beta1)
g_optim_init = g_optim_init_.minimize(mse_loss, var_list=g_vars)
###========================== 恢复模型 =============================###
sess = tf.Session(config=tf.ConfigProto( log_device_placement=False))
train_writer = tf.summary.FileWriter(config.TRAIN.summary_g,sess.graph)
train_writer.add_graph(sess.graph)
tl.layers.initialize_global_variables(sess)
if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}.npz'.format('SRFeat'), network=net_g) is False:
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}_init.npz'.format('SRFeat'), network=net_g)
###============================= 训练 ===============================###
## 使用训练集的第一个' batch_size '在训练期间进行快速测试
sample_imgs = read_all_imgs(valid_hr_img_list[0:5], path=config.VALID.hr_img_path, n_threads=5) # if no pre-load train set
sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=scale_imgs_fn)
sample_imgs_96 = read_all_imgs(valid_hr_img_list[0:5], path=config.VALID.lr_img_path, n_threads=5)
sample_imgs_96 = tl.prepro.threading_data(sample_imgs_96, fn=scale_imgs_fn)
val_mset = tl.cost.mean_squared_error(net_g_test.outputs , sample_imgs_384, is_mean=True)
val_summary = tf.summary.scalar('val_mse', val_mset)
###========================= 初始化 G ====================###
## 固定的学习速率
sess.run(tf.assign(lr_v, lr_init))
lr_hr_list = list(zip(train_hr_img_list,train_lr_img_list))
random.shuffle(lr_hr_list)
train_hr_img_list, train_lr_img_list = zip(*lr_hr_list)
i =0
print(" ** fixed learning rate: %f (for init G)" % lr_init)
for epoch in range(0, n_epoch_init+1):
# epoch_time = time.time()
total_mse_loss, n_iter = 0, 0
if epoch !=0 and (epoch % decay_every == 0):
new_lr_decay = lr_decay ** (epoch // decay_every)
sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
print(log)
elif epoch == 0:
sess.run(tf.assign(lr_v, lr_init))
log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
print(log)
for idx in range(0, len(train_hr_img_list) -batch_size , batch_size):
step_time = time.time()
b_imgs_list = train_hr_img_list[idx : idx + batch_size]
b_imgs_list_lr = train_lr_img_list[idx : idx + batch_size]
b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
b_imgs_96 = tl.prepro.threading_data(b_imgs_list_lr, fn=get_imgs_fn, path=config.TRAIN.lr_img_path)
b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=scale_imgs_fn)
b_imgs_96 = tl.prepro.threading_data(b_imgs_96, fn=scale_imgs_fn)
## update G
summary,errM, out, _ = sess.run([merged,mse_loss,net_g.outputs, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})
train_writer.add_summary(summary,i)
print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
total_mse_loss += errM
n_iter += 1
## quick evaluation on train set
if (i != 0) and (i % 20 == 0):
out,val_,val_summ = sess.run([net_g_test.outputs,val_mset,val_summary], {t_sample_image: sample_imgs_96})#; print('gen sub-image:', out.shape, out.min(), out.max())
train_writer.add_summary(val_summ,i)
print("validate")
# save model
if (i != 0) and (i % 100 == 0):
tl.files.save_npz(net_g.all_params, name=checkpoint_dir+'/g_{}_init.npz'.format('SRFeat'), sess=sess)
i= i+1
train_writer.close()
- model.py
for i in range(16):
nn = Conv2d(n, 128, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c1/%s' % i)
nn = BatchNormLayer(nn, act=lrelu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
nn = Conv2d(nn, 128, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c2/%s' % i)
nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
nn = ElementwiseLayer([n, nn], tf.add, 'b_residual_add/%s' % i)
n = nn
t = Conv2d(nn, 128, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c3/%s' % i)
temp.append(t)
n = ElementwiseLayer([n, temp[0],temp[1],temp[2],temp[3],
temp[4],temp[5],temp[6],temp[7],
temp[8],temp[9],temp[10],temp[11],
temp[12],temp[13],temp[14]], tf.add, 'add3')
# B residual blacks end
def SRGAN_g(t_image, is_train=False, reuse=False):
""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
feature maps (n) and stride (s) feature maps (n) and stride (s)
"""
#初始化
w_init = tf.random_normal_initializer(stddev=0.01)
b_init = None
g_init = tf.ones_initializer()
lrelu = lambda x : tl.act.lrelu(x, 0.2)
with tf.variable_scope("SRGAN_g", reuse=reuse) as vss:
tl.layers.set_name_reuse(reuse)
n = InputLayer(t_image, name='in')
#蓝色的
n = Conv2d(n, 128, (9, 9), (1, 1), act=None, padding='SAME', W_init=w_init, name='n64s1/c')
temp = []
# B residual blocks
for i in range(16):
nn = Conv2d(n, 128, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c1/%s' % i)
nn = BatchNormLayer(nn, act=lrelu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
nn = Conv2d(nn, 128, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c2/%s' % i)
nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
nn = ElementwiseLayer([n, nn], tf.add, 'b_residual_add/%s' % i)
n = nn
#每一块后面都加了一个卷积层
t = Conv2d(nn, 128, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c3/%s' % i)
temp.append(t)
n = ElementwiseLayer([n, temp[0],temp[1],temp[2],temp[3],
temp[4],temp[5],temp[6],temp[7],
temp[8],temp[9],temp[10],temp[11],
temp[12],temp[13],temp[14]], tf.add, 'add3')
# B residual blacks end
#最后两个小框
#第一次上采样
n = Conv2d(n, 512, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')
n = SubpixelConv2d(n, scale=2, act=lrelu, n_out_channel=None, name='pixelshufflerx2/1')
#第二次上采样
n = Conv2d(n, 512, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init ,name='n256s1/2')
n = SubpixelConv2d(n, scale=2, act=lrelu, n_out_channel=None, name='pixelshufflerx2/2')
#最后的蓝色的
n = Conv2d(n, 3, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='out')
return n
def SRGAN_d(t_image, is_train=True, reuse=False):
""" Discriminator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
feature maps (n) and stride (s) feature maps (n) and stride (s)
"""
# w_init = tf.truncated_normal_initializer(stddev=0.01)
w_init = tf.contrib.layers.variance_scaling_initializer()
b_init = tf.constant_initializer(value=0.0)
# g_init = tf.random_normal_initializer(1., 0.02)
g_init = tf.ones_initializer()
lrelu = lambda x : tl.act.lrelu(x, 0.2)
with tf.variable_scope("SRGAN_d", reuse=reuse) as vs:
tl.layers.set_name_reuse(reuse)
n = InputLayer(t_image, name='in')
#开始第一块
n = Conv2d(n, 64, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c')
n = Conv2d(n, 64, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n64s2/b')
n = Conv2d(n, 128, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n128s1/b')
n = Conv2d(n, 128, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n128s2/b')
n = Conv2d(n, 256, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s1/b')
n = Conv2d(n, 256, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s2/b')
n = Conv2d(n, 512, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n512s1/b')
n = Conv2d(n, 512, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n512s2/b')
#倒数Dense模块
n = FlattenLayer(n, name='f')
n = DenseLayer(n, n_units=1024,act=lrelu, name='d1024')
n = DenseLayer(n, n_units=1,name='out')
logits = n.outputs
n.outputs = tf.nn.sigmoid(n.outputs)
return n, logits
剩下的和VGG有关,论文中没有写具体的网络
def SRGAN_d2(input_images, is_train=True, reuse=False):
w_init = tf.contrib.layers.variance_scaling_initializer()
b_init = None
gamma_init = tf.ones_initializer()
df_dim = 64
lrelu = lambda x: tl.act.lrelu(x, 0.2)
with tf.variable_scope("SRGAN_d", reuse=reuse):
tl.layers.set_name_reuse(reuse)
net_in = InputLayer(input_images, name='input/images')
net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu,
padding='SAME', W_init=w_init, name='h0/c')
net_h1 = Conv2d(net_h0, df_dim*2, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='h1/c')
net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train,
gamma_init=gamma_init, name='h1/bn')
net_h2 = Conv2d(net_h1, df_dim*4, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='h2/c')
net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train,
gamma_init=gamma_init, name='h2/bn')
net_h3 = Conv2d(net_h2, df_dim*8, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='h3/c')
net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train,
gamma_init=gamma_init, name='h3/bn')
net_h4 = Conv2d(net_h3, df_dim*16, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='h4/c')
net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train,
gamma_init=gamma_init, name='h4/bn')
net_h5 = Conv2d(net_h4, df_dim*32, (4, 4), (2, 2), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='h5/c')
net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train,
gamma_init=gamma_init, name='h5/bn')
net_h6 = Conv2d(net_h5, df_dim*16, (1, 1), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='h6/c')
net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train,
gamma_init=gamma_init, name='h6/bn')
net_h7 = Conv2d(net_h6, df_dim*8, (1, 1), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='h7/c')
net_h7 = BatchNormLayer(net_h7, is_train=is_train,
gamma_init=gamma_init, name='h7/bn')
net = Conv2d(net_h7, df_dim*2, (1, 1), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='res/c')
net = BatchNormLayer(net, act=lrelu, is_train=is_train,
gamma_init=gamma_init, name='res/bn')
net = Conv2d(net, df_dim*2, (3, 3), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='res/c2')
net = BatchNormLayer(net, act=lrelu, is_train=is_train,
gamma_init=gamma_init, name='res/bn2')
net = Conv2d(net, df_dim*8, (3, 3), (1, 1), act=None,
padding='SAME', W_init=w_init, b_init=b_init, name='res/c3')
net = BatchNormLayer(net, is_train=is_train,
gamma_init=gamma_init, name='res/bn3')
net_h8 = ElementwiseLayer(layer=[net_h7, net],
combine_fn=tf.add, name='res/add')
net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2)
net_ho = FlattenLayer(net_h8, name='ho/flatten')
net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity,
W_init = w_init, name='ho/dense')
logits = net_ho.outputs
net_ho.outputs = tf.nn.sigmoid(net_ho.outputs)
return net_ho, logits
def SRGAN_vgg_d2(t_image, is_train=True, reuse=False):
# w_init = tf.truncated_normal_initializer(stddev=0.01)
w_init = tf.contrib.layers.variance_scaling_initializer()
b_init = tf.constant_initializer(value=0.0)
# g_init = tf.random_normal_initializer(1., 0.02)
g_init = tf.ones_initializer()
feat_dim = 64
lrelu = lambda x : tl.act.lrelu(x, 0.2)
with tf.variable_scope("SRGAN_vgg_d", reuse=reuse) as vs:
tl.layers.set_name_reuse(reuse)
n = InputLayer(t_image, name='in')
n = Conv2d(n, feat_dim, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c')
n = Conv2d(n, feat_dim, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n64s2/b')
n = Conv2d(n, feat_dim*2, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n128s1/b')
n = Conv2d(n, feat_dim*2, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n128s2/b')
n = Conv2d(n, feat_dim*4, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s1/b')
n = Conv2d(n, feat_dim*4, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s2/b')
n = Conv2d(n, feat_dim*8, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n512s1/b')
n = Conv2d(n, feat_dim*8, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n512s2/b')
# network = Conv2d(network, n_filter=1024, filter_size=(3, 3),
# strides=(1,1), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init,name='conv6')
# network = Conv2d(network, n_filter=1024, filter_size=(3, 3),
# strides=(1,1), act=tf.nn.relu, padding='SAME', W_init=w_init, b_init=b_init,name='conv7')
# network = Conv2d(network, n_filter=1, filter_size=(3, 3),
# strides=(1,1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='out')
# logits = network.outputs
# network.outputs = tf.nn.sigmoid(network.outputs)
n = Conv2d(n, 1024, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='d1024')
n = Conv2d(n, 1, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='out')
# n = FlattenLayer(n, name='f')
# n = DenseLayer(n, n_units=1024,act=lrelu, name='d1024')
# n = DenseLayer(n, n_units=1,name='out')
logits = n.outputs
n.outputs = tf.nn.sigmoid(n.outputs)
return n, logits
def SRGAN_vgg_d(t_image, is_train=True, reuse=False):
# w_init = tf.truncated_normal_initializer(stddev=0.01)
w_init = tf.contrib.layers.variance_scaling_initializer()
b_init = tf.constant_initializer(value=0.0)
# g_init = tf.random_normal_initializer(1., 0.02)
g_init = tf.ones_initializer()
feat_dim = 64
lrelu = lambda x : tl.act.lrelu(x, 0.2)
with tf.variable_scope("SRGAN_vgg_d", reuse=reuse) as vs:
tl.layers.set_name_reuse(reuse)
n = InputLayer(t_image, name='in')
n = Conv2d(n, feat_dim, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c')
n = Conv2d(n, feat_dim, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n64s2/b')
n = Conv2d(n, feat_dim*2, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n128s1/b')
n = Conv2d(n, feat_dim*2, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n128s2/b')
n = Conv2d(n, feat_dim*4, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s1/b')
n = Conv2d(n, feat_dim*4, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s2/b')
n = Conv2d(n, feat_dim*8, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n512s1/b')
n = Conv2d(n, feat_dim*8, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c')
n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n512s2/b')
n = FlattenLayer(n, name='f')
n = DenseLayer(n, n_units=1024,act=lrelu, name='d1024')
n = DenseLayer(n, n_units=1,name='out')
logits = n.outputs
n.outputs = tf.nn.sigmoid(n.outputs)
return n, logits
def Vgg19_simple_api(rgb, reuse):
"""
Build the VGG 19 Model
Parameters
-----------
rgb : rgb image placeholder [batch, height, width, 3] values scaled [0, 1]输入的RGB图像是归一化的
"""
VGG_MEAN = [103.939, 116.779, 123.68]
with tf.variable_scope("VGG19", reuse=reuse) as vs:
start_time = time.time()
print("build model started")
rgb = tf.maximum(0.0,tf.minimum(rgb,1.0))
rgb_scaled = rgb * 255.0
# Convert RGB to BGR
if tf.__version__ <= '0.11':
red, green, blue = tf.split(3, 3, rgb_scaled)
else: # TF 1.0
# print(rgb_scaled)
red, green, blue = tf.split(rgb_scaled, 3, 3)
# assert red.get_shape().as_list()[1:] == [224, 224, 1]
# assert green.get_shape().as_list()[1:] == [224, 224, 1]
# assert blue.get_shape().as_list()[1:] == [224, 224, 1]
if tf.__version__ <= '0.11':
bgr = tf.concat(3, [
blue - VGG_MEAN[0],
green - VGG_MEAN[1],
red - VGG_MEAN[2],
])
else:
bgr = tf.concat([
blue - VGG_MEAN[0],
green - VGG_MEAN[1],
red - VGG_MEAN[2],
], axis=3)
# assert bgr.get_shape().as_list()[1:] == [224, 224, 3]
""" input layer """
net_in = InputLayer(bgr, name='input')
""" conv1 """
network = Conv2d(net_in, n_filter=64, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv1_1')
network = Conv2d(network, n_filter=64, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv1_2')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
padding='SAME', name='pool1')
""" conv2 """
network = Conv2d(network, n_filter=128, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv2_1')
network = Conv2d(network, n_filter=128, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv2_2')
conv2 = network
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
padding='SAME', name='pool2')
""" conv3 """
network = Conv2d(network, n_filter=256, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_1')
network = Conv2d(network, n_filter=256, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_2')
network = Conv2d(network, n_filter=256, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_3')
network = Conv2d(network, n_filter=256, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_4')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
padding='SAME', name='pool3')
conv3 = network
""" conv4 """
network = Conv2d(network, n_filter=512, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_1')
network = Conv2d(network, n_filter=512, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_2')
network = Conv2d(network, n_filter=512, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_3')
network = Conv2d(network, n_filter=512, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_4')
conv4 = network
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
padding='SAME', name='pool4') # (batch_size, 14, 14, 512)
""" conv5 """
network = Conv2d(network, n_filter=512, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_1')
network = Conv2d(network, n_filter=512, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_2')
network = Conv2d(network, n_filter=512, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_3')
network = Conv2d(network, n_filter=512, filter_size=(3, 3),
strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_4')
conv5 = network
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
padding='SAME', name='pool5') # (batch_size, 7, 7, 512)
""" fc 6~8 """
# network = FlattenLayer(network, name='flatten')
# network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6')
# network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7')
# network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8')
print("build model finished: %fs" % (time.time() - start_time))
return network, conv5
- main_gan_train.py和前面init的区别是LOSS函数不同
def train():
##创建文件夹来保存结果图像和训练模型
save_dir_gan = config.TRAIN.save_valid_results
tl.files.exists_or_mkdir(save_dir_gan)
checkpoint_dir = config.TRAIN.checkpoint
tl.files.exists_or_mkdir(checkpoint_dir)
###====================== 载入数据 ===========================###
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
###========================== 定义模型 ============================###
## train inference
t_image = tf.placeholder('float32', [batch_size, patch_size_l, patch_size_l, 3], name='t_image_input_to_generator')
t_target_image = tf.placeholder('float32', [batch_size, patch_size_h, patch_size_h, 3], name='t_target_image')
# Generator
net_g= SRGAN_g(t_image, is_train=False, reuse=False)
# Discriminator
net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
_, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)
# VGG network
net_vgg, vgg_target_emb= Vgg19_simple_api((t_target_image+1)/2, reuse=False)
_, vgg_predict_emb = Vgg19_simple_api((net_g.outputs+1)/2, reuse=True)
# Feature Discriminator
vgg_scale5 = 1/12.75
net_vgg_d, logits_vgg_real = SRGAN_vgg_d(vgg_scale5*vgg_target_emb.outputs, is_train=True, reuse=False)
_, logits_vgg_fake = SRGAN_vgg_d(vgg_scale5*vgg_predict_emb.outputs, is_train=True, reuse=True)
#########
## 验证集
net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)
###========================== 定义训练 OPS ==========================###
# Discriminator loss 判别器的损失函数
d_loss1 = 1e-3 *(tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1'))#交叉熵,真的和1
d_loss2 = 1e-3 *(tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2'))#假的和0进行比较
d_vgg_loss1 = 1e-3 *(tl.cost.sigmoid_cross_entropy(logits_vgg_real, tf.ones_like(logits_vgg_real), name='d_vgg_1'))
d_vgg_loss2 = 1e-3 *(tl.cost.sigmoid_cross_entropy(logits_vgg_fake, tf.zeros_like(logits_vgg_fake), name='d_vgg_2'))
d_loss = d_loss1 + d_loss2 + d_vgg_loss1 + d_vgg_loss2
##为了画图的
d_loss1_summary = tf.summary.scalar('d_loss1', d_loss1)
d_loss2_summary = tf.summary.scalar('d_loss2', d_loss2)
d_vgg_loss1_summary = tf.summary.scalar('d_vgg_loss1', d_vgg_loss1)
d_vgg_loss2_summary = tf.summary.scalar('d_vgg_loss2', d_vgg_loss2)
merged_d = tf.summary.merge([d_loss1_summary, d_loss2_summary,d_vgg_loss1_summary,d_vgg_loss2_summary])
##
#GAN loss
g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')#交叉熵同理
g_gan_vgg_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_vgg_fake, tf.ones_like(logits_vgg_fake), name='g')
#vgg loss
vgg_loss = tl.cost.mean_squared_error(vgg_scale5*vgg_predict_emb.outputs, vgg_scale5*vgg_target_emb.outputs, is_mean=True) # weight..? feature map rescale?
##为了画图
vgg_summary = tf.summary.scalar('vgg_loss', vgg_loss)
g_gan_summary = tf.summary.scalar('g_gan_loss', g_gan_loss)
g_gan_vgg_summary = tf.summary.scalar('g_gan_vgg_loss', g_gan_vgg_loss)
merged_g = tf.summary.merge([vgg_summary, g_gan_summary, g_gan_vgg_summary]) #为了画图的
##
#Total loss
g_loss = vgg_loss + g_gan_loss + g_gan_vgg_loss
g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)
d_vgg_vars = tl.layers.get_variables_with_name('SRGAN_vgg_d', True, True)
with tf.variable_scope('learning_rate'):
lr_v = tf.Variable(lr_init, trainable=False)
## SRGAN 生成模型和判别模型优化模型优化
g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1, epsilon=1e-10).minimize(g_loss, var_list=g_vars)
d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1, epsilon=1e-10).minimize(d_loss, var_list=[d_vars, d_vgg_vars])
###========================== RESTORE MODEL =============================###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
summary_writer = tf.summary.FileWriter(config.TRAIN.summary_adv,sess.graph)
summary_writer.add_graph(sess.graph)
tl.layers.initialize_global_variables(sess)
if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}.npz'.format('SRFeat'), network=net_g) is False:
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}_init.npz'.format('SRFeat'), network=net_g)
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/d_{}.npz'.format('SRFeat'), network=net_d)
###============================= 载入VGG ===============================###
vgg19_npy_path = "vgg19.npy"
if not os.path.isfile(vgg19_npy_path):
print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
exit()
npz = np.load(vgg19_npy_path, encoding='latin1').item()
#
params = []
count_layers =0
for val in sorted( npz.items() ):
if(count_layers<16):
W = np.asarray(val[1][0])
b = np.asarray(val[1][1])
print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape))
params.extend([W, b])
count_layers += 1
tl.files.assign_params(sess, params, net_vgg)
###============================= 训练 ===============================###
## use first `batch_size` of train set to have a quick test during training
sample_list = [0,10001,20001,30001,40001,50001,60001,70001,80001]
sample_imgs = read_all_imgs([train_hr_img_list[k] for k in sample_list], path=config.TRAIN.hr_img_path, n_threads=batch_size) # if no pre-load train set
sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=scale_imgs_fn)
sample_imgs_96 = read_all_imgs([train_lr_img_list[k] for k in sample_list], path=config.TRAIN.lr_img_path, n_threads=batch_size)
sample_imgs_96 = tl.prepro.threading_data(sample_imgs_96, fn=scale_imgs_fn)
tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan+'/_train_sample_96.png')
tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan+'/_train_sample_384.png')
###========================= 训练 GAN (SRGAN) =========================###
iters =0
lr_hr_list = list(zip(train_hr_img_list,train_lr_img_list))
random.shuffle(lr_hr_list)
train_hr_img_list, train_lr_img_list = zip(*lr_hr_list)
for epoch in range(0, n_epoch+1):
## 更新学习率逐渐递减
if epoch !=0 and (epoch % decay_every == 0):
new_lr_decay = lr_decay ** (epoch // decay_every)
sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
print(log)
elif epoch == 0:
sess.run(tf.assign(lr_v, lr_init))
log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
print(log)
epoch_time = time.time()
total_d_loss, total_g_loss, n_iter = 0, 0, 0
for idx in range(0, len(train_hr_img_list)-batch_size, batch_size):
iters = iters+1
b_imgs_list = train_hr_img_list[idx : idx + batch_size]
b_imgs_list_lr = train_lr_img_list[idx : idx + batch_size]
b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
b_imgs_96 = tl.prepro.threading_data(b_imgs_list_lr, fn=get_imgs_fn, path=config.TRAIN.lr_img_path)
b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=scale_imgs_fn)
b_imgs_96 = tl.prepro.threading_data(b_imgs_96, fn=scale_imgs_fn)
step_time = time.time()
###########
summary_d, errD, errD1, errD2, errD3, errD4, _ = sess.run([merged_d, d_loss, d_loss1, d_loss2, d_vgg_loss1, d_vgg_loss2, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
## update G
summary_g, errG, errV, errA, errA2, _ = sess.run([merged_g, g_loss, vgg_loss, g_gan_loss, g_gan_vgg_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
## summary画图
summary_writer.add_summary(summary_d, iters)
summary_writer.add_summary(summary_g, iters)
print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f (d1: %.8f, d2: %.8f, d3_vgg: %.8f, d4_vgg: %.8f), g_loss: %.8f (vgg: %.6f adv: %.6f adv2: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errD1, errD2, errD3, errD4, errG, errV, errA, errA2))
total_d_loss += errD
total_g_loss += errG
n_iter += 1
## 训练集快速评估
if (iters != 0) and (iters % 100 == 0):
out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})#; print('gen sub-image:', out.shape, out.min(), out.max())
print("[*] save images")
tl.vis.save_images(out, [ni, ni], save_dir_gan+'/train_%d.png' % iters)
## 保存模型
if (iters != 0) and (iters % 1000 == 0):
tl.files.save_npz(net_g.all_params, name=checkpoint_dir+'/g_{}.npz'.format('SRFeat'), sess=sess)
if (iters != 0) and (iters % 10000 == 0):
tl.files.save_npz(net_d.all_params, name=checkpoint_dir+'/d_{}.npz'.format('SRFeat'), sess=sess)
log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss/n_iter, total_g_loss/n_iter)
print(log)
summary_writer.close()