beta-VAE

beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework

β−VAE\beta-VAEβVAE的目标是学习独立的特征,让某种特征对应某个生成因素,而独立于其他因素。For example,a model trained on photos of human faces might capture the gentle, skincolor, hair color, hair length, emotion, whether wearing a pair of glasses andmany other relatively independent factors in separate dimensions. Such a disentangled representation is very beneficial to facial image generation.
数据集D={X,V,W}\mathcal{D}=\{X, V, W\}D={X,V,W},其中x∈RN\mathbf{x} \in \mathbb{R}^{N}xRN代表图片,而v∈RK\mathbf{v} \in \mathbb{R}^{K}vRK为条件独立的隐变量,即log⁡p(v∣x)=∑klog⁡p(vk∣x)\log p(\mathbf{v} | \mathbf{x})=\sum_{k} \log p\left(v_{k} | \mathbf{x}\right)logp(vx)=klogp(vkx),还有条件依赖的隐变量w∈RH\mathbf{w} \in \mathbb{R}^{H}wRH。那么图片由着两个隐变量共同生成,则p(x∣v,w)=Sim⁡(v,w)p(\mathbf{x} | \mathbf{v}, \mathbf{w})=\operatorname{Sim}(\mathbf{v}, \mathbf{w})p(xv,w)=Sim(v,w)
我们现在希望通过一种无监督的方法,仅仅利用数据XXX就能得到x\mathbf xxz\mathbf zz的联合分布,其中z∈RM\mathbf{z} \in \mathbb{R}^{M}zRMM≥KM \geq KMK。也就是说有p(x∣z)≈p(x∣v,w)=Sim⁡(v,w)p(\mathbf{x} | \mathbf{z}) \approx p(\mathbf{x} | \mathbf{v}, \mathbf{w})=\operatorname{Sim}(\mathbf{v}, \mathbf{w})p(xz)p(xv,w)=Sim(v,w)。那么目标函数为max⁡θEpθ(z)[pθ(x∣z)] \max _{\theta} \mathbb{E}_{p_{\theta}(\mathbf{z})}\left[p_{\theta}(\mathbf{x} | \mathbf{z})\right] θmaxEpθ(z)[pθ(xz)]与VAE一样,引入一个inferred后验分布qϕ(z∣x)q_{\phi}(\mathbf{z} | \mathbf{x})qϕ(zx),我们的目标是qϕ(z∣x)q_{\phi}(\mathbf{z} | \mathbf{x})qϕ(zx)能够以解耦的方式将v\mathbf vv解耦出。为了能够达到目的,我们可以让这个后验接近p(z)p(\mathbf z)p(z),这样既能限制隐变量的信息瓶颈,又能达到之前说的解耦的效果,且p(z)=N(0,I)p(\mathbf{z})=\mathcal{N}(0, I)p(z)=N(0,I)。则目标函数为max⁡ϕ,θEx∼D[Eqϕ(z∣x)[log⁡pθ(x∣z)]] subject to DKL(qϕ(z∣x)∥p(z))&lt;ϵ \max _{\phi, \theta} \mathbb{E}_{x \sim \mathbf{D}}\left[\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} | \mathbf{z})\right]\right] \quad \text { subject to } D_{K L}\left(q_{\phi}(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z})\right)&lt;\epsilon ϕ,θmaxExD[Eqϕ(zx)[logpθ(xz)]] subject to DKL(qϕ(zx)p(z))<ϵ求解上式,利用拉格朗日KKT条件F(θ,ϕ,β;x,z)=Eqϕ(z∣x)[log⁡pθ(x∣z)]−β(DKL(qϕ(z∣x)∥p(z))−ϵ) \mathcal{F}(\theta, \phi, \beta ; \mathbf{x}, \mathbf{z})=\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} | \mathbf{z})\right]-\beta\left(D_{K L}\left(q_{\phi}(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z})\right)-\epsilon\right) F(θ,ϕ,β;x,z)=Eqϕ(zx)[logpθ(xz)]β(DKL(qϕ(zx)p(z))ϵ)又因为β,ϵ≥0\beta, \epsilon \geq 0β,ϵ0,则F(θ,ϕ,β;x,z)≥L(θ,ϕ;x,z,β)=Eqϕ(z∣x)[log⁡pθ(x∣z)]−βDKL(qϕ(z∣x)∥p(z)) \mathcal{F}(\theta, \phi, \beta ; \mathbf{x}, \mathbf{z}) \geq \mathcal{L}(\theta, \phi ; \mathbf{x}, \mathbf{z}, \beta)=\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} | \mathbf{z})\right]-\beta D_{K L}\left(q_{\phi}(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z})\right) F(θ,ϕ,β;x,z)L(θ,ϕ;x,z,β)=Eqϕ(zx)[logpθ(xz)]βDKL(qϕ(zx)p(z))可以发现与vanilla-VAE的唯一区别就是在KL项引入了。Therefore a higher β\betaβ encourages more efficient latent encoding and further encouragesthe disentanglement. Meanwhile, a higher β\betaβ may create a trade-off between reconstruction quality and the extent of disentanglement.

