手撕VAE(变分自编码器) – Day1 – dataset
目录
Variational Auto Encoder (VAE) 模型原理
VAE 网络结构图
和AE的区别
-
Part1:隐空间(latent)不是直接利用神经网络生成的,而是利用神经网络生成均值和方差来采样生成隐空间。
- 例如:原本是直接生成(1,10)的向量,但是现在是生成10个均值 u u u和10个标准差 σ \sigma σ(满足正态分布),然后通过采样生成10个向量值。
- 存在的问题:采样的操作是无法进行梯度反向传播的,因此需要利用重参数化进行反向传播,来进行均值和方差的参数更新。
-
Part2:重参数化技巧:利用正态分布之间的转换关系,来链接生成的均值,方差和生成的样本直接的关系。
- 例子:对于(0,1)的标准正态分布z,和( u u u, σ \sigma σ)的正态分布x,满足z= x − u σ \frac{x-u}{\sigma} σx−u,得到 x = u + σ z x=u+\sigma z x=u+σz。所以我们只需要随机生成一个标准正态分布z,乘上标准差,加上均值即可,这样就构建出了 u u u 和 σ \sigma σ 与x之间的关系,也就可以梯度反传了(非常巧妙的方法)。
-
Part3:损失函数添加了一个KL散度,从原来的重建损失,添加了一个分布损失,用来保证生成隐空间的分布,满足标准正态分布,因此该损失主要是衡量生成分布和标准正态分布之间的差距。
-
VAE(变分自编码器)的损失函数由两部分组成:重构损失 和 KL 散度。
1. 重构损失(Reconstruction Loss)
重构损失衡量了生成的数据 ( \hat{x} ) 与输入数据 ( x ) 之间的差异,通常使用 均方误差(MSE) 或 交叉熵 来衡量。对于连续数据,常用 MSE,对于离散数据,常用交叉熵。
L reconstruction = E q ( z ∣ x ) [ ∥ x − x ^ ∥ 2 ] \mathcal{L}_{\text{reconstruction}} = \mathbb{E}_{q(z|x)}[\| x - \hat{x} \|^2] Lreconstruction=Eq(z∣x)[∥x−x^∥2]
这里:- ( x ^ \hat{x} x^ = p(x|z) ) 是解码器生成的样本。
- ( x ) 是原始输入数据。
- ( ∥ x − x ^ ∥ 2 \| x - \hat{x} \|^2 ∥x−x^∥2 ) 是输入数据与生成数据之间的均方误差。
2. KL 散度(KL Divergence)
KL 散度项衡量了潜在空间的分布 ( q(z|x) ) 与标准正态分布 ( p(z) ) 之间的差异。通过最小化 KL 散度,我们希望潜在空间的分布接近标准正态分布 ( \mathcal{N}(0, I) ),从而避免过拟合,并保持潜在空间的结构。
KL 散度的公式为:
L KL = D KL [ q ( z ∣ x ) ∣ ∣ p ( z ) ] \mathcal{L}_{\text{KL}} = D_{\text{KL}}[q(z|x) || p(z)] LKL=DKL[q(z∣x)∣∣p(z)]对于每个潜在维度 ( z_i ),KL 散度可以表示为:
D KL [ N ( μ ( x ) , σ ( x ) 2 ) ∣ ∣ N ( 0 , 1 ) ] = 1 2 ( μ i 2 ( x ) + σ i 2 ( x ) − 1 − log σ i 2 ( x ) ) D_{\text{KL}}[\mathcal{N}(\mu(x), \sigma(x)^2) || \mathcal{N}(0, 1)] = \frac{1}{2} \left( \mu_i^2(x) + \sigma_i^2(x) - 1 - \log \sigma_i^2(x) \right) DKL[N(μ(x),σ(x)2)∣∣N(0,1)]=21(μi2(x)+σi2(x)−1−logσi2(x))这里:
- μ ( x ) \mu(x) μ(x)和 σ ( x ) \sigma(x) σ(x) 分别是编码器网络输出的潜在变量的均值和标准差。
- N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1) 是标准正态分布。
-
Dataset代码 - 利用Mnist数据集做自编码器
Dataset代码
Part1 库函数
# 该模块主要是为了实现数据集的下载,主要是Mnist数据集
'''
# Part1 引入相关的库函数
'''
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
Part2 初始化一个数据集的类
'''
# Part2 获取数据集的转换操作,以及数据集的获取
'''
transform_action=transforms.Compose([
transforms.ToTensor()
])
Mnist_dataset=torchvision.datasets.MNIST(root='./Mnist',train=True,transform=transform_action,download=True)
Part3 测试
'''
# 开始测试
'''
if __name__=='__main__':
imag,label=Mnist_dataset[0]
plt.figure(figsize=(45,45))
plt.imshow(imag.permute(2,1,0))
plt.show()
参考
视频讲解:DALL·E 2(内含扩散模型介绍)【论文精读】_哔哩哔哩_bilibili