从扩散模型开始的生成模型范式演变--SDE(2)

部署运行你感兴趣的模型镜像

上一篇SDE文章中,从去噪分数匹配、分数网络出发,逐渐将SDE引入到扩散过程表征中,主要集中在原理讲解和推导中,本文则进一步补充能在扩散过程中实际使用的SDE公式实例。

一种SDE定义

相信很多读者自己学习或在看完上一篇文章后,对SDE实际如何使用情况还是不清楚。其实,因为SDE是随机微分方程,所以其与DDPM中加噪的方差其实都是人为设定的,相当于超参数。所以在训练前就需要将SDE具体形式定义好,在此将SDE定义为
d x = σ t d ω t ∈ [ 0 , 1 ] (1) dx=\sigma^t d\omega \quad t \in [0,1] \tag1 dx=σtdωt[0,1](1)
公式(1)中丢弃了SDE定义中的漂移部分,只设置了扩散部分。此时已将时间区间正则化,即范围为[0,1],也知道了 d x dx dx随时间变化的公式,因为当前加噪是连续过程,故
x ( t ) = x ( 0 ) + ∫ 0 t d x = x ( 0 ) + ∫ 0 t σ s d ω ( s ) (2) x(t) = x(0) + \int_0^tdx=x(0)+\int_0^t\sigma^sd\omega(s) \tag2 x(t)=x(0)+0tdx=x(0)+0tσsdω(s)(2)
我们知道 ω ( s ) \omega(s) ω(s)是标准布朗运动,其符合高斯分布 N ( 0 , t I ) N(0,tI) N(0,tI);而 ∫ 0 t σ s d ω ( s ) \int_0^t\sigma^sd\omega(s) 0tσsdω(s)是一个随机积分,在时间区间[0,t]范围内的一个加权和;根据随机微分方程性质, ∫ 0 t σ s d ω ( s ) \int_0^t\sigma^sd\omega(s) 0tσsdω(s)也服从高斯分布,其均值为0,方差计算如下:

V a r [ ∫ 0 t σ s d ω ( s ) ] = ∫ 0 t σ 2 s d s = 1 2 log ⁡ σ σ 2 s ∣ 0 t = 1 2 log ⁡ σ ( σ 2 t − 1 ) \begin{align*} Var[\int_0^t\sigma^sd\omega(s)] &= \int_0^t\sigma^{2s}ds \\ & = \frac{1}{2 \log \sigma}\sigma^{2s}|^t_0 \\ & = \frac{1}{2 \log \sigma}(\sigma^{2t}-1) \tag3 \end{align*} Var[0tσsdω(s)]=0tσ2sds=2logσ1σ2s0t=2logσ1(σ2t1)(3)
∫ 0 t σ s d ω ( s ) ∼ N ( 0 , 1 2 log ⁡ σ ( σ 2 t − 1 ) I ) \int_0^t\sigma^sd\omega(s) \sim N(0, \frac{1}{2 \log \sigma}(\sigma^{2t}-1)I) 0tσsdω(s)N(0,2logσ1(σ2t1)I)
因为 ω ( s ) \omega(s) ω(s)符合高斯分布, p 0 t ( x ( t ) ∣ x ( 0 ) ) p_{0t}(x(t)|x(0)) p0t(x(t)x(0))也符合高斯分布,整合上述推导有
p 0 t ( x ( t ) ∣ x ( 0 ) ) = N ( x ( t ) ; x ( 0 ) , 1 2 log ⁡ σ ( σ 2 t − 1 ) I ) (4) p_{0t}(x(t)|x(0)) = N(x(t);x(0),\frac{1}{2 \log \sigma}(\sigma^{2t}-1)I) \tag4 p0t(x(t)x(0))=N(x(t);x(0),2logσ1(σ2t1)I)(4)
在上一篇SDE文档中提到过,当 p 0 t ( x ( t ) ∣ x ( 0 ) ) p_{0t}(x(t)|x(0)) p0t(x(t)x(0))为高斯分布时, λ ( t ) = 方差 = 1 2 log ⁡ σ ( σ 2 t − 1 ) \lambda(t)=方差=\frac{1}{2 \log \sigma}(\sigma^{2t}-1) λ(t)=方差=2logσ1(σ2t1)

