Disentangling by Factorising
本文旨在解决β\betaβ-VAE在重建上存在模糊的问题。首先分析了β\betaβ-VAE的问题:
β\betaβ-VAE的目标函数为1N∑i=1N[Eq(z∣x(i))[logp(x(i)∣z)]−βKL(q(z∣x(i))∥p(z))]
\frac{1}{N} \sum_{i=1}^{N}\left[\mathbb{E}_{q\left(z | x^{(i)}\right)}\left[\log p\left(x^{(i)} | z\right)\right]-\beta K L\left(q\left(z | x^{(i)}\right) \| p(z)\right)\right]
N1i=1∑N[Eq(z∣x(i))[logp(x(i)∣z)]−βKL(q(z∣x(i))∥p(z))]我们可以进一步把KL散度部分化解为Epdata(x)[KL(q(z∣x)∥p(z))]=I(x;z)+KL(q(z)∥p(z))
\mathbb{E}_{p_{\text {data}}(x)}[K L(q(z | x) \| p(z))]=I(x ; z)+K L(q(z) \| p(z))
Epdata(x)[KL(q(z∣x)∥p(z))]=I(x;z)+KL(q(z)∥p(z))其中I(x;z)I(x ; z)I(x;z)为x,zx,zx,z的互信息,且q(z)=Epdata(x)[q(z∣x)]
q(z)=\mathbb{E}_{p_{\text {data}}}(x)[q(z | x)]
q(z)=Epdata(x)[q(z∣x)]KL(q(z)∥p(z))K L(q(z) \| p(z))KL(q(z)∥p(z))部分让隐变量zzz的各个维度相互独立;然而惩罚互信息部分会减少zzz中关于xxx的信息,从而越大的β\betaβ会导致重构能力降低!那么就是说对于KL部分我们期待一个较大的β\betaβ,但是对于互信息部分则不需要一个太大的β\betaβ。
本文方法的目标函数则在正常的VAE中加入额外的正则项1N∑i=1N[Eq(z∣x(i))[logp(x(i)∣z)]−KL(q(z∣x(i))∥p(z))]−γKL(q(z)∥q‾(z))
\begin{aligned} \frac{1}{N} \sum_{i=1}^{N}\left[\mathbb{E}_{q\left(z | x^{(i)}\right)}\left[\log p\left(x^{(i)} | z\right)\right]\right.&-K L\left(q\left(z | x^{(i)}\right) \| p(z)\right) ] -\gamma K L(q(z) \| \overline{q}(z)) \end{aligned}
N1i=1∑N[Eq(z∣x(i))[logp(x(i)∣z)]−KL(q(z∣x(i))∥p(z))]−γKL(q(z)∥q(z))其中q‾(z):=∏j=1dq(zj)\overline{q}(z) :=\prod_{j=1}^{d} q\left(z_{j}\right)q(z):=∏j=1dq(zj),KL(q(z)∥q‾(z))K L(q(z) \| \overline{q}(z))KL(q(z)∥q(z))常常称为Total_CorrelationTotal\_CorrelationTotal_Correlation,一个用于测量多变量独立性的度量。但是这个部分很难直接处理,可以按照如下进行采样:
对于q(z)q(z)q(z)的采样,可以随机选择一个样本x(i)x^{(i)}x(i),然后从q(z∣x(i))q\left(z | x^{(i)}\right)q(z∣x(i))中采出zzz;对于q‾(z)\overline{q}(z)q(z)的采样,首先从q(z)q(z)q(z)中采样出ddd个样本,然后每个样本仅仅保留一个维度,从而得到了一个q‾(z)\overline{q}(z)q(z)的样本。在本文中采用了另一个更加高效的方法
对于新引入的KL部分,本文采用density-ratio trickTC(z)=KL(q(z)∥q‾(z))=Eq(z)[logq(z)q‾(z)]≈Eq(z)[logD(z)1−D(z)]
\begin{aligned} T C(z) &=K L(q(z) \| \overline{q}(z))=\mathbb{E}_{q(z)}\left[\log \frac{q(z)}{\overline{q}(z)}\right] \approx \mathbb{E}_{q(z)}\left[\log \frac{D(z)}{1-D(z)}\right] \end{aligned}
TC(z)=KL(q(z)∥q(z))=Eq(z)[logq(z)q(z)]≈Eq(z)[log1−D(z)D(z)]其中D(z)D(z)D(z)是一个二分类器,判断样本zzz来自q(z)q(z)q(z)时,标签为1。从而整个算法流程为