【延伸阅读】让老照片重现光彩(五):Pix2PixHD模型源代码+中文注释

本文基于英伟达和伯克利的Pix2PixHD项目,探讨高分辨率图像合成与语义操控。核心是Pix2PixHD模型,详细解读了源代码的三个关键文件:models.py、pix2pixHD_model.py和networks.py,涵盖了模型构造、损失函数和神经网络组件。通过理解这些内容,有助于深入理解老照片重现光彩的实现原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

英伟达公司和加州大学伯克利分校于2018年发表的“基于有条件GAN的高分辨率图像合成及语义操控”项目,是本项目“让老照片重现光彩”的技术基础,算是一个前置开源项目。

“基于有条件GAN的高分辨率图像合成及语义操控”项目的技术核心是Pix2PixHD模型,我们在这里分享一下相关的源代码+中文注释,基于此可以加深对“让老照片重现光彩”项目的理解(尤其是,在老照片项目的模型与训练源代码尚未开源的情况下)。

“基于有条件GAN的高分辨率图像合成及语义操控”项目在GitHub上的链接是:https://github.com/NVIDIA/pix2pixHD

Pix2PixHD模型使用PyTorch构建,代码清晰、整齐,相关的源代码主要是3个文件,分别是:./models/models.py、 ./models/pix2pixHD_model.py 和  ./models/networks.py

说明如下:

(1)./models/models.py

调用 Pix2PixHDModel() 创建模型。

import torch

# 创建模型,并返回模型
def create_model(opt):
    if opt.model == 'pix2pixHD':  # 选择pix2pixHD model
        from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
        if opt.isTrain:  # 若是训练,则为True
            model = Pix2PixHDModel()
        else:  # 否则,若仅仅是前向传播用来演示,则为False
            model = InferenceModel()
    else:  # 选择 UIModel model
    	from .ui_model import UIModel
    	model = UIModel()
    model.initialize(opt)  # 模型初始化参数
    if opt.verbose:  # 默认为false,表示之前并无模型保存
        print("model [%s] was created" % (model.name()))  # 打印label2city模型被创建

    if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)  # 多GPU训练

    return model

(2)./models/pix2pixHD_model.py 

构建模型的核心内容:

定义有条件GAN(Pix2PixHDModel)的生成器、鉴别器、编码器(用于生成实例的低维特征);

定义损失函数(包括:GANloss,VGGloss、特征匹配损失函数);

定义生成器和鉴别器的优化器(optimizer);

定义各模块的输入;

定义forward函数。

import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks

