基于分数的随机微分方程生成模型

基于分数的随机微分方程生成模型

Score-based generative modeling through stochastic differential equations

摘要

从数据中产生噪声很容易;从噪声中创建数据是生成性建模。我们提出了一个随机微分方程(SDE),它通过缓慢地注入噪声将复杂的数据分布平滑地转换为已知的先验分布,以及一个相应的反向时间SDE,它通过慢慢地去除噪声将先验分布转换回数据分布。至关重要的是,反向时间SDE仅取决于扰动数据分布的时间依赖梯度场(即分数)。通过利用基于分数的生成模型的进步,我们可以使用神经网络准确估计这些分数,并使用数值SDE求解器生成样本。我们表明,该框架封装了基于分数的生成建模和扩散概率建模中的先前方法,允许新的采样过程和新的建模能力。特别地,我们引入了一个预测器-校正器框架来校正离散化反向时间SDE演化中的误差。我们还导出了一个等效的神经ODE,它从与SDE相同的分布中采样,但另外还实现了精确的似然计算,并提高了采样效率。此外,我们提供了一种使用基于分数的模型来解决逆问题的新方法,如在类条件生成、图像修复和着色方面的实验所证明的。结合多个架构改进,我们在CIFAR-10上实现了创纪录的无条件图像生成性能,初始分数为9.89,FID为2.20,竞争可能性为2.99 bits/dim,并首次展示了基于分数的生成模型中1024×1024图像的高保真生成。

1简介

