“神仙姐姐”CycleGAN
在“风格迁移四部曲系列”的《风格迁移的“精神始祖”Conditional GAN》文章中,已经跟大伙一起在MNIST手写数据集上手撸了CGAN,让GAN学会了“认标签,写数字”。然后,我们将CGAN“拟合条件概率分布”的思想发扬光大,在文章 《用CGAN做图像转换的鼻祖pix2pix》 中,让GAN学会了“看图学画风”,并用学会的图片风格渲染新图片。到这里GAN是不是已经有了点艺术家的气质了~ 但是,前面介绍的两个GAN只能算是“阿朱、阿碧”那样的小丫鬟。本项目介绍的CycleGAN才是真正的大小姐“王姑娘”。既然Pix2Pix也能干风格迁移的活儿,为什么就和CycleGAN丫鬟小姐不同命呢?打个比方,非是两个丫头不够聪明(Pix2Pix效果不够好),而是她们不认识字(适用范围窄),武功秘籍都得大侠念给她们听才能记得(得让训练集的两组图片一一对应才能训练)。王姑娘则从小接受书香门第的全面素质教育(CycleGAN经朱俊彦大神悉心改造),自家的武功秘籍还能可劲儿看(网上的图片按域特征分成两组就能喂给CycleGAN),自然识得天下武功(CycleGAN应用发扬光大)。再说,Pix2Pix效果再惊艳,也不能老蹭人家分割任务的数据集用吧。比如,下面这个将照片转变为大师画作的任务中,只要备好了一组照片和一组大师的作品作为数据集,CycleGAN就能轻松搞定:CycleGAN的介绍
1.CycleGAN的原理
CycleGAN,即循环生成对抗网络,出自发表于 ICCV17 的论文 《Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks》 ,和它的兄长Pix2Pix(均为朱大神作品)一样,用于图像风格迁移任务。以前的GAN都是单向生成,CycleGAN为了突破Pix2Pix对数据集图片一一对应的限制,采用了双向循环生成的结构,因此得名CycleGAN。 首先,CycleGAN也是一个GAN模型,通过判别器和生成器的对抗训练,学习数据集图片的像素概率分布来生成图片。原理已经在前面的文章 《通俗理解经典GAN》 中详细介绍过了。 要完成X域到Y域的图片风格迁移,就要求GAN网络既要拟合Y域图片的风格分布分布,又要保持X域图片对应的内容特征。打个比方,用草图风格的猫图片生成照片风格的猫图片时,要求生成的猫咪“即要活灵活现,又要姿势不变”。“拟合数据分布”本来就是GAN干的活儿,而“保持原图片特征”在Pix2Pix上是这么实现的(详解可参考 《用CGAN做图像转换的鼻祖pix2pix》 ):2.CycleGAN的流程
下面,我们就来看看循环生成网络(CycleGAN)到底是怎么“循环起来”的:3.CycleGAN的结构
接下来,我们再看看这两对判别器、生成器怎么摆:4.CycleGAN的loss函数
前面分析了CycleGAN的原理,我们已经知道了CycleGAN的loss由对抗损失(称为gan loss或adversarial loss)和循环一致性损失(consitency loss)组成,下面看看公式:CycleGAN的实现
下面,我们就来用Paddle的动态图模式,实现这个将妹子照片转化为二次元风格的“讨喜神器”(单方精妙、小心炼制、谨慎使用~)。1.数据集准备
将selfie2anime数据集解压到/home/aistudio/data/data50363/路径下,trainA文件夹下存储照片风格训练集图片,trainB文件夹下存储卡通风格训练集图片,testA和testB分别存储照片风格和卡通风格的测试集图片。数据集的读取器和上个文章 《用CGAN做图像转换的鼻祖pix2pix》 一样使用Paddle套件代码库里的脚本。与其不同的是,得益于CycleGAN的训练数据适应能力,我们无需每次送入模型一对对应的图片,只需送入两个单独的读取器从两组图片中各自shuffle后输出的任意两张图片。这样,还能通过打乱顺序增加模型的泛化能力。 此外,为了实现模型的更佳效果,还使用了明暗、对比度、饱和度、拉伸、旋转等数据增强效果。具体的使用原因我们在最后的对比分析中再详细解释。# 解压数据集,首次运行后注释# !unzip -qa -d /home/aistudio/data/data50363/ /home/aistudio/data/data50363/selfie2anime_textlist.zipimport paddle.fluid as fluidimport data_reader_epoch as data_readerimport paddleimport matplotlib.pylab as plt
%matplotlib inlineimport numpy as npdef show_pics(pics, heatmap=np.zeros((1, 1))):
plt.figure(figsize=(3 * len(pics), 3), dpi=80)for i in range(len(pics)):
pics[i] = (pics[i][0].tran