Notation
- pθ(z|x)pθ(z|x): intractable posterior
- pθ(x|z)pθ(x|z): probabilistic decoder
- qϕ(z|x)qϕ(z|x): recognition model, variational approximation to pθ(z|x)pθ(z|x), also regarded as a probabilistic encoder
- pθ(z)pθ(x|z)pθ(z)pθ(x|z): generative model
- ϕϕ: variational parameters
- θθ: generative parameters
Abbreviation
- SGVB: Stochastic Gradient Variational Bayes
- AEVB: auto-encoding VB
- ML: maximum likelihood
- MAP: maximum a posteriori
Motivation
Problem
- How to perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distribution pθ(z|x)pθ(z|x) and large datasets?
Existing Solution and Difficulty
- VB: involves the optimization of an approximation to the intractable posterior
- mean-field: requires analytical solutions of expectations w.r.t. the approximate posterior, which are also intractable in the general case
Contribution of this paper
- (1) SGVB estimator: an estimator of the variational lower bound
- yielded by a reparameterization of the variational lower bound
- simple & differentiable & unbiased
- straightforwad to optimize using standard SG ascent techniques
- (2) AEVB algorithm
- using SGVB to optimize a recognition model that allows us to perform very efficient approximate posterior inference using simple ancestral sampling, which in turn allows us to efficiently learn the model parameters, without the need of expensive iterative inference schemes (such as MCMC) per datapoint.
- condition: i.i.d. datasets X={x(i)}Ni=1X={x(i)}i=1N & continuous latent variable zz per datapoint
Methodology
assumption
- directed graphical models with continuous latent variables
- i.i.d. dataset with latent variables per datapoint
- where we like to perform
- ML or MAP inference on the (global) paramters
- variational inference on the latent variable zz
- where we like to perform
- and pθ(x|z)pθ(x|z): both PDFs are differentiable almost everywhere w.r.t. both θθ and zz
target case
- intractability
- : so we cannot evaluate or differentiate it.
- pθ(z|x)=pθ(x|z)pθ(z)pθ(x)pθ(z|x)=pθ(x|z)pθ(z)pθ(x): so the EM algorithm cannot be used.
- the required integrals for any reasonable mean-field VB algorithm: so the VB algorithm cannot be used.
- in cases of moderately complicated likelihood function, e.g. in a neural network with a nonlinear hidden layer
- a large dataset
- batch optimization is too costly => minibatch or single datapoints
- sampling bases solutions are too slow, e.g. Monte Carlo EM, since it involves a typically expensive sampling loop per datapoint.
solution and application
- efficient approximate ML or MAP estimation for θθ (Full): Appendix F
- allow us to mimic the hidden random process and generate artificial data that resemble the real data
- efficient approximate posterior inference pθ(z|x)pθ(z|x) for a choice of θθ
- useful for coding or data representation tasks
- efficient approximate marginal inference of xx: Appendix D
- allow us to perform all kinds of inference tasks where is required, such as image denoising, inpainting, and super-resolution.
1. derivation of the variational bound
logpθ(x(1),⋯,x(N))=∑i=1Nlogpθ(x(i))logpθ(x(1),⋯,x(N))=∑i=1Nlogpθ(x(i))
Here we use xixi to represent x(i)x(i)logp(xi)=∫zq(z|xi)logp(xi)dz (q can be any distribution)=∫zq(z|xi)logp(z,xi)p(z|xi)dz=∫zq(z|xi)log[q(z|xi)p(z|xi)⋅p(z,xi)q(z|xi)]dz=∫zq(z|xi)logq(z|xi)p(z|xi)dz+∫zq(z|xi)logp(z,xi)q(z|xi)dz=DKL[qϕ(z|xi)∥pθ(z|xi)]+LB(θ,ϕ;xi)logp(xi)=∫zq(z|xi)logp(xi)dz (q can be any distribution)=∫zq(z|xi)logp(z,xi)p(z|xi)dz=∫zq(z|xi)log[q(z|xi)p(z|xi)⋅p(z,xi)q(z|xi)]dz=∫zq(z|xi)logq(z|xi)p(z|xi)dz+∫zq(z|xi)logp(z,xi)q(z|xi)dz=DKL[qϕ(z|xi)∥pθ(z|xi)]+LB(θ,ϕ;xi)Becuase DKL≥0DKL≥0, so LBLB is called the (variational) lower bound, then we have
logpθ(xi)≥LB(θ,ϕ;xi)logpθ(xi)≥LB(θ,ϕ;xi)LB(θ,ϕ;xi)=Eqϕ(z|x)[logpθ(x,z)−logqϕ(z|x)]=∫zq(z|xi)logp(z,xi)q(z|xi)dz=∫zq(z|xi)logp(xi|z)p(z)q(z|xi)dz=−DKL[q(z|xi)∥p(z)]+∫zq(z|xi)logp(xi|z)dz=−DKL[qϕ(z|xi)∥pθ(z)]+Eqϕ(z|xi)[logpθ(xi|z)]LB(θ,ϕ;xi)=Eqϕ(z|x)[logpθ(x,z)−logqϕ(z|x)]=∫zq(z|xi)logp(z,xi)q(z|xi)dz=∫zq(z|xi)logp(xi|z)p(z)q(z|xi)dz=−DKL[q(z|xi)∥p(z)]+∫zq(z|xi)logp(xi|z)dz=−DKL[qϕ(z|xi)∥pθ(z)]+Eqϕ(z|xi)[logpθ(xi|z)]- −DKL[qϕ(z|xi)∥pθ(z)]−DKL[qϕ(z|xi)∥pθ(z)]: act as a regularizer
- Eqϕ(z|xi)[logpθ(xi|z)]Eqϕ(z|xi)[logpθ(xi|z)]: a an expected negative reconstruction error
- TARGET: defferentiate and optimize LBLB w.r.t. both ϕϕ and θθ. However, ∇ϕLB∇ϕLB is problematic.
2. Solution-1 Naive Monte Carlo gradient estimator
- disadvantages: high variance & impractical for this purpose, 原文中出现的公式如下:
∇ϕEqϕ(z)[f(z)]=Eqϕ(z)[f(z)∇qϕ(z)logqϕ(z)]≃1L∑l=1Lf(z)∇qϕ(zl)logqϕ(zl)where zl∼qϕ(z|xi)∇ϕEqϕ(z)[f(z)]=Eqϕ(z)[f(z)∇qϕ(z)logqϕ(z)]≃1L∑l=1Lf(z)∇qϕ(zl)logqϕ(zl)where zl∼qϕ(z|xi)
但由于我还没来得及学Monte Carlo Gradient Estimator的理论,根据VAE这边论文后面的公式,个人觉得…上面的公式应该是(数学功底不够深厚,不确定二者是否等价):
∇ϕEqϕ(z)[f(z)]=Eqϕ(z)[f(z)∇qϕ(z)logqϕ(z)]≃1L∑l=1Lf(zl)where zl∼qϕ(z|xi)∇ϕEqϕ(z)[f(z)]=Eqϕ(z)[f(z)∇qϕ(z)logqϕ(z)]≃1L∑l=1Lf(zl)where zl∼qϕ(z|xi)3. Solution-2 SGVB estimator
- reparamterize z̃ ∼qϕ(z|x)z~∼qϕ(z|x) using a differentiable transformation gϕ(ϵ,x)gϕ(ϵ,x) of an (auxiliary) noise variable ϵϵ
- under certain mild conditions for a chosen approximate posterior qϕ(z|x)qϕ(z|x)
z̃ =gϕ(ϵ,x) with ϵ∼p(ϵ)z~=gϕ(ϵ,x) with ϵ∼p(ϵ)
- under certain mild conditions for a chosen approximate posterior qϕ(z|x)qϕ(z|x)
- form Monte Carlo estimates as follows:
qϕ(z|x)∏idzi=p(ϵ)∏idϵi∫qϕ(z|x)f(z)dz=∫p(ϵ)f(z)dϵ=∫p(ϵ)f(gϕ(ϵ,x))dϵEqϕ(z|xi)[f(z)]=Ep(ϵ)[f(gϕ(ϵ,xi))]≃1L∑l=1Lf(gϕ(ϵl,xi))where ϵl∼p(ϵ)qϕ(z|x)∏idzi=p(ϵ)∏idϵi∫qϕ(z|x)f(z)dz=∫p(ϵ)f(z)dϵ=∫p(ϵ)f(gϕ(ϵ,x))dϵEqϕ(z|xi)[f(z)]=Ep(ϵ)[f(gϕ(ϵ,xi))]≃1L∑l=1Lf(gϕ(ϵl,xi))where ϵl∼p(ϵ) - apply the MC estimator technique to LB(θ,ϕ;xi)LB(θ,ϕ;xi), yielding 2 SGVB estimator L̃ AL~A and L̃ BL~B:
L̃ A(θ,ϕ;xi)=1L∑l=1Llogpθ(xi,zi,l)−logqϕ(zi,l|xi)L̃ B(θ,ϕ;xi)=−DKL(qϕ(z|xi)∥pθ(z))+1L∑l=1Llogpθ(xi|zi,l)where zi,l=gϕ(ϵi,l,xi) and ϵl∼p(ϵ)L~A(θ,ϕ;xi)=1L∑l=1Llogpθ(xi,zi,l)−logqϕ(zi,l|xi)L~B(θ,ϕ;xi)=−DKL(qϕ(z|xi)∥pθ(z))+1L∑l=1Llogpθ(xi|zi,l)where zi,l=gϕ(ϵi,l,xi) and ϵl∼p(ϵ) - minibatch with size MM
- the number of samples L per datapoint can be set to 1 as long as the minibatch size MM was large enough, e.g.
3.1 Minibatch version of AEVB algorithm
- θ,ϕ←θ,ϕ← Initialize parameters
- repeate
- XM←XM← Random minibatch of MM datapoints
- Random samples from noise distribution p(ϵ)p(ϵ)
- g←∇θ,ϕL̃ M(θ,ϕ;XM,ϵ)g←∇θ,ϕL~M(θ,ϕ;XM,ϵ)
- θ,ϕ←θ,ϕ← Update parameters using gradients gg (e.g. SGD or Adagrad)
- until convergence of parameters
- return θ,ϕθ,ϕ
chosen of qϕ(z|x)qϕ(z|x), p(ϵ)p(ϵ), and gϕ(ϵ,x)gϕ(ϵ,x)
There are three basic approaches:
- Tractable inverse CDF. Let ϵ∼U(0,I)ϵ∼U(0,I), gϕ(ϵ,x)gϕ(ϵ,x) be the inverse CDF of qϕ(z|x)qϕ(z|x).
- Examples: Exponential, Cauchy, Logistic, Rayleigh, Pareto, Weibull, Reciprocal, Gompertz, Gumbel and Erlang distributions.
- ”location-scale” family of distributions: choose the standard distribution (with location = 0, scale = 1) as the auxiliary variable ϵϵ, and let g(⋅)=location+scale⋅ϵg(⋅)=location+scale⋅ϵ
- Examples: Laplace, Elliptical, Student’s t, Logistic, Uniform, Triangular and Gaussian distributions.
- 下文介绍的VAE即适用于此种情况
- Composition: It is often possible to express random variables as different transformations of auxiliary variables.
- Examples: Log-Normal (exponentiation of normally distributed variable), Gamma (a sum over exponentially distributed variables), Dirichlet (weighted sum of Gamma variates), Beta, Chi-Squared, and F distributions.
Variational Auto-Encoder
- let pθ(z)=N(z;0,I)pθ(z)=N(z;0,I)
- pθ(z|x)pθ(z|x) is intractable
- use a neural network for qϕ(z|x)qϕ(z|x)
- ϕϕ and θθ are optimized jointly with the AEVB algorithm
- params of pθ(x|z)pθ(x|z) are computed from zz with a MLP (multi-layered perceptrons, a fully-connected neural network with a hidden layer)
- multivariae Gaussian: in case of real-valued data
- Bernouli: incase of binary data
logp(x|z)=∑i=1Dxilogyi+(1−xi)⋅log(1−yi)where y=fσ(Wytanh(Wh+bh)+by)fσ(⋅):elementwise sigmoid activation functionlogp(x|z)=∑i=1Dxilogyi+(1−xi)⋅log(1−yi)where y=fσ(Wytanh(Wh+bh)+by)fσ(⋅):elementwise sigmoid activation function
- multivariae Gaussian: in case of real-valued data
From what mentioned above, we have:
logqϕ(z|xi)=logN(z;μi,σ2,iI)logqϕ(z|xi)=logN(z;μi,σ2,iI)
where μiμi and σiσi are outputs of the encoding MLP.We sample form zi,l∼qϕ(z|xi)zi,l∼qϕ(z|xi) using zi,l=gϕ(xi,ϵl)=μi+σi⊙ϵlzi,l=gϕ(xi,ϵl)=μi+σi⊙ϵl, where ϵl∼N(0,I)ϵl∼N(0,I) and ⊙⊙ denotes element-wize product.
Here both pθ(z)pθ(z) and qϕ(z|x)qϕ(z|x) are Gaussian, so we can use the estimator L̃ BL~B, since the KL divergence item is analytical. Then we have:
L(θ,ϕ;xi)≃−DKL(qϕ(z|xi)∥pθ(z))+1L∑l=1Llogpθ(xi|zi,l)≃12∑j=1L(1+log(σi2j)−μi2j−σi2j)+1L∑l=1Llogpθ(xi|zi,l)where zil=μi+σi⊙ϵl and ϵl∼N(0,I)L(θ,ϕ;xi)≃−DKL(qϕ(z|xi)∥pθ(z))+1L∑l=1Llogpθ(xi|zi,l)≃12∑j=1L(1+log(σji2)−μji2−σji2)+1L∑l=1Llogpθ(xi|zi,l)where zil=μi+σi⊙ϵl and ϵl∼N(0,I)Solution of −DKL(qϕ(z)∥pθ(z))−DKL(qϕ(z)∥pθ(z)), Gaussian case
Let JJ be the dimensionality of , then we have:
∫qθ(z)logpθ(z)dz=∫N(z;μ,σ2)logN(z;0,I)dz=−J2log(2π)−12∑j=1J(μ2j+σ2j)∫qθ(z)logpθ(z)dz=∫N(z;μ,σ2)logN(z;0,I)dz=−J2log(2π)−12∑j=1J(μj2+σj2)And:
∫qθ(z)logqθ(z)dz=∫N(z;μ,σ2)logN(z;μ,σ2)dz=−J2log(2π)−12∑j=1J(1+logσ2j)∫qθ(z)logqθ(z)dz=∫N(z;μ,σ2)logN(z;μ,σ2)dz=−J2log(2π)−12∑j=1J(1+logσj2)Therefore:
−DKL(qϕ(z)∥pθ(z))=∫qθ(z)(logpθ(z)−logqθ(z))dz=12∑j=1L(1+log(σi2j)−μi2j−σi2j)−DKL(qϕ(z)∥pθ(z))=∫qθ(z)(logpθ(z)−logqθ(z))dz=12∑j=1L(1+log(σji2)−μji2−σji2)此处附上相关证明:
Visulization
Since the prior of the latent space is Gaussian, linearly spaced coordinates on the unit square were transformed through the inverse CDF of the Gaussian to produce values of the latent variables zz.