两类成功的概率生成模型涉及用缓慢增加的噪声顺序地破坏训练数据,然后学习扭转这种破坏,以形成数据的生成模型。Score matching with Langevin dynamics(与朗之万动力学的得分匹配) (SMLD)(Song&Ermon,2019)估计每个噪声尺度上的分数(即,对数概率密度相对于数据的梯度),然后使用Langevin动态从生成期间的一系列降低的噪声尺度中采样。去噪扩散概率建模(DDPM)(Sohl Dickstein等人,2015;Ho等人,2020)训练一系列概率模型,以逆转噪声破坏的每一步,使用反向分布的函数形式的知识,使训练变得易于处理。对于连续状态空间,DDPM训练目标隐式地计算每个噪声尺度的分数。因此,我们将这两个模型类统称为基于分数的生成模型。基于分数的生成模型和相关技术(Bordes等人,2017;Goyal等人,2017年;Du和Moratch,2019)已被证明在生成图像(Song&Ermon,2019;2020;Ho等人,2020)、音频(Chen等人,2020;Kong等人,2019)、图形(Niu等人,2020年)和形状(Cai等人,2020。为了实现新的采样方法并进一步扩展基于分数的生成模型的能力,我们提出了一个统一的框架,通过随机微分方程(SDE)的视角推广了以前的方法。具体地说,我们考虑的是一个连续的分布,它根据扩散过程随时间演变,而不是用有限数量的噪声分布来扰动数据。该过程将数据点逐渐扩散为随机噪声,并由不依赖于数据且无可训练参数的规定SDE给出。通过反转此过程,我们可以将随机噪声平滑地模制到数据中以生成样本。至关重要的是,该反向过程满足反向时间SDE(Anderson,1982),该SDE可以从给定作为时间函数的边际概率密度分数的正向SDE中导出。因此,我们可以通过训练时间相关的神经网络来估计分数,然后使用数值SDE求解器来生成样本,从而近似逆时间SDE。我们的主要想法如图1所示。我们提出的框架有几个理论和实践贡献:

灵活的采样和似然计算:我们可以使用任何通用SDE解算器来集成采样的反向时间SDE。此外,我们提出了两种对一般SDE不可行的特殊方法:(i)预测校正器(PC)采样器,将数值SDE求解器与基于分数的MCMC方法相结合,例如Langevin MCMC(Parisi,1981)和HMC(Neal等人,2011);以及(ii)基于概率流常微分方程(ODE)的确定性采样器。前者对基于分数的模型的现有采样方法进行了统一和改进。后者允许通过黑箱ODE解算器进行快速自适应采样,通过潜在代码进行灵活的数据处理,唯一可识别的编码,尤其是精确的似然计算。

可控生成:我们可以通过调节训练期间不可用的信息来调节生成过程,因为可以根据无条件分数有效地估计条件反转时间SDE。这使得类条件生成、图像修复、着色和其他逆问题等应用成为可能,所有这些都可以使用单个无条件的基于分数的模型来实现,而无需重新训练。

统一框架:我们的框架提供了一种统一的方式来探索和调整各种SDE,以改进基于分数的生成模型SMLD和DDPM的方法可以合并到我们的框架中,作为两个独立SDE的离散化。尽管DDPM(Ho等人,2020)最近被报道实现了比SMLD更高的样本质量(Song&Ermon,2019;2020),但我们表明,随着我们的框架允许的更好的架构和新的采样算法,后者可以赶上它,它在CIFAR-10上实现了新的最先进的初始分数(9.89)和FID分数(2.20),以及首次从基于分数的模型中高保真地生成1024色1024图像。此外,我们在我们的框架下提出了一种新的SDE,该SDE在均匀去量化的CIFAR-10图像上实现了2.99 bits/dim的似然值,创下了这项任务的新纪录。

2背景

2.1郎之万动力学的去噪分数匹配(SMLD)

2.2去噪扩散概率模型(DDPM)

3基于分数的SDES生成建模

具有多个噪声尺度的扰动数据是先前方法成功的关键。我们建议将这一想法进一步推广到无限数量的噪声尺度,这样,当噪声增强时,扰动数据分布会根据SDE演变。我们的框架概述如图2所示。

 图2:通过SDE进行的基于分数的生成建模概述。我们可以使用SDE(第3.1节)将数据映射到噪声分布(先验分布),并反转该SDE进行生成建模(第3.2节)。我们还可以反转相关的概率流ODE(第4.3节),这产生了一个确定性过程,该过程从与SDE相同的分布中采样。反向时间SDE和概率流 ODE都可以通过估计得分来获得(第3.3节)。

 3.1使用SDES干扰数据

 3.2通过反转SDE生成样本

3.3估计SDE的分数

 3.4例:VE、VP SDES和BEYOND

 

4解决反向SDE

训练了基于时间依赖分数的模型sθ之后,我们可以使用它来构建逆时间SDE,然后用数值方法对其进行模拟,以从p0生成样本。

4.1通用数字SDE求解器

数值求解器提供SDE的近似轨迹。存在许多用于求解SDE的通用数值方法,如Euler Maruyama和随机Runge-Kutta方法(Kloe-den&Platen,2013),它们对应于随机动力学的不同离散化。我们可以将它们中的任何一个应用于反向时间SDE以生成样本。祖先采样,即DDPM的采样方法(等式(4)),实际上对应于反向时间VP SDE的一个特殊离散化(等式(11))(见附录E)。然而,推导新SDE的祖先采样规则可能并非易事。为了弥补这一点,我们提出了反向扩散采样器(详见附录E),它以与正向离散化相同的方式离散化反向时间SDE,因此可以很容易地在给定正向离散化的情况下导出。如表1所示,对于CIFAR-10上的SMLD和DDPM模型,反向扩散采样器的性能略好于祖先采样(DDPM型祖先采样也适用于SMLD模型,见附录F)

4.2预测校正采样器

与一般SDE不同,我们有可用于改进解决方案的附加信息。由于我们有一个基于分数的模型,我们可以使用基于分数的MCMC方法,如Langevin MCMC(Parisi,1981;Grenander和Miller,1994)或HMC(Neal et al.,2011),直接从pt中采样,并校正数值SDE解算器的解

具体而言,在每个时间步,数值SDE解算器首先在下一时间步给出样本的估计,扮演“预测者”的角色。然后,基于分数的MCMC方法校正估计样本的边际分布,起到“校正器”的作用。这个想法类似于预测-校正方法,这是一系列用于求解方程组的数值连续技术(Allgower&Georg,2012),我们将我们的混合采样算法命名为预测-校正(PC)采样器。请在附录G中找到伪代码和完整的描述。PC采样器概括了SMLD和DDPM的原始采样方法:前者使用恒等函数作为预测器,退火Langevin动力学作为校正器,而后者使用祖先采样作为预测器和恒等式作为校正器。

我们在用等式(1)和(3)给出的原始离散目标训练的SMLD和DDPM模型(见附录G中的算法2和3)上测试PC采样器。这显示了PC采样器与用固定数量的噪声尺度训练的基于分数的模型的兼容性。我们在表1中总结了不同采样器的性能,其中概率流是第4.3节中讨论的预测器。详细的实验设置和额外的结果在附录G中给出。我们观察到,我们的反向扩散采样器总是优于祖先采样,和仅校正器的方法(C2000)在相同计算下的性能比其他竞争对手(P2000、PC1000)差(事实上,我们需要每个噪声尺度更多的校正器步骤,从而需要更多的计算,以匹配其他采样器的性能。此外,它通常比在不添加校正器的情况下将预测步长增加一倍(P2000)要好,因为我们必须以特殊方式(详见附录G)在SMLD/DPM模型的噪声尺度之间进行插值。在图9(附录G)中,我们还提供了用连续目标方程训练的模型的定性比较。(7)在256色256 LSUN图像和VE SDE上,当使用适当数量的校正器步骤时,在可比计算下,PC采样器明显超过仅预测采样器。

 表1:比较CIFAR-10上不同的反向时间SDE解算器。着色区域是通过相同的计算(分数函数评估的数量)获得的。平均值和标准偏差是在五次抽样中报告的。“P1000”或“P2000”:仅使用1000或2000步的预测器采样器。“C2000”:仅校正器采样器,使用2000步。“PC1000”:预测器-校正器(PC)采样器,使用1000个预测器和1000个校正器步骤。

4.3概率流和与神经节点的连接

基于分数的模型实现了求解逆时间SDE的另一种数值方法。对于所有扩散过程,存在一个相应的确定性过程,其轨迹与SDE共享相同的边际概率密度。该确定性过程满足ODE(更多细节见附录D.1):

一旦得分已知,就可以从SDE中确定。我们将方程(13)中的ODE命名为概率流ODE(常微分方程)当分数函数由基于时间的分数的模型(通常是神经网络)近似时,这是神经ODE的一个例子(Chen et al.,2018)。

精确似然计算利用与神经常微分方程的连接,我们可以计算方程(13)定义的密度。通过变量的瞬时变化公式。这使我们能够计算任何输入数据的准确似然性(详见附录D.2)。例如,我们在表2中报告了CIFAR-10数据集上以位/分为单位测量的负对数似然性(NLL)。我们在均匀去量化数据上计算对数似然,并且只与以相同方式评估的模型进行比较(省略了用变分去量化评估的模型或离散数据),除了DDPM(L/Lsimple),其ELBO值(用*注释)在离散数据上报告。主要结果:

  1. 对于Ho等人(2020)中的相同DDPM模型,我们获得了比ELBO更好的bits/dim,因为我们的可能性是准确的;
  2. 使用相同的架构,我们训练了另一个具有等式7中连续目标的DDPM模型。(即,DDPM续),这进一步提高了可能性;
  3. 对于sub-VP SDE,与VP SDE相比,我们总是获得更高的可能性;
  4. 通过改进的架构(即,DDPM++续,第4.4节中的详细信息)和sub-VP SDE,即使没有最大似然训练,我们也可以在均匀去量化的CIFAR-10上设置2.99的新记录bits/dim。

操纵潜在表示通过积分方程(13),我们可以将任何数据点x(0)编码到潜在空间x (T)中。解码可以通过对反向时间SDE的对应ODE进行积分来实现。正如其他可逆模型所做的那样,如神经常微分方程和归一化流(Dinh et al.,2016;Kingma和Dhariwal,2018),我们可以操纵这种潜在的表示用于图像编辑,如插值和温度缩放(见图3和附录D.4)。

唯一可识别编码与大多数当前的可逆模型不同,我们的编码是唯一可识别的,这意味着在有足够的训练数据、模型容量和优化精度的情况下,输入的编码由数据分布唯一确定(Roeder et al.,2020)。这是因为我们的前向SDE方程(5)没有可训练的参数,并且其相关的概率流ODE方程(13)在给出完美估计的分数的情况下提供了相同的轨迹。我们在附录D.5中提供了关于这一性质的额外经验验证。

有效采样与神经ODE一样,我们可以通过从不同的最终条件x pTq?pT中解方程(13)来采样xp0q?p0。使用固定离散化策略,我们可以生成竞争性样本,尤其是与校正器结合使用时(表1,“概率流采样器”,详细信息见附录D.3)。使用黑盒ODE解算器(Dormand&Prince,1980)不仅可以生成高质量样本(表2,详细信息详见附录D.4),还可以明确地权衡准确性和效率。在较大的误差容限下,函数评估的数量可以减少90%以上,而不会影响样本的视觉质量(图3)。

图3:概率流ODE支持具有自适应步长的快速采样(左),并在不损害质量的情况下减少了分数函数评估的数量(NFE)。从延迟对象到图像的可逆映射允许进行插值(右)

4.4架构改进

我们探索了使用VE和VP SDE的基于分数的模型的几种新架构设计(详细信息见附录H),其中我们训练具有与SMLD/DPM中相同离散目标的模型。由于VP SDE的相似性,我们直接将其架构转移到子VP SDE。我们的VE SDE最优架构(称为NCSN++)在带有PC采样器的CIFAR-10上实现了2.45的FID,而我们的VP SDE最优体系结构(称为DDPM++)实现了2.78。

通过切换到方程中的连续训练目标。(7),并增加网络深度,我们可以进一步提高所有模型的样本质量。在表3中,VE和VP/子VP SDE的最终架构分别表示为NCSN++cont.和DDPM++cont。表3中报告的结果是针对训练过程中FID最小的检查点,其中使用PC采样器生成样本。相反,表2中的FID分数和NLL值是针对最后一个训练检查点报告的,并且使用黑箱ODE解算器获得样本。如表3所示,VE SDE通常比VP/su-VP SDE提供更好的样本质量,但我们也从经验上观察到,它们的可能性比VP/su-VP SDE的同行更差。这表明从业者可能需要针对不同的领域和体系结构对不同的SDE进行实验。

我们的样本质量最佳模型NCSN++cont.(deep,VE)将网络深度提高了一倍,并在无条件生成CIFAR-10时创下了初始得分和FID的新记录。令人惊讶的是,我们可以在不需要标记数据的情况下实现比以前最好的条件生成模型更好的FID。在所有改进的基础上,我们还从基于分数的模型中获得了CelebA HQ 1024的第一组高保真度样本(见附录H.3)。我们的最佳可能性模型DDPM++cont.(deep,sub-VP)类似地将网络深度增加了一倍,并实现了2.99 bits/dim的对数似然性,具有方程中的连续目标。(7)。据我们所知,这是统一去量化CIFAR-10的最高可能性。

5可控发电

我们的框架的连续结构使我们不仅可以从p0产生数据样本,而且如果ptpy|xptqq已知,还可以从p0pxp0q|yq产生数据样本。给定等式(5)中的正向SDE,我们可以通过从pT pxpTq|yq开始并求解条件反向时间SDE来从ptpxptq|yq进行采样:

一般来说,我们可以使用等式(14)来解决基于分数的生成模型的一大类逆问题,一旦给出了正向过程的梯度的估计值,就可以使用方程(14)|x log ptpy|x ptqq。在某些情况下,可以训练一个单独的模型来学习正向过程log pt p y | x ptqq并计算其梯度。否则,我们可以使用启发式和领域知识来估计梯度。在附录I.4中,我们提供了一种广泛适用的方法,用于在不需要训练辅助模型的情况下获得这样的估计。

图4:左图:32个ˆ32 CIFAR-10上的类条件样本。上面四排是汽车,下面四排是马。右:在256ˆ256 LSUN上进行着色(下两行)的结果。第一列是原始图像,第二列是掩蔽/灰度图像,其余列是采样的图像补全或着色。

我们考虑了该方法的三种可控生成应用:类条件生成、图像插补和彩色化。当y表示类标签时,我们可以训练用于类条件采样的时间依赖分类器ptpy|xptqq。由于前向SDE是可处理的,我们可以通过首先从数据集中采样pxp0q,yq,然后采样x ptq“p0tpxptq|xp0qq,轻松创建时间相关分类器的训练数据p xptq,yq。然后,我们可以在不同的时间步长上使用交叉熵损失的混合物,如等式。(7),以训练时间相关分类器ptpy|xptqq。我们在图4(左)中提供了类条件CIFAR-10样本,并将更多细节和结果放在附录I中。

插补是条件抽样的一种特殊情况。假设我们有一个不完整的数据点y,Ω pyq是已知的。ppxp0q采样的插补量|Ωpyqq,我们可以使用无条件模型完成(见附录I.2)。着色是插补的一种特殊情况,除了已知的数据维度是耦合的。我们可以通过正交线性变换将这些数据维度解耦,并在变换后的空间中进行插补(详见附录I.3)。图4(右)显示了使用无条件时间依赖的基于分数的模型实现的修复和着色结果。

6结论

我们提出了一个基于SDE的基于分数的生成建模框架。我们的工作使我们能够更好地理解现有方法、新的采样算法、精确似然计算、唯一可识别编码、潜在代码操作,并为基于分数的生成模型家族带来新的条件生成能力。虽然我们提出的采样方法改进了结果并实现了更有效的采样,但在相同的数据集上,它们的采样速度仍然比GANs(Goodfellow et al.,2014)慢。确定将基于分数的生成模型的稳定学习与GAN等隐式模型的快速采样相结合的方法仍然是一个重要的研究方向。此外,当允许访问评分函数时,可以使用的采样器的广度引入了许多超参数。未来的工作将受益于自动选择和调整这些超参数的改进方法,以及对各种采样器的优点和局限性进行更广泛的研究。

VE, VP AND SUB-VP SDES

下面我们提供了详细的推导,以表明SMLD和DDPM的噪声扰动分别是方差分解(VE)和方差保持(VP)SDE的离散化。我们还引入了、sub-VP SDE,这是对VP SDE的一种修改,通常在样本质量和可能性方面都能获得更好的性能。

首先,当使用总共N个噪声尺度时,SMLD的每个扰动核pσi (x|x0)可以从以下马尔可夫链导出:

 由f'(x)=lim(h->0)[(f(x+h)-f(x))/h] ,

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值