本系列已完结,全部文章地址为:
1 VAE的改进——数学上的解释
本节从数学角度说明VAE的改进,并说明与前一篇直觉上的解释是统一的。
1.1 VAE的目标:最大化似然函数
VAE的目标是学习到样本生成的概率分布,这是生成模型的一个共同目标(详细说明可参考笔者对GAN的学习笔记博文)。下文想说明的一个故事是,在求解这个概率分布过程中,使用神经网络解决了推导的困难,达成了目标。
假设要估计的概率分布是P(x),其似然函数表示为:
目标是最大化此函数。
1.2 引入潜变量
其中
引入z这一潜变量,认为生成样本的过程是先通过概率分布生成z。z是一个向量,维度是自己决定的,均值也是向量,方差是矩阵。每个z都根据一个函数(用神经网络来拟合)得到对应的均值和标准差,然后再由这个概率分布生成样本x。z实际上就是编码,这个函数其实就是Decoder。
此处假设z服从高斯分布是合理的,这是因为:第一,神经网络可以拟合出各种函数,因此无论什么分布都不太影响结果;第二,假设z的某个维度代表某种特征,这个特征服从正态分布表示极端情况少,普通情况多,也是合理的。另外也假设不同维度的z之间是独立的。
与高斯混合分布的关系如下
高斯混合分布模型(GMM)是将若干高斯分布加权组合在一起。先根据m的概率分布生成1个m,每个m对应一种高斯分布,再由对应的高斯分布产生样本x。用分布产生样本,可以提高生成的多样性。GMM中的m只能是有限的簇(cluster),而VAE下m是由分布得到的,因此有无限多个。
是不好直接求的,从公式字面意思看,其包含两部分,一是z的分布,可以假设是一个正态分布;二是条件概率分布,给定z下得到样本数据的概率最大。如果基于这个式子设计一个神经网络,x甚至都不用输入到网络中,仅凭随机的z得到输出却能和样本尽量一致,这听起来就不现实。
1.3 最大化下界
由于P(x)不好直接求,因此对其变形。
其中q(z|x)对于任意的分布都成立,因为将z做sum之后q(z|x)=1
可以看到,推导后右式是KL散度,是一定大于0的,因此左式是一个Lower Bound,记为Lb,一定小于似然函数。现在将优化目标变成最大化Lb。
可以看到,前文将目标表示成,希望找到
使得似然函数最大,此处将问题转化成了找到
和
使得似然函数最大,多了一项q(z|x)。
最大化Lb增加与最大化似然函数,解析如下:
q(z|x)与logP(x)是无关的,因此调整q(z|x)对于logP(x)是没有影响的。那么在最大化Lb的过程中,q(z|x)提高了Lb必然降低KL,这样logP(x)与Lb就会更加接近,因此最大化Lb就等价于最大化logP(x)。优化到最后,KL减小,q(z|x)会接近p(z|x)
继续变形
证明
这里假设i=1,即只有1个维度
括号外实际就是正态分布的概率密度,括号内第一项不含x,因此可以提出,第二项表示
二阶矩,
,第三项表示
1.4 优化目标的两部分
因此VAE的优化目标包含两部分:
上式化简后第一部分代表Encoder产生的分布应尽可能接近标准正态分布,避免标准差接近0而退化成AE。该结果与上一篇博文中描述一致。
上式化简后第二部分表示reconstruction error,即重建损失。因为在给定q(z|x)的情况下P(x|z)尽可能高,这就表示先用Encoder根据输入x得到输出z,然后利用Decoder根据z得到一个正态分布的均值和方差,为了让其产生x的几率最大,均值应该与x接近,标准差越小越好,实践中只让Decoder产生均值,不会再让其产生标准差了。
2 思维流程图
将上述推导总结如下
3 示例代码
代码来自卷积变分自编码器 | TensorFlow Core
这里使用的是卷积VAE,思路与VAE是一致的,只是Encoder和Decoder中的运算含有卷积和反卷积。
下面对关键代码做解析
3.1 z的维度
self.latent_dim = latent_dim
这里定义了z的维度,定义为2。
之所以定义为2,是因为后文遍历所有潜变量传入decoder查看效果。
分别遍历x和y,2个变量即可组成编码层,这样x和y组成横纵坐标,坐标上的画出decoder产生的结果,即可观察潜变量对应的图片。如果latent_dim是3,就是一个三维的图像.如果latent_dim再多,就不太好可视化分析了。
3.2 重参数化
def reparameterize(self, mean, logvar):
"""
重参数化
encoder输出mean和logvar,运算后传入decoder
epsilon * e^(1/2 * logvar) + mean
"""
eps = tf.random.normal(shape=mean.shape)
return eps * tf.exp(logvar * .5) + mean
重参数化,与博文的分析一致,等价于z=mu+eps*sigma。详见上一篇博文分析。
3.3 Encoder
self.encoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(shape=(28, 28, 1)),
tf.keras.layers.Conv2D(...
# No activation
tf.keras.layers.Dense(latent_dim + latent_dim),
]
)
Encoder的组成,输入是28*28,使用了卷积核。注意No activation,因为直接训练出对数方差,不需要再加入激活函数了。
3.4 Encoder输出均值和方差
def encode(self, x):
"""
encoder产生的结果拆分成2部分:mean和logvar
"""
mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
return mean, logvar
encode操作,是将输出分成2部分,分别是均值和对数方差
3.5 Decoder
Decoder是由反卷积组成,因为要从低维向高维映射。
self.decoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(shape=(latent_dim,)),
tf.keras.layers.Dense(units=7 * 7 * 32, activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(7, 7, 32)),
tf.keras.layers.Conv2DTranspose(...),
]
)
Decoder的组成,注意其中用了3个反卷积核,stride分别为2、2、1,这样就将边长为7的图像扩充为7*2*2*1=28,与输入一致。
3.6 损失函数
def compute_loss(model: CVAE, x):
"""
计算loss
除了交叉熵代表reconstruction error
还计算了分布与标准正态分布的差别,这里使用概率分布计算的
"""
mean, logvar = model.encode(x)
z = model.reparameterize(mean, logvar)
x_logit = model.decode(z)
cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
logpz = log_normal_pdf(z, 0., 0.)
logqz_x = log_normal_pdf(z, mean, logvar)
return -tf.reduce_mean(logpx_z + logpz - logqz_x)
损失函数包括2部分,重建损失以及分布与N(0, 1)的差别。这里没有用KL散度,而是用两个分布概率密度函数的对数作差,相当于KL散度中只取了log的部分。log_normal_pdf函数表示正态分布概率密度取对数。
3.7 训练
@tf.function
def train_step(model, x, optimizer):
"""
单步执行,更新权重。单步执行是为了查看每一步产生的图像变化,观察结果是否越来越清晰
"""
"""Executes one training step and returns the loss.
This function computes the loss and gradients, and uses the latter to
update the model's parameters.
"""
with tf.GradientTape() as tape:
loss = compute_loss(model, x)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
单步执行更新权重。单步执行是为了查看每一步产生的图像变化,观察结果是否越来越清晰
3.8 产生图片
训练10次,每次都调用sample函数,生成图片并保存,并将10张图片组合成一个gif。
第一次训练得到的图片是
最后一次训练得到的图片是
可以发现确实更清晰。
3.9 遍历编码层绘制图片
def plot_latent_images(model: CVAE, n, digit_size=28):
...
得到的结果为
还是能看出来一些规律的,笔者猜测越靠左下圆圈的特征越不明显,越靠右上圆圈特征越明显。总之编码层学习到了一些特征。
参考资料
李宏毅VAE课程