一、GAN存在的问题
在这之前,我们实现了GAN以及DCGAN。从训练效果看。都能生成相应的数据,但是GAN的训练过程中会出现很多问题。主要是训练的不稳定。在理论上,我们是先把判别器训练好,再去训练生成器。但是当判别器训练的越好。生成器反而越难优化。接下来,我们先分析一下GAN的问题。
真实数据的分布通常是一个低维度流形。流形是指数据虽然在高纬度空间里,但实际上并不具备高纬度特性,而是存在于一个嵌入在高纬度的低纬度空间里。生成器的工作是把一个低纬度的空间Z映射到与真实数据相同的高纬度空间上,我们希望的是能够把我们生成的低维度流形尽可能的逼近真实数据的流形。
如果真会数据与生成数据在空间上完全不相交的话,就可以得到一个判别器来完美划分真实数据与生成数据。在实际实践中,生成数据和真实数据在空间中完美重合的概率是非常低的。所以大部分情况都会得到一个完美的判别器。因此。在网络的反向传播中梯度更新几乎为零。也就是说当判别器D接近完美判别器的时候,生成器优化的梯度会有一个非常小的上界,并无限接近于0.公式如下:
当D能很好的区分生成数据时,生成器的极限趋近于0。梯度更新几乎为0。并且当真实数据分布与生成数据分布
。如果无法全维度重合的话。则KL散度的值为无穷大,JSD散度的值log2。有的时候即使结果非常好。他们的值依然是这样的结果。这样是不利于训练的。也就是说采用这些公式来计算两者的相似度似乎不是一个好得主意。
为了解决这个问题,有一个办法是换一个不同的梯度函数。如下公式:
经过试验发现这个梯度函数会导致网咯更新不稳定。在训练的过程中,随着迭代次数的上升,梯度上升非常快,同时曲线的噪声也在变大。也就是说梯度的方差在变大。导致图像质量低。
另外一个方法是对判别器的输入假如一个随机的噪声。当真实数据分布与生成数据分布很接近的时候,加入了随机噪声可以使得两者的流形能够有更多的几率重合。但是这样也存在问题。当生成数据与真实数据本身相似度距离较远的话。添加噪声的方案可能就无效了。因此。我们接下来讨论一个更好的方案--WGAN。
二、WGAN的理论研究
设定真实数据分布和生成数据分布
。我们先来看一下几种分布距离公式。都是用来描述两个分布之间的相似度的。
总变差距离:数学含义是指与
在区间范围内数值变化的差值的综合。
KL散度:是非对称的。
JSD散度:,是对称的。
Wasserstein距离:也称作EM距离。公式如下:其中是指真实数据与生成数据的联合概率分布。该距离又称推土机距离,意思是相当于推土机把一堆土搬到另一堆的最下成本。
假设有一个二维空间,假设真实数据的分布式X轴为0、Y轴为随机变量的分布。生成数据的分布为X轴为0,Y轴也为随机变量的分布,由此可以得到四个公式的结果:
从上面的四个距离公式可以看到。当逼近零的过程中,只有W距离公式在减小,而其他几种距离公式都是一个固定的值或者是无穷大。所以,其他距离公式无法优化整个网络,而EM距离则具备了一个连续可用的梯度。
另外,假如生成器满足Lipschiiz条件的话。可以推导出EM距离处处连续且可导。Lipschitz条件是指函数的导数始终小于某个固定的藏书K。当K=1时称为1-Lipschitz。我们进一步改写公式,其含义是对于真实数据分布中的输入x与生成数据分布的输入x,求它们分别对于所有满足1-Lipschitz条件的函数f(x)的期望值差值的上确界。加入1-Lipschitz条件是为了保证f(x)的梯度变化不会过大,从而使得网络能够保持正常的梯度优化。
我们令函数f(x)满足参数化条件,使得所有函数在Lipschitz条件上成立,这样继续把公式改写
为了将其参数化,我们使用权值裁剪,将权重的范围限制在[-c,c]之间。最后,网络要做的事情是通过判别器的梯度来优化网络参数,让生成数据尽可能靠近真实数据分布。
三、WGAN的代码实现
1. 导包
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
# 动态申请显存
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
2. 初始化
class WGAN():
# 初始化信息
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# 优化参数
self.n_critic = 5
self.clip_value = 0.01
optimizer = RMSprop(lr=0.00005)
# 建造编译判别器
self.critic = self.build_critic()
self.critic.compile(loss=self.wassertein_loss,
optimizer=optimizer,
metrics=['accuracy'])
# 建造生成器
self.generator = self.build_generator()
# 输入噪音得到图像
z = Input(shape=(100,))
img = self.generator(z)
# 到此仅训练判别器
self.critic.trainable = False
# 获得判别结果
valid = self.critic(img)
# 编译模型(生成器和判别器的堆叠)
self.combined = Model(z, valid)
self.combined.compile(loss=self.wassertein_loss,
optimizer=optimizer,
metrics=['accuracy'])
3. EM损失
def wassertein_loss(selfs, y_true, y_pred):
return K.mean(y_true * y_pred)
4. 构建生成器
def build_generator(self):
model = Sequential()
model.add(Dense(128 * 7 * 7, activation='relu', input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation('relu'))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=4, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation('relu'))
model.add(Conv2D(self.channels, kernel_size=4, padding='same'))
model.add(Activation('tanh'))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
5. 构建判别器
def build_critic(self):
model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding='same'))
model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=1, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
6. 训练
def train(self, epochs, batch_size=128, sample_interval=50):
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# 真实值
valid = -np.ones((batch_size, 1))
fake = np.ones((batch_size, 1))
for epoch in range(epochs):
for _ in range(self.n_critic):
# 选择随机批度图像
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# 噪声作为生成器输入
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# 生成新图像
gen_imgs = self.generator.predict(noise)
# 训练判别器
d_loss_real = self.critic.train_on_batch(imgs, valid)
d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
# 权重消减
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)
g_loss = self.combined.train_on_batch(noise, valid)
print("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss[0]))
if epoch % sample_interval == 0:
self.sample_images(epoch)
7. 展示数据
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 1
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig("images/mnist_%d.png" % epoch)
plt.close()
8. 运行代码
if __name__ == '__main__':
wgan = WGAN()
wgan.train(epochs=4000, batch_size=32, sample_interval=50)