上一篇SDE文章中,从去噪分数匹配、分数网络出发,逐渐将SDE引入到扩散过程表征中,主要集中在原理讲解和推导中,本文则进一步补充能在扩散过程中实际使用的SDE公式实例。
一种SDE定义
相信很多读者自己学习或在看完上一篇文章后,对SDE实际如何使用情况还是不清楚。其实,因为SDE是随机微分方程,所以其与DDPM中加噪的方差其实都是人为设定的,相当于超参数。所以在训练前就需要将SDE具体形式定义好,在此将SDE定义为
d
x
=
σ
t
d
ω
t
∈
[
0
,
1
]
(1)
dx=\sigma^t d\omega \quad t \in [0,1] \tag1
dx=σtdωt∈[0,1](1)
公式(1)中丢弃了SDE定义中的漂移部分,只设置了扩散部分。此时已将时间区间正则化,即范围为[0,1],也知道了
d
x
dx
dx随时间变化的公式,因为当前加噪是连续过程,故
x
(
t
)
=
x
(
0
)
+
∫
0
t
d
x
=
x
(
0
)
+
∫
0
t
σ
s
d
ω
(
s
)
(2)
x(t) = x(0) + \int_0^tdx=x(0)+\int_0^t\sigma^sd\omega(s) \tag2
x(t)=x(0)+∫0tdx=x(0)+∫0tσsdω(s)(2)
我们知道
ω
(
s
)
\omega(s)
ω(s)是标准布朗运动,其符合高斯分布
N
(
0
,
t
I
)
N(0,tI)
N(0,tI);而
∫
0
t
σ
s
d
ω
(
s
)
\int_0^t\sigma^sd\omega(s)
∫0tσsdω(s)是一个随机积分,在时间区间[0,t]范围内的一个加权和;根据随机微分方程性质,
∫
0
t
σ
s
d
ω
(
s
)
\int_0^t\sigma^sd\omega(s)
∫0tσsdω(s)也服从高斯分布,其均值为0,方差计算如下:
V
a
r
[
∫
0
t
σ
s
d
ω
(
s
)
]
=
∫
0
t
σ
2
s
d
s
=
1
2
log
σ
σ
2
s
∣
0
t
=
1
2
log
σ
(
σ
2
t
−
1
)
\begin{align*} Var[\int_0^t\sigma^sd\omega(s)] &= \int_0^t\sigma^{2s}ds \\ & = \frac{1}{2 \log \sigma}\sigma^{2s}|^t_0 \\ & = \frac{1}{2 \log \sigma}(\sigma^{2t}-1) \tag3 \end{align*}
Var[∫0tσsdω(s)]=∫0tσ2sds=2logσ1σ2s∣0t=2logσ1(σ2t−1)(3)
即
∫
0
t
σ
s
d
ω
(
s
)
∼
N
(
0
,
1
2
log
σ
(
σ
2
t
−
1
)
I
)
\int_0^t\sigma^sd\omega(s) \sim N(0, \frac{1}{2 \log \sigma}(\sigma^{2t}-1)I)
∫0tσsdω(s)∼N(0,2logσ1(σ2t−1)I)。
因为
ω
(
s
)
\omega(s)
ω(s)符合高斯分布,
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
p_{0t}(x(t)|x(0))
p0t(x(t)∣x(0))也符合高斯分布,整合上述推导有
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
=
N
(
x
(
t
)
;
x
(
0
)
,
1
2
log
σ
(
σ
2
t
−
1
)
I
)
(4)
p_{0t}(x(t)|x(0)) = N(x(t);x(0),\frac{1}{2 \log \sigma}(\sigma^{2t}-1)I) \tag4
p0t(x(t)∣x(0))=N(x(t);x(0),2logσ1(σ2t−1)I)(4)
在上一篇SDE文档中提到过,当
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
p_{0t}(x(t)|x(0))
p0t(x(t)∣x(0))为高斯分布时,
λ
(
t
)
=
方差
=
1
2
log
σ
(
σ
2
t
−
1
)
\lambda(t)=方差=\frac{1}{2 \log \sigma}(\sigma^{2t}-1)
λ(t)=方差=2logσ1(σ2t−1)。
p
0
t
p_{0t}
p0t是一个联合分布,
p
t
=
0
p_{t=0}
pt=0和
p
t
=
1
p_{t=1}
pt=1属于边缘分布,因为是连续概率,
p
t
=
1
p_{t=1}
pt=1概率等于
p
t
=
0
p_{t=0}
pt=0分布下所有值的概率积分,重新用
y
,
x
y,x
y,x分别表示
x
(
0
)
,
x
(
t
)
x(0),x(t)
x(0),x(t),
p
t
=
1
p_{t=1}
pt=1的计算公式如下
p
t
=
1
=
∫
p
0
(
y
)
N
(
x
;
y
,
1
2
log
σ
(
σ
2
−
1
)
I
)
d
y
(5)
p_{t=1}=\int p_0(y)N(x;y,\frac{1}{2 \log \sigma}(\sigma^2-1)I)dy \tag5
pt=1=∫p0(y)N(x;y,2logσ1(σ2−1)I)dy(5)
因为
t
=
1
t=1
t=1,所以公式(5)中的分布为
N
(
x
;
y
,
1
2
log
σ
(
σ
2
−
1
)
I
)
N(x;y,\frac{1}{2 \log \sigma}(\sigma^2-1)I)
N(x;y,2logσ1(σ2−1)I);当
σ
\sigma
σ很大时,该分布中的均值y也可以将其看是趋近于0。故
N
(
x
;
y
,
1
2
log
σ
(
σ
2
−
1
)
I
)
≈
N
(
x
;
0
,
1
2
log
σ
(
σ
2
−
1
)
I
)
N(x;y,\frac{1}{2 \log \sigma}(\sigma^2-1)I) \approx N(x;0,\frac{1}{2 \log \sigma}(\sigma^2-1)I)
N(x;y,2logσ1(σ2−1)I)≈N(x;0,2logσ1(σ2−1)I)。将其带入公式(5)有
p
t
=
1
≈
∫
p
0
(
y
)
N
(
x
;
0
,
1
2
log
σ
(
σ
2
−
1
)
I
)
d
y
=
N
(
x
;
0
,
1
2
log
σ
(
σ
2
−
1
)
I
)
∫
p
0
(
y
)
d
y
=
N
(
x
;
0
,
1
2
log
σ
(
σ
2
−
1
)
I
)
\begin{align*} p_{t=1} &\approx \int p_0(y)N(x;0,\frac{1}{2 \log \sigma}(\sigma^2-1)I)dy \\ & = N(x;0,\frac{1}{2 \log \sigma}(\sigma^2-1)I)\int p_0(y)dy \\ & = N(x;0,\frac{1}{2 \log \sigma}(\sigma^2-1)I) \tag6 \end{align*}
pt=1≈∫p0(y)N(x;0,2logσ1(σ2−1)I)dy=N(x;0,2logσ1(σ2−1)I)∫p0(y)dy=N(x;0,2logσ1(σ2−1)I)(6)
上述大费周章推动出公式(6)的原因是
p
t
=
1
p_{t=1}
pt=1是采样起始的先验分布,并且公式(6)中的
p
t
=
1
p_{t=1}
pt=1符合高斯分布,也与数据分布无关。在进行逆SDE过程求解时,可基于公式(6)采样初始值。
欧拉数值采样
常规SDE公式定义为
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
dx=f(x,t)dt+g(t)dw
dx=f(x,t)dt+g(t)dw,其对应的逆SDE公式为
d
x
=
[
f
(
x
,
t
)
−
g
2
(
t
)
∇
x
log
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
ˉ
dx=[f(x,t)-g^2(t)\nabla_x \log p_t(x)]dt+g(t)d\bar{w}
dx=[f(x,t)−g2(t)∇xlogpt(x)]dt+g(t)dwˉ,当前人为定义的SDE为公式(1),则其对应的逆SDE公式如下
d
x
=
−
σ
2
t
∇
x
log
p
t
(
x
)
d
t
+
σ
t
d
w
ˉ
(7)
dx=-\sigma^{2t} \nabla_x \log p_t(x)dt+\sigma^td\bar{w} \tag7
dx=−σ2t∇xlogpt(x)dt+σtdwˉ(7)
训练过程使得分数网络
s
θ
(
x
,
t
)
s_\theta(x,t)
sθ(x,t)近似分数函数,即
s
θ
(
x
,
t
)
=
∇
x
log
p
t
(
x
)
d
t
s_\theta(x,t) = \nabla_x \log p_t(x)dt
sθ(x,t)=∇xlogpt(x)dt,公式(7)进而转换为
d
x
=
−
σ
2
t
s
θ
(
x
,
t
)
d
t
+
σ
t
d
w
ˉ
(8)
dx=-\sigma^{2t} s_\theta(x,t)dt+\sigma^td\bar{w} \tag8
dx=−σ2tsθ(x,t)dt+σtdwˉ(8)
在加噪过程中
p
t
=
0
p_{t=0}
pt=0对应原始数据分布,
p
t
=
1
p_{t=1}
pt=1对应加噪结束后的先验分布,前文已推导出该先验分布,即公式(6)。逆SDE过程先从先验分布中随机采样一个样本,然后使用Euler-Maruyama等数值方法求解逆SDE。该方法基于 SDE 的简单离散化,用
Δ
t
\Delta_t
Δt替代
d
t
dt
dt,用
z
∼
N
(
0
,
g
2
(
t
)
Δ
t
I
)
z \sim N(0,g^2(t)\Delta{t}I)
z∼N(0,g2(t)ΔtI)替代
d
ω
d\omega
dω。当应用于公式(7)中可得到以下迭代采样公式
x
t
−
Δ
t
=
x
t
+
σ
2
t
s
θ
(
x
,
t
)
Δ
t
+
σ
t
Δ
t
z
t
(9)
x_{t-\Delta{t}} = x_t + \sigma^{2t} s_\theta(x,t)\Delta{t}+\sigma^t\sqrt{\Delta{t}}z_t \tag9
xt−Δt=xt+σ2tsθ(x,t)Δt+σtΔtzt(9)
其中
x
t
∈
N
(
0
,
I
)
x_t \in N(0,I)
xt∈N(0,I)。
Predictor-Corrector Sampler
上述欧拉采样速度较快,但是质量不高。既然已通过分数匹配方法训练得到了可估计分数的分数网络,故也可以运用一些基于分数模拟采样的方法从分布中采样,如朗之万动力学采样。将逆SDE的数值解法与模拟采样相结合,可以进一步提高采样质量。具体步骤是使用数值解法–欧拉采样先获得一个预测值,然后使用模拟采样–朗之万动力学采样进行校正;数值解法称为Predictor,模拟采样方法称为Corrector,此种融合方式称为Predictor-Corrector Sampler,或简称为PC-Sampler。
当分数已知时,基于分数的朗之万MCMC采样方法可从
p
(
x
)
p(x)
p(x)中生成样本,具体公式如下
x
i
+
1
=
x
i
+
ϵ
∇
x
log
p
(
x
i
)
+
2
ϵ
z
i
,
(10)
x_{i+1} = x_i + \epsilon\nabla_x \log p(x_i)+\sqrt{2\epsilon}z_i,\tag{10}
xi+1=xi+ϵ∇xlogp(xi)+2ϵzi,(10)
其中
z
i
∈
N
(
0
,
I
)
z_i \in N(0,I)
zi∈N(0,I),
ϵ
>
0
\epsilon > 0
ϵ>0是步长。
PC-Sampler的伪代码如下所示,采样过程还是很直观的。首先应用逆SDE的数值解法从
x
t
x_t
xt中获取
x
t
−
Δ
t
x_{t-\Delta{t}}
xt−Δt,此过程称为predictor步骤。再使用朗之万动力学采样精炼
x
t
−
Δ
t
x_{t-\Delta{t}}
xt−Δt,使得
x
t
−
Δ
t
x_{t-\Delta{t}}
xt−Δt成为分布
p
t
−
Δ
t
(
x
)
p_{t-\Delta{t}}(x)
pt−Δt(x)更准确的样本,此步骤称为corrector步骤,能降低数值解法中的误差。

回忆上一篇SDE文章中的退火朗之万动力学采样过程,也是一个嵌套for循环形式,外层循环是从大到小的噪声等级,内层循环是与公式(10)一致的朗之万动力学采样。想到此时,又是一阵豁然开朗。因为逆SDE的数值采样方法本质就是通过离散化的方式求解连续数值,而退火朗之万动力学采样中的不同加噪等级就是有限的、离散的加噪过程。
后续研究表明,DDPM和Score-based Generative Model均可融入到基于分数的SDE中,并且采样方法可以整合到PC-Sampler框架中
基于ODE数值采样
对于所有扩散过程,在求解逆SDE时,存在一个相应的确定性过程,其轨迹与SDE具有相同的边缘概率密度,即逆向求解的结果和逆SDE求解出来的样本服从相同分布。这个确定性过程如下,称为概率流常微分方程:
d
x
=
[
f
(
x
,
t
)
−
1
2
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
(11)
dx=[f(x,t)-\frac{1}{2}g(t)^2\nabla_x \log p_t(x)]dt \tag{11}
dx=[f(x,t)−21g(t)2∇xlogpt(x)]dt(11)

以上示意图显示了概率流 ODE 的轨迹与 SDE 轨迹的不同,但仍然从相同的分布中采样。因此,也可以从
p
t
=
1
p_{t=1}
pt=1分布中的样本出发,以逆时间方向对ODE进行积分,最终可获得符合
p
t
=
0
p_{t=0}
pt=0的样本。
按照前文公式(1)中定义的SDE过程,其对应的ODE方程如下
d
x
=
−
1
2
σ
2
t
s
θ
(
x
,
t
)
d
t
(12)
dx=-\frac{1}{2}\sigma^{2t} s_\theta(x,t)dt \tag{12}
dx=−21σ2tsθ(x,t)dt(12)
基于公式(12),使用一些数值解法就能进行样本采样。

2218

被折叠的 条评论
为什么被折叠?



