SRfeat的代码理解

https://github.com/HyeongseokSon1/SRFeat

  1. config.py
#在深度学习中往往利用 easydict 建立一个全局的变量

from easydict import EasyDict as edict
import json

config = edict()
config.TRAIN = edict()

## Adam
config.TRAIN.batch_size = 9
config.TRAIN.lr_init = 1e-4
config.TRAIN.lr_decay = 0.1
config.TRAIN.beta1 = 0.9

# various log location
config.TRAIN.checkpoint = 'checkpoint/'
config.TRAIN.save_valid_results = 'result_valid' 
config.TRAIN.summary_g = 'summary_init'
config.TRAIN.summary_adv = 'summary'

## train set location
config.TRAIN.hr_img_path = 'SRdataset/DIV_train_cropped/GT/' #need to change
config.TRAIN.lr_img_path = 'DIV_train_cropped/LR_bicubic/' #need to change

config.VALID = edict()
## validation set location
config.VALID.hr_img_path = 'SRdataset/valid/GT/' #need to change
config.VALID.lr_img_path = 'SRdataset/valid/LR_bicubic/' #need to change

config.TEST = edict()
config.TEST.checkpoint = 'models/'
config.TEST.input_path = 'SRdataset/test/LR_bicubic/' #need to change
config.TEST.save_path = 'result_test'


def log_config(filename, cfg):
    with open(filename, 'w') as f:
        f.write("================================================\n")
        f.write(json.dumps(cfg, indent=4))
        f.write("\n================================================\n")

  1. main_gan_eval.py
def  modcrop(imgs, modulo):#将图片剪切成可以整除modulo的尺寸

    tmpsz = imgs.shape
    sz = tmpsz[0:2]

    h = sz[0] - sz[0]%modulo #减去余数
    w = sz[1] - sz[1]%modulo #减去余数
    imgs = imgs[0:h+1, 0:w+1,:]
    return imgs

def read_all_imgs(img_list, path='', n_threads=32):
    """ 通过指定路径和每个图像文件的名称返回数组中的所有图像 """
    imgs = []
    for idx in range(0, len(img_list), n_threads):
        b_imgs_list = img_list[idx : idx + n_threads]
        b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=path)
        # print(b_imgs.shape)
        imgs.extend(b_imgs)
        print('read %d from %s' % (len(imgs), path))
    return imgs

def evaluate():
    ## 创建文件夹来保存结果图像
    save_dir = config.TEST.save_path
    tl.files.exists_or_mkdir(save_dir)
    checkpoint_dir = config.TEST.checkpoint
    im_path_lr = config.TEST.input_path

    ###====================== 预先加载数据 ===========================###
    valid_lr_img_list = sorted(tl.files.load_file_list(path=im_path_lr, regx='.*.*', printable=False))#排序
    
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    
    t_image = tf.placeholder('float32', [None, None, None, 3], name='input_image')
      
    net_g = SRGAN_g(t_image, is_train=False, reuse=False)

    ###========================== RESTORE G =============================###
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/SRFeat_full.npz', network=net_g)#已有的模型,读取
    
    
    for imid in range(len(valid_lr_img_list)):
        valid_lr_img = get_imgs_fn(valid_lr_img_list[imid],im_path_lr)

        print(valid_lr_img.shape)

        valid_lr_img = (valid_lr_img / 127.5) - 1   # rescale to [-1, 1]归一化
        ###======================= EVALUATION =============================###
        start_time = time.time()# 时间
        out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})# 执行模型;输出图像
            
        print("took: %4.4fs" % (time.time() - start_time))
    
        print("LR size: %s /  generated HR size: %s" % (valid_lr_img.shape, out.shape)) # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
        print("[*] save images")
        tl.vis.save_image(out[0], save_dir+'/' + valid_lr_img_list[imid])    #保存图片
  1. main_gan_init.py
