https://arxiv.org/pdf/1711.00937v2.pdf(论文下载链接)
之所以将VQ-VAE(Vector Quantised Variational AutoEncoder)论文,主要是为讲解后面两篇论文做准备,VQ-VAE不管是视频还是博客,都有人在讲解,但是这里也做一个总结,以衔接后面的两篇论文讲解,关于VAE(Variational AutoEncoder)相关的论文比较多,并且其中涉及的数学原理以及推导也比较多,导致我们在阅读VAE方法的时候可能存在较多的困惑,自己在看的过程中也遇到了较多的困惑,但是还是准备做一个总结。
目录

现有方法存在的局限性
1. 连续潜在变量的局限性
-
问题:传统变分自编码器(VAE)使用连续潜在变量,但许多模态(如语言、语音)本质上是离散的。
-
局限性:
-
建模不匹配:连续潜在变量难以有效表示离散数据结构(如语言符号、语音音素)。
-
后验坍塌(Posterior Collapse):当解码器过于强大时,潜在变量容易被忽略,导致模型退化为单纯的自回归解码器。
-
2. 离散潜在变量训练的挑战
-
梯度估计困难:离散变量的训练依赖高方差梯度估计方法(如NVIL、VIMCO),收敛速度慢且稳定性差。
-
规模限制:现有方法多在小规模数据集(如MNIST)和低维潜在空间(维度<8)上验证,难以扩展到复杂数据(如ImageNet、语音波形)。
3. 自回归模型的效率问题
-
计算成本高:PixelCNN、WaveNet等自回归模型逐像素或逐样本生成,效率低下。
-
全局结构建模弱:自回归模型更关注局部统计特征,难以捕获图像的全局语义结构。
一.提出目的和方法
1.提出目的
传统的VAE(变分自编码器)在隐空间中使用连续分布,导致生成的隐变量难以进行有效的离散化表示(如用于序列建模或强化学习)。
2.提出方法
VQ-VAE提出了一种离散隐变量的自编码方法,通过向量量化(Vector Quantization, VQ) 实现隐空间的离散化,从而提升表征的可解释性和生成质量。具体方法:编码器网络输出离散而非连续代码;且先验分布是动态学习而非静态预设。为学习离散潜在表征,融入了向量量化(VQ)的核心思想。采用VQ方法使模型能够规避VAE框架中常见的"后验坍塌"问题(即当潜在变量与强大的自回归解码器结合时被忽略)
二.VQ-VAE贡献点
- 提出VQ-VAE模型:该模型结构简单,采用离散潜在变量,既不会出现"后验坍塌"问题,也不存在方差异常;
与本研究最相关的工作当属变分自编码器(VAEs)。VAE包含以下核心组件:
1)编码器网络:用于参数化离散潜在随机变量z的后验分布q(z|x),其中x为输入数据;
2)先验分布p(z);
3)解码器网络:建立输入数据条件分布p(x|z)。
传统VAE通常假设后验分布与先验分布均为对角协方差的正态分布,这种设定可利用高斯重参数化技巧。现有扩展方法包括:
本研究提出的VQ-VAE创新性地采用离散潜在变量,其训练方法受向量量化(VQ)启发。该模型中:
1)后验与先验分布均为类别分布;
2)从分布中采样的离散值作为嵌入表的索引;
3)检索到的嵌入向量将作为解码器网络的输入。
三.VQ-VAE具体方法
1.离散化隐藏向量

注:也就是这里计算嵌入空间和编码器输出向量Ze(x)之间的距离,找到最小距离的索引K,然后下面将其转换为one-hot编码格式。

转换为one-hot编码格式之后。通过为1的位置获得对应的离散向量。 为了更好的理解这个过程,使用下面的图来给大家表示一下:(建议结合代码看)



四.VQ-VAE学习方式

虽然方程2没有明确定义的梯度,但本文采用类似于直通估计器(straight-through estimator)的方法来近似梯度,即直接将解码器输入 zq(x) 的梯度复制到编码器输出 ze(x)。
# TODO gradient copy trick (Add the residue back to the latents)
quantized_latents = latents + (quantized_latents - latents).detach()
计算量化误差(zq − ze),但通过 .detach() 断开梯度回传,确保这部分不会影响编码器的梯度。
- 前向传播时,等价于直接返回 zq(因为 ze+(zq−ze)=zq)。
- 反向传播时,由于右侧项被
detach(),梯度会直接通过左侧的latents(即 ze)回传,相当于将 zq 的梯度复制给了 ze。



核心代码实现
class VectorQuantizer(nn.Module):
"""
Reference:
[1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
beta: float = 0.25):
super(VectorQuantizer, self).__init__()
self.K = num_embeddings
self.D = embedding_dim
self.beta = beta
#TODO 定义的嵌入向量e
self.embedding = nn.Embedding(self.K, self.D)
self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
def forward(self, latents: Tensor) -> Tensor:
#TODO 编码器输出的编码向量
latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D]
latents_shape = latents.shape
flat_latents = latents.view(-1, self.D) # [BHW x D]
# TODO 计算隐藏向量和嵌入向量权重之间的L2距离 Compute L2 distance between latents and embedding weights
dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - \
2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K]
# TODO 获得最小距离对应的索引 Get the encoding that has the min distance
encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1]
# TODO 将其索引转换为对应的one-hot编码 Convert to one-hot encodings
device = latents.device
encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K]
#TODO 获得离散化隐藏向量空间 Quantize the latents
quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D]
quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D]
# TODO Compute the VQ Losses
commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
embedding_loss = F.mse_loss(quantized_latents, latents.detach())
vq_loss = commitment_loss * self.beta + embedding_loss
# Add the residue back to the latents
quantized_latents = latents + (quantized_latents - latents).detach()
return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W]
五.实验比较





悄悄举手:若觉得文章有用,不妨留下一个小赞?(´▽`ʃƪ)

918

被折叠的 条评论
为什么被折叠?



