VAE的损失函数的拆分

本文深入探讨了VAE(变分自编码器)的损失函数,特别是其Disentangle概念,即解纠缠,旨在寻找数据中可解释的、分离的因子。文章详细介绍了VAE的损失函数的不同形式,以及如何通过分解目标函数来理解各部分的作用,包括后验分布的近似、条件分布的一致性和边际分布的匹配。通过移除目标函数的不同项,分析了它们对模型学习行为的影响。

「Structured Disentangled Representations」这篇文章对VAE的损失函数提出了一个统一化的解释,根据这个解释可以很好的分析近几年来对VAE的各种变形。

什么是Disentangle

Disentangle 的意思是解纠缠,所谓解纠缠,也叫做解耦,就是将原始数据空间中纠缠着的数据变化,变换到一个好的表征空间中,在这个空间中,不同要素的变化是可以彼此分离的。
比如,人脸数据集经过编码器,在潜变量空间Z中,我们就会获得人脸是否微笑、头发颜色、方位角等信息的分离表示,我们把这些分离表示称为Factors。 解纠缠的变量通常包含可解释的语义信息,并且能够反映数据变化中的分离的因子。在生成模型中,我们就可以根据这些分布进行特定的操作,比如改变人脸宽度、添加眼镜等操作。(转自知乎用户:大象)

VAE 要做的就是要找到这些隐式的解纠缠变量,当然也有很多其他方法能做到这一点,比如狄利克雷过程。
这篇文章主要讲了自动编码机上的应用。

VAE损失函数的不同形式

  1. VAE的目标以往被定义成每个数据点ELBO(Evidence Lower Bound)的期望值在数据点X上的集合,或者被定义为一个用有限个数的数据点来接近实际分布的经验性的分布。个人理解就是如果你有这个要求的分布的很多点的话,就能拿来定义这个分布。
    即:
    LVAE(θ,ϕ):=Eq(x)[Eqϕ(z∣x)[logpθ(x,z)qϕ(z∣x)]]L^{VAE}(\theta,\phi):=E_{q(x)}[E_{q\phi(z|x)}[log \frac{p_\theta(x,z)}{q_\phi(z|x)}]]LVAE(θ,ϕ):=Eq(x)[Eqϕ(zx)[logqϕ(zx)pθ(x,z)]]
    q(x):=1N∑Nn=1δx()xq(x) := \frac{1}{N} \sum_{N}^{n=1}\delta_x()xq(x):=N1Nn=1δx()x
    这里的分布q和分布p分别被ϕ\phiϕθ\thetaθ参数化。
  2. 把VAE的损失函数定义为KL散度。这个KL散度是生成模型(generative model)pθ(z,x)p_\theta(z,x)pθ(z,x)与推测模型(inference model)qϕ(z,x)=q(z∣x)q(x)q_\phi(z,x)= q(z |x)q(x)qϕ
好的,首先需要安装PyTorch库: ```python !pip install torch ``` 然后,我们可以按照以下步骤构建两个编码器和一个解码器: ```python import torch import torch.nn as nn import torch.nn.functional as F # 定义编码器 class Encoder(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(Encoder, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x # 定义解码器 class Decoder(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, dropout): super(Decoder, self).__init__() self.dropout = nn.Dropout(dropout) self.gat1 = nn.MultiheadAttention(input_dim, 2) self.gat2 = nn.MultiheadAttention(input_dim, 2) self.mlp1 = nn.Linear(input_dim*2, hidden_dim) self.mlp2 = nn.Linear(hidden_dim, hidden_dim) self.mlp3 = nn.Linear(hidden_dim, output_dim) def forward(self, x, adj): # GAT处理速度和度特征 x1 = self.dropout(x) x1 = x1.permute(1, 0, 2) x1, _ = self.gat1(x1, x1, x1) x1 = x1.permute(1, 0, 2) # GAT处理速度和星期 x2 = self.dropout(x) x2 = x2.permute(1, 0, 2) x2, _ = self.gat2(x2, x2, x2) x2 = x2.permute(1, 0, 2) # 两个GAT的输出进行拼接 x = torch.cat((x1, x2), dim=2) # MLP处理拼接后的特征 x = F.relu(self.mlp1(x)) x = F.relu(self.mlp2(x)) x = self.mlp3(x) return x # 定义VAE class VAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim, dropout): super(VAE, self).__init__() self.encoder1 = Encoder(input_dim, hidden_dim, latent_dim) self.encoder2 = Encoder(input_dim, hidden_dim, latent_dim) self.decoder = Decoder(latent_dim, hidden_dim, input_dim, dropout) def forward(self, x, adj): # 编码器1处理速度特征 mu1 = self.encoder1(x) logvar1 = self.encoder1(x) std1 = torch.exp(0.5*logvar1) eps1 = torch.randn_like(std1) z1 = eps1.mul(std1).add_(mu1) # 编码器2处理度特征和星期特征 mu2 = self.encoder2(x) logvar2 = self.encoder2(x) std2 = torch.exp(0.5*logvar2) eps2 = torch.randn_like(std2) z2 = eps2.mul(std2).add_(mu2) # 将两个编码器的输出拼接并送入解码器 z = torch.cat((z1, z2), dim=2) recon_x = self.decoder(z, adj) return mu1, logvar1, mu2, logvar2, recon_x ``` 每行代码的功能如下: - 第3-9行:定义了一个编码器,它包含两个线性层和一个ReLU激活函数。 - 第11-23行:定义了一个解码器,包含两个GAT层和三个线性层,其中第一个线性层和第二个GAT层处理速度和度特征,第二个线性层和第二个GAT层处理速度和星期特征,最后一个线性层将两个GAT的输出进行拼接并送入MLP中处理。 - 第25-39行:定义了VAE模型,包含两个编码器和一个解码器。编码器1处理速度特征,编码器2处理度特征和星期特征。将两个编码器的输出拼接并送入解码器中进行解码。 现在,我们可以生成形状为(16992,307,12,3)的数据集,并将其按照batchsize=16送入模型: ```python # 随机生成数据集 data = torch.randn(16992, 307, 12, 3) # 将数据集按照batchsize=16拆分,送入模型 batch_size = 16 for i in range(0, data.shape[0], batch_size): x = data[i:i+batch_size].reshape(-1, 12, 3) adj = torch.randn(x.shape[0], x.shape[1], x.shape[1]) mu1, logvar1, mu2, logvar2, recon_x = model(x, adj) print('VAE隐变量的输出维度:', mu1.shape, logvar1.shape, mu2.shape, logvar2.shape) print('解码器的输出维度:', recon_x.shape) ``` 在这里,我们按照batchsize=16将数据集拆分为多个小批量,然后对每个小批量进行编码和解码。最后,我们打印出VAE隐变量的输出维度和解码器的输出维度。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值