手撕VQVAE(向量量化变分自编码器) – Day1 – dataset
目录
Vector Quantized Variational AutoEncoder (VQVAE) 模型原理
VQVAE 网络结构图
和AE以及VAE的区别
-
Part1:隐空间(latent)不是直接利用神经网络生成的,也不是是利用神经网络生成均值和方差来采样生成隐空间,而是先利用编码器进行图像编码,得到隐藏间的向量(用于查询),然后到一个向量字典(图中的codebook)中进行查询与其最接近的向量,作为最终的编码向量。
- 例如:原本是直接生成(1,10)的向量,或者是生成10个均值 u u u和10个标准差 σ \sigma σ(满足正态分布),然后通过采样生成10个向量值,但是现在是利用生成的10维向量,到字典中查询到最接近的那个10维向量,作为目标的隐空间编码向量。
- 存在的问题:查询,并取最接近的向量的操作是无法进行梯度反向传播的,因此需要利用直通量化操作(Straight-Through Estimator (STE))进行反向传播,来进行编码器的参数更新。
-
Part2:直通量化技巧:设编码器编码得到的隐空间向量为encode_z,查询得到的最接近的隐空间向量为z_q,那么利用公式如下,即可完成梯度的反向传播,更新编码器参数。
-
z q = e n c o d e z + ( z q − e n c o d e z ) . d e t a c h ( ) z_q=encode_z+(z_q-encode_z).detach() zq=encodez+(zq−encodez).detach()
-
这里将得到的 z q z_q zq作为编码器的输入,那么这里就将后续解码器的梯度利用这个公式以梯度为1,传递到编码器。
-
-
Part3:损失函数去掉了KL散度(因为没有正态分布的采样了,因此VQVAE其实也失去了生成的能力,主要起编码的作用),这里添加了两个损失,分别为量化损失和承诺损失,两者的共同点是都是衡量编码器和查询后得到的向量的距离,差别是更新谁的参数,量化损失更新查询字典的参数,承诺损失更新编码器的参数。(个人记法就是,承诺,意味金标准,因此是以查询的那个字典作为标准来更新编码器的参数,而量化则意味着图像量化为向量作为标准,那就更新字典的参数)
-
VAE(变分自编码器)的损失函数由两部分组成:重构损失 和 KL 散度。
1. 重构损失(Reconstruction Loss)
重构损失衡量了生成的数据 ( \hat{x} ) 与输入数据 ( x ) 之间的差异,通常使用 均方误差(MSE) 或 交叉熵 来衡量。对于连续数据,常用 MSE,对于离散数据,常用交叉熵。
L reconstruction = E q ( z ∣ x ) [ ∥ x − x ^ ∥ 2 ] \mathcal{L}_{\text{reconstruction}} = \mathbb{E}_{q(z|x)}[\| x - \hat{x} \|^2] Lreconstruction=Eq(z∣x)[∥x−x^∥2]
这里:- ( x ^ \hat{x} x^ = p(x|z) ) 是解码器生成的样本。
- ( x ) 是原始输入数据。
- ( ∥ x − x ^ ∥ 2 \| x - \hat{x} \|^2 ∥x−x^∥2 ) 是输入数据与生成数据之间的均方误差。
2. 向量量化损失(Vector Quantization Loss)- 更新字典参数
-
向量量化损失用于衡量编码器输出的特征与字典向量之间的差异。其目标是最小化每个编码的特征与最接近的字典向量之间的欧几里得距离。该损失由以下公式表示:
L v q = ∥ sg ( e ( x ) ) − z e ∥ 2 L_{vq} = \| \text{sg}(e(x)) - z_e \|^2 Lvq=∥sg(e(x))−ze∥2
这里:- e ( x ) e(x) e(x) 是编码器输出的特征(编码后的表示)
- z e z_e ze 是量化后的向量
- ( sg ( ⋅ ) \text{sg}(\cdot) sg(⋅) ) 表示 stop gradient 操作,确保梯度不通过量化操作传播。
3. 承诺损失(Commitment Loss)- 更新编码器参数
-
承诺损失用于鼓励编码器的输出尽量接近其最近的字典向量。其目标是最小化编码器输出与其最接近字典向量之间的差距。该损失由以下公式表示:
L c o m m i t = β ∥ e ( x ) − s g ( z e ) ∥ 2 L_{commit} = \beta \| e(x) - sg(z_e) \|^2 Lcommit=β∥e(x)−sg(ze)∥2
这里:- β \beta β 是一个超参数,通常为一个较小的值,控制承诺损失的权重,
- e ( x ) e(x) e(x) 是编码器输出的特征,
- z e z_e ze 是量化后的字典向量。
-
Dataset代码 - 利用Mnist数据集做向量量化变分自编码器
Dataset代码
Part1 库函数
# 该模块主要是为了实现数据集的下载,主要是Mnist数据集
'''
# Part1 引入相关的库函数
'''
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
Part2 初始化一个数据集的类
'''
# Part2 获取数据集的转换操作,以及数据集的获取
'''
transform_action=transforms.Compose([
transforms.ToTensor()
])
Mnist_dataset=torchvision.datasets.MNIST(root='./Mnist',train=True,transform=transform_action,download=True)
Part3 测试
'''
# 开始测试
'''
if __name__=='__main__':
imag,label=Mnist_dataset[0]
plt.figure(figsize=(45,45))
plt.imshow(imag.permute(2,1,0))
plt.show()
参考
视频讲解:DALL·E 2(内含扩散模型介绍)【论文精读】_哔哩哔哩_bilibili