Understanding disentangling in β\betaβ-VAE

从信息瓶颈的角度看β−VAE\beta-VAEβVAE。首先先科普下什么是information bottlenec。假设我们面对分类任务,标注数据对是(x1,y1),…,(xN,yN)\left(x_{1}, y_{1}\right), \ldots,\left(x_{N}, y_{N}\right)(x1,y1),,(xN,yN)。我们把这个任务分为两步来理解,第一步是编码,第二步就是分类。第一步是把xxx编码为一个隐变量zzz,然后分类器把zzz识别为类别yyy
我们试想在zzz处加一个“瓶颈”β\betaβ,它像一个沙漏,进入的信息量可能有很多,但是出口就只有β\betaβ那么大,所以这个瓶颈的作用是:不允许流过zzz的信息量多于β\betaβ。跟沙漏不同的是,沙漏的沙过了瓶颈就完事了,而信息过了信息瓶颈后,还需要完成它要完成的任务(分类、回归等),所以模型迫不得已,只好想办法让最重要的信息通过瓶颈。这就是信息瓶颈的原理!
那么衡量信息量大小的式子刚好有互信息I(Z,Y;θ)=∫dzdyp(z,y∣θ)log⁡p(z,y∣θ)p(z∣θ)p(y∣θ) I(Z, Y ; \boldsymbol{\theta})=\int d z d y p(z, y | \boldsymbol{\theta}) \log \frac{p(z, y | \boldsymbol{\theta})}{p(z | \boldsymbol{\theta}) p(y | \boldsymbol{\theta})} I(Z,Y;θ)=dzdyp(z,yθ)logp(zθ)p(yθ)p(z,yθ)显然上式最大的时候刚好为Z=YZ=YZ=Y时,这样我们的学习就没有意义了。因此可以考虑对上式加上一定的约束max⁡θI(Z,Y;θ) s.t. I(X,Z;θ)≤Ic \max _{\boldsymbol{\theta}} I(Z, Y ; \boldsymbol{\theta}) \text { s.t. } \quad I(X, Z ; \boldsymbol{\theta}) \leq I_{c} θmaxI(Z,Y;θ) s.t. I(X,Z;θ)IcX,ZX,ZX,Z的互信息约束在一个范围中,这样就是相当于一个瓶颈,提取十分必要的信息,而略去一些没有意义的信息!求解上式同样适用拉格朗日乘子,最后得到RIB(θ)=I(Z,Y;θ)−βI(Z,X;θ) R_{I B}(\boldsymbol{\theta})=I(Z, Y ; \boldsymbol{\theta})-\beta I(Z, X ; \boldsymbol{\theta}) RIB(θ)=I(Z,Y;θ)βI(Z,X;θ)Intuitively, the first term in RIB encourages Z to be predictive of Y ; thesecond term encourages Z to“forget”X. Essentially it forces Z to act likea minimal sufficient statistic of X for predicting Y .
那么β−VAE\beta-VAEβVAE就能用信息瓶颈的理论解释了。首先第一部分为一个重建问题max⁡Eq(z∣x)[log⁡p(x∣z)] \max E_{q(\mathbf{z} | \mathbf{x})}[\log p(\mathbf{x} | \mathbf{z})] maxEq(zx)[logp(xz)]使得上式最大时,就是在隐含着互信息I(Z,Y;θ)I(Z, Y ; \boldsymbol{\theta})I(Z,Y;θ)变大,其中Y=XY=XY=X。只有X,ZX,ZX,Z的相关程度很高的时候,重建才能很好!更加极端的情况就是AE自动编码机,Z=encoder(X)Z=encoder(X)Z=encoder(X)为一个决定性关系,此时重构可以做到很高的精度,而且X,ZX,ZX,Z的互信息也很大!
KLKLKL部分DKL(qϕ(z∣x)∥p(z))D_{K L}\left(q_{\phi}(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z})\right)DKL(qϕ(zx)p(z)),在IB理论中是减少X,ZX,ZX,Z的互信息,当我们减少KLKLKL散度时,则有qϕ(z∣x)→p(z)q_{\phi}(\mathbf{z} | \mathbf{x}) \rightarrow p(\mathbf z)qϕ(zx)p(z),那就是说明X,ZX,ZX,Z就是独立了,从而互信息也就为0了。那么KLKLKL部分越大,说明该维度的隐变量含有的信息量就越大。
β−VAE\beta-VAEβVAE虽然能够学习到很好的特征,但是其重建能力真的很差,如下图所示。
在这里插入图片描述
本文发现逐渐增大信息瓶颈可以让重建能力上升,且能够学习到很棒的特征。因此提出了如下的目标函数:L(θ,ϕ;x,z,C)=Eqϕ(z∣x)[log⁡pθ(x∣z)]−γ∣DKL(qϕ(z∣x)∥p(z))−C∣ \mathcal{L}(\theta, \phi ; \mathbf{x}, \mathbf{z}, C)=\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} | \mathbf{z})\right]-\gamma\left|D_{K L}\left(q_{\phi}(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z})\right)-C\right| L(θ,ϕ;x,z,C)=Eqϕ(zx)[logpθ(xz)]γDKL(qϕ(zx)p(z))C其中CCC为一个逐渐增大的数字,一般γ\gammaγ取一个很大的数字,保证KLKLKL部分能满足CCC
在这里插入图片描述

