英伟达公司和加州大学伯克利分校于2018年发表的“基于有条件GAN的高分辨率图像合成及语义操控”项目,是本项目“让老照片重现光彩”的技术基础,算是一个前置开源项目。
“基于有条件GAN的高分辨率图像合成及语义操控”项目的技术核心是Pix2PixHD模型,我们在这里分享一下相关的源代码+中文注释,基于此可以加深对“让老照片重现光彩”项目的理解(尤其是,在老照片项目的模型与训练源代码尚未开源的情况下)。
“基于有条件GAN的高分辨率图像合成及语义操控”项目在GitHub上的链接是:https://github.com/NVIDIA/pix2pixHD
Pix2PixHD模型使用PyTorch构建,代码清晰、整齐,相关的源代码主要是3个文件,分别是:./models/models.py、 ./models/pix2pixHD_model.py 和 ./models/networks.py
说明如下:
(1)./models/models.py
调用 Pix2PixHDModel() 创建模型。
import torch
# 创建模型,并返回模型
def create_model(opt):
if opt.model == 'pix2pixHD': # 选择pix2pixHD model
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
if opt.isTrain: # 若是训练,则为True
model = Pix2PixHDModel()
else: # 否则,若仅仅是前向传播用来演示,则为False
model = InferenceModel()
else: # 选择 UIModel model
from .ui_model import UIModel
model = UIModel()
model.initialize(opt) # 模型初始化参数
if opt.verbose: # 默认为false,表示之前并无模型保存
print("model [%s] was created" % (model.name())) # 打印label2city模型被创建
if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) # 多GPU训练
return model
(2)./models/pix2pixHD_model.py
构建模型的核心内容:
定义有条件GAN(Pix2PixHDModel)的生成器、鉴别器、编码器(用于生成实例的低维特征);
定义损失函数(包括:GANloss,VGGloss、特征匹配损失函数);
定义生成器和鉴别器的优化器(optimizer);
定义各模块的输入;
定义forward函数。
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
# loss滤波器:其中g_gan、d_real、d_fake三个loss值是肯定返回的
# 这里的g_gan_feat即论文中的“特征匹配损失函数”(论文中的等式(4))
# g_vgg为论文中的VGG感知损失函数,稍微改善了输出结果
# g_gan_feat、g_vgg两个loss值根据train_options的opt.no_ganFeat_loss, not opt.no_vgg_loss而定(默认是需要返回的)
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat
self.gen_features = self.use_features and not self.opt.load_features
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
##### define networks
# Generator network
# 生成器网络
netG_input_nc = input_nc
if not opt.no_instance:
netG_input_nc += 1 # 添加instance通道(区分不同实例)
if self.use_features:
netG_input_nc += opt.feat_num # 添加feature_map通道(使用encoder)
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
# Discriminator network
# 鉴别器网络
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = input_nc + opt.output_nc # real_images + fake_images
if not opt.no_instance:
netD_input_nc += 1 # 添加instance通道(区分不同实例)
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
### Encoder network
# 编码器网络(是define_G()中的一个子函数)
if self.gen_features:
self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
# 加载网络(模型)
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
if self.isTrain:
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
if self.gen_features:
self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size) # 初始化fake_pool:num_imgs = 0,images = []
self.old_lr = opt.lr
# define loss functions
# 定义损失函数,在.forward()中使用
# 默认使用ganfeat_loss和vgg_loss
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionFeat = torch.nn.L1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss(self.gpu_ids)
# Names so we can breakout loss
# 给损失函数命名
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')
# initialize optimizers
# 初始化优化器
# optimizer G(含:encoder)
if opt.niter_fix_global > 0:
import sys
if sys.version_info >= (3,0):
finetune_list = set()
else:
from sets import Set
finetune_list = Set()
params_dict = dict(self.netG.named_parameters())
params = []
for key, value in params_dict.items():
if key.startswith('model' + str(opt.n_local_enhancers)):
params += [value]
finetune_list.add(key.split('.')[0])
print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
print('The layers that are finetuned are ', sorted(finetune_list))
else:
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# feat=feature(特征),inst=instance(实例)
# label_map(标签图)每个像素值代表像素的对象类,inst_map(实例图)每个像素包含每个单独对象的唯一对象ID
# 获取实例图的边界(边缘),将edge_map与label_map的one-hot向量拼接在一起,封装为Variable,赋值给input_label
# real_image和feat_map,封装为Variable,赋值给real_image和feat_map;label_map赋值给inst_map
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer