生成对抗网络的TensorFlow初探
原创: 比昂 比昂日记 3月28日
之前介绍过生成对抗网络的初步原理,参见(生成对抗网络浅析(GAN))。
今天结合最近很火的TensorFlow,看看原理背后的实现。
01
模型
上一篇,参见(生成对抗网络浅析(GAN))定义了GAN模型的Model,
使用TFGAN我们组要定义4个重要属性
a. Generator, 在噪声的干扰下,生成Fake image;
b. Discriminator, 判定输入Training set,是Real,还是Fake;
c. 真实图片,Real Images;
d. Random noise;
Generator
def generator_fn(noise, weight_decay=2.5e-5, is_training=True):
"""G 生成MNIST图片的G网络.
Args:
noise: Tensor表征的噪音。
weight_decay: L2正则化 -- light weight decay。
is_training: 如果为“True”,批量规范使用批量统计。如果是'False`,批量规范使用从人口中收集的指数移动平均线统计.
Returns:
生成图像范围[-1, 1].
"""
with framework.arg_scope(
[layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
weights_regularizer=layers.l2_regularizer(weight_decay)), \
framework.arg_scope([layers.batch_norm], is_training=is_training,
zero_debias_moving_mean=True):
net = layers.fully_connected(noise, 1024)
net = layers.fully_connected(net, 7 * 7 * 256)
net = tf.reshape(net, [-1, 7, 7, 256])
net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
# Make sure that generator output is in the same range as `inputs`
# ie [-1, 1].
net = layers.conv2d(net, 1, 4, normalizer_fn=None, activation_fn=tf.tanh)
return net
Discriminator
def discriminator_fn(img, unused_conditioning, weight_decay=2.5e-5,
is_training=True):
"""D 使用MNIST数字的D网络.
Args:
img: 真实或生成的图片,范围 [-1, 1].
unused_conditioning: TFGAN API可以帮助处理条件GAN,这需要向生成器和鉴别器提供额外的“条件”信息。由于此示例不是有条件的,因此我们不使用此参数。
weight_decay: L2 正则化 weight decay。
is_training: 同G网络。
Returns:
记录图像真实概率。
"""
with framework.arg_scope(
[layers.conv2d, layers.fully_connected],
activation_fn=leaky_relu, normalizer_fn=None,
weights_regularizer=layers.l2_regularizer(weight_decay),
biases_regularizer=layers.l2_regularizer(weight_decay)):
net = layers.conv2d(img, 64, [4, 4], stride=2)
net = layers.conv2d(net, 128, [4, 4], stride=2)
net = layers.flatten(net)
with framework.arg_scope([layers.batch_norm], is_training=is_training):
net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)
return layers.linear(net, 1)
Real images, 使用mnist数据源作为real images输入。
with tf.device('/cpu:0'):
real_images, _, _ = data_provider.provide_data(
'train', batch_size, MNIST_DATA_DIR)
GANModel Tuple
gan_model = tfgan.gan_model(
generator_fn,
discriminator_fn,
real_data=real_images,
generator_inputs=tf.random_normal([batch_size, noise_dims]))
02
损失函数
损失函数(loss function)是用来估量模型的预测值f(x)与真实值Y的不一致程度。
其中,前面的均值函数表示的是经验风险函数,L代表的是损失函数,后面的Φ是正则化项(regularizer)或者叫惩罚项(penalty term),它可以是L1,也可以是L2,或者其他的正则函数。整个式子表示的意思是找到使目标函数最小时的θ值。
对于GAN, 论文中的的损失函数就是二元极大极小 -- minmax,
使用TF中的minmax损失函数
# 使用原始论问中的minmax损失函数。
vanilla_gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.minimax_generator_loss,
discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss)
同样也可以使用Wasserstein、Improved Wasserstein, 可参见论文https://arxiv.org/pdf/1701.07875.pdf
# 使用 Wasserstein loss , 参考(https://arxiv.org/abs/1701.07875)
# (https://arxiv.org/abs/1704.00028).
improved_wgan_loss = tfgan.gan_loss(
gan_model,
# We make the loss explicit for demonstration, even though the default is
# Wasserstein loss.
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
gradient_penalty_weight=1.0)
参考TF的实现
# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).
def wasserstein_generator_loss(
discriminator_gen_outputs,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Wasserstein generator loss for GANs.
See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.
Args:
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`discriminator_gen_outputs`, and must be broadcastable to
`discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or
the same as the corresponding dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add detailed summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with ops.name_scope(scope, 'generator_wasserstein_loss', (
discriminator_gen_outputs, weights)) as scope:
discriminator_gen_outputs = _to_float(discriminator_gen_outputs)
loss = - discriminator_gen_outputs
loss = losses.compute_weighted_loss(
loss, weights, scope, loss_collection, reduction)
if add_summaries:
summary.scalar('generator_wass_loss', loss)
return loss
自定义损失函数。
def silly_custom_generator_loss(gan_model, add_summaries=False):
return tf.reduce_mean(gan_model.discriminator_gen_outputs)
def silly_custom_discriminator_loss(gan_model, add_summaries=False):
return (tf.reduce_mean(gan_model.discriminator_gen_outputs) -
tf.reduce_mean(gan_model.discriminator_real_outputs))
03
训练&评估
训练
GAN的训练过程中,需要交替训练Generator和Discriminator网络,让Generator和Discriminator处于不断的优化和对抗中,正如论文算法的过程
过程相对比较简单,首先定义GANTrainOps的元组,然后设置优化参数
generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)
discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)
gan_train_ops = tfgan.gan_train_ops(
gan_model,
improved_wgan_loss,
generator_optimizer,
discriminator_optimizer)
评估
使用‘Inception Score’和’Frechet Inception distance‘, 来衡量生成image的分布和真实image的分布的近似情况。
num_images_to_eval = 500
MNIST_CLASSIFIER_FROZEN_GRAPH = './mnist/data/classify_mnist_graph_def.pb'
# 要加载变量,请使用与训练job相同的变量范围。
with tf.variable_scope('Generator', reuse=True):
eval_images = gan_model.generator_fn(
tf.random_normal([num_images_to_eval, noise_dims]),
is_training=False)
# 计算 Inception score.
eval_score = util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)
# 计算 Frechet Inception distance.
with tf.device('/cpu:0'):
real_images, _, _ = data_provider.provide_data(
'train', num_images_to_eval, MNIST_DATA_DIR)
frechet_distance = util.mnist_frechet_distance(
real_images, eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)
# 重绘eval图片
generated_data_to_visualize = tfgan.eval.image_reshaper(
eval_images[:20,...], num_cols=10)
训练过程和结果
TFGAN使用源于GAN minmax博弈的交替训练思路,可以更改G和D的更新比率。
train_step_fn = tfgan.get_sequential_train_steps()
global_step = tf.train.get_or_create_global_step()
loss_values, mnist_scores, frechet_distances = [], [], []
with tf.train.SingularMonitoredSession() as sess:
start_time = time.time()
for i in xrange(1601):
cur_loss, _ = train_step_fn(
sess, gan_train_ops, global_step, train_step_kwargs={})
loss_values.append((i, cur_loss))
if i % 200 == 0:
mnist_score, f_distance, digits_np = sess.run(
[eval_score, frechet_distance, generated_data_to_visualize])
mnist_scores.append((i, mnist_score))
frechet_distances.append((i, f_distance))
print('Current loss: %f' % cur_loss)
print('Current MNIST score: %f' % mnist_scores[-1][1])
print('Current Frechet distance: %f' % frechet_distances[-1][1])
visualize_training_generator(i, start_time, digits_np)
可以看到如论文中的演进曲线的变化(生成对抗网络浅析(GAN))
时间维度的评估指标变化如下
扩展阅读
参考:
https://github.com/tensorflow/models/tree/master/research/gan
https://arxiv.org/pdf/1701.07875.pdf
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/losses/python/losses_impl.py
https://blog.youkuaiyun.com/stalbo/article/details/79356739
https://zhuanlan.zhihu.com/p/44407513
http://www.csuldw.com/2016/03/26/2016-03-26-loss-function/
https://arxiv.org/pdf/1606.03498.pdf
THE END
- 晚安 -
图片长按2秒,识别图中二维码,关注订阅号
微信扫一扫
关注该公众号