p 0 t p_{0t} p0t是一个联合分布, p t = 0 p_{t=0} pt=0 p t = 1 p_{t=1} pt=1属于边缘分布,因为是连续概率, p t = 1 p_{t=1} pt=1概率等于 p t = 0 p_{t=0} pt=0分布下所有值的概率积分,重新用 y , x y,x y,x分别表示 x ( 0 ) , x ( t ) x(0),x(t) x(0),x(t) p t = 1 p_{t=1} pt=1的计算公式如下
p t = 1 = ∫ p 0 ( y ) N ( x ; y , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) d y (5) p_{t=1}=\int p_0(y)N(x;y,\frac{1}{2 \log \sigma}(\sigma^2-1)I)dy \tag5 pt=1=p0(y)N(x;y,2logσ1(σ21)I)dy(5)
因为 t = 1 t=1 t=1,所以公式(5)中的分布为 N ( x ; y , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) N(x;y,\frac{1}{2 \log \sigma}(\sigma^2-1)I) N(x;y,2logσ1(σ21)I);当 σ \sigma σ很大时,该分布中的均值y也可以将其看是趋近于0。故 N ( x ; y , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) ≈ N ( x ; 0 , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) N(x;y,\frac{1}{2 \log \sigma}(\sigma^2-1)I) \approx N(x;0,\frac{1}{2 \log \sigma}(\sigma^2-1)I) N(x;y,2logσ1(σ21)I)N(x;0,2logσ1(σ21)I)。将其带入公式(5)有
p t = 1 ≈ ∫ p 0 ( y ) N ( x ; 0 , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) d y = N ( x ; 0 , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) ∫ p 0 ( y ) d y = N ( x ; 0 , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) \begin{align*} p_{t=1} &\approx \int p_0(y)N(x;0,\frac{1}{2 \log \sigma}(\sigma^2-1)I)dy \\ & = N(x;0,\frac{1}{2 \log \sigma}(\sigma^2-1)I)\int p_0(y)dy \\ & = N(x;0,\frac{1}{2 \log \sigma}(\sigma^2-1)I) \tag6 \end{align*} pt=1p0(y)N(x;0,2logσ1(σ21)I)dy=N(x;0,2logσ1(σ21)I)p0(y)dy=N(x;0,2logσ1(σ21)I)(6)
上述大费周章推动出公式(6)的原因是 p t = 1 p_{t=1} pt=1是采样起始的先验分布,并且公式(6)中的 p t = 1 p_{t=1} pt=1符合高斯分布,也与数据分布无关。在进行逆SDE过程求解时,可基于公式(6)采样初始值。

欧拉数值采样

常规SDE公式定义为 d x = f ( x , t ) d t + g ( t ) d w dx=f(x,t)dt+g(t)dw dx=f(x,t)dt+g(t)dw,其对应的逆SDE公式为 d x = [ f ( x , t ) − g 2 ( t ) ∇ x log ⁡ p t ( x ) ] d t + g ( t ) d w ˉ dx=[f(x,t)-g^2(t)\nabla_x \log p_t(x)]dt+g(t)d\bar{w} dx=[f(x,t)g2(t)xlogpt(x)]dt+g(t)dwˉ,当前人为定义的SDE为公式(1),则其对应的逆SDE公式如下
d x = − σ 2 t ∇ x log ⁡ p t ( x ) d t + σ t d w ˉ (7) dx=-\sigma^{2t} \nabla_x \log p_t(x)dt+\sigma^td\bar{w} \tag7 dx=σ2txlogpt(x)dt+σtdwˉ(7)
训练过程使得分数网络 s θ ( x , t ) s_\theta(x,t) sθ(x,t)近似分数函数,即 s θ ( x , t ) = ∇ x log ⁡ p t ( x ) d t s_\theta(x,t) = \nabla_x \log p_t(x)dt sθ(x,t)=xlogpt(x)dt,公式(7)进而转换为
d x = − σ 2 t s θ ( x , t ) d t + σ t d w ˉ (8) dx=-\sigma^{2t} s_\theta(x,t)dt+\sigma^td\bar{w} \tag8 dx=σ2tsθ(x,t)dt+σtdwˉ(8)
在加噪过程中 p t = 0 p_{t=0} pt=0对应原始数据分布, p t = 1 p_{t=1} pt=1对应加噪结束后的先验分布,前文已推导出该先验分布,即公式(6)。逆SDE过程先从先验分布中随机采样一个样本,然后使用Euler-Maruyama等数值方法求解逆SDE。该方法基于 SDE 的简单离散化,用 Δ t \Delta_t Δt替代 d t dt dt,用 z ∼ N ( 0 , g 2 ( t ) Δ t I ) z \sim N(0,g^2(t)\Delta{t}I) zN(0,g2(t)ΔtI)替代 d ω d\omega dω。当应用于公式(7)中可得到以下迭代采样公式
x t − Δ t = x t + σ 2 t s θ ( x , t ) Δ t + σ t Δ t z t (9) x_{t-\Delta{t}} = x_t + \sigma^{2t} s_\theta(x,t)\Delta{t}+\sigma^t\sqrt{\Delta{t}}z_t \tag9 xtΔt=xt+σ2tsθ(x,t)Δt+σtΔt zt(9)
其中 x t ∈ N ( 0 , I ) x_t \in N(0,I) xtN(0,I)

