目录
参考资料
论文:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
博客:
CycleGAN:图片风格,想换就换 | ICCV 2017论文解读
视频:
代码:
第1章 CycleGAN的作用
CycleGAN的一个重要应用领域是 Domain Adaptation
(域迁移:可以通俗的理解为画风迁移),比如可以把一张普通的风景照变成梵高化作,或者将游戏画面变化成真实世界画面等等。以下是原论文中给出的一些应用:
第2章 CycleGAN的优势
其实在CycleGAN之前,就已经有了Domain Adaptation模型,比如Pix2Pix,不过 Pix2Pix
要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,而CycleGAN只需要两种域的数据,而不需要他们有严格对应关系,这使得CycleGAN的应用更为广泛。原论文中是这样解释的:
第3章 CycleGAN的网络结构
CycleGAN 可以让两个 domain 的图片互相转化。传统的 GAN 是单向生成,而 CycleGAN 是互相生成,网络是个环形,所以命名为 Cycle。并且 CycleGAN 一个非常实用的地方就是输入的两张图片可以是任意的两张图片,也就是 unpaired
。
3.1 单向GAN
CycleGAN 本质上是两个镜像对称的 GAN,构成了一个环形网络。其实只要理解了一半的单向 GAN 就等于理解了整个CycleGAN。
上图是一个单向 GAN 的示意图。我们希望能够把 domain A
的图片(命名为 A)转化为 domain B
的图片(命名为图片 B)。为了实现这个过程,我们需要两个生成器
G
A
B
G_{AB}
GAB 和
G
B
A
G_{BA}
GBA,分别把 domain A
和 domain B
的图片进行互相转换。
图片
A
A
A 经过生成器
G
A
B
G_{AB}
GAB 表示为 Fake Image in domain B
,用
G
A
B
(
A
)
G_{AB}(A)
GAB(A) 表示。而 经
G
A
B
(
A
)
G_{AB}(A)
GAB(A)过生成器
G
B
A
G_{BA}
GBA表示为图片
A
A
A 的重建图片,用
G
B
A
(
G
A
B
(
A
)
)
G_{BA}(G_{AB}(A))
GBA(GAB(A)) 表示。
最后为了训练这个单向 GAN 需要两个 loss,分别是 生成器的重建 loss
和 判别器的判别 loss
。
(1)判别 loss:判别器
D
B
D_B
DB 是用来判断输入的图片是否是真实的 domain B
图片,于是生成的假图片
G
A
B
(
A
)
G_{AB}(A)
GAB(A)和原始的真图片 B 都会输入到判别器里面,公式挺好理解的,就是一个 0,1 二分类的损失。最后的 loss 表示为:
(2)生成 loss:生成器用来重建图片 A,目的是希望生成的图片 G B A ( G A B ( A ) ) G_{BA}(G_{AB}(A)) GBA(GAB(A)) 和原图 A 尽可能的相似,那么可以很简单的采取 L 1 L o s s L_1\ Loss L1 Loss 或者 。最 L 2 L o s s L_2\ Loss L2 Loss 后生成 Loss 就表示为:
以上就是 A→B 单向 GAN 的原理。
3.2 CycleGAN
CycleGAN 其实就是一个 A→B 单向 GAN 加上一个 B→A 单向 GAN。两个 GAN 有两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。一个单向 GAN 有两个 loss,而 CycleGAN 加起来总共有四个 loss。CycleGAN 论文的原版原理图和公式如下,其实理解了单向 GAN 那么 CycleGAN 已经很好理解。
(1) X → Y X→Y X→Y 的判别器损失为:
(2) Y → X Y→X Y→X 的判别器损失为:
(3)两个生成器的 loss 加起来表示为:
(4)最终网络的所有损失加起来为:
(5)论文里面提到判别器如果是对数损失(BCE Loss)训练不是很稳定,所以改成的均方误差损失(MSE Loss),如下:
3.3 改进Loss
上面我们提到,我们希望生成器只进行风格的迁移而保证内容不变,具体而言:
- 风格迁移,内容不变: G 吃一张房子的照片,吐一张梵高风格的房子的照片;
- 风格迁移,内容改变: G 吃一张房子的照片,任意吐一张梵高风格的照片
仅靠上面的 Loss 能否保证风格迁移,内容不变呢?我认为不能!以下图为例:
正常情况下,我们希望 G ( x ) = a G(x)=a G(x)=a ,但是根据上面的 L o s s Loss Loss 会不会导致 G ( x ) = b , F ( b ) = x G(x)=b , F(b)=x G(x)=b,F(b)=x 的情况发生呢?答案是肯定的,对于 G ( x ) = b G(x)=b G(x)=b ,虽然产生的图片 b b b 并不是我们希望的,但是由于 b b b 的确是梵高画风,所以判别器会给它高分,这会鼓励生成器错误产生 b b b 的这个行为。
其次,在更新 F F F 参数的时候,由于 L o s s c y c l e Loss_{cycle} Losscycle 中 E x ∼ p d a t a ( x ) [ ‖ F ( G ( x ) ) − x ‖ 1 ] E_{x∼pdata} (x)[‖F(G(x))−x‖_1] Ex∼pdata(x)[‖F(G(x))−x‖1] 一项的存在,即使 G ( x ) G(x) G(x) 错误产生了 b b b , F F F 任然会努力把 G ( x ) G(x) G(x) 错误的结果“掰”回 x x x ,这就像 F F F 在“包庇” G G G 的错误。
同样,在更新 G G G 参数的时候,由于 E y ∼ p d a t a ( y ) [ ‖ G ( F ( y ) ) − y ‖ 1 ] E_{y∼pdata} (y)[‖G(F(y))−y‖_1] Ey∼pdata(y)[‖G(F(y))−y‖1] 的存在, G G G 也会去“包庇” F F F 。这样一来,就会出现上图中风格迁移,内容改变的情况。
这一点在原文中有提到,但原文说 Identity Loss
的作用主要是保证色调不变。Identity Loss
的形式为:
就是将真实的B输入到A生成B的判别器中,查看判别器的识别损失,希望越小越好!说明生成器网络真正的理解了B的结构。
加上 Identity Loss
后,整个损失函数的表达式为:
总结一下 Loss
实现:
3.4 Instance Normalization
图片使用了Instance Normalization而非经典DCGAN中所使用的Batch Normalization,Instance Normalization
和 Batch Normalization
一样,也是Normalization的一种方法,只是IN是作用于单张图片,但是BN作用于一个Batch。
假如现在图像先进行了卷积运算得到如上图所示的激活状态 ( N , C , H , W ) (N,C,H,W) (N,C,H,W) ,其中 N N N 是样本数, C C C 为通道数即特征图数。
- BN:取不同样本的同一个通道的特征做归一化,逐特征维度归一化。这个就是对batch维度进行计算。所以假设5个100通道的特征图的话,就会计算出100个均值方差。5个batch中每一个通道就会计算出来一个均值方差。
- LN:取的是同一个样本的不同通道做归一化,逐个样本归一化。5个10通道的特征图,LN会给出5个均值方差。
- IN:仅仅对每一个图片的每一个通道最归一化。也就是说,对【H,W】维度做归一化。假设一个特征图有10个通道,那么就会得到10个均值和10个方差;要是一个batch有5个样本,每个样本有10个通道,那么IN总共会计算出50个均值方差。
- GN:这个是介于LN和IN之间的一种方法。假设Group分成2个,那么10个通道就会被分成5和5两组。然后5个10通道特征图会计算出10个均值方差。
3.5 PatchGAN
参考:
CycleGAN网络中的判别器使用的是一种叫 PatchGAN
的设计,原始GAN的discriminator的设计是仅输出一个评价值(True or False),该值是对生成器生成的整幅图像的一个评价。
在以往的GAN学习中,判别器D网络的输出是一个标量,介于0~1之间,代表是真实图片的概率。
而PatchGAN的设计不同,PatchGAN设计成全卷积的形式,图像经过各种卷积层后,并不会输入到全连接层或者激活函数中,而是使用卷积将输入映射为N*N矩阵,该矩阵等同于原始GAN中的最后的评价值用以评价生成器的生成图像。
N × N N\times N N×N 矩阵中每个点(true or false)即代表原始图像中的一块小区域(这也就是patch含义)评价值,这也就是“感受野(下图)”的应用。
原来用一个值衡量整幅图,现在使用
N
×
N
N\times N
N×N 的矩阵来评价整幅图(使用 PatchGAN
标签也需要设置成为
N
×
N
N\times N
N×N 的格式,这样就可以进行损失计算了),显然后者可以关注更多的区域,这也就是 PatchGAN
的优势。
PatchGAN主要是用于判别器,普通的判别器我们所得到的是判断一张图像是否为目标图像(输入可以是期望的图像,也可以是生成器生成的图像)。PatchGAN则是基于映射的关系,通过卷积的感受野来判断某个小区域是否为我们想要的目标图片,并最终进行加权。
3.6 训练细节
- (1)图片使用了Instance Normalization而非经典DCGAN中所使用的Batch Normalization;
- (2)写代码时,并没有使用上面Loss中的 log likelihood 形式,而是使用的least-squares loss;
- (3)判别器采用的70×70 PatchGAN形式;
- (4)生成器网络使用了residual blocks;
- (5)训练时的Batch Size为1;
【有以下几种说法】:(参考:为什么 CycleGAN 的batchSize等于1?)
1、如果bs>1的话,一个batch里面存在不同内容的源域数据和不同风格的目标数据,会混淆图片的生成,在风格迁移里,就应该一张图一张图地训练。2、想在高分辨率图像上训练;
3、为了将train和test的batchsize保持一致;
4、instancenormal的使用,使用batchsize=1比较好;
- (6)学习率在前100个epochs不变,在后面的epochs线性衰减;
- (7)使用了Reflection padding而非普通的Zero padding;
- (8)生成器各层激活函数主要为ReLU,判别器各层激活函数主要为LeakyReLU;
- (9)训练判别器时还会用到生成器产生的历史数据(Buffer);
第4章 Pytorch实现CycleGAN
参考资料:
4.1 Models.py
import torch.nn as nn
import torch.nn.functional as F
import torch
# 初始化权重函数
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
if hasattr(m, "bias") and m.bias is not None:
torch.nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
##############################
# RESNET
##############################
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
# 经过两次3x3的卷积,(WxH) -> (W-4)x(H-4)
# 经过两次pandding,(WxH) -> (W+4)x(H+4)
# 所以经过整个操作后,(WxH) -> (WxH)
self.block = nn.Sequential(
# nn.ReflectionPad2d()函数用法参考:https://blog.youkuaiyun.com/LionZYT/article/details/120181586
nn.ReflectionPad2d(1), # 对四周都填充1行
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def forward(self, x):
return x + self.block(x)
##############################
# Generator
##############################
class GeneratorResNet(nn.Module):
def __init__(self, input_shape, num_residual_blocks):
super(GeneratorResNet, self).__init__()
# input_shape = (3, 256, 256)
channels = input_shape[0]
# Initial convolution block
out_features = 64
model = [
# (3, 256, 256) -> (3, 262, 262)
nn.ReflectionPad2d(channels),
# (3, 262, 262) -> (64, 256, 256)
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
# in_features = 64
in_features = out_features
# Downsampling下采样
for _ in range(2):
# 1:out_features = 128
# 2:out_features = 256
out_features *= 2
# (64, 256, 256) -> (256, 64, 64)
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
# in_features = 256
in_features = out_features
# Residual blocks
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)]
# Upsampling上采样
for _ in range(2):
out_features //= 2
# (256, 64, 64) -> (64, 256, 256)
model += [
nn.Upsample(scale_factor=2), # (W,H)->(Wx2, Hx2)
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
# (64, 256, 256) -> (3, 256, 256)
model += [
nn.ReflectionPad2d(channels),
nn.Conv2d(out_features, channels, 7),
nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
##############################
# Discriminator
##############################
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
# input_shape = (3, 256, 256)
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
# (1, 16, 16)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
# (3, 256, 256) -> (1, 16, 16)
self.model = nn.Sequential(
*discriminator_block(channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)), # 左右上下的顺序
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, img):
return self.model(img)
4.2 Datasets.py
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
"""
主要是ImageDataset函数的操作,__init__操作将trainA和trainB的路径读入files_A 和files_B;
__getitem__对两个文件夹的图片进行读取,若不是RGB图片则进行转换;__len__返回两个文件夹数据数量的大值。
"""
# 转为rgb图片
def to_rgb(image):
rgb_image = Image.new("RGB", image.size)
rgb_image.paste(image)
return rgb_image
# 对数据进行读取
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned
self.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*"))
self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*"))
def __getitem__(self, index):
image_A = Image.open(self.files_A[index % len(self.files_A)])
if self.unaligned:
image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
else:
image_B = Image.open(self.files_B[index % len(self.files_B)])
# Convert grayscale images to rgb
if image_A.mode != "RGB":
image_A = to_rgb(image_A)
if image_B.mode != "RGB":
image_B = to_rgb(image_B)
item_A = self.transform(image_A)
item_B = self.transform(image_B)
return {"A": item_A, "B": item_B}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
4.3 Utils.py
import random
import time
import datetime
import sys
from torch.autograd import Variable
import torch
import numpy as np
from torchvision.utils import save_image
"""
主要关注学习率衰减(LambdaLR)。
"""
class ReplayBuffer:
def __init__(self, max_size=50):
assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
to_return = []
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0, 1) > 0.5:
i = random.randint(0, self.max_size - 1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return Variable(torch.cat(to_return))
# 学习率在前100个epochs不变,在后面的epochs线性衰减
class LambdaLR:
def __init__(self, n_epochs, offset, decay_start_epoch):
assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
self.n_epochs = n_epochs
self.offset = offset
self.decay_start_epoch = decay_start_epoch
def step(self, epoch):
return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
4.4 Cyclegan.py
import argparse
import os
import numpy as np
import math
import itertools
from tqdm.autonotebook import tqdm
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models import *
from datasets import *
from utils import *
import torch.nn as nn
import torch.nn.functional as F
import torch
'''
参数表格
epoch:使用数据集的所有数据进行一次模型训练,一代训练,从第0代开始训练
n_epochs:训练的次数,默认200次
dataset_name:数据集文件夹的名字,默认"monet2photo"
batch_size:使用数据中的一部分数据进行模型权重更新的这部分数据大小,默认1
lr:adam学习率
b1&b2:adam学习参数
decay_epoch:lr学习率开始衰减
n_cpu:训练过程中用到的CPU线程数目
img_height:输入图片的高度,默认256
img_width:输入图片的宽度,默认256
channels:图片的通道数,默认为彩色图片,channels=3
sample_interval:每隔一段时间对训练输出进行采样并展示,默认100
n_residual_blocks:生成器中的residual模块的数量
lambda_cyc:cycle loss权重参数
lambda_id:identity loss权重参数
'''
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="monet2photo", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight")
opt = parser.parse_args()
print(opt)
# 创建文件夹保存模型和采样输出图片
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)
# 损失函数定义和初始化
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
# 判断电脑是否可以使用GPU进行训练
cuda = torch.cuda.is_available()
# input_shape保存输入图片的通道数,高度,宽度
input_shape = (opt.channels, opt.img_height, opt.img_width)
# 初始化四个网络(G_AB,G_BA,D_A,D_B)
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
# 采用GPU进行训练
if cuda:
G_AB = G_AB.cuda()
G_BA = G_BA.cuda()
D_A = D_A.cuda()
D_B = D_B.cuda()
criterion_GAN.cuda()
criterion_cycle.cuda()
criterion_identity.cuda()
# 如果不是从第0代开始训练,则从保存的模型中调用模型以及加载开始训练的代数,继续训练
if opt.epoch != 0:
# Load pretrained models
G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
# 如果从头开始训练,就初始化权重
# Initialize weights
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)
# 定义初始化模型的优化器
optimizer_G = torch.optim.Adam(
itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# 按照epoch的次数自动调整学习率
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
# 数据预处理包括resize、crop、flip、normalize等操作
transforms_ = [
transforms.Resize(int(opt.img_height * 1.12), interpolation=InterpolationMode.BICUBIC), # 调整Image对象的尺寸
transforms.RandomCrop((opt.img_height, opt.img_width)), # 扩大后剪切成img_height*img_width大小的图片
transforms.RandomHorizontalFlip(), # 依据概率p对PIL图片进行水平翻转,p默认0.5
transforms.ToTensor(), # 转为tensor格式
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
]
# 加载训练数据
# Training data loader
dataloader = DataLoader(
# ../表示当前目录的父目录
ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="train"),
batch_size=opt.batch_size,
shuffle=True, # 将数据打乱,数值越大,混乱程度越大
# num_workers=0,
num_workers=opt.n_cpu, # 线程数
)
# 测试数据加载
# Test data loader
val_dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
batch_size=5,
shuffle=True,
num_workers=0,
)
# 定义测试数据喂进网络的输出展示函数
def sample_images(batches_done):
"""Saves a generated sample from the test set"""
imgs = next(iter(val_dataloader))
G_AB.eval()
G_BA.eval()
real_A = Variable(imgs["A"].type(Tensor))
fake_B = G_AB(real_A)
real_B = Variable(imgs["B"].type(Tensor))
fake_A = G_BA(real_B)
# Arange images along x-axis
real_A = make_grid(real_A, nrow=5, normalize=True)
real_B = make_grid(real_B, nrow=5, normalize=True)
fake_A = make_grid(fake_A, nrow=5, normalize=True)
fake_B = make_grid(fake_B, nrow=5, normalize=True)
# Arange images along y-axis
image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)
# ----------
# Training
# 开始训练
# ----------
if __name__ == '__main__':
prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
loop = tqdm(dataloader, colour='red', unit='img')
for i, batch in enumerate(loop):
# 设置模型输入
# Set model input
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# 对抗生成网络中的真实图片和虚假图片
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# 训练生成器
# ------------------
G_AB.train()
G_BA.train()
# 梯度清零,方便下代训练
optimizer_G.zero_grad()
# Identity loss :
# 用于保证生成图像的连续性,一个图像x,经过其中一个生成器生成图像 G(x),尽可能与原来图像接近。
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# 总损失函数
# Total loss
loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
# 反向传播
loss_G.backward()
# 权重更新
optimizer_G.step()
# -----------------------
# Train Discriminator A
# 训练分类器A
# -----------------------
optimizer_D_A.zero_grad()
# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
loss_D_A.backward()
optimizer_D_A.step()
# -----------------------
# Train Discriminator B
# 训练分类器B
# -----------------------
optimizer_D_B.zero_grad()
# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
loss_D_B.backward()
optimizer_D_B.step()
# 两个分类器损失之和
loss_D = (loss_D_A + loss_D_B) / 2
# --------------
# Log Progress
# 训练进程显示
# --------------
# Determine approximate time left
batches_done = epoch * len(dataloader) + i
# batches_left = opt.n_epochs * len(dataloader) - batches_done
# time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
# prev_time = time.time()
# # Print log
# sys.stdout.write(
# "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
# % (
# epoch,
# opt.n_epochs,
# i,
# len(dataloader),
# loss_D.item(),
# loss_G.item(),
# loss_GAN.item(),
# loss_cycle.item(),
# loss_identity.item(),
# time_left,
# )
# )
# 进度条参数
loop.set_description(f"Epoch [{epoch}/{opt.n_epochs}] Batch[{i}/{len(dataloader)}]")
loop.set_postfix(D_loss=loss_D.item(), G_loss=loss_G.item(),
loss_GAN=loss_GAN.item(), loss_Cycle=loss_cycle.item(),
loss_Identity=loss_identity.item())
# If at sample interval save image
if batches_done % opt.sample_interval == 0:
sample_images(batches_done)
# Update learning rates
# 更新学习率
lr_scheduler_G.step()
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()
if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# Save model checkpoints
torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))