ef train():
    ## create folders to save result images and trained model
    checkpoint_dir = config.TRAIN.checkpoint 
    tl.files.exists_or_mkdir(checkpoint_dir)
    ###====================== 预先加载数据 ===========================###
    train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
    valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.bmp', printable=False))
    valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.bmp', printable=False))


    ###========================== 定义模型 ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, patch_size_l, patch_size_l, 3], name='t_image_input')
    t_target_image = tf.placeholder('float32', [batch_size, patch_size_h, patch_size_h, 3], name='t_target_image')

    net_g= SRGAN_g(t_image, is_train=True, reuse=False)

    ## test inference
    t_sample_image = tf.placeholder('float32', [5, 56, 56, 3], name='t_sample_image')
    net_g_test = SRGAN_g(t_sample_image, is_train=False, reuse=True)

    # ###========================== 定义训练 OPS ==========================###
    mse_loss = tl.cost.mean_squared_error(net_g.outputs , t_target_image, is_mean=True)#用MSE进行初始训练
    tf.summary.scalar('mse_loss', mse_loss)
    merged = tf.summary.merge_all()

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
    ## Pretrain
    g_optim_init_= tf.train.AdamOptimizer(lr_v, beta1=beta1)
    g_optim_init = g_optim_init_.minimize(mse_loss, var_list=g_vars)
    ###========================== 恢复模型 =============================###
    sess = tf.Session(config=tf.ConfigProto( log_device_placement=False))
    train_writer = tf.summary.FileWriter(config.TRAIN.summary_g,sess.graph)
    train_writer.add_graph(sess.graph)
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}.npz'.format('SRFeat'), network=net_g) is False:
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}_init.npz'.format('SRFeat'), network=net_g)

    ###============================= 训练 ===============================###
    ## 使用训练集的第一个' batch_size '在训练期间进行快速测试
    sample_imgs = read_all_imgs(valid_hr_img_list[0:5], path=config.VALID.hr_img_path, n_threads=5) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=scale_imgs_fn)
    
    sample_imgs_96 = read_all_imgs(valid_hr_img_list[0:5], path=config.VALID.lr_img_path, n_threads=5)
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_96, fn=scale_imgs_fn)
   
    val_mset = tl.cost.mean_squared_error(net_g_test.outputs , sample_imgs_384, is_mean=True)  
    val_summary = tf.summary.scalar('val_mse', val_mset)
    
    ###========================= 初始化 G ====================###
    ## 固定的学习速率
    sess.run(tf.assign(lr_v, lr_init))
    
    lr_hr_list = list(zip(train_hr_img_list,train_lr_img_list))
    random.shuffle(lr_hr_list)
    train_hr_img_list, train_lr_img_list = zip(*lr_hr_list)
    
    i =0
    print(" ** fixed learning rate: %f (for init G)" % lr_init)
    for epoch in range(0, n_epoch_init+1):
