github代码:https://github.com/tgeorgy/mgan
文章的创新点:
1.生成网络输入x,输出包括分割模板mask,和中间图像y,根据mask将输入x与中间图像y结合,得到生成图像.这样得到的生成图像背景与输入x相同,前景为生成部分.
2.采用端到端训练,在cyclegan损失函数的基础上,添加了对输出生成图像进行约束.
模型结构如下,
生成网络首先输出为分割模板mask,以及中间图像y,将中间图像y和mask混合,得到的输出作为最后的生成生成图像.生成网络代码如下,
class Generator(nn.Module):
def __init__(self, input_nc=3, output_nc=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6):
assert(n_blocks >= 0)
super(Generator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=1),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2)*4,
kernel_size=1, stride=1),
nn.PixelShuffle(2),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
self.model = nn.Sequential(*model)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
代码中,生成网络输入通道为3,输出通道为4,第一个通道为mask,其他三个通道为中间生成图像.
def forward(self, input):
output = self.model(input)
mask = F.sigmoid(output[:, :1])
oimg = output[:, 1:]
mask = mask.repeat(1, 3, 1, 1)
oimg = oimg*mask + input*(1-mask)
return oimg, mask
1
2
3
4
5
6
7
8
采用cyclegan结构,也就是,包含两个生成网络,两个判别网络.
对于每个生成网络,损失函数包括三个部分,第一个为loss_P2N_cyc ,与cyclegan loss相同,即输入到生成网络g1的输出,在输入生成网络g2,得到输出与输入尽量相同.第二个loss_P2N_gan为gan损失函数,也就是判别网络判断label为真.第三个为loss_N2P_idnt,也就是生成网路g1的输出与label尽量相似,也就是文章是end to end(输入-label对应)训练,由于cyclegan不是end to end,所以没有这个损失函数,
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_gan = nn.MSELoss()
1
2
3
# Train P2N Generator
real_pos_v = Variable(real_pos)
fake_neg, mask_neg = netP2N(real_pos_v)
rec_pos, _ = netN2P(fake_neg)
fake_neg_lbl = netDN(fake_neg)
loss_P2N_cyc = criterion_cycle(rec_pos, real_pos_v)
loss_P2N_gan = criterion_gan(fake_neg_lbl, Variable(real_lbl))
loss_N2P_idnt = criterion_identity(fake_neg, real_pos_v)
1
2
3
4
5
6
7
8
9
# Train N2P Generator
real_neg_v = Variable(real_neg)
fake_pos, mask_pos = netN2P(real_neg_v)
rec_neg, _ = netP2N(fake_pos)
fake_pos_lbl = netDP(fake_pos)
loss_N2P_cyc = criterion_cycle(rec_neg, real_neg_v)
loss_N2P_gan = criterion_gan(fake_pos_lbl, Variable(real_lbl))
loss_P2N_idnt = criterion_identity(fake_pos, real_neg_v)
loss_G = ((loss_P2N_gan + loss_N2P_gan)*0.5 +
(loss_P2N_cyc + loss_N2P_cyc)*lambda_cycle +
(loss_P2N_idnt + loss_N2P_idnt)*lambda_identity)
1
2
3
4
5
6
7
8
9
10
11
12
13
判别网络用于判别输入的真假,
# Train Discriminators
netDN.zero_grad()
netDP.zero_grad()
fake_neg_score = netDN(fake_neg.detach())
loss_D = criterion_gan(fake_neg_score, Variable(fake_lbl))
fake_pos_score = netDP(fake_pos.detach())
loss_D += criterion_gan(fake_pos_score, Variable(fake_lbl))
real_neg_score = netDN.forward(real_neg_v)
loss_D += criterion_gan(real_neg_score, Variable(real_lbl))
real_pos_score = netDP.forward(real_pos_v)
loss_D += criterion_gan(real_pos_score, Variable(real_lbl))
---------------------
作者:imperfect00
来源:优快云
原文:https://blog.youkuaiyun.com/u011961856/article/details/79057469
版权声明:本文为博主原创文章,转载请附上博文链接!
本文提出了一个域转换网络(domain transfer network,DTN),网络的作用是,对于给定两个域S,T,我们希望学习一个生成函数G,将S域的样本映射到域T,这样,对于一个给定函数f,不管f的输入为来自域S或T,f的输出会保持不变.
网络结构如下:
生成网络包括函数f,g.f用于提取输入图像的特征,得到一个特征向量.g的输入为f的输出,输出为目标风格的图像.训练数据为为无监督数据,即,原图像,目标图像不一一对应,分别采用原图像库,目标风格图像库,作为训练.对于原图像,输入生成网络G,输出风格图像.对于目标库的图像,输入生成网络G,输出还是该图像.
网络还包括一个判别网络D,判别网络的作用是判别输入为生成图像(fake),还是输入图像(real).
损失函数
1.对于生成网络,输入原图像,输出为目标风格的图像.同时我们还希望,当输入为目标图像时,生成网络输出也为目标图像,即生成网络对目标图像起到identity matrix的作用,这样构造损失函数LTIDLTID,
式中,x∈tx∈t表示图像x为目标图像,t为目标图像集合.
2.对与函数f,我们希望输入原图像提取的特征向量和原图像通过生成网络G生成的图像的f函数特征向量尽量相似,
式中,x∈sx∈s表示图像x为原图像,s为原图像集合.
3.判别网络D,用于判别原图像的生成图像,目标图像及目标图像的生成图像,用于判别是生成图像还是输入图像,损失函数为:
式中,D1D1用表示判别原图像经过生成网络G的生成图像.D2D2用于判别目标图像经过生成网络G的生成图像.D2D2用于判别目标图像.
4.对于生成网络,损失函数为:
LG=LGANG+αLCONST+βLTID+γLTVLG=LGANG+αLCONST+βLTID+γLTV
式中,B=1.
代码分析
生成网络的输入,输出为32×32×332×32×3的图像.
特征提取函数f部分网络结构包括4个卷积层,前3个卷积层卷积核为3×33×3,第4个卷积核大小为4×44×4,卷积核的stride=2.对于特征提取函数f,可以在其后加一个卷积层,对输入进行分类,例如对于手写字体,可以将其分为10类.可以对该网络进行分类任务训练,这样便起到了对网络进行预训练的作用.代码如下:
def content_extractor(self, images, reuse=False):
# images: (batch, 32, 32, 3) or (batch, 32, 32, 1)
if images.get_shape()[3] == 1:
# For mnist dataset, replicate the gray scale image 3 times.
images = tf.image.grayscale_to_rgb(images)
with tf.variable_scope('content_extractor', reuse=reuse):
with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train' or self.mode=='pretrain')):
net = slim.conv2d(images, 64, [3, 3], scope='conv1') # (batch_size, 16, 16, 64)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d(net, 128, [3, 3], scope='conv2') # (batch_size, 8, 8, 128)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d(net, 256, [3, 3], scope='conv3') # (batch_size, 4, 4, 256)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d(net, 128, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 128)
net = slim.batch_norm(net, activation_fn=tf.nn.tanh, scope='bn4')
if self.mode == 'pretrain':
net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out')
net = slim.flatten(net)
return net
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
函数g为f的输出特征向量进行解码,得到输出图像,相当于f网络的逆过程,也就是说g的网络结构为4个反卷积层,
def generator(self, inputs, reuse=False):
# inputs: (batch, 1, 1, 128)
with tf.variable_scope('generator', reuse=reuse):
with slim.arg_scope([slim.conv2d_transpose], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
net = slim.conv2d_transpose(inputs, 512, [4, 4], padding='VALID', scope='conv_transpose1') # (batch_size, 4, 4, 512)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d_transpose(net, 256, [3, 3], scope='conv_transpose2') # (batch_size, 8, 8, 256)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d_transpose(net, 128, [3, 3], scope='conv_transpose3') # (batch_size, 16, 16, 128)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d_transpose(net, 1, [3, 3], activation_fn=tf.nn.tanh, scope='conv_transpose4') # (batch_size, 32, 32, 1)
return net
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
判别网络也为4个卷积层,
def discriminator(self, images, reuse=False):
# images: (batch, 32, 32, 1)
with tf.variable_scope('discriminator', reuse=reuse):
with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None,
stride=2, weights_initializer=tf.contrib.layers.xavier_initializer()):
with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
activation_fn=tf.nn.relu, is_training=(self.mode=='train')):
net = slim.conv2d(images, 128, [3, 3], activation_fn=tf.nn.relu, scope='conv1') # (batch_size, 16, 16, 128)
net = slim.batch_norm(net, scope='bn1')
net = slim.conv2d(net, 256, [3, 3], scope='conv2') # (batch_size, 8, 8, 256)
net = slim.batch_norm(net, scope='bn2')
net = slim.conv2d(net, 512, [3, 3], scope='conv3') # (batch_size, 4, 4, 512)
net = slim.batch_norm(net, scope='bn3')
net = slim.conv2d(net, 1, [4, 4], padding='VALID', scope='conv4') # (batch_size, 1, 1, 1)
net = slim.flatten(net)
return net
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
输入为原图像src_images,目标图像trg_images:
self.src_images = tf.placeholder(tf.float32, [None, 32, 32, 3], 'svhn_images')
self.trg_images = tf.placeholder(tf.float32, [None, 32, 32, 1], 'mnist_images')
1
2
对于s域,将原图像输入f,g得到特征向量fx,生成图像fake_images,并将生成图像输入判别网络,
# source domain (svhn to mnist)
self.fx = self.content_extractor(self.src_images)
self.fake_images = self.generator(self.fx)
self.logits = self.discriminator(self.fake_images)
self.fgfx = self.content_extractor(self.fake_images, reuse=True)
# loss
self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits))
self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits))
self.f_loss_src = tf.reduce_mean(tf.square(self.fx - self.fgfx)) * 15.0
1
2
3
4
5
6
7
8
9
10
对于t域,将目标图像输入f,g,并将目标图像,生成图像分别输入判别网络,
# target domain (mnist)
self.fx = self.content_extractor(self.trg_images, reuse=True)
self.reconst_images = self.generator(self.fx, reuse=True)
self.logits_fake = self.discriminator(self.reconst_images, reuse=True)
self.logits_real = self.discriminator(self.trg_images, reuse=True)
# loss
self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake))
self.d_loss_real_trg = slim.losses.sigmoid_cross_entropy(self.logits_real, tf.ones_like(self.logits_real))
self.d_loss_trg = self.d_loss_fake_trg + self.d_loss_real_trg
self.g_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.ones_like(self.logits_fake))
self.g_loss_const_trg = tf.reduce_mean(tf.square(self.trg_images - self.reconst_images)) * 15.0
self.g_loss_trg = self.g_loss_fake_trg + self.g_loss_const_trg
1
2
3
4
5
6
7
8
9
10
11
12
13
试验结果
首先下载代码,
git clone https://github.com/yunjey/domain-transfer-network
下载训练数据:
cd domain-transfer-network/
./download.sh
1
2
3
将手写字体reize到32×3232×32:
python prepro.py
1
预训练:
python main.py --mode='pretrain'
1
训练:
python main.py --mode='train'
1
测试:
python main.py --mode='eval'
---------------------
作者:imperfect00
来源:优快云
原文:https://blog.youkuaiyun.com/u011961856/article/details/78606706
版权声明:本文为博主原创文章,转载请附上博文链接!