Predictor-Corrector Sampler

上述欧拉采样速度较快,但是质量不高。既然已通过分数匹配方法训练得到了可估计分数的分数网络,故也可以运用一些基于分数模拟采样的方法从分布中采样,如朗之万动力学采样。将逆SDE的数值解法与模拟采样相结合,可以进一步提高采样质量。具体步骤是使用数值解法–欧拉采样先获得一个预测值,然后使用模拟采样–朗之万动力学采样进行校正;数值解法称为Predictor,模拟采样方法称为Corrector,此种融合方式称为Predictor-Corrector Sampler,或简称为PC-Sampler。

当分数已知时,基于分数的朗之万MCMC采样方法可从 p ( x ) p(x) p(x)中生成样本,具体公式如下
x i + 1 = x i + ϵ ∇ x log ⁡ p ( x i ) + 2 ϵ z i , (10) x_{i+1} = x_i + \epsilon\nabla_x \log p(x_i)+\sqrt{2\epsilon}z_i,\tag{10} xi+1=xi+ϵxlogp(xi)+2ϵ zi,(10)
其中 z i ∈ N ( 0 , I ) z_i \in N(0,I) ziN(0,I) ϵ > 0 \epsilon > 0 ϵ>0是步长。

PC-Sampler的伪代码如下所示,采样过程还是很直观的。首先应用逆SDE的数值解法从 x t x_t xt中获取 x t − Δ t x_{t-\Delta{t}} xtΔt,此过程称为predictor步骤。再使用朗之万动力学采样精炼 x t − Δ t x_{t-\Delta{t}} xtΔt,使得 x t − Δ t x_{t-\Delta{t}} xtΔt成为分布 p t − Δ t ( x ) p_{t-\Delta{t}}(x) ptΔt(x)更准确的样本,此步骤称为corrector步骤,能降低数值解法中的误差。
在这里插入图片描述
回忆上一篇SDE文章中的退火朗之万动力学采样过程,也是一个嵌套for循环形式,外层循环是从大到小的噪声等级,内层循环是与公式(10)一致的朗之万动力学采样。想到此时,又是一阵豁然开朗。因为逆SDE的数值采样方法本质就是通过离散化的方式求解连续数值,而退火朗之万动力学采样中的不同加噪等级就是有限的、离散的加噪过程。

后续研究表明,DDPM和Score-based Generative Model均可融入到基于分数的SDE中,并且采样方法可以整合到PC-Sampler框架中

基于ODE数值采样

对于所有扩散过程,在求解逆SDE时,存在一个相应的确定性过程,其轨迹与SDE具有相同的边缘概率密度,即逆向求解的结果和逆SDE求解出来的样本服从相同分布。这个确定性过程如下,称为概率流常微分方程:
d x = [ f ( x , t ) − 1 2 g ( t ) 2 ∇ x log ⁡ p t ( x ) ] d t (11) dx=[f(x,t)-\frac{1}{2}g(t)^2\nabla_x \log p_t(x)]dt \tag{11} dx=[f(x,t)21g(t)2xlogpt(x)]dt(11)
在这里插入图片描述
以上示意图显示了概率流 ODE 的轨迹与 SDE 轨迹的不同,但仍然从相同的分布中采样。因此,也可以从 p t = 1 p_{t=1} pt=1分布中的样本出发,以逆时间方向对ODE进行积分,最终可获得符合 p t = 0 p_{t=0} pt=0的样本。

按照前文公式(1)中定义的SDE过程,其对应的ODE方程如下
d x = − 1 2 σ 2 t s θ ( x , t ) d t (12) dx=-\frac{1}{2}\sigma^{2t} s_\theta(x,t)dt \tag{12} dx=21σ2tsθ(x,t)dt(12)
基于公式(12),使用一些数值解法就能进行样本采样。

您可能感兴趣的与本文相关的镜像

Qwen-Image-Edit-2509

Qwen-Image-Edit-2509

图片编辑
Qwen

Qwen-Image-Edit-2509 是阿里巴巴通义千问团队于2025年9月发布的最新图像编辑AI模型,主要支持多图编辑,包括“人物+人物”、“人物+商品”等组合玩法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值