扩散模型需要一点变分贝叶斯的知识,基本出于功利性目的,停留在浅尝辄止的程度。
朴素贝叶斯
- 数据(Data)x={x1,x2,⋯ ,xn}{\bold x}=\{x_1,x_2,\cdots,x_n\}x={x1,x2,⋯,xn}
- 参数(Parameter)z={z1,z2,⋯ ,zm}{\bold z}=\{z_1,z_2,\cdots,z_m\}z={z1,z2,⋯,zm}
- 先验(Prior)p(z)p({\bold z})p(z)
- 后验(Posterior)p(z∣x)p({\bold z}\vert{\bold x})p(z∣x)
先验一般使用常见的概率分布,比如扩散模型(Diffusion Model)选择高斯分布 N(0,I){\cal N}({\bold 0},{\bold I})N(0,I)。
后验 p(z∣x)p({\bold z}\vert{\bold x})p(z∣x) 即在 x\bold xx 的分布下 z\bold zz 的条件概率,也就是根据数据 x\bold xx 来估计参数 z\bold zz。相比于先验,后验经过了数据 x{\bold x}x 的修正,因此能够更加贴合真实值。
后验的计算方法(贝叶斯公式):
p(z∣x)=p(z,x)p(x)=p(x∣z)⋅p(z)p(x)
p({\bold z}\vert{\bold x})=\frac{p({\bold z},{\bold x})}{p({\bold x})}=\frac{p({\bold x}\vert{\bold z})\cdot p({\bold z})}{p({\bold x})}
p(z∣x)=p(x)p(z,x)=p(x)p(x∣z)⋅p(z)
- 似然(Likelihood)p(x∣z)p({\bold x}\vert{\bold z})p(x∣z)
- 证据(Evidence)p(x)p({\bold x})p(x)
变分贝叶斯
后验没有解析表示,需要采用近似方法计算。马尔科夫链蒙特卡洛(MCMC)就是一种典型思路,虽然它能得到相对精确的结果,但是速度非常慢。巧妙一点的方法是把问题转化到凸优化上来。
假设在某一函数族 Q{\mathbb Q}Q 内寻找与 x\bold xx 无关的概率密度函数 q(z)q({\bold z})q(z) 来近似 p(z∣x)p({\bold z}\vert{\bold x})p(z∣x),优化目标:
q∗(z)=argminq(z)∈QL(q(z),p(z∣x))
q^*({\bold z})=\mathop{\arg\min}\limits_{q({\bold z})\in{\mathbb Q}}{{\cal L}\left(q({\bold z}),p({\bold z}\vert{\bold x})\right)}
q∗(z)=q(z)∈QargminL(q(z),p(z∣x))
q∗(z)q^*({\bold z})q∗(z) 是我们追求的理想近似函数,L\cal LL 是我们的度量函数,用于衡量函数的近似水平。
扩散模型一般选择 KL 散度:
KL(q(z)∥p(z∣x))=∫zq(z)logq(z)p(z∣x)dz=∫zq(z)logq(z)⋅p(x)p(z,x)dz=∫zq(z)logq(z)dz−∫zq(z)logp(z,x)dz+logp(x)∫zq(z)dz=Eqlogq(z)−Eqlogp(z,x)+logp(x)
\begin{aligned}
\mathop{\rm KL}{\left(q({\bold z})\middle\Vert p({\bold z}\vert{\bold x})\right)}
&=\int_{\bold z}{q({\bold z})\log\frac{q({\bold z})}{p({\bold z}\vert{\bold x})}{{\rm d}{\bold z}}} \\
&=\int_{\bold z}{q({\bold z})\log\frac{q({\bold z})\cdot p({\bold x})}{p({\bold z},{\bold x})}{{\rm d}{\bold z}}} \\
&=\int_{\bold z}{q({\bold z})\log{q({\bold z})}{{\rm d}{\bold z}}}
-\int_{\bold z}{q({\bold z})\log{p({\bold z},{\bold x})}{{\rm d}{\bold z}}}
+\log{p({\bold x})\int_{\bold z}{q({\bold z})}{{\rm d}{\bold z}}}
\\
&=\mathop{{\bf E}_q}{\log{q({\bold z})}}-\mathop{{\bf E}_q}{\log{p({\bold z},{\bold x})}}+\log{p({\bold x})}
\end{aligned}
KL(q(z)∥p(z∣x))=∫zq(z)logp(z∣x)q(z)dz=∫zq(z)logp(z,x)q(z)⋅p(x)dz=∫zq(z)logq(z)dz−∫zq(z)logp(z,x)dz+logp(x)∫zq(z)dz=Eqlogq(z)−Eqlogp(z,x)+logp(x)
上式的前两项取负号,记作证据下界(Evidence Lower Bound,ELBO):
ELBO(q)=Eqlogp(z,x)−Eqlogq(z)
{\rm ELBO}(q)=\mathop{{\bf E}_q}{\log{p({\bold z},{\bold x})}}-\mathop{{\bf E}_q}{\log{q({\bold z})}}
ELBO(q)=Eqlogp(z,x)−Eqlogq(z)
于是:
logp(x)=KL(q(z)∥p(z∣x))+ELBO(q)≥ELBO(q)
\log{p({\bold x})}=\mathop{\rm KL}{\left(q({\bold z})\middle\Vert p({\bold z}\vert{\bold x})\right)}+{\rm ELBO}(q)\ge{\rm ELBO}(q)
logp(x)=KL(q(z)∥p(z∣x))+ELBO(q)≥ELBO(q)
上式证据 logp(x)\log{p({\bold x})}logp(x) 是与 qqq 无关的常数,从而优化目标等价于:
q∗(z)=argminq(z)∈QKL(q(z)∥p(z∣x))=argmaxq(z)∈QELBO(q)
q^*({\bold z})=\mathop{\arg\min}\limits_{q({\bold z})\in{\mathbb Q}}{\mathop{\rm KL}{\left(q({\bold z})\middle\Vert p({\bold z}\vert{\bold x})\right)}}=\mathop{\arg\max}\limits_{q({\bold z})\in{\mathbb Q}}{{\rm ELBO}(q)}
q∗(z)=q(z)∈QargminKL(q(z)∥p(z∣x))=q(z)∈QargmaxELBO(q)
坐标下降法
为参数 z\bold zz 的每一个分量独立估计各自的分布(平均场假设):
q(z)=∏j=1mqj(zj)
q({\bold z})=\prod_{j=1}^m{q_j(z_j)}
q(z)=j=1∏mqj(zj)
固定其余参数,优化目标:
qj∗(zj)=argmaxqjELBO(q)
q_j^*(z_j)=\mathop{\arg\max}\limits_{q_j}{{\rm ELBO}(q)}
qj∗(zj)=qjargmaxELBO(q)
考虑关于 jjj 的 ELBO:
ELBO(q)=Eqlogp(z,x)−Eqlog∏i=1mqi(zi)=Eqlogp(z,x)−Eqjlogqj(zj)−∑i≠jEqilogqi(zi)=∫zq(z)logp(z,x)dz−∫zjqj(zj)logqj(zj)dzj+C=∫zjqj(zj)dzj[∫z′=z−{zj}q(z′)logp(z,x)dz′]−∫zjqj(zj)logqj(zj)dzj+C=∫zjqj(zj)Eq−jlogp(z,x)dzj−∫zjqj(zj)logqj(zj)dzj+C=−∫zjqj(zj)logqj(zj)expEq−jlogp(z,x)dzj+C=−KL(qj(zj)∥expEq−jlogp(z,x))+C
\begin{aligned}
{\rm ELBO}(q)
&=\mathop{{\bf E}_q}{\log{p({\bold z},{\bold x})}}-\mathop{{\bf E}_q}{\log{\prod_{i=1}^m{q_i(z_i)}}} \\
&=\mathop{{\bf E}_q}{\log{p({\bold z},{\bold x})}}-\mathop{{\bf E}_{q_j}}{\log{q_j(z_j)}}-\sum_{i\ne j}{\mathop{{\bf E}_{q_i}}{\log{q_i(z_i)}}} \\
&=\int_{\bold z}{q({\bold z})\log{p({\bold z},{\bold x})}{{\rm d}{\bold z}}}
-\int_{z_j}{q_j(z_j)\log{q_j(z_j)}{{\rm d}z_j}}+C \\
&=\int_{z_j}{q_j(z_j){{\rm d}z_j}\left[\int_{{\bold z'}={\bold z}-\{z_j\}}{q({\bold z'})\log{p({\bold z},{\bold x})}{{\rm d}{\bold z'}}}\right]}
-\int_{z_j}{q_j(z_j)\log{q_j(z_j)}{{\rm d}z_j}}+C \\
&=\int_{z_j}{q_j(z_j)\mathop{{\bf E}_{q_{-j}}}{\log{p({\bold z},{\bold x})}}{{\rm d}z_j}}
-\int_{z_j}{q_j(z_j)\log{q_j(z_j)}{{\rm d}z_j}}+C \\
&=-\int_{z_j}{q_j(z_j)\log\frac{q_j(z_j)}{\exp\mathop{{\bf E}_{q_{-j}}}{\log{p({\bold z},{\bold x})}}}{{\rm d}z_j}}+C \\
&=-\mathop{\rm KL}{\left(q_j(z_j)\middle\Vert \exp\mathop{{\bf E}_{q_{-j}}}{\log{p({\bold z},{\bold x})}}\right)}+C
\end{aligned}
ELBO(q)=Eqlogp(z,x)−Eqlogi=1∏mqi(zi)=Eqlogp(z,x)−Eqjlogqj(zj)−i=j∑Eqilogqi(zi)=∫zq(z)logp(z,x)dz−∫zjqj(zj)logqj(zj)dzj+C=∫zjqj(zj)dzj[∫z′=z−{zj}q(z′)logp(z,x)dz′]−∫zjqj(zj)logqj(zj)dzj+C=∫zjqj(zj)Eq−jlogp(z,x)dzj−∫zjqj(zj)logqj(zj)dzj+C=−∫zjqj(zj)logexpEq−jlogp(z,x)qj(zj)dzj+C=−KL(qj(zj)expEq−jlogp(z,x))+C
从而:
qj∗(zj)=argminqjKL(qj(zj)∥expEq−jlogp(z,x))=expEq−jlogp(z,x)
q_j^*(z_j)=\mathop{\arg\min}\limits_{q_j}{\mathop{\rm KL}{\left(q_j(z_j)\middle\Vert \exp\mathop{{\bf E}_{q_{-j}}}{\log{p({\bold z},{\bold x})}}\right)}}=\exp\mathop{{\bf E}_{q_{-j}}}{\log{p({\bold z},{\bold x})}}
qj∗(zj)=qjargminKL(qj(zj)expEq−jlogp(z,x))=expEq−jlogp(z,x)
由于 q(z)q({\bold z})q(z) 整体需要满足概率分布,对每个分量进行归一化:
qj′(zj)=qj∗(zj)∫jqj∗(zj)dzj
q_j'(z_j)=\frac{q_j^*(z_j)}{\displaystyle\int_{j}{q_j^*(z_j){{\rm d}z_j}}}
qj′(zj)=∫jqj∗(zj)dzjqj∗(zj)