CycleGAN(基于PyTorch框架)
0.论文简介
CycleGAN是一款实现风格迁移的模型,其论文可以在各大平台找到。我们在aixiv上可以找到:https://arxiv.org/pdf/1703.10593.pdf。
我们复现的的代码是来自下面这个github仓库:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix。
虽然看起来很简单,不过对于刚入门的新手来说难度还真不低,下面让我们来仔细看看代码的结构。
0.1本文主要的工作
- 在训练集缺失的情况下,将图片从某一种风格转移到另一种风格。希望学习到一种映射规则G,使得G(X)=Y。
- 希望找到G的可逆变换F使得F(G(X))=X。
- 本文进行了很多任务方面的尝试,并和之前的方法作对比。
0.2引言
- 整体步骤概述:捕获风格1的图像特征,在没有训练集提示的情况下将其转化为风格2的特征。
- 研究背景:获取不同风格的成对数据存在一定的困难。
- 具体实施:虽无法获取图像级别的监督(缺乏有标注的图像对),但可获取集合级别的监督(X和Y中各有一组图像,我们不知道X中的某张图对应Y中的哪张图,但我们可以知道X和Y这两个集合是彼此对应的)。经过训练,使得y ̂=G(X)与y无法区分,也就是使得y ̂和y的分布尽可能一致。
- 遇到的问题:第一是无法确定哪一个才是有意义的配对(可能会有很多组映射G),第二是独立地去优化对抗损失是一件很困难的事(配对程序会将所有输入的映射图像都转换成同一个输出图像)。
- 解决措施:添加循环一致性损失,把F(G(X))与x、G(F(Y))与y的损失也都加到对抗网络的损失里。
0.3方法
- 整体损失项:损失函数一共是4个部分,其中有两个是对抗损失,两个是循环一致性损失。四项分别是:(1)D_Y:用于衡量y ̂=G(x)与y的损失;(2)D_X:用于衡量x ̂=F(y)与x的损失;(3)F(G(x))与x之间的损失;(4)G(F(y))与y的损失。
- 对抗损失:基本公式如下。
优化思想是:
- 循环一致性损失。基本公式如下:
- 总损失函数
优化目标是
1.代码结构
1.1根目录中的文件
我们将目光聚集到根目录这个位置。
这里面,我们先来看根目录下的文件:
- README.md就是说明书。
- requirements.txt是说明这样的一个仓库所需要的各种包的版本。
- .gitignore文件是那些上传的时候要忽略的东西(不是所有的数据都需要被commit到仓库里,有时候可能只需要交源码)。
- LICENSE文件是许可证文件,会告知我们有什么样的权限(比如这个项目的代码我们可以下载并在本地中修改,但是不能修改远程仓库的内容)。
- .replit文件提供了所使用的信息,便于在浏览器中运行代码,这样一来就无需在本地配置环境。这是在使用云编辑器repl.it的时候可能会用到的设置。
- environment.yml这个文件相当于Python+requirement.txt,我们可以直接使用
conda env create -f environment.yml
来创建一个environment.yml文件里指定的环境(里头什么样的包、环境名是什么、Python版本是多少)。当然,如果你现在就有这样的一个环境,想要导出这个环境所对应的environment.yml,只需要下面的命令就可以conda env export | grep -v "^prefix: " > environment.yml
。 - train.py是用于训练的主脚本,可以指定使用什么样的数据集和模型。我们使用–model选项可以指定使用什么样的模型(例如:pix2pix, cyclegan, colorization),通过–dataset_mode指定数据模式(例如:aligned, unaligned, single, colorization),通过–dataroot指定数据集路径,通过–name指定实验名称。这里可以列举一个命令供参考
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
。 - test.py是用于测试的主脚本,通过–checkpoints_dir可以设置模型读取的路径,通过–results_dir可以设置结果的保存路径,通过–dataroot可以设置数据集的路径,通过–name可以设置任务名称,通过–model设置所采用的模型。对于CycleGAN双向检验,可以利用命令
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
来实现,其中的–model cycle_gan会将数据导入的模式变为双向。对于CycleGAN的单项检验,可以利用命令python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
来实现。–no_dropout指的是不需要dropout;–model test指的是单向验证CycleGAN模型,这将使得–dataset_mode自动变成single,也就是导入单一集合的数据。 - CycleGAN.ipynb和pix2pix.ipynb里头是两个模型的运行教程(在jupyter notebook)。
1.1.1 train.py文件
21-25行不说了,都是导入一些基本的类。
第27行的意思是,如果这个脚本作为主脚本使用,那么就运行下方的东西。28行是先把TrainOption实例化成对象,然后用parse进行解析,这样形成一个结果,赋给opt,也就是说,opt解析出来的结果。29行是根据这个结果去创建数据集。30行获取数据集中样本的数量。31行不说了。
33行是创建模型。34行是根据opt创建合适的学习率调整策略、导入网络并打印。第35行是根据opt创建可视化实例。36行是训练迭代次数。
第38行是迭代过程的开始,opt.epoch_count是从哪个epoch开始,opt.n_epochs_decay是持续多少epoch。39-40行不说了,获取这个epoch开始的时间和本轮epoch导入数据的时间。41行是在本轮epoch当中的第几次迭代。42行是可视化机器的重置,保证在每个epoch里它至少有一次保存图片。43行是在每次epoch之前率先更新一下学习率。
第44行就是每个epoch内部的循环了,enumerate函数的作用是同时列出数据和下标,这个无需多说,注意这里的i是batch的编号,而data也不是一张图,而是一个batch的图。45行是本次iteration开始的时间。46-47行是说如果总迭代次数total_iters到了opt.print_freq的整倍数,就计算t_data,也就是本轮iteration开始的时刻到本轮epoch导入数据的时间已经过去了多久。49-50行指的是,一共多少个数据参与了训练以及本轮epoch里有多少数据参与了迭代。51行是把每一个数据解包,52行是参数优化,这些都是在model当中的basemodel.py当中定义的。
54-57行是数据可视化的部分,如果本轮epoch当中已经参与迭代的样本总数是opt.display_freq的整数倍,那么执行55-57行的操作。55行是返回一个叫save_result的布尔值,用于判定是否需要存出结果到html文件里。第56行是只有在着色任务中才有用,是展示图片的命令,其他的模型中compute_visuals函数只有一个命令,那就是pass。第57行则是存储到html文件里的命令,其中save_result就决定本行是否执行,可以参见util文件夹里的visualizer.py。
59-64行是打印的部分。如果本轮epoch当中已经参与迭代的样本总数是opt.print_freq的整数倍,那么执行60-64行的操作。第60行是获取当前的损失函数。第61行是计算每个图片所用的时间。第62行是输出当前的损失值,后面的参数含义大家可以点击util文件夹下的visuallizer.py去查看相应的函数。63-64行是损失值可视化的部分,如果window id of the web display这个值大于0,那么就利用plot_current_losses函数输出,参数的含义可以点击util文件夹下的visuallizer.py去查看相应的函数。
66-69行是保存权重文件的部分,如果本轮epoch当中已经参与迭代的样本总数是opt.save_latest_freq的整数倍,那么执行67-69行的操作。67行不说了,68行是设置保存后缀,69行是保存模型。
71行是重新获取时间。72-75行也是在保存模型,不过这次是在每个epoch结束的时候。
77行是输出,无需多言。
1.1.2 test.py文件
它将从’–checkpoints_dir’加载保存的模型,并将结果保存到’–results_dir’。它首先在给定opt选项的情况下创建模型和数据集。它将硬编码一些参数。然后,它对“–num_test”图像运行推断,并将结果保存到HTML文件中。
29-34行不用多说了,导入包。
36-39行是导入wandb包,可以帮我们记录超参数指标。
42-43行不说了,和上面的train.py有异曲同工之妙。45行,测试模式仅能使用单线程,至于哪一个线程,你可以去自己指定;46行,batch_size只能为1;47行是确定数据需不需要打乱;48行则是是否翻转;49行是放弃展示图片;50-52行不说了,上面的train.py里解释过是什么含义。
54-57行,没太看懂。不太熟悉wandb这个包的含义。
59-行,是在创建网站。60行是在确定地址。61-62行是要根据本轮迭代来确定网址的域名(整体)。63行不说了,就是打印一下结果。64行是确定网页的地址和标题(这块可能得在util文件夹里头找html.py)。68-69行是评估模式开启。
70-72行比较简单,不再重复。73-74行也比较容易理解,分别是解包数据、测试。75-78行可以参考注释,获取图像结果、获取图像路径,每隔5个图片打印一次。79行是保存图片到html中,参数的含义可以点击util文件夹下的visuallizer.py去查看相应的函数。最后80行,保存html。
1.2根目录中的文件夹
之后,我们再来看看各个文件夹都是怎么回事。
1.2.1 docs文件夹
docs文件夹不多说了,里头是各种说明文档。
1.2.2 .git文件夹
.git文件夹也不多说了,这是用于分布式版本管理的工具,具体什么是git请自行百度。在我【教程搬运】的专栏下也有专门介绍git的博文。
1.2.3 data文件夹
data文件夹,里头是各种和数据加载、处理的模块。里头的__init__.py是一个接口文件,basedataset.py是一个基础文件(包含一些常见的转换功能,有点相当于公用的“基类”,不知道怎么描述了),template_dataset.py这是一个模板,相当于示例文件。其它的都是具体的数据集对应的文件了。
1.2.3.1 template_dataset.py
首先让我们聚焦一个模板文件,也就是template_dataset.py,在这里我们仅仅给出一些说明,读完之后觉得抽象也没关系,我们后面还有例子(1.2.3.2节之后),慢慢体会,慢慢读就可以了。
这个文件主要起到一个模板的作用,是一个参考,具体说明如下:
- 这个脚本可以被当做一个模板,被用于创建新的数据类型。如果说此时此刻的我们想建立一个新的数据类型dummy,就需要在这个根目录底下创建一个名叫dummy_dataset.py的文件,里面需要定义一个类,名叫DummyDataset,而这个类需要继承父类BaseDataset(当然这个类就在data文件夹中的base_dataset.py之中),在类DummyDataset这里面需要实现四个重要的功能,我们将在后面仔细分析。
- 创建完之后如何使用呢?可以通过–dataset_mode template来指定,但需要注意,你所创建的类名class TemplateDataset、在–data_mode后面所指定的template、文件名template_dataset.py这三者都要保持一致,在实际应用中把template换为你自己的数据集名。具体的命名规范在template_dataset.py这个脚本前面有表述。
好了,刚刚我们已经说明了这个模板函数的作用,下面让我们详细地说一下要实现的是个具体功能:
- __len__函数,用于统计数据集里有多少数据,这无需多言,里面需要传入一个self参数,这显然是实例化之后的对象。返回值一般是len(self.A_path),括号内的内容是访问self的路径属性。
- modify_commandline_options函数,用于添加针对这个数据集特定的选项,这个脚本里头只是一个样例。
- __getitem__函数,这个函数将用来获取数据点,最后要返回的是数据和数据的路径,{‘data_A’: data_A, ‘data_B’: data_B, ‘path’: path},一切信息就都包含在这样一个字典里头。
- __init__函数,注意到它需要传入两个参数,一个是self,另一个是opt,前者就是将类进行实例化出来的对象,不用管;后者是我们添加的选项,在options里头文件夹里头有一些BaseOption,我们的opt必须是其中的子集。然后先得继承一下BaseDataset.__init__这个方法,之后在此处要获取数据集的路径,并且还要对输入数据进行一定的预处理。
为了便于大家理解__init__函数,我列举了single_dataset.py这个脚本里的内容进行举例。