06-24
### Beta-VAE 介绍 Beta-VAE 是一种变分自编码器(Variational Autoencoder, VAE)的扩展形式,旨在通过引入一个额外的超参数 $\beta$ 来增强模型学习解耦表示的能力。传统的 VAE 通过最大化数据的边缘似然下界来训练,而 Beta-VAE 则在损失函数中对 KL 散度项进行加权,其目标函数可以表示为: $$ \mathcal{L}_{\beta} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \beta \cdot D_{KL}(q(z|x) \| p(z)) $$ 其中,$\beta > 1$ 的设置会加强对潜在变量分布 $q(z|x)$ 与先验分布 $p(z)$ 差异的惩罚,从而鼓励模型学习出更加独立的潜在因子。这种方法在无监督学习任务中特别有用,尤其是在需要解释性较强的潜在空间的应用场景中[^2]。 ### 使用方法 使用 Beta-VAE 的关键在于合理配置其超参数 $\beta$ 和网络结构。以下是一般步骤: 1. **数据准备**:确保输入数据已经被适当预处理并标准化。 2. **构建模型**:定义编码器和解码器网络结构。通常采用卷积神经网络(CNN)或全连接网络来实现。 3. **设定 $\beta$ 值**:选择合适的 $\beta$ 值以控制潜在空间的稀疏性和解耦程度。较大的 $\beta$ 值会导致更强的正则化效果。 4. **训练模型**:优化重构损失与 KL 散度之间的平衡。 5. **评估与应用**:利用训练好的模型进行数据生成、特征提取或进一步用于下游任务如分类或强化学习。 具体实现时可参考项目文档中的说明,尤其是关于配置文件的详细描述,这有助于更好地理解每个参数的作用及其影响[^1]。 ### 实现代码 以下是一个基于 PyTorch 的简单 Beta-VAE 模型实现示例: ```python import torch import torch.nn as nn import torch.nn.functional as F class BetaVAE(nn.Module): def __init__(self, latent_dim=10, beta=4): super(BetaVAE, self).__init__() self.latent_dim = latent_dim self.beta = beta # Encoder self.encoder = nn.Sequential( nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 256, kernel_size=4, stride=2, padding=1), nn.ReLU() ) self.fc_mu = nn.Linear(256*4*4, latent_dim) self.fc_logvar = nn.Linear(256*4*4, latent_dim) # Decoder self.decoder = nn.Sequential( nn.Linear(latent_dim, 256*4*4), nn.ReLU(), nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(), nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(), nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1), nn.Sigmoid() ) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): h = self.encoder(x) h = h.view(h.size(0), -1) mu = self.fc_mu(h) logvar = self.fc_logvar(h) z = self.reparameterize(mu, logvar) recon_x = self.decoder(z).view_as(x) return recon_x, mu, logvar def loss_function(self, recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + self.beta * KLD ``` 上述代码展示了一个基本的 Beta-VAE 架构,适用于 MNIST 或类似图像数据集。它包括了编码器、解码器以及重参数化技巧的实现,并定义了带有 $\beta$ 调整的损失函数。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值