物理学视角讲解diffusion生成模型——混合高斯扩散模型

学习评分函数

想要通过逆向扩散从某个目标分布中抽样——其功能形式未知,我们只能通过抽样来学习——但这需要我们知道对应于目标分布的评分函数。知道评分函数,即这个分布对数的梯度,似乎等同于知道分布本身。我们如何学习评分函数呢?

定义评分学习的目标函数

首先,让我们写下一个合理的目标函数。假设我们有一些参数化的得分函数sθ(x,t)\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}, t)sθ(x,t),它依赖于一组参数θ\boldsymbol{\theta}θ。我们希望准确地近似所有x\mathbf{x}x值和所有t值的得分函数,因此我们可能尝试写下如下的目标函数:
J(θ):=?12∫dxdt [sθ(x,t)−∇xlog⁡p(x,t)]2 .J(\boldsymbol{\theta}) \stackrel{?}{:=} \frac{1}{2} \int d\mathbf{x} dt \ \left[ \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}, t) - \nabla_{\mathbf{x}} \log p(\mathbf{x}, t) \right]^2 \ .J(θ):=?21dxdt [sθ(x,t)xlogp(x,t)]2 .
这个目标函数的问题在于它没有优先考虑任何特定的x\mathbf{x}x值。我们特别感兴趣的是对高概率值的得分函数进行准确的近似,因此对上述目标函数的一个合理修改是:
J(θ):=?12∫dxdt p(x,t) [sθ(x,t)−∇xlog⁡p(x,t)]2 .J(\boldsymbol{\theta}) \stackrel{?}{:=} \frac{1}{2} \int d\mathbf{x} dt \ p(\mathbf{x}, t) \ \left[ \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}, t) - \nabla_{\mathbf{x}} \log p(\mathbf{x}, t) \right]^2 \ .J(θ):=?21dxdt p(x,t) [sθ(x,t)xlogp(x,t)]2 .
类似地,我们可能会考虑添加一个不同时间的权重因子,因为得分函数偏离精确值的规模可能随时间变化:
Jnaive(θ):=12∫dxdt λ(t) p(x,t) [sθ(x,t)−∇xlog⁡p(x,t)]2 .J_{naive}(\boldsymbol{\theta}) := \frac{1}{2} \int d\mathbf{x} dt \ \lambda(t) \ p(\mathbf{x}, t) \ \left[ \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}, t) - \nabla_{\mathbf{x}} \log p(\mathbf{x}, t) \right]^2 \ .Jnaive(θ):=21dxdt λ(t) p(x,t) [sθ(x,t)xlogp(x,t)]2 .
这是一个完全合理的目标函数。但我们有一个重要问题:即很难估计p(x,t)p(\mathbf{x}, t)p(x,t)的对数的梯度,因为p(x,t)p(\mathbf{x}, t)p(x,t)可能强烈依赖于p(x,0)p(\mathbf{x}, 0)p(x,0)(即我们的目标分布)。而我们不知道我们的目标分布,这是我们做所有这些的原因!
此时,我们可以使用一个有趣的技巧。尽管上述目标函数相当合理,但它太难以处理;技巧是找到一个具有相同全局最小值的替代目标函数。这由下式提供:
Jmod(θ):=12∫dxdx(0)dt p(x,t∣x(0),0)p(x(0)) [sθ(x,t)−∇xlog⁡p(x,t∣x(0),0)]2 .J_{mod}(\boldsymbol{\theta}) := \frac{1}{2} \int d\mathbf{x} d\mathbf{x}^{(0)} dt \ p(\mathbf{x}, t | \mathbf{x}^{(0)}, 0) p(\mathbf{x}^{(0)}) \ \left[ \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}, t) - \nabla_{\mathbf{x}} \log p(\mathbf{x}, t | \mathbf{x}^{(0)}, 0) \right]^2 \ .Jmod(θ):=21dxdx(0)dt p(x,tx(0),0)p(x(0)) [sθ(x,t)xlogp(x,tx(0),0)]2 .
注意到:
KaTeX parse error: {split} can be used only in display mode.
KaTeX parse error: {split} can be used only in display mode.
实际上,我们已经显示了更强的结果:这两个目标函数作为θ\boldsymbol{\theta}θ的函数是相同的,仅相差一个加法常数。
现在我们的目标函数涉及到估计过渡概率对数的梯度,这通常可以通过我们对前向随机过程的了解解析地计算,因此是可用的。
我们将采用的目标函数是最后一个(我们将去掉“mod”下标,以赋予它额外的重要性):
KaTeX parse error: {split} can be used only in display mode.

使用样本近似目标函数

将目标函数或损失函数用期望值来表示的一个好处是,这提示了一种使用样本来近似它的清晰策略:我们可以采取蒙特卡洛类型的方法。
给定一个来自我们目标分布的样本 x(0)\mathbf{x}^{(0)}x(0),我们可以做以下事情:

  1. 从 [0,T] 中均匀抽取一个时间 t。
  2. 利用我们对转移概率的了解,抽取 x∼p(x,t∣x(0),0)\mathbf{x} \sim p(\mathbf{x}, t | \mathbf{x}^{(0)}, 0)xp(x,tx(0),0)
  3. 利用我们对转移概率的了解,计算我们样本的 ∇xlog⁡p(x,t∣x(0),0)\nabla_{\mathbf{x}} \log p(\mathbf{x}, t | \mathbf{x}^{(0)}, 0)xlogp(x,tx(0),0)

然后我们就得到了一个近似值
J(θ)≈12λ(t)[ sθ(x,t)−∇xlog⁡p(x,t∣x(0),0) ]2 .J(\boldsymbol{\theta}) \approx \frac{1}{2} \lambda(t) \left[ \ \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}, t) - \nabla_{\mathbf{x}} \log p(\mathbf{x}, t | \mathbf{x}^{(0)}, 0) \ \right]^2 \ .J(θ)21λ(t)[ sθ(x,t)xlogp(x,tx(0),0) ]2 .
更一般地,如果我们有一批 S 个样本,我们可以对每一个样本都遵循这个程序来构建近似值
J(θ)≈12S∑j=1Sλ(tj)[ sθ(xj,tj)−∇xlog⁡p(xj,t∣xj(0),0) ]2 .J(\boldsymbol{\theta}) \approx \frac{1}{2 S} \sum_{j = 1}^S \lambda(t_j) \left[ \ \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}_j, t_j) - \nabla_{\mathbf{x}} \log p(\mathbf{x}_j, t | \mathbf{x}^{(0)}_j, 0) \ \right]^2 \ .J(θ)2S1j=1Sλ(tj)[ sθ(xj,tj)xlogp(xj,txj(0),0) ]2 .
幸运的是,对于我们将要使用的解析可行的前向过程,过渡概率的对数通常具有特别简单的形式。例如,对于 VE SDE(见第2节),其定义是通过
x˙=d[σ2(t)]dt η(t) ,\dot{\mathbf{x}} = \sqrt{ \frac{d[ \sigma^2(t) ]}{dt} } \ \boldsymbol{\eta}(t) \ ,x˙=dtd[σ2(t)]  η(t) ,
相应的转移概率是
KaTeX parse error: {split} can be used only in display mode.
所以过渡概率的对数的梯度是
∇xlog⁡p(x,t∣x(0),0)=−[x−x(0)]σ2(t) .\nabla_{\mathbf{x}} \log p(\mathbf{x}, t | \mathbf{x}^{(0)}, 0) = - \frac{\left[ \mathbf{x} - \mathbf{x}^{(0)} \right]}{\sigma^2(t)} \ .xlogp(x,tx(0),0)=

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值