class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'

    # loss滤波器:其中g_gan、d_real、d_fake三个loss值是肯定返回的
    # 这里的g_gan_feat即论文中的“特征匹配损失函数”(论文中的等式(4))
    # g_vgg为论文中的VGG感知损失函数,稍微改善了输出结果
    # g_gan_feat、g_vgg两个loss值根据train_options的opt.no_ganFeat_loss, not opt.no_vgg_loss而定(默认是需要返回的)
    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
        return loss_filter
    
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        ##### define networks        
        # Generator network
        # 生成器网络
        netG_input_nc = input_nc        
        if not opt.no_instance:
            netG_input_nc += 1  # 添加instance通道(区分不同实例)
        if self.use_features:
            netG_input_nc += opt.feat_num  # 添加feature_map通道(使用encoder)
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        # 鉴别器网络
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc  # real_images + fake_images
            if not opt.no_instance:
                netD_input_nc += 1  # 添加instance通道(区分不同实例)
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        ### Encoder network
        # 编码器网络(是define_G()中的一个子函数)
        if self.gen_features:          
            self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                          opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        # 加载网络(模型)
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)  # 初始化fake_pool:num_imgs = 0,images = []
            self.old_lr = opt.lr

            # define loss functions
            # 定义损失函数,在.forward()中使用
            # 默认使用ganfeat_loss和vgg_loss
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
            
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)   
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:             
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
                
            # Names so we can breakout loss
            # 给损失函数命名
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')

            # initialize optimizers
            # 初始化优化器
            # optimizer G(含:encoder)
            if opt.niter_fix_global > 0:                
                import sys
                if sys.version_info >= (3,0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():       
                    if key.startswith('model' + str(opt.n_local_enhancers)):                    
                        params += [value]
                        finetune_list.add(key.split('.')[0])  
                print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                print('The layers that are finetuned are ', sorted(finetune_list))                         
            else:
                params = list(self.netG.parameters())
            if self.gen_features:              
                params += list(self.netE.parameters())         
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

    # feat=feature(特征),inst=instance(实例)
    # label_map(标签图)每个像素值代表像素的对象类,inst_map(实例图)每个像素包含每个单独对象的唯一对象ID
    # 获取实例图的边界(边缘),将edge_map与label_map的one-hot向量拼接在一起,封装为Variable,赋值给input_label
    # real_image和feat_map,封装为Variable,赋值给real_image和feat_map;label_map赋值给inst_map
    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer
以下是去掉了注意力机制的 `correction_gan` 代码: ```python import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os import argparse import glob import matplotlib.pyplot as plt from tqdm import tqdm import numpy as np class UNet(nn.Module): def __init__(self): super(UNet, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), ) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), nn.ReLU(), nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2), nn.Sigmoid(), ) def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.encoder1 = self.conv_block(3, 64) self.encoder2 = self.conv_block(64, 128) self.encoder3 = self.conv_block(128, 256) self.encoder4 = self.conv_block(256, 512) self.bottleneck = self.conv_block(512, 1024) self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.decoder4 = self.conv_block(1024, 512) self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.decoder3 = self.conv_block(512, 256) self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.decoder2 = self.conv_block(256, 128) self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.decoder1 = self.conv_block(128, 64) self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0) self.sigmoid = nn.Sigmoid() def conv_block(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): e1 = self.encoder1(x) e2 = self.encoder2(F.max_pool2d(e1, 2)) e3 = self.encoder3(F.max_pool2d(e2, 2)) e4 = self.encoder4(F.max_pool2d(e3, 2)) b = self.bottleneck(F.max_pool2d(e4, 2)) d4 = self.upconv4(b) d4 = torch.cat((e4, d4), dim=1) d4 = self.decoder4(d4) d3 = self.upconv3(d4) d3 = torch.cat((e3, d3), dim=1) d3 = self.decoder3(d3) d2 = self.upconv2(d3) d2 = torch.cat((e2, d2), dim=1) d2 = self.decoder2(d2) d1 = self.upconv1(d2) d1 = torch.cat((e1, d1), dim=1) d1 = self.decoder1(d1) out = self.final_conv(d1) out = self.sigmoid(out) return out class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 1, kernel_size=16), ) def forward(self, x): return self.main(x).view(-1) def compute_iou(outputs, targets, threshold=0.5): outputs = (outputs > threshold).float() targets = (targets > threshold).float() intersection = (outputs * targets).sum(dim=(1, 2, 3)) union = outputs.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3)) - intersection iou = (intersection + 1e-6) / (union + 1e-6) return iou.mean().item() from skimage.metrics import peak_signal_noise_ratio as psnr_metric from skimage.metrics import structural_similarity as ssim_metric def compute_psnr(outputs, targets): outputs = outputs.cpu().detach().numpy() targets = targets.cpu().detach().numpy() psnr = 0 for i in range(outputs.shape[0]): psnr += psnr_metric(targets[i], outputs[i], data_range=1.0) return psnr / outputs.shape[0] def compute_ssim(outputs, targets): outputs = outputs.cpu().detach().numpy() targets = targets.cpu().detach().numpy() ssim = 0 for i in range(outputs.shape[0]): output_img = outputs[i].transpose(1, 2, 0) target_img = targets[i].transpose(1, 2, 0) H, W, _ = output_img.shape min_dim = min(H, W) win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1) win_size = max(win_size, 3) ssim += ssim_metric(target_img, output_img, data_range=1.0, channel_axis=-1, win_size=win_size) return ssim / outputs.shape[0] def wasserstein_loss(pred, target): return torch.mean(pred * target) from torch.autograd import grad def compute_gradient_penalty(discriminator, real_samples, fake_samples, device): alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device) interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates = discriminator(interpolates) fake = torch.ones(real_samples.size(0), device=device) gradients = grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty def train_correction_model(generator, discriminator, dataloader, optimizer_G, optimizer_D, device, lambda_gp, lambda_pixel, n_critic): generator.train() discriminator.train() running_g_loss = 0.0 running_d_loss = 0.0 running_iou = 0.0 running_psnr = 0.0 running_ssim = 0.0 for batch_idx, (inputs, targets) in enumerate(tqdm(dataloader, desc="Training")): inputs = inputs.to(device) targets = targets.to(device) # Train Discriminator optimizer_D.zero_grad() corrected_images = generator(inputs) real_validity = discriminator(targets) fake_validity = discriminator(corrected_images.detach()) gp = compute_gradient_penalty(discriminator, targets.data, corrected_images.data, device) d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp d_loss.backward() optimizer_D.step() # Train Generator if batch_idx % n_critic == 0: optimizer_G.zero_grad() corrected_images = generator(inputs) fake_validity = discriminator(corrected_images) g_adv_loss = -torch.mean(fake_validity) pixelwise_loss = nn.L1Loss() g_pixel_loss = pixelwise_loss(corrected_images, targets) g_loss = g_adv_loss + lambda_pixel * g_pixel_loss g_loss.backward() optimizer_G.step() else: g_loss = torch.tensor(0.0) running_g_loss += g_loss.item() running_d_loss += d_loss.item() iou = compute_iou(corrected_images, targets) psnr = compute_psnr(corrected_images, targets) ssim = compute_ssim(corrected_images, targets) running_iou += iou running_psnr += psnr running_ssim += ssim epoch_g_loss = running_g_loss / len(dataloader) epoch_d_loss = running_d_loss / len(dataloader) epoch_iou = running_iou / len(dataloader) epoch_psnr = running_psnr / len(dataloader) epoch_ssim = running_ssim / len(dataloader) return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim def validate_correction_model(generator, discriminator, dataloader, device, lambda_gp): generator.eval() discriminator.eval() running_g_loss = 0.0 running_d_loss = 0.0 running_iou = 0.0 running_psnr = 0.0 running_ssim = 0.0 with torch.no_grad(): for inputs, targets in tqdm(dataloader, desc="Validation"): inputs = inputs.to(device) targets = targets.to(device) corrected_images = generator(inputs) real_validity = discriminator(targets) fake_validity = discriminator(corrected_images) d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) g_adv_loss = -torch.mean(fake_validity) g_loss = g_adv_loss running_g_loss += g_loss.item() running_d_loss += d_loss.item() iou = compute_iou(corrected_images, targets) psnr = compute_psnr(corrected_images, targets) ssim = compute_ssim(corrected_images, targets) running_iou += iou running_psnr += psnr running_ssim += ssim epoch_g_loss = running_g_loss / len(dataloader) epoch_d_loss = running_d_loss / len(dataloader) epoch_iou = running_iou / len(dataloader) epoch_psnr = running_psnr / len(dataloader) epoch_ssim = running_ssim / len(dataloader) return epoch_g_loss, epoch_d_loss, epoch_iou, epoch_psnr, epoch_ssim def visualize_results(generator, dataloader, device, num_images=10, save_path='./results'): generator.eval() inputs, targets = next(iter(dataloader)) inputs = inputs.to(device) targets = targets.to(device) with torch.no_grad(): corrected_images = generator(inputs) inputs = inputs.cpu().numpy() targets = targets.cpu().numpy() corrected_images = corrected_images.cpu().numpy() if not os.path.exists(save_path): os.makedirs(save_path) plt.figure(figsize=(20, 10)) for i in range(num_images): plt.subplot(3, num_images, i + 1) plt.imshow(targets[i].transpose(1, 2, 0)) plt.title("Original") plt.axis('off') plt.subplot(3, num_images, i + 1 + num_images) plt.imshow(inputs[i].transpose(1, 2, 0)) plt.title("Simulated Colorblind") plt.axis('off') plt.subplot(3, num_images, i + 1 + 2 * num_images) plt.imshow(corrected_images[i].transpose(1, 2, 0)) plt.title("Corrected") plt.axis('off') plt.tight_layout() plt.savefig(f'{save_path}_visualization.png') plt.show() def plot_and_save_metrics(train_metrics, val_metrics, epoch, path='./metrics_plots'): if not os.path.exists(path): os.makedirs(path) epochs = np.arange(1, epoch + 1) train_g_losses, train_d_losses, train_ious, train_psnrs, train_ssims = zip(*train_metrics) val_g_losses, val_d_losses, val_ious, val_psnrs, val_ssims = zip(*val_metrics) plt.figure() plt.plot(epochs, train_g_losses, label='Training Generator Loss') plt.plot(epochs, val_g_losses, label='Validation Generator Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Generator Loss over Epochs') plt.legend() plt.savefig(f'{path}/generator_loss.png') plt.close() plt.figure() plt.plot(epochs, train_d_losses, label='Training Discriminator Loss') plt.plot(epochs, val_d_losses, label='Validation Discriminator Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Discriminator Loss over Epochs') plt.legend() plt.savefig(f'{path}/discriminator_loss.png') plt.close() plt.figure() plt.plot(epochs, train_ious, label='Training IoU') plt.plot(epochs, val_ious, label='Validation IoU') plt.xlabel('Epoch') plt.ylabel('IoU') plt.title('IoU over Epochs') plt.legend() plt.savefig(f'{path}/iou.png') plt.close() plt.figure() plt.plot(epochs, train_psnrs, label='Training PSNR') plt.plot(epochs, val_psnrs, label='Validation PSNR') plt.xlabel('Epoch') plt.ylabel('PSNR') plt.title('PSNR over Epochs') plt.legend() plt.savefig(f'{path}/psnr.png') plt.
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值