CycleGAN更换MNIST底色
完整代码:https://github.com/SongDark/domain_transfer_mnist
概述
可能你会觉得我吃饱了撑的,杀鸡用牛刀,拿GAN来做MNIST的底色变换。其实只是为了试验方便,我懒得下载大型数据集罢了。。。你如果感兴趣,去这里下载数据集,能实现“从普通马到斑马的转换”,模型是一样的。
数据准备
从 这里 下载 mnist.npz
。
- 将背景改成彩色,数字保持白色,背景rgb随机生成。
# [28, 28] -> [28, 28, 3]
def change_background(img):
rgb = np.random.randint(low=0, high=255, size=(3,))
res = np.tile(img[:,:,None], (1,1,3))
for i in range(3):
res[:,:,i][res[:,:,i]<127.5] = rgb[i]
return res
- 将数字改成彩色,背景改为白色,数字rgb随机生成。
# [28, 28] -> [28, 28, 3]
def change_numeral(img):
rgb = np.random.randint(low=0, high=255, size=(3,))
res = np.tile(img[:,:,None], (1,1,3))
for i in range(3):
res[:,:,i][res[:,:,i]>=127.5] = rgb[i]
res[:,:,i][res[:,:,i]<127.5] = 255
return res
原始图像 | 彩底白字 | 白底彩字 |
---|---|---|
![]() |
![]() |
![]() |
CycleGAN
论文参考:CycleGAN论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
代码实现参考: https://hardikbansal.github.io/CycleGANBlog/
概述
Pix2Pix之类的网络做Domain Transfer时需要两个Domain中对应的两个样本才能训练,而CycleGAN不需要这样的配对样本,也能实现转换。CycleGAN有两个步骤,首先将原Domain样本映射到目标Domain,然后再映射回来。映射到目标域的工作由Generator实现,Generator生成的样本质量由一个Discriminator判断。如果Generator的结构足够复杂,那么对抗训练总能够保证它生成的图片属于目标Domain,但不一定是我们所期望的。倘若Generator学到的是“周围彩色,中间白色”,那么它生成一些周围彩色中间一坨白色的无意义样本似乎也是合理的。我们所期望的是,生成的样本中能包含关于源图像的一些有用信息,利用这些信息可以恢复出原图像,那么Generator至少应该学着提取“周围彩色,中间白色,数字是8”这样的信息才行。
为了实现上面提到的功能,需要两对Generator和Discriminator,Generator输入图像输出图像,Discriminator输入图像输出真假判定,下面的两张图给出了样本的走向。
(图片出自 https://hardikbansal.github.io/CycleGANBlog/)


生成器设计
生成器Generator
接受图像输入,输出同等尺寸(通道数除外)的图像,例如输入黑白MNIST图像 ( 28 , 28 , 1 ) (28,28,1) (28,28,1),就应该输出尺寸为 ( 28 , 28 , 3 ) (28,28,3) (28,28,3)的彩色图像。Generator
要具备提取特征的能力(需要卷积层)、从特征生成图像的能力(需要解卷积层)和保持原图像基本不变只微调部分细节的能力(这里指变色,需要Res层)。
Generator
中的激活函数应用relu
。
Generator
的最后一层激活函数需依照真实数据样本分布决定,若为 ( 0 , 1 ) (0,1) (0,1)则用sigmoid
, ( − 1 , 1 ) (-1,1) (−1,1)则用tanh
。

def resnet_block(x, dim, is_training=True, name='resnet'):
with tf.variable_scope(name):
out = tf.pad(x, [[0,0],[1,1],[1,1],[0,0]], "REFLECT")
out = tf.nn.relu(bn(conv2d(out, dim, 3, 3, 1, 1, 0.02