#        epoch_time = time.time()
        total_mse_loss, n_iter = 0, 0
 
        if epoch !=0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay ** (epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
            print(log)
      
        for idx in range(0, len(train_hr_img_list) -batch_size , batch_size):        
            step_time = time.time()
            b_imgs_list = train_hr_img_list[idx : idx + batch_size]
            b_imgs_list_lr = train_lr_img_list[idx : idx + batch_size]

            b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_list_lr, fn=get_imgs_fn, path=config.TRAIN.lr_img_path)

            b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=scale_imgs_fn)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_96, fn=scale_imgs_fn)          
            ## update G
            summary,errM, out, _ = sess.run([merged,mse_loss,net_g.outputs, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})
            train_writer.add_summary(summary,i)
            print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
                     
            total_mse_loss += errM
            n_iter += 1
    
            ## quick evaluation on train set
            if (i != 0) and (i % 20 == 0):
                out,val_,val_summ = sess.run([net_g_test.outputs,val_mset,val_summary], {t_sample_image: sample_imgs_96})#; print('gen sub-image:', out.shape, out.min(), out.max())

                train_writer.add_summary(val_summ,i)
                print("validate")
    
            # save model
            if (i != 0) and (i % 100 == 0):
                tl.files.save_npz(net_g.all_params, name=checkpoint_dir+'/g_{}_init.npz'.format('SRFeat'), sess=sess)
            i= i+1


    train_writer.close()


  1. model.py
    在这里插入图片描述
        for i in range(16):
            nn = Conv2d(n, 128, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c1/%s' % i)
            
            nn = BatchNormLayer(nn, act=lrelu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
            
            nn = Conv2d(nn, 128, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c2/%s' % i)
            
            nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
            
            nn = ElementwiseLayer([n, nn], tf.add, 'b_residual_add/%s' % i)
            n = nn       
            
            t = Conv2d(nn, 128, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c3/%s' % i)
            temp.append(t)
            

        n = ElementwiseLayer([n, temp[0],temp[1],temp[2],temp[3],
                              temp[4],temp[5],temp[6],temp[7], 
                              temp[8],temp[9],temp[10],temp[11], 
                              temp[12],temp[13],temp[14]], tf.add, 'add3')
        # B residual blacks end

在这里插入图片描述

def SRGAN_g(t_image, is_train=False, reuse=False):
    """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
    feature maps (n) and stride (s) feature maps (n) and stride (s)
    """
    #初始化
    w_init = tf.random_normal_initializer(stddev=0.01)
    b_init = None
    g_init = tf.ones_initializer()
    lrelu = lambda x : tl.act.lrelu(x, 0.2)
    
    with tf.variable_scope("SRGAN_g", reuse=reuse) as vss:
        tl.layers.set_name_reuse(reuse)
        n = InputLayer(t_image, name='in')
        #蓝色的
        n = Conv2d(n, 128, (9, 9), (1, 1), act=None, padding='SAME', W_init=w_init, name='n64s1/c')
        temp = []

        # B residual blocks
        for i in range(16):
            nn = Conv2d(n, 128, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c1/%s' % i)
            
            nn = BatchNormLayer(nn, act=lrelu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
            
            nn = Conv2d(nn, 128, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c2/%s' % i)
            
            nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
            
            nn = ElementwiseLayer([n, nn], tf.add, 'b_residual_add/%s' % i)
            n = nn       
            #每一块后面都加了一个卷积层
            t = Conv2d(nn, 128, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c3/%s' % i)
            temp.append(t)
            

        n = ElementwiseLayer([n, temp[0],temp[1],temp[2],temp[3],
                              temp[4],temp[5],temp[6],temp[7], 
                              temp[8],temp[9],temp[10],temp[11], 
                              temp[12],temp[13],temp[14]], tf.add, 'add3')
        # B residual blacks end
        #最后两个小框
        #第一次上采样
        n = Conv2d(n, 512, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')
        n = SubpixelConv2d(n, scale=2, act=lrelu, n_out_channel=None, name='pixelshufflerx2/1')
        
        #第二次上采样
        n = Conv2d(n, 512, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init ,name='n256s1/2')
        n = SubpixelConv2d(n, scale=2, act=lrelu, n_out_channel=None,  name='pixelshufflerx2/2')
        
        #最后的蓝色的
        n = Conv2d(n, 3, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='out')
        return n

在这里插入图片描述

def SRGAN_d(t_image, is_train=True, reuse=False):
    """ Discriminator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
    feature maps (n) and stride (s) feature maps (n) and stride (s)
    """
#    w_init = tf.truncated_normal_initializer(stddev=0.01)
    w_init = tf.contrib.layers.variance_scaling_initializer()

    b_init = tf.constant_initializer(value=0.0)
#    g_init = tf.random_normal_initializer(1., 0.02)
    g_init = tf.ones_initializer()

    lrelu = lambda x : tl.act.lrelu(x, 0.2)
    with tf.variable_scope("SRGAN_d", reuse=reuse) as vs:
        tl.layers.set_name_reuse(reuse)
        n = InputLayer(t_image, name='in')
        #开始第一块
        n = Conv2d(n, 64, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c')

        n = Conv2d(n, 64, (3, 3), (2, 2),  padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c')
        n = BatchNormLayer(n,  act=lrelu, is_train=is_train,  gamma_init=g_init, name='n64s2/b')

        n = Conv2d(n, 128, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n128s1/b')

        n = Conv2d(n, 128, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n128s2/b')

        n = Conv2d(n, 256, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n256s1/b')

        n = Conv2d(n, 256, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s2/b')

        n = Conv2d(n, 512, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n512s1/b')

        n = Conv2d(n, 512, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n512s2/b')
        #倒数Dense模块
        n = FlattenLayer(n, name='f')
        n = DenseLayer(n, n_units=1024,act=lrelu, name='d1024')
        n = DenseLayer(n, n_units=1,name='out')

        logits = n.outputs
        n.outputs = tf.nn.sigmoid(n.outputs)

        return n, logits

剩下的和VGG有关,论文中没有写具体的网络

def SRGAN_d2(input_images, is_train=True, reuse=False):
    w_init = tf.contrib.layers.variance_scaling_initializer()
    
    b_init = None 
    gamma_init = tf.ones_initializer()
    
    df_dim = 64
    lrelu = lambda x: tl.act.lrelu(x, 0.2)
    with tf.variable_scope("SRGAN_d", reuse=reuse):
        tl.layers.set_name_reuse(reuse)
        net_in = InputLayer(input_images, name='input/images')
        net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu,
                padding='SAME', W_init=w_init, name='h0/c')

        net_h1 = Conv2d(net_h0, df_dim*2, (4, 4), (2, 2), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='h1/c')
        net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train,
                gamma_init=gamma_init, name='h1/bn')
        net_h2 = Conv2d(net_h1, df_dim*4, (4, 4), (2, 2), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='h2/c')
        net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train,
                gamma_init=gamma_init, name='h2/bn')
        net_h3 = Conv2d(net_h2, df_dim*8, (4, 4), (2, 2), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='h3/c')
        net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train,
                gamma_init=gamma_init, name='h3/bn')
        net_h4 = Conv2d(net_h3, df_dim*16, (4, 4), (2, 2), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='h4/c')
        net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train,
                gamma_init=gamma_init, name='h4/bn')
        net_h5 = Conv2d(net_h4, df_dim*32, (4, 4), (2, 2), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='h5/c')
        net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train,
                gamma_init=gamma_init, name='h5/bn')
        net_h6 = Conv2d(net_h5, df_dim*16, (1, 1), (1, 1), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='h6/c')
        net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train,
                gamma_init=gamma_init, name='h6/bn')
        net_h7 = Conv2d(net_h6, df_dim*8, (1, 1), (1, 1), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='h7/c')
        net_h7 = BatchNormLayer(net_h7, is_train=is_train,
                gamma_init=gamma_init, name='h7/bn')

        net = Conv2d(net_h7, df_dim*2, (1, 1), (1, 1), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='res/c')
        net = BatchNormLayer(net, act=lrelu, is_train=is_train,
                gamma_init=gamma_init, name='res/bn')
        net = Conv2d(net, df_dim*2, (3, 3), (1, 1), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='res/c2')
        net = BatchNormLayer(net, act=lrelu, is_train=is_train,
                gamma_init=gamma_init, name='res/bn2')
        net = Conv2d(net, df_dim*8, (3, 3), (1, 1), act=None,
                padding='SAME', W_init=w_init, b_init=b_init, name='res/c3')
        net = BatchNormLayer(net, is_train=is_train,
                gamma_init=gamma_init, name='res/bn3')
        net_h8 = ElementwiseLayer(layer=[net_h7, net],
                combine_fn=tf.add, name='res/add')
        net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2)

        net_ho = FlattenLayer(net_h8, name='ho/flatten')
        net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity,
                W_init = w_init, name='ho/dense')
        logits = net_ho.outputs
        net_ho.outputs = tf.nn.sigmoid(net_ho.outputs)

    return net_ho, logits


def SRGAN_vgg_d2(t_image, is_train=True, reuse=False):
#    w_init = tf.truncated_normal_initializer(stddev=0.01)
    w_init = tf.contrib.layers.variance_scaling_initializer()

    b_init = tf.constant_initializer(value=0.0)
#    g_init = tf.random_normal_initializer(1., 0.02)
    g_init = tf.ones_initializer()

    feat_dim = 64
    lrelu = lambda x : tl.act.lrelu(x, 0.2)
    with tf.variable_scope("SRGAN_vgg_d", reuse=reuse) as vs:
        tl.layers.set_name_reuse(reuse)
        n = InputLayer(t_image, name='in')
        n = Conv2d(n, feat_dim, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c')

        n = Conv2d(n, feat_dim, (3, 3), (2, 2),  padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c')
        n = BatchNormLayer(n,  act=lrelu, is_train=is_train,  gamma_init=g_init, name='n64s2/b')

        n = Conv2d(n, feat_dim*2, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n128s1/b')

        n = Conv2d(n, feat_dim*2, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n128s2/b')

        n = Conv2d(n, feat_dim*4, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n256s1/b')

        n = Conv2d(n, feat_dim*4, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s2/b')

        n = Conv2d(n, feat_dim*8, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n512s1/b')

        n = Conv2d(n, feat_dim*8, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n512s2/b')

#         network = Conv2d(network, n_filter=1024, filter_size=(3, 3),
#                          strides=(1,1), act=tf.nn.relu, padding='SAME',  W_init=w_init, b_init=b_init,name='conv6')
#         network = Conv2d(network, n_filter=1024, filter_size=(3, 3),
#                          strides=(1,1), act=tf.nn.relu, padding='SAME',  W_init=w_init, b_init=b_init,name='conv7')             
#         network = Conv2d(network, n_filter=1, filter_size=(3, 3),
#                          strides=(1,1), act=None, padding='SAME',  W_init=w_init, b_init=b_init, name='out')
#         logits = network.outputs
#         network.outputs = tf.nn.sigmoid(network.outputs)
        
        n = Conv2d(n, 1024, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='d1024')
        n = Conv2d(n, 1, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='out')
#        n = FlattenLayer(n, name='f')
#        n = DenseLayer(n, n_units=1024,act=lrelu, name='d1024')
#        n = DenseLayer(n, n_units=1,name='out')

        logits = n.outputs
        n.outputs = tf.nn.sigmoid(n.outputs)

        return n, logits
def SRGAN_vgg_d(t_image, is_train=True, reuse=False):
#    w_init = tf.truncated_normal_initializer(stddev=0.01)
    w_init = tf.contrib.layers.variance_scaling_initializer()

    b_init = tf.constant_initializer(value=0.0)
#    g_init = tf.random_normal_initializer(1., 0.02)
    g_init = tf.ones_initializer()

    feat_dim = 64
    lrelu = lambda x : tl.act.lrelu(x, 0.2)
    with tf.variable_scope("SRGAN_vgg_d", reuse=reuse) as vs:
        tl.layers.set_name_reuse(reuse)
        n = InputLayer(t_image, name='in')
        n = Conv2d(n, feat_dim, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init,b_init=b_init, name='n64s1/c')

        n = Conv2d(n, feat_dim, (3, 3), (2, 2),  padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c')
        n = BatchNormLayer(n,  act=lrelu, is_train=is_train,  gamma_init=g_init, name='n64s2/b')

        n = Conv2d(n, feat_dim*2, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n128s1/b')

        n = Conv2d(n, feat_dim*2, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n128s2/b')

        n = Conv2d(n, feat_dim*4, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n256s1/b')

        n = Conv2d(n, feat_dim*4, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train, gamma_init=g_init, name='n256s2/b')

        n = Conv2d(n, feat_dim*8, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n512s1/b')

        n = Conv2d(n, feat_dim*8, (3, 3), (2, 2), padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c')
        n = BatchNormLayer(n, act=lrelu, is_train=is_train,  gamma_init=g_init, name='n512s2/b')

        n = FlattenLayer(n, name='f')
        n = DenseLayer(n, n_units=1024,act=lrelu, name='d1024')
        n = DenseLayer(n, n_units=1,name='out')

        logits = n.outputs
        n.outputs = tf.nn.sigmoid(n.outputs)

        return n, logits


def Vgg19_simple_api(rgb, reuse):
    """
    Build the VGG 19 Model

    Parameters
    -----------
    rgb : rgb image placeholder [batch, height, width, 3] values scaled [0, 1]输入的RGB图像是归一化的
    """
    VGG_MEAN = [103.939, 116.779, 123.68]
    with tf.variable_scope("VGG19", reuse=reuse) as vs:
        start_time = time.time()
        print("build model started")
        rgb = tf.maximum(0.0,tf.minimum(rgb,1.0))        
        rgb_scaled = rgb * 255.0
        # Convert RGB to BGR
        if tf.__version__ <= '0.11':
            red, green, blue = tf.split(3, 3, rgb_scaled)
        else: # TF 1.0
            # print(rgb_scaled)
            
            red, green, blue = tf.split(rgb_scaled, 3, 3)
#        assert red.get_shape().as_list()[1:] == [224, 224, 1]
#        assert green.get_shape().as_list()[1:] == [224, 224, 1]
#        assert blue.get_shape().as_list()[1:] == [224, 224, 1]
        if tf.__version__ <= '0.11':
            bgr = tf.concat(3, [
                blue - VGG_MEAN[0],
                green - VGG_MEAN[1],
                red - VGG_MEAN[2],
            ])
        else:
            bgr = tf.concat([
                blue - VGG_MEAN[0],
                green - VGG_MEAN[1],
                red - VGG_MEAN[2],
            ], axis=3)
#        assert bgr.get_shape().as_list()[1:] == [224, 224, 3]

        """ input layer """
        net_in = InputLayer(bgr, name='input')
        """ conv1 """
        network = Conv2d(net_in, n_filter=64, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv1_1')
        network = Conv2d(network, n_filter=64, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv1_2')
        network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                    padding='SAME', name='pool1')
        """ conv2 """
        network = Conv2d(network, n_filter=128, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv2_1')
        network = Conv2d(network, n_filter=128, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv2_2')
        conv2 = network
        network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                    padding='SAME', name='pool2')
        """ conv3 """
        network = Conv2d(network, n_filter=256, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_1')
        network = Conv2d(network, n_filter=256, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_2')
        network = Conv2d(network, n_filter=256, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_3')
        network = Conv2d(network, n_filter=256, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_4')
        network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                    padding='SAME', name='pool3')
        conv3 = network
        """ conv4 """
        network = Conv2d(network, n_filter=512, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_1')
        network = Conv2d(network, n_filter=512, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_2')
        network = Conv2d(network, n_filter=512, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_3')
        network = Conv2d(network, n_filter=512, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_4')
        conv4 = network
        
        network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                    padding='SAME', name='pool4')                               # (batch_size, 14, 14, 512)
        """ conv5 """
        network = Conv2d(network, n_filter=512, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_1')
        network = Conv2d(network, n_filter=512, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_2')
        network = Conv2d(network, n_filter=512, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_3')
        network = Conv2d(network, n_filter=512, filter_size=(3, 3),
                    strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_4')
        conv5 = network
        network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                    padding='SAME', name='pool5')                               # (batch_size, 7, 7, 512)
        """ fc 6~8 """
#        network = FlattenLayer(network, name='flatten')
#        network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6')
#        network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7')
#        network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8')
        print("build model finished: %fs" % (time.time() - start_time))
        return network, conv5
  1. main_gan_train.py和前面init的区别是LOSS函数不同

def train():
    ##创建文件夹来保存结果图像和训练模型
    save_dir_gan = config.TRAIN.save_valid_results
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = config.TRAIN.checkpoint 
    tl.files.exists_or_mkdir(checkpoint_dir)
    ###====================== 载入数据 ===========================###
    train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
    train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))

    ###========================== 定义模型 ============================###
    ## train inference
    t_image = tf.placeholder('float32', [batch_size, patch_size_l, patch_size_l, 3], name='t_image_input_to_generator')
    t_target_image = tf.placeholder('float32', [batch_size, patch_size_h, patch_size_h, 3], name='t_target_image')

    # Generator
    net_g= SRGAN_g(t_image, is_train=False, reuse=False)
    # Discriminator
    net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
    _,     logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)
    # VGG network
    net_vgg, vgg_target_emb= Vgg19_simple_api((t_target_image+1)/2, reuse=False)
    _, vgg_predict_emb = Vgg19_simple_api((net_g.outputs+1)/2, reuse=True)
    # Feature Discriminator
    vgg_scale5 = 1/12.75
    net_vgg_d, logits_vgg_real = SRGAN_vgg_d(vgg_scale5*vgg_target_emb.outputs, is_train=True, reuse=False)
    _,     logits_vgg_fake = SRGAN_vgg_d(vgg_scale5*vgg_predict_emb.outputs, is_train=True, reuse=True)

######### 
    ## 验证集
    net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

    ###========================== 定义训练 OPS ==========================###
    # Discriminator loss 判别器的损失函数
    d_loss1 = 1e-3 *(tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1'))#交叉熵,真的和1
    d_loss2 = 1e-3 *(tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2'))#假的和0进行比较

    d_vgg_loss1 = 1e-3 *(tl.cost.sigmoid_cross_entropy(logits_vgg_real, tf.ones_like(logits_vgg_real), name='d_vgg_1'))
    d_vgg_loss2 = 1e-3 *(tl.cost.sigmoid_cross_entropy(logits_vgg_fake, tf.zeros_like(logits_vgg_fake), name='d_vgg_2'))

    d_loss = d_loss1 + d_loss2 + d_vgg_loss1 + d_vgg_loss2

    ##为了画图的
    d_loss1_summary = tf.summary.scalar('d_loss1', d_loss1)
    d_loss2_summary = tf.summary.scalar('d_loss2', d_loss2)
    d_vgg_loss1_summary = tf.summary.scalar('d_vgg_loss1', d_vgg_loss1)
    d_vgg_loss2_summary = tf.summary.scalar('d_vgg_loss2', d_vgg_loss2)   

    merged_d = tf.summary.merge([d_loss1_summary, d_loss2_summary,d_vgg_loss1_summary,d_vgg_loss2_summary])    
    
    ##
    #GAN loss
    g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')#交叉熵同理
    g_gan_vgg_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_vgg_fake, tf.ones_like(logits_vgg_fake), name='g')    

    #vgg loss
    vgg_loss = tl.cost.mean_squared_error(vgg_scale5*vgg_predict_emb.outputs, vgg_scale5*vgg_target_emb.outputs, is_mean=True) # weight..? feature map rescale?
   
    
    ##为了画图
    vgg_summary = tf.summary.scalar('vgg_loss', vgg_loss)
    g_gan_summary = tf.summary.scalar('g_gan_loss', g_gan_loss)
    g_gan_vgg_summary = tf.summary.scalar('g_gan_vgg_loss', g_gan_vgg_loss)
    
    merged_g = tf.summary.merge([vgg_summary, g_gan_summary, g_gan_vgg_summary]) #为了画图的  

    ##
    #Total loss 
    g_loss = vgg_loss + g_gan_loss + g_gan_vgg_loss 

    g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
    d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)
    d_vgg_vars = tl.layers.get_variables_with_name('SRGAN_vgg_d', True, True)
    
    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)

    ## SRGAN 生成模型和判别模型优化模型优化
    g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1, epsilon=1e-10).minimize(g_loss, var_list=g_vars)
    d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1, epsilon=1e-10).minimize(d_loss, var_list=[d_vars, d_vgg_vars])    
    
    ###========================== RESTORE MODEL =============================###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    summary_writer = tf.summary.FileWriter(config.TRAIN.summary_adv,sess.graph)
    summary_writer.add_graph(sess.graph)
    tl.layers.initialize_global_variables(sess)
    if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}.npz'.format('SRFeat'), network=net_g) is False:
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}_init.npz'.format('SRFeat'), network=net_g)
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/d_{}.npz'.format('SRFeat'), network=net_d)
    
    ###============================= 载入VGG ===============================###
    vgg19_npy_path = "vgg19.npy"
    if not os.path.isfile(vgg19_npy_path):
        print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
        exit()
    npz = np.load(vgg19_npy_path, encoding='latin1').item()
#
    params = []
    count_layers =0
    for val in sorted( npz.items() ):
        if(count_layers<16):
            W = np.asarray(val[1][0])
            b = np.asarray(val[1][1])
            print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
            params.extend([W, b])
        count_layers += 1
        
    tl.files.assign_params(sess, params, net_vgg)

    ###============================= 训练 ===============================###
    ## use first `batch_size` of train set to have a quick test during training
    sample_list = [0,10001,20001,30001,40001,50001,60001,70001,80001]
    sample_imgs = read_all_imgs([train_hr_img_list[k] for k in sample_list], path=config.TRAIN.hr_img_path, n_threads=batch_size) # if no pre-load train set
    sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=scale_imgs_fn)
    
    sample_imgs_96 = read_all_imgs([train_lr_img_list[k] for k in sample_list], path=config.TRAIN.lr_img_path, n_threads=batch_size)
    sample_imgs_96 = tl.prepro.threading_data(sample_imgs_96, fn=scale_imgs_fn)

    tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan+'/_train_sample_96.png')
    tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan+'/_train_sample_384.png')

    
    ###========================= 训练 GAN (SRGAN) =========================###
    iters =0
    lr_hr_list = list(zip(train_hr_img_list,train_lr_img_list))
    random.shuffle(lr_hr_list)
    train_hr_img_list, train_lr_img_list = zip(*lr_hr_list)

    for epoch in range(0, n_epoch+1):
        ## 更新学习率逐渐递减
        if epoch !=0 and (epoch % decay_every == 0):
            new_lr_decay = lr_decay ** (epoch // decay_every)
            sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
            log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
            print(log)
        elif epoch == 0:
            sess.run(tf.assign(lr_v, lr_init))
            log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
            print(log)
    
        epoch_time = time.time()
        total_d_loss, total_g_loss, n_iter = 0, 0, 0

       
        for idx in range(0, len(train_hr_img_list)-batch_size, batch_size):
            iters = iters+1
            
            b_imgs_list = train_hr_img_list[idx : idx + batch_size]
            b_imgs_list_lr = train_lr_img_list[idx : idx + batch_size]
            b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_list_lr, fn=get_imgs_fn, path=config.TRAIN.lr_img_path)            
            b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=scale_imgs_fn)
            b_imgs_96 = tl.prepro.threading_data(b_imgs_96, fn=scale_imgs_fn)                      

            step_time = time.time()
            ###########           
           summary_d, errD, errD1, errD2, errD3, errD4, _ = sess.run([merged_d, d_loss, d_loss1, d_loss2, d_vgg_loss1, d_vgg_loss2, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})

            ## update G
            summary_g, errG, errV, errA, errA2, _ = sess.run([merged_g, g_loss, vgg_loss, g_gan_loss, g_gan_vgg_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})            
 
            ## summary画图
            summary_writer.add_summary(summary_d, iters)
            summary_writer.add_summary(summary_g, iters)
            print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f (d1: %.8f, d2: %.8f, d3_vgg: %.8f, d4_vgg: %.8f), g_loss: %.8f (vgg: %.6f adv: %.6f adv2: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errD1, errD2, errD3, errD4,  errG, errV, errA, errA2))
                     
            total_d_loss += errD
            total_g_loss += errG
            n_iter += 1

            ## 训练集快速评估
            if (iters != 0) and (iters % 100 == 0):
                out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})#; print('gen sub-image:', out.shape, out.min(), out.max())
                print("[*] save images")
                tl.vis.save_images(out, [ni, ni], save_dir_gan+'/train_%d.png' % iters)
    
            ## 保存模型
            if (iters != 0) and (iters % 1000 == 0):
                tl.files.save_npz(net_g.all_params, name=checkpoint_dir+'/g_{}.npz'.format('SRFeat'), sess=sess)
            if (iters != 0) and (iters % 10000 == 0):                
                tl.files.save_npz(net_d.all_params, name=checkpoint_dir+'/d_{}.npz'.format('SRFeat'), sess=sess)

        log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss/n_iter, total_g_loss/n_iter)
        print(log)
            
    summary_writer.close()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值