12 变分推断(Variational Inference)
12.1 背景介绍
这一小节的主要目的:为什么要使用Variational Inference,Inference到底有什么用。机器学习,我们可以从频率角度和贝叶斯角度两个角度来看,其中频率角度可以被解释为优化问题,贝叶斯角度可以被解释为积分问题。
12.1.1 频 率 角 度 → 优 化 问 题 频率角度\rightarrow优化问题 频率角度→优化问题
为什么说频率派角度的分析是一个优化问题呢?从回归和SVM两个例子上进行分析。数据集描述为: D = { ( x i , y i ) } i = 1 N , x i ∈ R p , y i ∈ R 1 D = \{ (x_i,y_i) \}_{i=1}^N,x_i \in \mathbb{R}^p,y_i \in \mathbb{R}^1 D={(xi,yi)}i=1N,xi∈Rp,yi∈R1。
- 回归问题
- 回归模型定义: f ( w ) = w T x . (12.1.1) f(w) = w^Tx.\tag{12.1.1} f(w)=wTx.(12.1.1)
- 回归模型策略:
- 其中loss function被定义为:
L ( w ) = ∑ i = 1 N ∣ ∣ w T x i − y i ∣ ∣ 2 (12.1.2) L(w) = \sum_{i=1}^N || w^Tx_i - y_i ||^2\tag{12.1.2} L(w)=i=1∑N∣∣wTxi−yi∣∣2(12.1.2) - 优化可以表达为
w
^
=
a
r
g
m
i
n
L
(
w
)
(12.1.3)
\hat{w} = argmin\ L(w)\tag{12.1.3}
w^=argmin L(w)(12.1.3)
这是个无约束优化问题。
- 其中loss function被定义为:
- 回归模型求解方法可以分成两种:数值解和解析解。
- 解析解的解法为:
∂ L ( w ) ∂ w = 0 ⇒ w ∗ = ( X T X ) − 1 X T Y (12.1.3) \frac{\partial L(w)}{\partial w} = 0 \Rightarrow w^{\ast} = (X^TX)^{-1}X^TY\tag{12.1.3} ∂w∂L(w)=0⇒w∗=(XTX)−1XTY(12.1.3)
其中, X X X是一个 N × p N\times p N×p的矩阵。 - 数值解常用的是GD算法,也就是 G r a d i e n t D e s c e n t Gradient\;Descent GradientDescent,或者 S t o c h a s t i c G r a d i e n t d e s c e n t ( S G D ) Stochastic\;Gradient\;descent (SGD) StochasticGradientdescent(SGD)。
- 解析解的解法为:
- SVM(分类问题)
- SVM的模型: f ( w ) = s i g n ( w T x + b ) (12.1.4) f(w) = sign(w^Tx+b)\tag{12.1.4} f(w)=sign(wTx+b)(12.1.4)
- SVM的策略:
loss function为:
{ min 1 2 w T w s . t . y i ( w T x i + b ) ≥ 1 (12.1.5) \left\{ \begin{array}{ll} \min\ \frac{1}{2}w^Tw & \\ s.t. \quad y_i(w^Tx_i + b) \geq 1 & \\ \end{array}\right.\tag{12.1.5} {min 21wTws.t.yi(wTxi+b)≥1(12.1.5)
这是一个有约束的Convex优化问题。 - SVM求解方法:
常用的解决条件为:拉格朗日乘子法、QP方法和Lagrange 对偶。
- EM算法
- EM优化目标为:
θ ^ = arg max log P ( x ∣ θ ) (12.1.6) \hat{\theta} = \arg\max\ \log P(x|\theta)\tag{12.1.6} θ^=argmax logP(x∣θ)(12.1.6) - EM优化的迭代算法为:
θ ( t + 1 ) = arg max θ ∫ z log P ( x , z ∣ θ ) ⋅ p ( z ∣ x , θ ( t ) ) d z (12.1.7) \theta^{(t+1)} = \arg\underset{\theta}{\max}\int_{z} \log P(x,z|\theta)\cdot p(z|x,\theta^{(t)}) dz\tag{12.1.7} θ(t+1)=argθmax∫zlogP(x,z∣θ)⋅p(z∣x,θ(t))dz(12.1.7)
- EM优化目标为:
12.1.2 贝 叶 斯 角 度 → 积 分 问 题 贝叶斯角度\rightarrow积分问题 贝叶斯角度→积分问题
从贝叶斯的角度来说,这就是一个积分问题,为什么呢?从Bayes公式的表达看:
P
(
θ
∣
x
)
=
P
(
x
∣
θ
)
P
(
θ
)
P
(
x
)
(12.1.8)
P(\theta|x) = \frac{P(x|\theta)P(\theta)}{P(x)}\tag{12.1.8}
P(θ∣x)=P(x)P(x∣θ)P(θ)(12.1.8)
其中, P ( θ ∣ x ) P(\theta|x) P(θ∣x)称为后验公式, P ( x ∣ θ ) P(x|\theta) P(x∣θ)称为似然函数, P ( θ ) P(\theta) P(θ)称为先验分布, P ( x ) P(x) P(x)为已知的概率分布,并且 P ( x ) = ∫ θ P ( x ∣ θ ) P ( θ ) d θ P(x) = \int_{\theta}P(x|\theta)P(\theta)d\theta P(x)=∫θP(x∣θ)P(θ)dθ。贝叶斯角度分为 推 断 ( I n f e r e n c e ) \color{red}推断(Inference) 推断(Inference)和 决 策 \color{red}决策 决策。
- 贝叶斯推断(inference)(求后验
P
(
θ
∣
x
)
P(\theta|x)
P(θ∣x))
什么是推断呢?通俗的说就是求解后验分布 P ( θ ∣ x ) P(\theta|x) P(θ∣x),求解推断可以分为: 精 确 推 断 \color{red}精确推断 精确推断和 近 似 推 断 \color{red}近似推断 近似推断。- 精确推断
直接求解 P ( θ ∣ x ) P(\theta|x) P(θ∣x)。 - 近似推断
P ( θ ∣ x ) P(\theta|x) P(θ∣x)的计算在高维空间的时候非常的复杂,通常不能直接精确的求得,需要采用方法来求一个近似的解。- 确定性近似推断
变分推断(VI) - 随机近似推断
MCMC、MH、Gibbs
- 确定性近似推断
- 精确推断
- 贝叶斯决策
数据集 X X X(N个样本)。我们用数学的语言来表述也就是, x ~ \widetilde{x} x 为新的样本,求 p ( x ~ ∣ X ) p(\widetilde{x}|X) p(x ∣X):
l l P ( x ~ ∣ X ) = ∫ θ P ( x ~ , θ ∣ X ) d θ = ∫ θ P ( x ~ ∣ θ ) ⋅ P ( θ ∣ X ) d θ ( P ( θ ∣ X ) 为 公 式 ( 12.1.8 ) 中 的 后 验 ) = E θ ∣ X [ P ( x ^ ∣ θ ) ] (12.1.9) \begin{aligned}{ll}P(\widetilde{x}|X) & = \int_{\theta} P(\widetilde{x},\theta|X) d\theta \\ & = \int_{\theta} P(\widetilde{x}|\theta)\cdot P(\theta|X)d\theta\color{green}{(P(\theta|X)为公式(12.1.8)中的后验)}\\ & = \mathbf{E}_{\theta|X} [P(\hat{x}|\theta)]\end{aligned}\tag{12.1.9} llP(x ∣X)=∫θP(x ,θ∣X)dθ=∫θP(x ∣θ)⋅P(θ∣X)dθ(P(θ∣X)为公式(12.1.8)中的后验)=Eθ∣X[P(x^∣θ)](12.1.9)
本章主讲:
贝 叶 斯 角 度 → 贝 叶 斯 推 断 → 近 似 推 断 → 确 定 性 近 似 推 断 → 变 分 推 断 \color{red}贝叶斯角度\rightarrow贝叶斯推断\rightarrow近似推断\rightarrow确定性近似推断\rightarrow变分推断 贝叶斯角度→贝叶斯推断→近似推断→确定性近似推断→变分推断
12.2 公式推导
-
数据
有以下数据:- X : o b s e r v e d v a r i a b l e → X : { x i } i = 1 N X:observed\;variable\rightarrow X:\left \{x_{i}\right \}_{i=1}^{N} X:observedvariable→X:{xi}i=1N
- Z : l a t e n t v a r i a b l e + p a r a m e t e r → Z : { z i } i = 1 N Z:latent\;variable + parameter\rightarrow Z:\left \{z_{i}\right \}_{i=1}^{N} Z:latentvariable+parameter→Z:{zi}i=1N
- ( X , Z ) : c o m p l e t e d a t a (X,Z):complete\;data (X,Z):completedata
记 z z z为隐变量和参数的集合。接着变换概率 p ( x ) p(x) p(x)的形式然后引入分布 q ( z ) q(z) q(z):
l o g p ( x ) = l o g p ( x , z ) − l o g p ( z ∣ x ) = l o g p ( x , z ) q ( z ) − l o g p ( z ∣ x ) q ( z ) (12.2.1) \color{blue}log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)}\tag{12.2.1} logp(x)=logp(x,z)−logp(z∣x)=logq(z)p(x,z)−logq(z)p(z∣x)(12.2.1) -
公式简化
对公式(12.2.1)进行简化,式子两边同时对 q ( z ) q(z) q(z)求积分(期望):
左 边 = ∫ z q ( z ) ⋅ l o g p ( x ∣ θ ) d z = l o g p ( x ∣ θ ) ∫ z q ( z ) d z = l o g p ( x ∣ θ ) (12.2.2) 左边=\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta )\tag{12.2.2} 左边=∫zq(z)⋅logp(x∣θ)dz=logp(x∣θ)∫zq(z)dz=logp(x∣θ)(12.2.2)
右 边 = ∫ z q ( z ) l o g p ( x , z ∣ θ ) q ( z ) d z ⏟ E L B O ( e v i d e n c e l o w e r b o u n d ) − ∫ z q ( z ) l o g p ( z ∣ x , θ ) q ( z ) d z ⏟ K L ( q ( z ) ∣ ∣ p ( z ∣ x , θ ) ) = L ( q ) ⏟ 变 分 + K L ( q ∣ ∣ p ) ⏟ ≥ 0 (12.2.3) 右边=\underset{ELBO(evidence\; lower\; bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}=\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}}\tag{12.2.3} 右边=ELBO(evidencelowerbound) ∫zq(z)logq(z)p(x,z∣θ)dzKL(q(z)∣∣p(z∣x,θ)) −∫zq(z)logq(z)p(z∣x,θ)dz=变分 L(q)+≥0 KL(q∣∣p)(12.2.3)
Evidence Lower Bound (ELBO)是变分, L ( q ) L(q) L(q)和 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)被记为:
{ L ( q ) = ∫ z q ( z ) log p ( x , z ∣ θ ) q ( z ) d z K L ( q ∣ ∣ p ) = − ∫ z q ( z ) log p ( z ∣ x ) q ( z ) d z \color{blue}\{ \begin{array}{ll}L(q)&=\int_z q(z)\log\ \frac{p(x,z|\theta)}{q(z)}dz\\ KL(q||p)&= - \int_z q(z)\log\ \frac{p(z|x)}{q(z)}dz \end{array} {L(q)KL(q∣∣p)=∫zq(z)log q(z)p(x,z∣θ)dz=−∫zq(z)log q(z)p(z∣x)dz
p ( x ) p(x) p(x)是个定值,我们的目的是寻找一个使得 q ( z ) q(z) q(z)与 p ( z ∣ x , θ ) p(z|x,\theta) p(z∣x,θ)更接近,也就是使 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)越小越好,也就是要使 L ( q ) L(q) L(q)越大越好:
q ~ ( z ) = arg max q ( z ) L ( q ) ⇒ q ~ ( z ) ≈ p ( z ∣ x ) (12.2.4) \color{blue}\tilde{q}(z)=\arg\underset{q(z)}{\max}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x)\tag{12.2.4} q~(z)=argq(z)maxL(q)⇒q~(z)≈p(z∣x)(12.2.4)- L ( q ) \color{red}L(q) L(q)并非普通的函数,而是以函数 q q q为自变量的函数,这就是 泛 函 \color{red}泛函 泛函。泛函可以看成是函数概念的推广,而变分方法是处理泛函的数学领域,和处理函数的普通微积分相对。
- 变 分 法 最 终 寻 求 的 是 极 值 函 数 : 它 们 使 得 泛 函 取 得 极 大 或 极 小 值 。 \color{red}变分法最终寻求的是极值函数:它们使得泛函取得极大或极小值。 变分法最终寻求的是极值函数:它们使得泛函取得极大或极小值。
-
模型求解
平均场理论:把多维变量的不同维度分为 M M M组,组与组之间是相互独立的:
q ( z ) = ∏ i = 1 M q i ( z i ) (12.2.5) \color{red}q(z)=\prod_{i=1}^{M}q_{i}(z_{i})\tag{12.2.5} q(z)=i=1∏Mqi(zi)(12.2.5)在这种分解的思想中,我们每次只考虑第 j \color{blue}j j个分布,那么令 q i ( 1 , 2 , ⋯ , j − 1 , j + 1 , ⋯ , M ) \color{blue}q_i(1,2,\cdots,j-1,j+1,\cdots,M) qi(1,2,⋯,j−1,j+1,⋯,M)个分布 f i x e d \color{blue}fixed fixed。将 L ( q ) L(q) L(q)写作两部分:
L ( q ) = ∫ z q ( z ) l o g p ( x , z ) d z ⏟ ① − ∫ z q ( z ) l o g q ( z ) d z ⏟ ② (12.2.6) L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}}\tag{12.2.6} L(q)=① ∫zq(z)logp(x,z)dz−② ∫zq(z)logq(z)dz(12.2.6)- 对于①:
① = ∫ z q ( z ) l o g p ( x , z ) d z = ∫ z ∏ i = 1 M q i ( z i ) l o g p ( x , z ) d z 1 d z 2 ⋯ d z M = ∫ z j q j ( z j ) ( ∫ z 1 ∫ z 2 ⋯ ∫ z M ∏ i ≠ j M q i ( z i ) l o g p ( x , z ) d z 1 d z 2 ⋯ d z M ( i ≠ j ) ) ⏟ ∫ z − z j l o g p ( x , z ) ∏ i ≠ j M q i ( z i ) d z i d z j = ∫ z j q j ( z j ) ⋅ E ∏ i ≠ j M q i ( z i ) [ l o g p ( x , z ) ] ⋅ d z j (12.2.7) \begin{aligned}①&=\int _{z}q(z)log\; p(x,z)\mathrm{d}z\\ &=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; p(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ &=\int _{z_{j}}q_{j}(z_{j})\underset{\int _{z-z_{j}}log\; p(x,z)\prod_{i\neq j}^{M}q_{i}(z_{i})\mathrm{d}z_{i}}{\underbrace{\left (\int_{z_1}\int_{z_2}\cdots\int_{z_M}\prod_{i\neq j}^{M}q_{i}(z_{i})log\; p(x,z)\underset{(i\neq j)}{\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}}\right )}}\mathrm{d}z_{j}\\ &=\int _{z_{j}}q_{j}(z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]\cdot \mathrm{d}z_{j}\end{aligned}\tag{12.2.7} ①=∫zq(z)logp(x,z)dz=∫zi=1∏Mqi(zi)logp(x,z)dz1dz2⋯dzM=∫zjqj(zj)∫z−zjlogp(x,z)∏i=jMqi(zi)dzi ⎝⎛∫z1∫z2⋯∫zMi=j∏Mqi(zi)logp(x,z)(i=j)dz1dz2⋯dzM⎠⎞dzj=∫zjqj(zj)⋅E∏i=jMqi(zi)[logp(x,z)]⋅dzj(12.2.7)
因为我们仅仅只关注第 j j j项,其他的项都不关注。为了进一步表达计算,我们将:
E ∏ i ≠ j M q i ( z i ) [ log p ( x , z ) ] = log p ^ ( x , z j ) (12.2.8) \mathbf{E}_{\prod_{i \neq j}^Mq_i(z_i)}\left[ \log p(x,z) \right] = \log \hat{p}(x,z_j)\tag{12.2.8} E∏i=jMqi(zi)[logp(x,z)]=logp^(x,zj)(12.2.8)
那么(12.2.7)式可以写作:
① = ∫ z q ( z ) l o g p ( x , z ) d z = ∫ z j q j ( z j ) log p ^ ( x , z j ) d z j (12.2.9) \color{red}\begin{aligned}①&=\int _{z}q(z)log\; p(x,z)\mathrm{d}z\\ & = \int_{z_j}q_j(z_j) \log \hat{p}(x,z_j) dz_j\end{aligned}\tag{12.2.9} ①=∫zq(z)logp(x,z)dz=∫zjqj(zj)logp^(x,zj)dzj(12.2.9)
这里的 p ^ ( x , z j ) \hat{p}(x,z_j) p^(x,zj)表示为一个相关的函数形式,假设具体参数未知。 - 对于②:
② = ∫ z q ( z ) l o g q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) ∑ i = 1 M l o g q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) [ l o g q 1 ( z 1 ) + l o g q 2 ( z 2 ) + ⋯ + l o g q M ( z M ) ] d z (12.2.10) \begin{aligned}②&=\int _{z}q(z)log\; q(z)\mathrm{d}z\\ &=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})\sum_{i=1}^{M}log\; q_{i}(z_{i})\mathrm{d}z\\ &=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})[log\; q_{1}(z_{1})+log\; q_{2}(z_{2})+\cdots +log\; q_{M}(z_{M})]\mathrm{d}z\end{aligned}\tag{12.2.10} ②=∫zq(z)logq(z)dz=∫zi=1∏Mqi(zi)i=1∑Mlogqi(zi)dz=∫zi=1∏Mqi(zi)[logq1(z1)+logq2(z2)+⋯+logqM(zM)]dz(12.2.10)
对其中第一项进行处理:
∫ z ∏ i = 1 M q i ( z i ) l o g q 1 ( z 1 ) d z = ∫ z 1 z 2 ⋯ z M q 1 ( z 1 ) q 2 ( z 2 ) ⋯ q M ( z M ) ⋅ l o g q 1 ( z 1 ) d z 1 d z 2 ⋯ d z M = ∫ z 1 q 1 ( z 1 ) l o g q 1 ( z 1 ) d z 1 ⋅ ∫ z 2 q 2 ( z 2 ) d z 2 ⏟ = 1 ⋅ ∫ z 3 q 3 ( z 3 ) d z 3 ⏟ = 1 ⋯ ∫ z M q M ( z M ) d z M ⏟ = 1 = ∫ z 1 q 1 ( z 1 ) l o g q 1 ( z 1 ) d z 1 (12.2.11) \begin{aligned} &\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{1}(z_{1})\mathrm{d}z\\ & =\int _{z_{1}z_{2}\cdots z_{M}}q_{1}(z_{1})q_{2}(z_{2})\cdots q_{M}(z_{M})\cdot log\; q_{1}(z_{1})\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ &=\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\cdot \underset{=1}{\underbrace{\int _{z_{2}}q_{2}(z_{2})\mathrm{d}z_{2}}}\cdot \underset{=1}{\underbrace{\int _{z_{3}}q_{3}(z_{3})\mathrm{d}z_{3}}}\cdots \underset{=1}{\underbrace{\int _{z_{M}}q_{M}(z_{M})\mathrm{d}z_{M}}}\\ &=\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\end{aligned}\tag{12.2.11} ∫zi=1∏Mqi(zi)logq1(z1)dz=∫z1z2⋯zMq1(z1)q2(z2)⋯qM(zM)⋅logq1(z1)dz1dz2⋯dzM=∫z1q1(z1)logq1(z1)dz1⋅=1 ∫z2q2(z2)dz2⋅=1 ∫z3q3(z3)dz3⋯=1 ∫zMqM(zM)dzM=∫z1q1(z1)logq1(z1)dz1(12.2.11)
也就是说:
∫ z ∏ i = 1 M q i ( z i ) l o g q k ( z k ) d z = ∫ z k q k ( z k ) l o g q k ( z k ) d z k (12.2.12) \int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{k}(z_{k})\mathrm{d}z=\int _{z_{k}}q_{k}(z_{k})log\; q_{k}(z_{k})\mathrm{d}z_{k}\tag{12.2.12} ∫zi=1∏Mqi(zi)logqk(zk)dz=∫zkqk(zk)logqk(zk)dzk(12.2.12)
则:
② = ∫ z q ( z ) l o g q ( z ) d z = ∑ i = 1 M ∫ z i q i ( z i ) l o g q i ( z i ) d z i = ∫ z j q j ( z j ) l o g q j ( z j ) d z j + C (12.2.13) \color{red}\begin{aligned}②&=\int _{z}q(z)log\; q(z)\mathrm{d}z\\&=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\; q_{i}(z_{i})\mathrm{d}z_{i}\\ &=\int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C\end{aligned}\tag{12.2.13} ②=∫zq(z)logq(z)dz=i=1∑M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+C(12.2.13) -
L
(
q
)
L(q)
L(q)可以写成:
L ( q ) = ∫ z q ( z ) l o g p ( x , z ) d z ⏟ ① − ∫ z q ( z ) l o g q ( z ) d z ⏟ ② = ∫ z j q j ( z j ) log p ^ ( x , z j ) d z j − ∫ z j q j ( z j ) l o g q j ( z j ) d z j + C = − K L ( q j ∣ ∣ p ^ ( x , z j ) ) + C (12.2.14) \begin{aligned}L(q)&=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}}\\ &= \int_{z_j}q_j(z_j) \log \hat{p}(x,z_j) dz_j - \int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C\\ &=-KL(q_j || \hat{p}(x,z_j)) +C \end{aligned}\tag{12.2.14} L(q)=① ∫zq(z)logp(x,z)dz−② ∫zq(z)logq(z)dz=∫zjqj(zj)logp^(x,zj)dzj−∫zjqj(zj)logqj(zj)dzj+C=−KL(qj∣∣p^(x,zj))+C(12.2.14)
其中 − K L ( q j ∣ ∣ p ^ ( x , z j ) ) ≤ 0 -KL(q_j || \hat{p}(x,z_j)) \leq 0 −KL(qj∣∣p^(x,zj))≤0,根据公式(12.2.4)可得:
q ~ ( z ) = arg max q ( z ) L ( q ) = arg max q j ( z j ) − K L ( q j ∣ ∣ p ^ ( x , z j ) ) = arg min q j ( z j ) K L ( q j ∣ ∣ p ^ ( x , z j ) ) (12.2.15) \color{red}\begin{aligned}\tilde{q}(z)&=\arg\underset{q(z)}{\max}\; L(q)\\ & = \arg\underset{q_j(z_j)}{\max}\; -KL(q_j || \hat{p}(x,z_j))\\ & = \arg\underset{q_j(z_j)}{\min}\;KL(q_j || \hat{p}(x,z_j))\end{aligned}\tag{12.2.15} q~(z)=argq(z)maxL(q)=argqj(zj)max−KL(qj∣∣p^(x,zj))=argqj(zj)minKL(qj∣∣p^(x,zj))(12.2.15)
当 log p ^ ( x , z j ) = E ∏ i ≠ j M q i ( z i ) [ log p ( x , z ) ] 取 最 小 值 \color{red}\log \hat{p}(x,z_j)=\mathbf{E}_{\prod_{i \neq j}^Mq_i(z_i)}\left[ \log p(x,z) \right]取最小值 logp^(x,zj)=E∏i=jMqi(zi)[logp(x,z)]取最小值:
log q j ( z j ) = E ∏ i ≠ j q i ( z i ) [ log p ( x , z ∣ θ ) ] + C (12.2.16) \color{red}\log q_j(z_j) = \mathbf{E}_{\prod_{i \neq j}q_i(z_i)}\left[ \log p(x,z|\theta) \right] + C\tag{12.2.16} logqj(zj)=E∏i=jqi(zi)[logp(x,z∣θ)]+C(12.2.16)- 公 式 ( 12.2.16 ) 就 是 V I 算 法 的 基 本 思 路 \color{blue}公式(12.2.16)就是VI算法的基本思路 公式(12.2.16)就是VI算法的基本思路。但是现实生活中 z z z很难求解,因此需要用平均场理论进行一下化简。
- 下一节将回归EM算法,并给出求解的过程。
- 对于①:
12.3 再回首
- Variational Inference(VI)的核心思想是在于用一个分布
q
(
z
)
q(z)
q(z)来近似得到
p
(
z
∣
x
)
p(z|x)
p(z∣x)。其中优化目标为:
q ^ = a r g m i n K L ( q ∣ ∣ p ) \hat{q} = argmin\ KL(q||p) q^=argmin KL(q∣∣p) - 在这个求解中,主要想求的是
q
(
z
)
q(z)
q(z),那么需要弱化
θ
\theta
θ的作用。所以,计算的目标函数为:
q ^ = arg min q K L ( q ∣ ∣ p ) = arg max q L ( q ) (12.3.1) \color{blue}\hat{q} = \arg\min_{q} KL(q||p) = \arg\max_q \mathcal{L}(q)\tag{12.3.1} q^=argqminKL(q∣∣p)=argqmaxL(q)(12.3.1)
所以本节对上一节的一些地方进行解释、对EM算法的符号进行规范化处理,以及对迭代方法进行求解。
-
平均场理论解释
平均场理论:把多维变量的不同维度分为 M M M组,组与组之间是相互独立的:
q ( z ) = ∏ i = 1 M q i ( z i ) (12.3.2) \color{red}q(z)=\prod_{i=1}^{M}q_{i}(z_{i})\tag{12.3.2} q(z)=i=1∏Mqi(zi)(12.3.2)
注: z i 表 示 的 不 是 一 个 数 , 而 是 一 个 数 据 维 度 的 集 合 , {\color{red} z_i表示的不是一个数,而是一个数据维度的集合,} zi表示的不是一个数,而是一个数据维度的集合, 它 表 示 的 不 是 一 个 维 度 , 而 是 一 个 类 似 的 最 大 团 , 也 就 是 多 个 维 度 凑 在 一 起 。 {\color{red} 它表示的不是一个维度,而是一个类似的最大团,也就是多个维度凑在一起。} 它表示的不是一个维度,而是一个类似的最大团,也就是多个维度凑在一起。 -
数学符号规范化(仔细与上一节进行对比)
- 数据
有以下数据:- X : o b s e r v e d v a r i a b l e → X : { x ( i ) } i = 1 N X:observed\;variable\rightarrow X:\left \{x^{(i)}\right \}_{i=1}^{N} X:observedvariable→X:{x(i)}i=1N
- Z : l a t e n t v a r i a b l e + p a r a m e t e r → Z : { z ( i ) } i = 1 N Z:latent\;variable + parameter\rightarrow Z:\left \{z^{(i)}\right \}_{i=1}^{N} Z:latentvariable+parameter→Z:{z(i)}i=1N
- ( X , Z ) : c o m p l e t e d a t a (X,Z):complete\;data (X,Z):completedata
- ELBO和KL
在这里我们弱化了相关参数 θ \theta θ,也就是求解过程中,不太考虑 θ \theta θ起到的作用。展示一下似然函数:
log p θ ( X ) = log ∏ i = 1 N p θ ( x ( i ) ) = ∑ i = 1 N log p θ ( x ( i ) ) (12.3.3) \log p_{\theta}(X) = \log \prod_{i=1}^N p_{\theta}(x^{(i)}) = \sum_{i=1}^N \log p_{\theta}(x^{(i)})\tag{12.3.3} logpθ(X)=logi=1∏Npθ(x(i))=i=1∑Nlogpθ(x(i))(12.3.3)
目标是使每一个 x ( i ) x^{(i)} x(i)最大,所以将对ELBO和 K L ( p ∣ ∣ q ) KL(p||q) KL(p∣∣q)进行规范化表达:-
E L B O \color{blue}ELBO ELBO(第十讲:公式(10.5.6)):
E q ( z ) [ log p θ ( x ( i ) , z ) q ( z ) ] = E q ( z ) [ log p θ ( x ( i ) , z ) ] + H ( q ( z ) ) (12.3.4) \mathbf{E}_{q(z)}\left[ \log \frac{p_{\theta}(x^{(i)},z)}{q(z)} \right] = \mathbf{E}_{q(z)}\left[ \log p_{\theta}(x^{(i)},z) \right]+ H(q(z))\tag{12.3.4} Eq(z)[logq(z)pθ(x(i),z)]=Eq(z)[logpθ(x(i),z)]+H(q(z))(12.3.4) -
K L \color{blue}KL KL(第十讲:公式(10.5.2)):
K L ( q ∣ ∣ p ) = ∫ q ( z ) ⋅ log q ( z ) p θ ( z ∣ x ( i ) ) d z (12.3.5) KL(q||p) = \int q(z)\cdot \log \frac{q(z)}{p_{\theta}(z|x^{(i)})} dz\tag{12.3.5} KL(q∣∣p)=∫q(z)⋅logpθ(z∣x(i))q(z)dz(12.3.5) -
log q j ( z j ) \color{blue}\log\;q_j(z_j) logqj(zj)(本节:公式(12.2.16))
log q j ( z j ) = E ∏ i ≠ j q i ( z i ) [ log p θ ( x ( i ) , z ) ] + C = ∫ q 1 ∫ q 2 ⋯ ∫ q j − 1 ∫ q j + 1 ⋯ ∫ q M q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q M log p θ ( x ( i ) , z ) d q 1 d q 2 ⋯ d q j − 1 d q j + 1 ⋯ d q M (12.3.6) \begin{aligned} & \log q_j(z_j)\\ & = \mathbf{E}_{\prod_{i \neq j} q_i(z_i)}\left[ \log p_{\theta} (x^{(i)},z) \right] + C \\ & = \int_{q_1} \int_{q_2} \cdots \int_{q_{j-1}}\int_{q_{j+1}} \cdots \int_{q_{M}} q_1q_2\cdots q_{j-1}q_{j+1} \cdots q_M \log p_{\theta} (x^{(i)},z)dq_1dq_2 \cdots dq_{j-1}dq_{j+1} \cdots dq_{M} \\ \end{aligned}\tag{12.3.6} logqj(zj)=E∏i=jqi(zi)[logpθ(x(i),z)]+C=∫q1∫q2⋯∫qj−1∫qj+1⋯∫qMq1q2⋯qj−1qj+1⋯qMlogpθ(x(i),z)dq1dq2⋯dqj−1dqj+1⋯dqM(12.3.6)
-
- 数据
-
VI算法的具体求解
根据 公 式 ( 12.2.16 ) 公式(12.2.16) 公式(12.2.16)使用迭代算法来进行求解:
q ^ 1 ( z 1 ) = ∫ q 2 ⋯ ∫ q M q 2 ⋯ q M [ log p θ ( x ( i ) , z ) ] d q 2 ⋯ d q M q ^ 2 ( z 2 ) = ∫ q ^ 1 ( z 1 ) ∫ q 3 ⋯ ∫ q M q ^ 1 q 3 ⋯ q M [ log p θ ( x ( i ) , z ) ] q ^ 1 d q 2 ⋯ d q M ⋮ q ^ M ( z M ) = ∫ q ^ 1 ⋯ ∫ q ^ M − 1 q ^ 1 ⋯ q ^ M − 1 [ log p θ ( x ( i ) , z ) ] d q ^ 1 ⋯ d q ^ M − 1 (12.3.7) \color{red}\begin{array}{ll} \hat{q}_1(z_1) = \int_{q_2} \cdots \int_{q_{M}} q_2 \cdots q_M \left[ \log p_{\theta}(x^{(i)},z) \right]dq_2 \cdots dq_{M} \\ \hat{q}_2(z_2) = \int_{\hat{q}_1(z_1)}\int_{q_3} \cdots \int_{q_{M}} \hat{q}_1q_3 \cdots q_M \left[ \log p_{\theta}(x^{(i)},z) \right]\hat{q}_1dq_2 \cdots dq_{M} \\ \vdots \\ \hat{q}_M(z_M) = \int_{\hat{q}_1} \cdots \int_{\hat{q}_{M-1}} \hat{q}_1 \cdots \hat{q}_{M-1} \left[ \log p_{\theta}(x^{(i)},z) \right]d\hat{q}_1 \cdots d\hat{q}_{M-1}\end{array}\tag{12.3.7} q^1(z1)=∫q2⋯∫qMq2⋯qM[logpθ(x(i),z)]dq2⋯dqMq^2(z2)=∫q^1(z1)∫q3⋯∫qMq^1q3⋯qM[logpθ(x(i),z)]q^1dq2⋯dqM⋮q^M(zM)=∫q^1⋯∫q^M−1q^1⋯q^M−1[logpθ(x(i),z)]dq^1⋯dq^M−1(12.3.7)
如果将 q 1 , q 2 , ⋯ , q M {q}_1,{q}_2,\cdots,{q}_M q1,q2,⋯,qM看成一个个的坐标点,那么随着计算的深入,知道的坐标点越来越多,这实际上就是一种坐标上升的方法(Coordinate Ascend)。这是一种迭代算法,那怎么考虑迭代的停止条件呢?设置当 L ( t + 1 ) ≤ L ( t ) \color{blue}\mathcal{L}^{(t+1)} \leq \mathcal{L}^{(t)} L(t+1)≤L(t)时停止迭代。
-
VI算法的整体步骤
针对平均场变分分布, 坐 标 上 升 近 似 推 断 算 法 ( C A V I ) \color{green}坐标上升近似推断算法(CAVI) 坐标上升近似推断算法(CAVI)是最常见的优化方法。CAVI交替地更新每个隐变量,更新时固定其他的隐变量的变分分布参数,用来计算当前隐变量 z j z_j zj的坐标上升公式。CAVI的算法步骤如下图所示。
用一张图来表示 q q q分布的变化。
-
Mean Field Theory(平均场理论)的存在问题
-
假
设
太
强
\color{red}假设太强
假设太强
首先这个假设太强了。在假设中,假设变分后验分式是一种完全可分解的分布。实际上,这样的适用条件挺少的。大部分时候都并不会适用。 -
I
n
t
r
a
c
t
a
b
l
e
\color{red}Intractable
Intractable
本来就是因为后验分布 p ( Z ∣ X ) p(Z|X) p(Z∣X)的计算非常的复杂,所以才使用变分推断来进行计算。但这个迭代的方法也非常的难以计算,
log q j ( z j ) = E ∏ i ≠ j q i ( z i ) [ log p ( X , Z ∣ θ ) ] + C (12.3.8) \log q_j(z_j) = \mathbf{E}_{\prod_{i \neq j}q_i(z_i)}\left[ \log p(X,Z|\theta) \right] + C\tag{12.3.8} logqj(zj)=E∏i=jqi(zi)[logp(X,Z∣θ)]+C(12.3.8)
并且公式(12.3.8)的计算也非常的复杂。所以需要寻找一种更加优秀的方法,比如Stein Disparency等等。Stein变分是个非常Fashion的东西,机器学习理论中非常强大的算法,以后会详细的分析。
-
假
设
太
强
\color{red}假设太强
假设太强
12.4 随机梯度变分推断-SGVI-1
- 在上一小节分析了 M e a n F i e l d T h e o r y V a r i a t i o n a l I n f e r e n c e \color{green}Mean\;Field\;Theory\;Variational\;Inference MeanFieldTheoryVariationalInference(平均场论变分推断),通过平均假设来得到变分推断的理论,是一种 C l a s s i c a l V I Classical\;VI ClassicalVI,可以将其看成 C o o r d i n a t e A s c e n d \color{green}Coordinate\;Ascend CoordinateAscend(坐标上升)。
- 本节为了克服Mean Field Theory的存在问题,介绍另一种方法是
S
t
o
c
h
a
s
t
i
c
G
r
a
d
i
e
n
t
V
a
r
i
a
t
i
o
n
a
l
I
n
f
e
r
e
n
c
e
\color{green}Stochastic\;Gradient\;Variational\;Inference
StochasticGradientVariationalInference(SGVI,随机梯度变分推断)。
对于隐变量参数 z z z和数据集 x x x。
- z ⟶ x \color{red}z \longrightarrow x z⟶x是Generative Model,也就是 p ( x ∣ z ) p(x|z) p(x∣z)和 p ( x , z ) p(x,z) p(x,z),这个过程也被我们称为 D e c o d e r \color{red}Decoder Decoder。
- x ⟶ z \color{red}x \longrightarrow z x⟶z是Inference Model,表达关系是 p ( z ∣ x ) p(z|x) p(z∣x),这个过程被我们称为 E n c o d e r \color{red}Encoder Encoder。
本节先对SGVI参数规范,然后SGVI的梯度推导。
- SGVI参数规范
本节的 S t o c h a s t i c G r a d i e n t V a r i a t i o n a l I n f e r e n c e ( S G V I ) \color{green}Stochastic\;Gradient\;Variational\;Inference (SGVI) StochasticGradientVariationalInference(SGVI)方法的基本思路(此处参数更新和平均场论变分推理方法的参数的更新方法类似)为:
ϕ ( t + 1 ) ⟶ ϕ ( t ) + λ ( t ) ∇ L ( q ) (12.4.1) \color{red}\phi^{(t+1)} \longrightarrow \phi^{(t)} + \lambda^{(t)}\nabla {L}(q)\tag{12.4.1} ϕ(t+1)⟶ϕ(t)+λ(t)∇L(q)(12.4.1)
其中, q ( z ∣ x ) q(z|x) q(z∣x)简化表示为 q ( z ) q(z) q(z);令 q ( z ) q(z) q(z)是一个固定形式的概率分布, ϕ \phi ϕ为这个分布的参数,那么这个概率可写成 q ϕ ( z ) \color{blue}q_{\phi}(z) qϕ(z)。 目 标 就 是 求 解 ∇ L ( q ) ( ∇ ϕ L ( ϕ ) ) \color{blue}目标就是求解\nabla {L}(q)(\nabla_{\phi}{L}(\phi)) 目标就是求解∇L(q)(∇ϕL(ϕ))。- 那么ELBO(
L
(
q
)
=
∫
z
q
(
z
)
log
p
(
x
,
z
∣
θ
)
q
(
z
)
d
z
\color{blue}L(q)=\int_z q(z)\log\ \frac{p(x,z|\theta)}{q(z)}dz
L(q)=∫zq(z)log q(z)p(x,z∣θ)dz)被记为:
E L B O = L ( ϕ ) = E q ϕ ( z ) [ log p θ ( x ( i ) , z ) − log q ϕ ( z ) ] (12.4.2) \color{red}ELBO = {L}(\phi)= \mathbf{E}_{q_{\phi}(z)}\left[ \log p_{\theta}(x^{(i)},z) - \log q_{\phi}(z) \right]\tag{12.4.2} ELBO=L(ϕ)=Eqϕ(z)[logpθ(x(i),z)−logqϕ(z)](12.4.2) -
l
o
g
p
(
x
)
log\;p(x)
logp(x)(
l
o
g
p
(
x
)
=
E
L
B
O
+
K
L
(
q
∣
∣
p
)
log\; p(x)=ELBO+KL(q||p)
logp(x)=ELBO+KL(q∣∣p))可以写为:
log p θ ( x ( i ) ) = E L B O + K L ( q ∣ ∣ p ) ≥ L ( ϕ ) (12.4.3) \log p_{\theta}(x^{(i)}) = ELBO + KL(q||p) \geq {L}(\phi)\tag{12.4.3} logpθ(x(i))=ELBO+KL(q∣∣p)≥L(ϕ)(12.4.3)
因此求解目标转换成:
p ^ = arg max ϕ L ( ϕ ) (12.4.4) \hat{p} = \arg\max_{\phi} {L}(\phi)\tag{12.4.4} p^=argϕmaxL(ϕ)(12.4.4)
- 那么ELBO(
L
(
q
)
=
∫
z
q
(
z
)
log
p
(
x
,
z
∣
θ
)
q
(
z
)
d
z
\color{blue}L(q)=\int_z q(z)\log\ \frac{p(x,z|\theta)}{q(z)}dz
L(q)=∫zq(z)log q(z)p(x,z∣θ)dz)被记为:
- SGVI的梯度推导
- 根据公式(12.4.1)和公式(12.4.2)得:
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z ⏟ ① + ∫ q ϕ ( z ) ∇ ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z ⏟ ② (12.4.5) \color{blue}\begin{aligned}\nabla_{\phi }L(\phi )& =\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ &=\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ &=\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(z)\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}+\underset{②}{\underbrace{\int q_{\phi }(z)\nabla_{\phi }[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}\end{aligned}\tag{12.4.5} ∇ϕL(ϕ)=∇ϕEqϕ[logpθ(x,z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logpθ(x,z)−logqϕ(z)]dz=① ∫∇ϕqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz+② ∫qϕ(z)∇ϕ[logpθ(x,z)−logqϕ(z)]dz(12.4.5) - 在对其中①和②单独计算:
② = ∫ q ϕ ( z ) ∇ ϕ [ l o g p θ ( x , z ) ⏟ 与 ϕ 无 关 − l o g q ϕ ( z ) ] d z = − ∫ q ϕ ( z ) ∇ ϕ l o g q ϕ ( z ) d z = − ∫ q ϕ ( z ) 1 q ϕ ( z ) ∇ ϕ q ϕ ( z ) d z = − ∫ ∇ ϕ q ϕ ( z ) d z = − ∇ ϕ ∫ q ϕ ( z ) d z = − ∇ ϕ 1 = 0 (12.4.6) \begin{aligned}②&=\int q_{\phi }(z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(x,z)}}-log\; q_{\phi }(z)]\mathrm{d}z\\ &=-\int q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)\mathrm{d}z\\ &=-\int q_{\phi }(z)\frac{1}{q_{\phi }(z)}\nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ & =-\int \nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ &=-\nabla_{\phi }\int q_{\phi }(z)\mathrm{d}z\\ &=-\nabla_{\phi }1=0\end{aligned}\tag{12.4.6} ②=∫qϕ(z)∇ϕ[与ϕ无关 logpθ(x,z)−logqϕ(z)]dz=−∫qϕ(z)∇ϕlogqϕ(z)dz=−∫qϕ(z)qϕ(z)1∇ϕqϕ(z)dz=−∫∇ϕqϕ(z)dz=−∇ϕ∫qϕ(z)dz=−∇ϕ1=0(12.4.6)
因此公式(12.4.5)可以简化为:
∇ ϕ L ( ϕ ) = ① = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = ∫ q ϕ ( z ) ∇ ϕ l o g q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = E q ϕ [ ( ∇ ϕ l o g q ϕ ( z ) ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ] (12.4.7) \begin{aligned}\nabla_{\phi }L(\phi )=① &=\int {\color{Red}{\nabla_{\phi }q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ &=\int {\color{Red}{q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ &=E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\end{aligned}\tag{12.4.7} ∇ϕL(ϕ)=①=∫∇ϕqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz=∫qϕ(z)∇ϕlogqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz=Eqϕ[(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))](12.4.7)
其中红色部分是根据公式(12.4.6)的第二行到第四行得到的。因此:
∇ ϕ L ( ϕ ) = E q ϕ [ ∇ ϕ log q ϕ ( log p θ ( x ( i ) , z ) − log q ϕ ) ] (12.4.8) \color{red}\nabla_{\phi} {L}(\phi) = \mathbf{E}_{q_{\phi}} \left[ \nabla_{\phi}\log q_{\phi} (\log p_{\theta}(x^{(i)},z) - \log q_{\phi}) \right]\tag{12.4.8} ∇ϕL(ϕ)=Eqϕ[∇ϕlogqϕ(logpθ(x(i),z)−logqϕ)](12.4.8)
那么如何求这个期望呢?采用的是蒙特卡罗采样法,假设 z l ∼ q ϕ ( z ) l = 1 , 2 , ⋯ , L z^l \sim q_{\phi} (z)\ l = 1, 2, \cdots, L zl∼qϕ(z) l=1,2,⋯,L,那么有:
∇ ϕ L ( ϕ ) ≈ 1 L ∑ l = 1 L ∇ ϕ log q ϕ ( z ( l ) ) [ log p θ ( x ( i ) , z ) − log q ϕ ( z ( l ) ) ] (12.4.9) \color{blue}\nabla_{\phi} {L}(\phi) \approx \frac{1}{L} \sum_{l=1}^L \nabla_{\phi}\log q_{\phi}(z^{(l)})\left[ \log p_{\theta}(x^{(i)},z) - \log q_{\phi}(z^{(l)})\right]\tag{12.4.9} ∇ϕL(ϕ)≈L1l=1∑L∇ϕlogqϕ(z(l))[logpθ(x(i),z)−logqϕ(z(l))](12.4.9)
- 根据公式(12.4.1)和公式(12.4.2)得:
12.5 随机梯度变分推断-SGVI-2
本节继上一节的内容,介绍Variance Reduction(方差缩减) 。
- 存在问题
上节最后的公式(12.4.8):
∇ ϕ L ( ϕ ) = E q ϕ [ ∇ ϕ log q ϕ ( log p θ ( x ( i ) , z ) − log q ϕ ) ] \nabla_{\phi} {L}(\phi) = \mathbf{E}_{q_{\phi}} \left[ {\color{red}\nabla_{\phi}\log q_{\phi}}( \log p_{\theta}(x^{(i)},z) - \log q_{\phi}) \right] ∇ϕL(ϕ)=Eqϕ[∇ϕlogqϕ(logpθ(x(i),z)−logqϕ)]
这样的求法存在问题?- 一方面在采样的过程中,可能采到 q ϕ ( z ) ⟶ 0 \color{red}q_{\phi}(z) \longrightarrow 0 qϕ(z)⟶0的点,对于log函数来说, lim x ⟶ 0 l o g x = ∞ \color{red}\underset{x\longrightarrow 0}{\lim}log\;x = \infty x⟶0limlogx=∞,那么梯度的变化会非常的剧烈,非常的不稳定。就会出现 H i g h V a r i a n c e High\;Variance HighVariance的问题,没有办法求解。
- 另一方面 ϕ ^ ⟶ q ( z ) \color{red}\hat{\phi} \longrightarrow q(z) ϕ^⟶q(z)也有误差,此误差和梯度剧烈变化带来的误差,误差叠加,这算法根本没有办法用。
- 解决方法
- 整体思路:利用一个确定的解 p ( ϵ ) \color{red}p(\epsilon) p(ϵ),简化计算。因为 z z z来自于 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x),将 z z z中的随机变量给解放出来。
- 改写方法
即: 使用 转 换 z = g ϕ ( ϵ , x ( i ) ) \color{red}转换z = g_{\phi}(\epsilon, x^{(i)}) 转换z=gϕ(ϵ,x(i)),其中 ϵ ∼ p ( ϵ ) \epsilon \sim p(\epsilon) ϵ∼p(ϵ)。这样做有什么好处呢?- 原来的 ∇ ϕ E q ϕ [ ⋅ ] \nabla_{\phi} \mathbf{E}_{q_{\phi}}[\cdot] ∇ϕEqϕ[⋅]将转换为 E p ( ϵ ) [ ∇ ϕ ( ⋅ ) ] \mathbf{E}_{p(\epsilon)}[\nabla_{\phi}(\cdot)] Ep(ϵ)[∇ϕ(⋅)], 方 差 \color{blue}方差 方差不再是连续的关于 ϕ \phi ϕ的采样,可以有效的降低方差。
- 并且,
z
z
z 是一个关于
ϵ
{\epsilon}
ϵ 的函数,我们将随机性转移到了
ϵ
{\epsilon}
ϵ,那么问题就可以简化为:
z ∼ q ϕ ( z ∣ x ( i ) ) ⟶ ϵ ∼ p ( ϵ ) (12.5.1) \color{red}z \sim q_{\phi}(z|x^{(i)}) \longrightarrow \epsilon \sim p(\epsilon)\tag{12.5.1} z∼qϕ(z∣x(i))⟶ϵ∼p(ϵ)(12.5.1) - 因为
∫
q
ϕ
(
z
∣
x
(
i
)
)
d
z
=
∫
p
(
ϵ
)
d
ϵ
=
1
\int q_{\phi}(z|x^{(i)})dz = \int p(\epsilon)d\epsilon = 1
∫qϕ(z∣x(i))dz=∫p(ϵ)dϵ=1,则
q
ϕ
(
z
∣
x
(
i
)
)
q_{\phi}(z|x^{(i)})
qϕ(z∣x(i))和
p
(
ϵ
)
p(\epsilon)
p(ϵ)之间存在一个变换关系,即:
∣ q ϕ ( z ∣ x ( i ) ) d z ∣ = ∣ p ( ϵ ) d ϵ ∣ (12.5.2) \color{red}|q_{\phi}(z|x^{(i)})dz| = |p(\epsilon)d\epsilon|\tag{12.5.2} ∣qϕ(z∣x(i))dz∣=∣p(ϵ)dϵ∣(12.5.2)
- 改写
∇
ϕ
L
(
ϕ
)
\nabla_{\phi} \mathcal{L}(\phi)
∇ϕL(ϕ)
改写 ∇ ϕ L ( ϕ ) \nabla_{\phi} \mathcal{L}(\phi) ∇ϕL(ϕ):
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ log p θ ( x ( i ) , z ) − log q ϕ ] = ∇ ϕ ∫ [ log p θ ( x ( i ) , z ) − log q ϕ ] q ϕ d z = ∇ ϕ ∫ [ log p θ ( x ( i ) , z ) − log q ϕ ] p ( ϵ ) d ϵ = ∇ ϕ E p ( ϵ ) [ log p θ ( x ( i ) , z ) − log q ϕ ] ( E p ( ϵ ) 中 的 p ( ϵ ) 与 梯 度 ϕ 无 关 ) = E p ( ϵ ) ∇ ϕ [ ( log p θ ( x ( i ) , z ) − log q ϕ ) ] = E p ( ϵ ) ∇ z [ ( log p θ ( x ( i ) , z ) − log q ϕ ( z ∣ x ( i ) ) ) ∇ ϕ z ] = E p ( ϵ ) ∇ z [ ( log p θ ( x ( i ) , z ) − log q ϕ ( z ∣ x ( i ) ) ) ∇ ϕ z ] = E p ( ϵ ) ∇ z [ ( log p θ ( x ( i ) , z ) − log q ϕ ( z ∣ x ( i ) ) ) ∇ ϕ g ϕ ( ϵ , x ( i ) ) ] (12.5.3) \begin{aligned} \nabla_{\phi} \mathcal{L}(\phi) & = \nabla_{\phi} \mathbf{E}_{q_{\phi}}\left[ \log p_{\theta}(x^{(i)},z) - \log q_{\phi} \right] \\ &= \nabla_{\phi} \int \left[ \log p_{\theta}(x^{(i)},z) - \log q_{\phi} \right]q_{\phi} dz \\ & = \nabla_{\phi} \int \left[ \log p_{\theta}(x^{(i)},z) - \log q_{\phi} \right]p(\epsilon) d\epsilon \\ & = \nabla_{\phi} \mathbf{E}_{p(\epsilon)}\left[ \log p_{\theta}(x^{(i)},z) - \log q_{\phi} \right] ({\color{blue}\mathbf{E}_{p(\epsilon)}中的p(\epsilon)与梯度\phi无关})\\ & = \mathbf{E}_{p(\epsilon)} \nabla_{\phi} \left[( \log p_{\theta}(x^{(i)},z) - \log q_{\phi}) \right] \\ & = \mathbf{E}_{p(\epsilon)}\nabla_{z}\left[( \log p_{\theta}(x^{(i)},z) - \log q_{\phi}(z|x^{(i)}))\nabla_{\phi}z \right] \\ & = \mathbf{E}_{p(\epsilon)}\nabla_{z}\left[( \log p_{\theta}(x^{(i)},z) - \log q_{\phi}(z|x^{(i)}))\nabla_{\phi}z \right] \\ & = \mathbf{E}_{p(\epsilon)}\nabla_{z}\left[( \log p_{\theta}(x^{(i)},z) - \log q_{\phi}(z|x^{(i)}))\nabla_{\phi}g_{\phi}(\epsilon, x^{(i)}) \right]\end{aligned}\tag{12.5.3} ∇ϕL(ϕ)=∇ϕEqϕ[logpθ(x(i),z)−logqϕ]=∇ϕ∫[logpθ(x(i),z)−logqϕ]qϕdz=∇ϕ∫[logpθ(x(i),z)−logqϕ]p(ϵ)dϵ=∇ϕEp(ϵ)[logpθ(x(i),z)−logqϕ](Ep(ϵ)中的p(ϵ)与梯度ϕ无关)=Ep(ϵ)∇ϕ[(logpθ(x(i),z)−logqϕ)]=Ep(ϵ)∇z[(logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕz]=Ep(ϵ)∇z[(logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕz]=Ep(ϵ)∇z[(logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕgϕ(ϵ,x(i))](12.5.3)
即:
∇ ϕ L ( ϕ ) = E p ( ϵ ) ∇ z [ ( log p θ ( x ( i ) , z ) − log q ϕ ( z ∣ x ( i ) ) ) ∇ ϕ g ϕ ( ϵ , x ( i ) ) ] (12.2.4) \nabla_{\phi} \mathcal{L}(\phi)= \mathbf{E}_{p(\epsilon)}\nabla_{z}\left[( \log p_{\theta}(x^{(i)},z) - \log q_{\phi}(z|x^{(i)}))\nabla_{\phi}g_{\phi}(\epsilon, x^{(i)}) \right]\tag{12.2.4} ∇ϕL(ϕ)=Ep(ϵ)∇z[(logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕgϕ(ϵ,x(i))](12.2.4)
因为 p ( ϵ ) p(\epsilon) p(ϵ)的采样与 ϕ \phi ϕ无关,求解步骤可以是:- 先求关于 ϕ \phi ϕ的梯度;
- 然后再求关于 z z z的梯度;
- 最后,我们再对结果进行采样,
ϵ
(
l
)
∼
p
(
ϵ
)
,
l
=
1
,
2
,
⋯
,
L
\epsilon^{(l)} \sim p(\epsilon), \quad l = 1, 2, \cdots, L
ϵ(l)∼p(ϵ),l=1,2,⋯,L。
那么这三者之间就互相隔离开了。
- 小结
SGVI可以简要的表述为:对于分布为 q ϕ ( Z ∣ X ) q_{\phi}(Z|X) qϕ(Z∣X), ϕ \phi ϕ为参数,参数的更新方法为:
ϕ ( t + 1 ) ⟶ ϕ ( t ) + λ ( t ) ∇ ϕ L ( ϕ ) (12.5.5) \color{red}\phi^{(t+1)} \longrightarrow \phi^{(t)} + \lambda^{(t)}\nabla_{\phi} \mathcal{L}(\phi)\tag{12.5.5} ϕ(t+1)⟶ϕ(t)+λ(t)∇ϕL(ϕ)(12.5.5)
对公式(12.2.4)使用蒙特卡洛方法, ∇ ϕ L ( ϕ ) \nabla_{\phi} \mathcal{L}(\phi) ∇ϕL(ϕ)为:
∇ ϕ L ( ϕ ) ≈ 1 L ∑ i = 1 L ∇ z [ log p θ ( x ( i ) , z ) − log q ϕ ( z ∣ x ( i ) ) ) ∇ ϕ g ϕ ( ϵ , x ( i ) ) ] (12.5.6) \color{red}\nabla_{\phi} \mathcal{L}(\phi) \approx \frac{1}{L} \sum_{i=1}^L \nabla_{z} \left[ \log p_{\theta}(x^{(i)},z) - \log q_{\phi}(z|x^{(i)}))\nabla_{\phi}g_{\phi}(\epsilon, x^{(i)}) \right]\tag{12.5.6} ∇ϕL(ϕ)≈L1i=1∑L∇z[logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕgϕ(ϵ,x(i))](12.5.6)
其中 z ⟵ g ϕ ( ϵ ( i ) , x ( i ) ) z \longleftarrow g_{\phi}(\epsilon^{(i)},x^{(i)}) z⟵gϕ(ϵ(i),x(i)), ϵ ( l ) ∼ p ( ϵ ) , l = 1 , 2 , ⋯ , L \epsilon^{(l)} \sim p(\epsilon), \quad l = 1, 2, \cdots, L ϵ(l)∼p(ϵ),l=1,2,⋯,L。