手撕VAE(变分自编码器) – Day4 – predict.py
目录
Variational Auto Encoder (VAE) 模型原理
VAE 网络结构图
和AE的区别
-
Part1:隐空间(latent)不是直接利用神经网络生成的,而是利用神经网络生成均值和方差来采样生成隐空间。
- 例如:原本是直接生成(1,10)的向量,但是现在是生成10个均值 u u u和10个标准差 σ \sigma σ(满足正态分布),然后通过采样生成10个向量值。
- 存在的问题:采样的操作是无法进行梯度反向传播的,因此需要利用重参数化进行反向传播,来进行均值和方差的参数更新。
-
Part2:重参数化技巧,利用正态分布之间的转换关系,来链接生成的均值,方差和生成的样本直接的关系。
- 例子:对于(0,1)的标准正态分布z,和( u u u, σ \sigma σ)的正态分布x,满足z= x − u σ \frac{x-u}{\sigma} σx−u,得到 x = u + σ z x=u+\sigma z x=u+σz。所以我们只需要随机生成一个标准正态分布z,乘上标准差,加上均值即可,这样就构建出了 u u u 和 σ \sigma σ 与x之间的关系,也就可以梯度反传了(非常巧妙的方法)。
-
Part3:损失函数添加了一个KL散度,从原来的重建损失,添加了一个分布损失,用来保证生成隐空间的分布,满足标准正态分布,因此该损失主要是衡量生成分布和标准正态分布之间的差距。
-
VAE(变分自编码器)的损失函数由两部分组成:重构损失 和 KL 散度。
1. 重构损失(Reconstruction Loss)
重构损失衡量了生成的数据 ( \hat{x} ) 与输入数据 ( x ) 之间的差异,通常使用 均方误差(MSE) 或 交叉熵 来衡量。对于连续数据,常用 MSE,对于离散数据,常用交叉熵。
L reconstruction = E q ( z ∣ x ) [ ∥ x − x ^ ∥ 2 ] \mathcal{L}_{\text{reconstruction}} = \mathbb{E}_{q(z|x)}[\| x - \hat{x} \|^2] Lreconstruction=Eq(z∣x)[∥x−x^∥2]
这里:- ( x ^ \hat{x} x^ = p(x|z) ) 是解码器生成的样本。
- ( x ) 是原始输入数据。
- ( ∥ x − x ^ ∥ 2 \| x - \hat{x} \|^2 ∥x−x^∥2 ) 是输入数据与生成数据之间的均方误差。
2. KL 散度(KL Divergence)
KL 散度项衡量了潜在空间的分布 ( q(z|x) ) 与标准正态分布 ( p(z) ) 之间的差异。通过最小化 KL 散度,我们希望潜在空间的分布接近标准正态分布 ( \mathcal{N}(0, I) ),从而避免过拟合,并保持潜在空间的结构。
KL 散度的公式为:
L KL = D KL [ q ( z ∣ x ) ∣ ∣ p ( z ) ] \mathcal{L}_{\text{KL}} = D_{\text{KL}}[q(z|x) || p(z)] LKL=DKL[q(z∣x)∣∣p(z)]
-
对于每个潜在维度 ( z_i ),KL 散度可以表示为:
D
KL
[
N
(
μ
(
x
)
,
σ
(
x
)
2
)
∣
∣
N
(
0
,
1
)
]
=
1
2
(
μ
i
2
(
x
)
+
σ
i
2
(
x
)
−
1
−
log
σ
i
2
(
x
)
)
D_{\text{KL}}[\mathcal{N}(\mu(x), \sigma(x)^2) || \mathcal{N}(0, 1)] = \frac{1}{2} \left( \mu_i^2(x) + \sigma_i^2(x) - 1 - \log \sigma_i^2(x) \right)
DKL[N(μ(x),σ(x)2)∣∣N(0,1)]=21(μi2(x)+σi2(x)−1−logσi2(x))
这里:
-
μ
(
x
)
\mu(x)
μ(x)和
σ
(
x
)
\sigma(x)
σ(x) 分别是编码器网络输出的潜在变量的均值和标准差。
-
N
(
0
,
1
)
\mathcal{N}(0, 1)
N(0,1) 是标准正态分布。
VAE 测试代码 - 变分自编码器
VAE 测试代码
Part1 引入相关库函数
# 该模块主要是为了预测推理的,输入一个图像得到一个浅层或者输入浅层得到一个图像
'''
# Part1 引入相关的模型
'''
import torch
from dataset import Mnist_dataset
import matplotlib.pyplot as plt
Part2 初始化 VAE 的模型
'''
# part2 下载模型
'''
net = torch.load('VAE_encoder_eopch_20.pt')
net.eval()
net1 = torch.load('VAE_decoder_eopch_20.pt')
net1.eval()
data_cs = Mnist_dataset
Part3 测试
'''
# Part3 开始测试
'''
if __name__ == '__main__':
with torch.no_grad():
img, label = data_cs[2]
# 开始绘制初始的图像
print('已经绘制了初始图像')
plt.imshow(img.permute(2, 1, 0))
plt.show()
img=img.unsqueeze(0)
mid_latent_predict = net(img.view(1,-1))
mu, sigma = torch.chunk(mid_latent_predict, chunks=2, dim=1)
# 重参数化
latent = mu + sigma * torch.randn_like(sigma) # 利用均值和标准差,生成对应的latent取样值
# 生成中间1*10的浅层,生成对应的图像
latent_cs=latent
print('生成的编码向量为:{}'.format(latent_cs))
img_predict=net1(latent_cs)
result = img_predict.view(img.size()[1],img.size()[2],img.size()[3])
print(result.size())
# 开始绘制结果图像
print('已经绘制了结果图像')
plt.imshow(result.permute(2,1,0))
plt.show()
参考
视频讲解:DALL·E 2(内含扩散模型介绍)【论文精读】_哔哩哔哩_bilibili