【深度学习】一文搞懂DDPM+代码实现

Computer drawing 计算机绘画

这个方法最早出现在2015年,知道2019年,OpenAI才把这个方法用到了图像生成领域。人们才发现这真的是个宝。

DDPM:Denoise Diffusion Probabilistic Model

我们拿画一幅画做例子。一般人画画都是先画一个草图,然后逐渐上颜色,然后填充细节,然后上颜色,然后不断修饰,不断润色,直至画稿完成。

计算机绝不是这么画的。计算机的画画是一个采样的过程。 好比我们打算画一个人,计算机会拿到关于这个人不同的照片,计算机认为这些不同的照片是从这个人分布中采样出来的一个个样本。而不同的人分布自然也不尽相同,就好比你的眉毛眼睛鼻子的轮廓都和别人不一样。

计算机如果想要画一个人,它得先学习,就从这些个分布里采出来的样本进行学习,从而炼出一个采样器(Sampler / Ramdom Number Generator),这个采样器能够产生出我要画的那个对象所对应的分布的样本。

  • 小结: Computer drawing ⇔ Ramdom Number Sampling \text{Computer drawing} \Leftrightarrow \text{Ramdom Number Sampling} Computer drawingRamdom Number Sampling

DDPM 直观理解

接下来我们来深入认识一下计算机如何训练这个采样器:

假如现在要学习的样本是一张山峦的图片。计算机怎么能够知道这个山峦是个什么分布呢?换句话说计算机怎么能够知道山峦对应的每个像素点灰度取值的概率呢?

Diffusion提供了这样一种思路:一个陌生的分布我们不好把握,但是高斯分布我们研究得非常清楚,能够很好地把握,对高斯进行采样简直不要太简单。 那么我能否在这个陌生分布和高斯分布之间架起一道桥梁,使得从桥梁的一端到桥的另一端能够畅行无阻? 一旦能够建立起这样的桥梁,我就可以在高斯分布的一端进行采样,然后把样本运送到桥梁的另一端变成这个陌生分布的样本。这样一来,我就得到了我想要的目标分布采样器。即“桥梁”本身和高斯分布采样器共同构成的这个陌生分布的采样器。

于是问题便转化成:这道桥梁怎么构建?

假设桥的左边是一个轮廓清晰的世界(山峦),桥的右边是个一片混沌的世界(高斯白噪声)。

从左到右,应该逐渐变得混沌,换句话说,应该逐步添加噪声。为什么要逐步?砖石需要一块一块铺,桥要一段一段地连接,分而治之永远是这个世界解决困难问题的一个普遍真理。

x 0 → + N 1 x 1 → + N 2 x 2 → + N 3 x 3 → . . . → + N T x T ∼ N ( 0 , σ ε I ) , ( T 足够大 ) x_0\stackrel{+N_1}{\to}x_1\stackrel{+N_2}{\to}x_2\stackrel{+N_3}{\to}x_3\to ...\stackrel{+N_T}{\to} x_T\sim N(0, \sigma^{\varepsilon} I), (T足够大) x0+N1x1+N2x2+N3x3...+NTxTN(0,σεI),(T足够大)

这一步,被称为直接扩散(Direct Diffusion),因为没有任何的神经网络。

下一步从右到左,我们前面所铺的这些“砖头”就起作用了。我们可以站在每一块“砖头”上用神经网络对我们的“桥梁”进行学习加固:神经网络沿着我们来时的路一步一砖头地走回去,学习该如何去噪声,也就是反向推理(Reverse Inference)

x 0 ⟵ N N . . . ⟵ N N x t − 1 ⟵ N N x t ⟵ N N . . . ⟵ N N x T ∼ N ( 0 , σ ε I ) x_0\stackrel{NN}{\longleftarrow}...\stackrel{NN}{\longleftarrow}x_{t-1}\stackrel{NN}{\longleftarrow} x_t\stackrel{NN}{\longleftarrow}...\stackrel{NN}{\longleftarrow}x_T\sim N(0, \sigma^{\varepsilon} I) x0NN...NNxt1NNxtNN...NNxTN(0,σεI)

通过不断地来回训练,神经网络最终会明白山是怎么一回事。

直观层面,我们认识了Diffusion是怎么一回事,从数学层面,我们还需要进行细致地剖析。

DDPM 数学建模与公式推导

1. 前向扩散过程(Forward Diffusion Process)

目标:通过逐步添加高斯噪声,将数据 x 0 \mathbf{x}_0 x0 逐渐破坏为纯噪声 x T \mathbf{x}_T xT

数学建模

  1. 单步噪声添加
    定义马尔可夫链,每一步的条件分布为:
    q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}\left(\mathbf{x}_t; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}\right) q(xtxt1)=N(xt;1βt xt1,βtI) 其中 β t ∈ ( 0 , 1 ) \beta_t \in (0,1) βt(0,1) 是噪声调度参数(逐渐增大)。

  2. 闭合解(直接计算 x t \mathbf{x}_t xt
    通过递归展开,得到从 ( \mathbf{x}_0 ) 到 ( \mathbf{x}_t ) 的闭合解:
    x t = α ˉ t x 0 + 1 − α ˉ t ϵ , ϵ ∼ N ( 0 , I ) \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I}) xt=αˉt x0+1αˉt ϵ,ϵN(0,I) 其中:

    • α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt
    • α ˉ t = ∏ s = 1 t α s \bar{\alpha}_t = \prod_{s=1}^t \alpha_s αˉt=s=1tαs

    推导
    通过一个参数trick,将多步噪声叠加合并为单步高斯分布:
    X t = α t X t − 1 + 1 − α t ε , ε ∼ N ( 0 , I ) X_t=\sqrt{\alpha_t}X_{t-1}+\sqrt{1-\alpha_t}\varepsilon, \quad \varepsilon \sim N(0,I) Xt=αt Xt1+1αt ε,εN(0,I)
    X t = α t X t − 1 + 1 − α t ε = α t ( α t − 1 X t − 2 + 1 − α t − 1 ε 1 ) + 1 − α t ε 2 = α t α t − 1 X t − 2 + α t 1 − α t ε 1 + 1 − α t ε 2 = α t α t − 1 X t − 2 + 1 − α t α t − 1 ε , ( ε 1 和 ε 2 独立同分布于 N ( 0 , I ) ) = . . . = ∏ k = 1 t α k X 0 + 1 − ∏ k = 1 t α k ε = α t ‾ X 0 + 1 − α t ‾ ε , ( ∏ k = 1 t α k 记作 α t ‾ ) \begin{aligned}X_t &=\sqrt{\alpha_t}X_{t-1}+\sqrt{1-\alpha_t}\varepsilon \\&=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}X_{t-2}+\sqrt{1-\alpha_{t-1}}\varepsilon_1 )+\sqrt{1-\alpha_t}\varepsilon_2\\&=\sqrt{\alpha_t \alpha_{t-1}}X_{t-2}+\sqrt{\alpha_t}\sqrt{1-\alpha_t}\varepsilon_1+\sqrt{1-\alpha_t}\varepsilon_2\\&=\sqrt{\alpha_t \alpha_{t-1}}X_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\varepsilon,\quad(\varepsilon_1和\varepsilon_2独立同分布于N(0,I))\\&=...\\&=\sqrt{\prod_{k=1}^{t}\alpha_k}X_0+\sqrt{1-\prod_{k=1}^t\alpha_k}\varepsilon \\&=\sqrt{\overline{\alpha_t}}X_0+\sqrt{1-\overline{\alpha_t}}\varepsilon ,\quad (\prod_{k=1}^t\alpha_k记作\overline{\alpha_t})\end{aligned} Xt=αt Xt1+1αt ε=αt (αt1 Xt2+1αt1 ε1)+1αt ε2=αtαt1 Xt2+αt 1αt ε1+1αt ε2=αtαt1 Xt2+1αtαt1 ε,(ε1ε2独立同分布于N(0,I))=...=k=1tαk X0+1k=1tαk ε=αt X0+1αt ε,(k=1tαk记作αt)

上述展开充分说明,通过精心设计的参数Trick,我能够从 X 0 X_0 X0一步跳到任意 X t X_t Xt,即这个分布我现在是清楚的: P ( X t ∣ X 0 ) P(X_t|X_0) P(XtX0),这为之后的Denoise做了充足的准备。


2. 反向去噪过程(Reverse Denoising Process)

目标:学习一个神经网络 ϵ θ \epsilon_\theta ϵθ,逐步从 x t \mathbf{x}_t xt 去噪生成 x t − 1 \mathbf{x}_{t-1} xt1

数学建模

  1. 真实反向分布 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0)
    通过贝叶斯定理推导:
    q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t , β ~ t I ) q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}\left(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t \mathbf{I}\right) q(xt1xt,x0)=N(xt1;μ~t,β~tI) 其中均值和方差为:
    μ ~ t = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 , β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\mu}_t = \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0, \quad \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t μ~t=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0,β~t=1αˉt1αˉt1βt

  2. 参数化反向过程 p θ p_\theta pθ
    定义模型预测均值和方差:
    p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}\left(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_t, t), \Sigma_\theta(\mathbf{x}_t, t)\right) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))
    关键参数化技巧
    μ θ \mu_\theta μθ 表示为噪声预测形式:
    μ θ = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\mathbf{x}_t, t) \right) μθ=αt 1(xt1αˉt βtϵθ(xt,t))
    方差通常设为固定值:
    Σ θ = σ t 2 I , σ t 2 = β t  或  1 − α ˉ t − 1 1 − α ˉ t β t \Sigma_\theta = \sigma_t^2 \mathbf{I}, \quad \sigma_t^2 = \beta_t \ \text{或} \ \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t Σθ=σt2I,σt2=βt  1αˉt1αˉt1βt
    详细推导过程(略长):
    P ( X t − 1 ∣ X t , X 0 ) = P ( X t ∣ X t − 1 , X 0 ) P ( X t − 1 ∣ X 0 ) P ( X t ∣ X 0 ) P(X_{t-1}|X_t,X_0)=P(X_t|X_{t-1},X_0)\frac{P(X_{t-1}|X_0)}{P(X_t|X_0)} P(Xt1Xt,X0)=P(XtXt1,X0)P(XtX0)P(Xt1X0)
    P ( X t ∣ X t − 1 , X 0 ) ∼ N ( α t X t − 1 , ( 1 − α t ) I ) P ( X t ∣ X 0 ) ∼ N ( α t ‾ X 0 , ( 1 − α t ‾ ) I ) P ( X t − 1 ∣ X 0 ) ∼ N ( α t − 1 ‾ X 0 , ( 1 − α t − 1 ‾ ) I ) \begin{aligned}&P(X_t|X_{t-1},X_0)\sim N(\sqrt{\alpha_t}X_{t-1},(1-\alpha_t)I)\\&P(X_t|X_0)\sim N(\sqrt{\overline{\alpha_t}}X_0, (1-\overline{\alpha_t})I)\\&P(X_{t-1}|X_0)\sim N(\sqrt{\overline{\alpha_{t-1}}}X_0, (1-\overline{\alpha_{t-1}})I) \end{aligned} P(XtXt1,X0)N(αt Xt1,(1αt)I)P(XtX0)N(αt X0,(1αt)I)P(Xt1X0)N(αt1 X0,(1αt1)I)接下来我们只需要把这三个高斯分布的概率密度展开计算即可得到关于 X t − 1 X_{t-1} Xt1的概率密度。注意由于是高斯分布,我只关心指数上方的那个二次型,并且为了书写简便,我会用一维来代替(比如协方差矩阵的逆写到分母上、分子的向量相乘写成平方;事实上这么做并不损失任何精神实质):
    − 1 2 ( ( X t − α t X t − 1 ) 2 1 − α t + ( X t − 1 − α t − 1 ‾ X 0 ) 2 1 − α t − 1 ‾ − ( ( X t − α t ‾ X 0 ) 2 1 − α t ‾ ) -\frac{1}{2}(\frac{(X_t -\sqrt{\alpha_t}X_{t-1})^2}{1-\alpha_t}+ \frac{(X_{t-1} -\sqrt{\overline{\alpha_{t-1}}}X_0)^2}{1-\overline{\alpha_{t-1}}}- (\frac{(X_t -\sqrt{\overline{\alpha_t}}X_{0})^2}{1-\overline{\alpha_t}}) 21(1αt(Xtαt Xt1)2+1αt1(Xt1αt1 X0)2(1αt(Xtαt X0)2)
    接下来我们只需展开做一个配方:注意,我们是对 X t − 1 X_{t-1} Xt1做的配方,因为 X t − 1 X_{t-1} Xt1才是主元,此时 X t 和 X 0 X_t和X_0 XtX0可以看作常数。
    − 1 2 ( X t 2 − 2 α t X t X t − 1 + α t X t − 1 2 1 − α t + X t − 1 2 − 2 α t − 1 ‾ X t − 1 X 0 + α t − 1 ‾ X 0 2 1 − α t − 1 ‾ − X t 2 − 2 α t ‾ X 0 X t + α t ‾ X 0 2 1 − α t ‾ ) = − 1 2 ( ( α t 1 − α t + 1 1 − α t − 1 ‾ ) X t − 1 2 − 2 X t − 1 ( α t 1 − α t X t + α t − 1 ‾ 1 − α t − 1 ‾ X 0 ) + C ( X 0 , X t ) ) -\frac{1}{2}(\frac{X_t^2 -2\sqrt{\alpha_t}X_tX_{t-1}+\alpha_tX_{t-1}^2}{1-\alpha_t}+\frac{X_{t-1}^2 -2\sqrt{\overline{\alpha_{t-1}}}X_{t-1}X_0+\overline{\alpha_{t-1}}X_0^2}{1-\overline{\alpha_{t-1}}}- \frac{X_t^2 -2\sqrt{\overline{\alpha_t}}X_{0}X_t+\overline{\alpha_t}X_0^2}{1-\overline{\alpha_t}})\\= -\frac{1}{2}((\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\overline{\alpha_{t-1}}})X_{t-1}^2-2X_{t-1}(\frac{\sqrt{\alpha_t}}{1-\alpha_t}X_t+\frac{\sqrt{\overline{\alpha_{t-1}}}}{1-\overline{\alpha_{t-1}}}X_0)+C(X_0,X_t)) 21(1αtXt22αt XtXt1+αtXt12+1αt1Xt122αt1 Xt1X0+αt1X021αtXt22αt X0Xt+αtX02)=21((1αtαt+1αt11)Xt122Xt1(1αtαt Xt+1αt1αt1 X0)+C(X0,Xt))至此,我们可以从上面这个二次型中直接写出 X t − 1 X_{t-1} Xt1方差 X t − 1 2 X_{t-1}^2 Xt12前面的系数的逆):
    V a r = ( α t 1 − α t + 1 1 − α t − 1 ‾ ) − 1 = ( α t ( 1 − α t − 1 ‾ ) + 1 − α t ( 1 − α t ) ( 1 − α t − 1 ‾ ) ) − 1 = ( 1 − α t ‾ ( 1 − α t ) ( 1 − α t − 1 ‾ ) ) − 1 = 1 − α t − 1 ‾ 1 − α t ‾ β \begin{aligned}Var&=(\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\overline{\alpha_{t-1}}})^{-1} \\[8pt]&=(\frac{\alpha_t(1-\overline{\alpha_{t-1}})+1-\alpha_t}{(1-\alpha_t)(1-\overline{\alpha_{t-1}})})^{-1}\\[8pt] &=(\frac{1-\overline{\alpha_t}}{(1-\alpha_t)(1-\overline{\alpha_{t-1}})})^{-1}\\&=\frac{1-\overline{\alpha_{t-1}}}{1-\overline{\alpha_t}}\beta\end{aligned} Var=(1αtαt+1αt11)1=((1αt)(1αt1)αt(1αt1)+1αt)1=((1αt)(1αt1)1αt)1=1αt1αt1β均值
    M e a n = α t 1 − α t X t + α t − 1 ‾ 1 − α t − 1 ‾ X 0 1 − α t ‾ ( 1 − α t ) ( 1 − α t − 1 ‾ ) = ( 1 − α t − 1 ‾ ) α t 1 − α t ‾ X t + ( 1 − α t ) α t − 1 ‾ 1 − α t ‾ X 0 \begin{aligned}Mean=&\frac{\frac{\sqrt{\alpha_t}}{1-\alpha_t}X_t+\frac{\sqrt{\overline{\alpha_{t-1}}}}{1-\overline{\alpha_{t-1}}}X_0}{\frac{1-\overline{\alpha_t}}{(1-\alpha_t)(1-\overline{\alpha_{t-1}})}}\\\\ =&\frac{(1-\overline{\alpha_{t-1}})\sqrt{\alpha_t}}{1-\overline{\alpha_t}}X_t+\frac{(1-\alpha_t)\sqrt{\overline{\alpha_{t-1}}}}{1-\overline{\alpha_t}}X_0\end{aligned} Mean==(1αt)(1αt1)1αt1αtαt Xt+1αt1αt1 X01αt(1αt1)αt Xt+1αt(1αt)αt1 X0到这里我们就明白了,神经网络在每一步的过程中其实是在利用 X t X_{t} Xt X t − 1 X_{t-1} Xt1做拟合。然而,在实际推断的时候, X 0 X_0 X0作为Ground Truth,其背后的分布我们根本不知道(它只能作为最终效果的评判标准),于是现在变成了我想用 X t X_t Xt估计 X t − 1 X_{t-1} Xt1那么首先得估计出 X 0 X_0 X0!如果神经网络能够这么简单估计出 X 0 X_0 X0的话,我还要估计 X t − 1 X_{t-1} Xt1作甚?俗话说的好,一口吃不成胖子:直接估计 X 0 X_0 X0这个任务太困难,我们需要把 X 0 X_0 X0替换成一个别的目标进行估计。那么怎么替换呢?这时候别忘了我们在前头Diffusion Forward阶段有这么一个结论:
    X t = α t ‾ X 0 + 1 − α t ‾ ε ⇒ X 0 = 1 α t ‾ ( X t − 1 − α t ‾ ε ) X_t=\sqrt{\overline{\alpha_t}}X_0+\sqrt{1-\overline{\alpha_t}}\varepsilon\\\Rightarrow X_0=\frac{1}{\sqrt{\overline{\alpha_t}}}(X_t-\sqrt{1-\overline{\alpha_t}}\varepsilon) Xt=αt X0+1αt εX0=αt 1(Xt1αt ε)我们可以把 X 0 X_0 X0替换成噪声,而用神经网络来估计这个噪声!这才是合理的做法,因为前面我们也提到了,神经网络在Denoise的阶段学习的是如何去噪声,那么估计噪声不就显得十分合理。
    M e a n = ( 1 − α t − 1 ‾ ) α t 1 − α t ‾ X t + ( 1 − α t ) α t − 1 ‾ 1 − α t ‾ ( 1 α t ‾ ( X t − 1 − α t ‾ ε ) ) = ( ( 1 − α t − 1 ‾ ) α t 1 − α t ‾ + ( 1 − α t ) α t − 1 ‾ ( 1 − α t ‾ ) α t ‾ ) X t − ( ( 1 − α t ) α t − 1 ‾ ( 1 − α t ‾ ) α t ‾ 1 − α t ‾ ) ε = α t α t − 1 ‾ ( 1 − α t − 1 ‾ ) + ( 1 − α t ) α t − 1 ‾ ( 1 − α t ‾ ) α t ‾ X t + . . . = α t − 1 ‾ α t ‾ X t − ( 1 − α t 1 − α t ‾ α t ) ε = 1 α t X t − ( 1 − α t 1 − α t ‾ α t ) ε = 1 α t ( X t − β t 1 − α t ‾ ε ) \begin{aligned}Mean=&\frac{(1-\overline{\alpha_{t-1}})\sqrt{\alpha_t}}{1-\overline{\alpha_t}}X_t+\frac{(1-\alpha_t)\sqrt{\overline{\alpha_{t-1}}}}{1-\overline{\alpha_t}}(\frac{1}{\sqrt{\overline{\alpha_t}}}(X_t-\sqrt{1-\overline{\alpha_t}}\varepsilon))\\[8pt] =&(\frac{(1-\overline{\alpha_{t-1}})\sqrt{\alpha_t}}{1-\overline{\alpha_t}}+\frac{(1-\alpha_t)\sqrt{\overline{\alpha_{t-1}}}}{(1-\overline{\alpha_t})\sqrt{\overline{\alpha_t}}})X_t -(\frac{(1-\alpha_t)\sqrt{\overline{\alpha_{t-1}}}}{(1-\overline{\alpha_t})\sqrt{\overline{\alpha_t}}}\sqrt{1-\overline{\alpha_t}})\varepsilon\\[8pt] =& \frac{\alpha_t\sqrt{\overline{\alpha_{t-1}}}(1-\overline{\alpha_{t-1}})+(1-\alpha_t)\sqrt{\overline{\alpha_{t-1}}}}{(1-\overline{\alpha_t})\sqrt{\overline{\alpha_t}}}X_t+...\\[8pt]=&\frac{\sqrt{\overline{\alpha_{t-1}}}}{\sqrt{\overline{\alpha_t}}}X_t-(\frac{1-\alpha_t}{\sqrt{1-\overline{\alpha_t}}\sqrt{\alpha_t}})\varepsilon \\[8pt]=&\frac{1}{\sqrt{\alpha_t}}X_t-(\frac{1-\alpha_t}{\sqrt{1-\overline{\alpha_t}}\sqrt{\alpha_t}})\varepsilon\\ =&\frac{1}{\sqrt{\alpha_t}}\bigg(X_t-\frac{\beta_t}{\sqrt{1-\overline{\alpha_t}}}\varepsilon\bigg) \end{aligned} Mean======1αt(1αt1)αt Xt+1αt(1αt)αt1 (αt 1(Xt1αt ε))(1αt(1αt1)αt +(1αt)αt (1αt)αt1 )Xt((1αt)αt (1αt)αt1 1αt )ε(1αt)αt αtαt1 (1αt1)+(1αt)αt1 Xt+...αt αt1 Xt(1αt αt 1αt)εαt 1Xt(1αt αt 1αt)εαt 1(Xt1αt βtε)至此,大家在文献里头一贯越过的复杂运算都已经算清楚了。


3. 损失函数(Loss Function)

扩散模型的损失函数源自变分推断中的 证据下界(Evidence Lower Bound, ELBO) 的优化。以下是详细推导过程:

1. 目标:最大化数据似然

扩散模型的目标是最大化训练数据 x 0 \mathbf{x}_0 x0 的似然 p θ ( x 0 ) p_\theta(\mathbf{x}_0) pθ(x0)。由于直接计算似然困难,转而优化其下界(ELBO)。


2. 变分下界(ELBO)

通过引入隐变量 x 1 : T \mathbf{x}_{1:T} x1:T,ELBO 分解为:
log ⁡ p θ ( x 0 ) ≥ ELBO = E q [ log ⁡ p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] \log p_\theta(\mathbf{x}_0) \geq \text{ELBO} = \mathbb{E}_q \left[ \log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} | \mathbf{x}_0)} \right] logpθ(x0)ELBO=Eq[logq(x1:Tx0)pθ(x0:T)]
推导:
ELBO(Evidence Lower Bound,证据下界)是通过变分推断(Variational Inference)构造出来的,其核心思想是用一个简单的分布去近似复杂的后验分布,从而将难以直接优化的对数似然问题转化为一个可计算的下界优化问题。以下是 ELBO 的构造过程及其推导的详细解释。

1. 问题背景
假设我们有一个生成模型,其联合分布为 p θ ( x , z ) p_\theta(\mathbf{x}, \mathbf{z}) pθ(x,z),其中:

  • x \mathbf{x} x 是观测数据。
  • z \mathbf{z} z 是隐变量。
  • θ \theta θ 是模型参数。

我们的目标是最大化观测数据的对数似然 log ⁡ p θ ( x ) \log p_\theta(\mathbf{x}) logpθ(x),但由于隐变量 z \mathbf{z} z 的存在,直接计算 log ⁡ p θ ( x ) \log p_\theta(\mathbf{x}) logpθ(x) 通常非常困难(涉及到对隐变量积分)。

2. 引入变分分布
为了近似后验分布 p θ ( z ∣ x ) p_\theta(\mathbf{z} | \mathbf{x}) pθ(zx),我们引入一个变分分布 q ϕ ( z ∣ x ) q_\phi(\mathbf{z} | \mathbf{x}) qϕ(zx),其中 ϕ \phi ϕ 是变分参数。这个分布是人为设计的,通常选择简单的分布(如高斯分布),以便于计算。

3. 构造 ELBO
通过引入变分分布,我们可以将 log ⁡ p θ ( x ) \log p_\theta(\mathbf{x}) logpθ(x) 分解为两部分:
log ⁡ p θ ( x ) = E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ) ] ( 这里 x 被条件住了, 所以其随机性可以看作暂时消失了, 从而对 z 做期望的时候可以看作一个常数 ) \begin{aligned} \log p_\theta(\mathbf{x}) = \mathbb{E}_{q_\phi(\mathbf{z} | \mathbf{x})} \left[ \log p_\theta(\mathbf{x}) \right]_{(这里x被条件住了,\\所以其随机性可以看作暂时消失了,\\从而对z做期望的时候可以看作一个常数)}\end{aligned} logpθ(x)=Eqϕ(zx)[logpθ(x)](这里x被条件住了,所以其随机性可以看作暂时消失了,从而对z做期望的时候可以看作一个常数)利用条件概率公式 p θ ( x , z ) = p θ ( x ∣ z ) p θ ( z ) p_\theta(\mathbf{x}, \mathbf{z}) = p_\theta(\mathbf{x} | \mathbf{z}) p_\theta(\mathbf{z}) pθ(x,z)=pθ(xz)pθ(z),将上式改写为:
log ⁡ p θ ( x ) = E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x , z ) p θ ( z ∣ x ) ] \log p_\theta(\mathbf{x}) = \mathbb{E}_{q_\phi(\mathbf{z} | \mathbf{x})} \left[ \log \frac{p_\theta(\mathbf{x}, \mathbf{z})}{p_\theta(\mathbf{z} | \mathbf{x})} \right] logpθ(x)=Eqϕ(zx)[logpθ(zx)pθ(x,z)]

p θ ( z ∣ x ) p_\theta(\mathbf{z} | \mathbf{x}) pθ(zx) 替换为变分分布 q ϕ ( z ∣ x ) q_\phi(\mathbf{z} | \mathbf{x}) qϕ(zx),并引入一个 KL 散度项:
log ⁡ p θ ( x ) = E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) ] + D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) \log p_\theta(\mathbf{x}) = \mathbb{E}_{q_\phi(\mathbf{z} | \mathbf{x})} \left[ \log \frac{p_\theta(\mathbf{x}, \mathbf{z})}{q_\phi(\mathbf{z} | \mathbf{x})} \right] + D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) \| p_\theta(\mathbf{z} | \mathbf{x})) logpθ(x)=Eqϕ(zx)[logqϕ(zx)pθ(x,z)]+DKL(qϕ(zx)pθ(zx))

由于 KL 散度 D K L D_{KL} DKL 是非负的,因此:
log ⁡ p θ ( x ) ≥ E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) ] \log p_\theta(\mathbf{x}) \geq \mathbb{E}_{q_\phi(\mathbf{z} | \mathbf{x})} \left[ \log \frac{p_\theta(\mathbf{x}, \mathbf{z})}{q_\phi(\mathbf{z} | \mathbf{x})} \right] logpθ(x)Eqϕ(zx)[logqϕ(zx)pθ(x,z)]右边的表达式就是 ELBO(证据下界):
ELBO = E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) ] \text{ELBO} = \mathbb{E}_{q_\phi(\mathbf{z} | \mathbf{x})} \left[ \log \frac{p_\theta(\mathbf{x}, \mathbf{z})}{q_\phi(\mathbf{z} | \mathbf{x})} \right] ELBO=Eqϕ(zx)[logqϕ(zx)pθ(x,z)]


3. 变分下界(ELBO)分解

展开后得到:
ELBO = E q [ log ⁡ p θ ( x 0 ∣ x 1 ) ] ⏟ 重建项 − ∑ t = 2 T E q [ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ] ⏟ KL散度项 − D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) ⏟ 先验匹配项 \text{ELBO} = \underbrace{\mathbb{E}_{q}[\log p_\theta(\mathbf{x}_0|\mathbf{x}_1)]}_{\text{重建项}} - \sum_{t=2}^T \underbrace{\mathbb{E}_q \left[ D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) \| p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)) \right]}_{\text{KL散度项}} - \underbrace{D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0) \| p(\mathbf{x}_T))}_{\text{先验匹配项}} ELBO=重建项 Eq[logpθ(x0x1)]t=2TKL散度项 Eq[DKL(q(xt1xt,x0)pθ(xt1xt))]先验匹配项 DKL(q(xTx0)p(xT))
推导:
ELBO 可以进一步分解为两项:
ELBO = E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ∣ z ) ] − D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ) ) \text{ELBO} = \mathbb{E}_{q_\phi(\mathbf{z} | \mathbf{x})} \left[ \log p_\theta(\mathbf{x} | \mathbf{z}) \right] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) \| p_\theta(\mathbf{z})) ELBO=Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)pθ(z))

(1)重建项(Reconstruction Term)
E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ∣ z ) ] \mathbb{E}_{q_\phi(\mathbf{z} | \mathbf{x})} \left[ \log p_\theta(\mathbf{x} | \mathbf{z}) \right] Eqϕ(zx)[logpθ(xz)]

  • 衡量从隐变量 z \mathbf{z} z 重建数据 x \mathbf{x} x 的能力。
  • 在生成模型中,通常建模为高斯分布或伯努利分布。

(2)KL 散度项(KL Divergence Term)
D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ) ) D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) \| p_\theta(\mathbf{z})) DKL(qϕ(zx)pθ(z))

  • 衡量变分分布 q ϕ ( z ∣ x ) q_\phi(\mathbf{z} | \mathbf{x}) qϕ(zx) 与先验分布 p θ ( z ) p_\theta(\mathbf{z}) pθ(z) 的差异。
  • 用于正则化隐变量的分布,避免其偏离先验。
  • KL散度项又可以根据条件概率公式拆解为多个时间步的KL散度项,最终拼接上一个先验匹配项,也就有了原式。

4. 关键项简化
(1) 重建项

对应最后一步从 x 1 \mathbf{x}_1 x1 生成 x 0 \mathbf{x}_0 x0,通常简化为均方误差或交叉熵。

(2) 先验匹配项

由于前向过程的最终分布 q ( x T ∣ x 0 ) ≈ N ( 0 , I ) q(\mathbf{x}_T|\mathbf{x}_0) \approx \mathcal{N}(0, \mathbf{I}) q(xTx0)N(0,I),与先验 p ( x T ) = N ( 0 , I ) p(\mathbf{x}_T) = \mathcal{N}(0, \mathbf{I}) p(xT)=N(0,I) 一致,此项趋近于零。

(3) KL散度项

核心在于 反向过程分布 p θ p_\theta pθ 与真实后验 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0) 的匹配


5. KL散度的闭合解

对于一般的高斯分布 q q q p θ p_\theta pθ,KL散度有闭合解:
D K L ( q ∥ p θ ) = 1 2 ( ∥ μ ~ t − μ θ ∥ 2 σ t 2 + log ⁡ σ t 2 σ ~ t 2 − 1 + σ ~ t 2 σ t 2 ) D_{KL}(q \| p_\theta) = \frac{1}{2} \left( \frac{\|\tilde{\mu}_t - \mu_\theta\|^2}{\sigma_t^2} + \log \frac{\sigma_t^2}{\tilde{\sigma}_t^2} - 1 + \frac{\tilde{\sigma}_t^2}{\sigma_t^2} \right) DKL(qpθ)=21(σt2μ~tμθ2+logσ~t2σt21+σt2σ~t2)
其中:

  • μ ~ t \tilde{\mu}_t μ~t: 真实后验均值 μ θ = 1 α t ( x t − β t 1 − α ˉ t ϵ ) \mu_\theta = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon \right) μθ=αt 1(xt1αˉt βtϵ)
  • μ θ \mu_\theta μθ: 模型预测均值 μ θ = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\mathbf{x}_t, t) \right) μθ=αt 1(xt1αˉt βtϵθ(xt,t))
  • σ ~ t 2 \tilde{\sigma}_t^2 σ~t2: 真实后验方差 β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t β~t=1αˉt1αˉt1βt
  • σ t 2 \sigma_t^2 σt2: 模型方差(通常固定为 β t \beta_t βt β ~ t \tilde{\beta}_t β~t

代入KL散度公式后:
D K L ( q ∥ p θ ) ∝ ∥ μ ~ t − μ θ ∥ 2 σ t 2 = ∥ β t α t ( 1 − α ˉ t ) ( ϵ − ϵ θ ( x t , t ) ) ∥ 2 σ t 2 D_{KL}(q \| p_\theta) \propto \frac{\|\tilde{\mu}_t - \mu_\theta\|^2}{\sigma_t^2} = \frac{\left\| \frac{\beta_t}{\sqrt{\alpha_t (1-\bar{\alpha}_t)}} (\epsilon - \epsilon_\theta(\mathbf{x}_t, t)) \right\|^2}{\sigma_t^2} DKL(qpθ)σt2μ~tμθ2=σt2 αt(1αˉt) βt(ϵϵθ(xt,t)) 2

化简后:

D K L ( q ∥ p θ ) ∝ β t 2 2 σ t 2 α t ( 1 − α ˉ t ) ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 D_{KL}(q \| p_\theta) \propto \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1-\bar{\alpha}_t)} \| \epsilon - \epsilon_\theta(\mathbf{x}_t, t) \|^2 DKL(qpθ)2σt2αt(1αˉt)βt2ϵϵθ(xt,t)2

取期望后:

D K L ∝ E [ β t 2 2 σ t 2 α t ( 1 − α ˉ t ) ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] D_{KL} \propto \mathbb{E} \left[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1-\bar{\alpha}_t)} \| \epsilon - \epsilon_\theta(\mathbf{x}_t, t) \|^2 \right] DKLE[2σt2αt(1αˉt)βt2ϵϵθ(xt,t)2]


6. 最终损失函数

忽略常数系数后,得到简化的均方误差损失:
L = E t , x 0 , ϵ [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] \mathcal{L} = \mathbb{E}_{t, \mathbf{x}_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(\mathbf{x}_t, t) \|^2 \right] L=Et,x0,ϵ[ϵϵθ(xt,t)2]
其中:

  • x t = α ˉ t x 0 + 1 − α ˉ t ϵ \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t} \epsilon xt=αˉt x0+1αˉt ϵ(前向过程采样)
  • ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, \mathbf{I}) ϵN(0,I)

最终损失函数为:
L = E t , x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] \mathcal{L} = \mathbb{E}_{t, \mathbf{x}_0, \epsilon} \left[ \left\| \epsilon - \epsilon_\theta\left( \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t} \epsilon, t \right) \right\|^2 \right] L=Et,x0,ϵ[ ϵϵθ(αˉt x0+1αˉt ϵ,t) 2]
这是扩散模型训练的核心目标,即最小化噪声预测误差。

代码实现:

  • UNet模型(实现噪声预测网络 ϵ θ ( x t , t ) ϵ_θ(x_t,t) ϵθ(xt,t)):
class Block(nn.Module):
    def forward(self, x, t):
        # 时间步嵌入对应公式中的t输入
        time_emb = self.act(self.time_mlp(t))  # 将标量t映射为高维向量
        h = h + time_emb.reshape(...)  # 将时间信息注入每个空间位置

class UNet(nn.Module):
    def forward(self, x, t):
        t = self.time_mlp(t.unsqueeze(-1).float())  # 时间编码器
        # 下采样-上采样结构用于捕捉多尺度特征
        ...
        # 跳跃连接帮助梯度流动(与原始UNet设计一致)
        ...

其中,时间步嵌入对应扩散步数t的条件输入

  • 训练循环(实现损失函数)
# 随机采样时间步(均匀分布)
t = torch.randint(0, timesteps, (images.shape[0],)).to(device)  # t ∼ Uniform(0, T-1)

# 前向加噪过程
noisy_images, noise = scheduler.add_noise(images, t)  # 计算x_t

# 噪声预测
pred_noise = model(noisy_images, t)  # ε_θ(x_t, t)

# 损失计算
loss = nn.MSELoss()(pred_noise, noise)  # ||ε - ε_θ||²

其中,随机时间步采样用来实现期望的蒙特卡洛估计


4. 采样过程(Sampling)

目标:从噪声 x T ∼ N ( 0 , I ) \mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I}) xTN(0,I) 出发,逐步去噪生成数据 x 0 \mathbf{x}_0 x0

DDPM采样步骤
对于 t = T , T − 1 , … , 1 t = T, T-1, \dots, 1 t=T,T1,,1

  1. 预测噪声:
    ϵ θ = ϵ θ ( x t , t ) \epsilon_\theta = \epsilon_\theta(\mathbf{x}_t, t) ϵθ=ϵθ(xt,t)
  2. 计算均值:
    μ θ = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ) \mu_\theta = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta \right) μθ=αt 1(xt1αˉt βtϵθ)
  3. 添加噪声(当 ( t > 1 ) 时):
    x t − 1 = μ θ + σ t z , z ∼ N ( 0 , I ) \mathbf{x}_{t-1} = \mu_\theta + \sigma_t \mathbf{z}, \quad \mathbf{z} \sim \mathcal{N}(0, \mathbf{I}) xt1=μθ+σtz,zN(0,I) t = 1 t = 1 t=1 时, z = 0 \mathbf{z} = 0 z=0

代码实现:

  • 噪声调度器(对应前向过程公式)
class NoiseScheduler:
    def __init__(self, beta_start=1e-4, beta_end=0.02):
        # 对应 β_t 的线性调度
        self.betas = torch.linspace(beta_start, beta_end, timesteps)  # 公式中的{β_t}
        
        # 计算 α 系列
        self.alphas = 1. - self.betas                      # α_t = 1 - β_t
        self.alpha_cumprod = torch.cumprod(self.alphas, 0)  # ᾱ_t = ∏α_s
        
        # 预计算系数(优化用)
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)          # √ᾱ_t
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1. - self.alpha_cumprod)  # √(1-ᾱ_t)

    def add_noise(self, x, t):
        noise = torch.randn_like(x)  # ϵ ∼ N(0, I)
        # 对应公式 x_t = √ᾱ_t x_0 + √(1-ᾱ_t)ϵ
        sqrt_alpha = self.sqrt_alpha_cumprod[t].reshape(-1, 1, 1, 1)  # 广播维度
        sqrt_one_minus_alpha = self.sqrt_one_minus_alpha_cumprod[t].reshape(-1, 1, 1, 1)
        return sqrt_alpha * x + sqrt_one_minus_alpha * noise, noise
  • 采样过程(实现反向扩散)
def sample(num_images=16):
    x = torch.randn((num_images, 1, 28, 28))  # x_T ∼ N(0, I)
    
    for t in reversed(range(timesteps)):
        pred_noise = model(x, t_tensor)  # 预测噪声
        
        # 对应反向过程均值计算
        x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * pred_noise)
        
        # 添加噪声项(最后一步不加)
        if t > 0:
            x += torch.sqrt(beta) * torch.randn_like(x)

5. 代码与公式对应关系

数学公式代码实现
x t = α ˉ t x 0 + 1 − α ˉ t ϵ \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t} \epsilon xt=αˉt x0+1αˉt ϵNoiseScheduler.add_noise
ϵ θ ( x t , t ) \epsilon_\theta(\mathbf{x}_t, t) ϵθ(xt,t)UNet 前向传播
L = E [ ∣ ϵ − ϵ θ ∣ 2 ] \mathcal{L} = \mathbb{E}[| \epsilon - \epsilon_\theta |^2] L=E[ϵϵθ2]nn.MSELoss(pred_noise, noise)
反向采样公式sample 函数中的循环

6. 关键改进方向

  1. 噪声调度

    • 线性调度: β t \beta_t βt 线性增加。
    • 余弦调度: β t = cosine ( t / T ) \beta_t = \text{cosine}(t/T) βt=cosine(t/T),更平滑。
  2. 加速采样

    • DDIM(Denoising Diffusion Implicit Models):确定性采样,跳过部分步骤。
    • 采样公式:
      x t − 1 = α ˉ t − 1 ( x t − 1 − α ˉ t ϵ θ α ˉ t ) + 1 − α ˉ t − 1 ϵ θ \mathbf{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \left( \frac{\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t} \epsilon_\theta}{\sqrt{\bar{\alpha}_t}} \right) + \sqrt{1-\bar{\alpha}_{t-1}} \epsilon_\theta xt1=αˉt1 (αˉt xt1αˉt ϵθ)+1αˉt1 ϵθ
  3. 条件生成

    • 分类器引导:在损失中加入类别条件。
      L = E t , x 0 , y , ϵ [ ∥ ϵ − ϵ θ ( x t , t , y ) ∥ 2 ] \mathcal{L} = \mathbb{E}_{t, \mathbf{x}_0,y, \epsilon} \left[ \left\| \epsilon - \epsilon_\theta\left(\mathbf x_t, t,y \right) \right\|^2 \right] L=Et,x0,y,ϵ[ϵϵθ(xt,t,y)2]
for images, labels in dataloader:
    images = images.to(device)
    labels = labels.to(device)  # 类别标签
    
    # 随机采样时间步
    t = torch.randint(0, timesteps, (images.size(0),)).to(device)
    
    # 添加噪声
    noisy_images, noise = scheduler.add_noise(images, t)
    
    # 预测噪声(加入条件信息)
    pred_noise = model(noisy_images, t, labels)
    
    # 计算损失
    loss = nn.MSELoss()(pred_noise, noise)
    
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

DDPM 完整代码实现

以下是使用 PyTorch 实现一个基础扩散模型的代码示例。该模型将在 MNIST 数据集上训练,并能够生成手写数字图像。

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

# 超参数设置
BATCH_SIZE = 128
NUM_EPOCHS = 10
TIMESTEPS = 1000
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
VALIDATION_RATIO = 0.1
CLASSIFIER_GUIDANCE_SCALE = 3.0  # 分类引导的强度

# 设备配置
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. 定义噪声调度器,采用余弦调度策略
class NoiseScheduler:
    def __init__(self, timesteps=TIMESTEPS):
        """
        初始化噪声调度器,使用余弦调度策略生成 betas 值。
        :param timesteps: 时间步数
        """
        s = 0.008
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        betas = torch.clip(betas, 0, 0.999)

        # 将所有张量转换为 float32 类型
        self.betas = betas.to(DEVICE).float()
        self.alphas = (1. - self.betas).float()
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0).float()
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod).float()
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1. - self.alpha_cumprod).float()

    def add_noise(self, x, t):
        """
        向输入图像添加噪声。
        :param x: 输入图像
        :param t: 时间步
        :return: 添加噪声后的图像和噪声
        """
        # 确保噪声的类型和输入图像一致
        noise = torch.randn_like(x).type_as(x)
        sqrt_alpha = self.sqrt_alpha_cumprod[t].reshape(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alpha_cumprod[t].reshape(-1, 1, 1, 1)
        return sqrt_alpha * x + sqrt_one_minus_alpha * noise, noise

# 2. 定义 UNet 模型,添加残差块和注意力机制
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        """
        初始化残差块。
        :param in_ch: 输入通道数
        :param out_ch: 输出通道数
        :param time_emb_dim: 时间嵌入维度
        """
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.act = nn.SiLU()
        self.residual = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t):
        """
        前向传播。
        :param x: 输入特征图
        :param t: 时间嵌入
        :return: 输出特征图
        """
        h = self.act(self.conv1(x))
        time_emb = self.act(self.time_mlp(t))
        h = h + time_emb.reshape(time_emb.shape[0], -1, 1, 1)
        h = self.act(self.conv2(h))
        return h + self.residual(x)

class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        """
        初始化注意力块。
        :param in_channels: 输入通道数
        """
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_channels)
        self.attention = nn.MultiheadAttention(in_channels, num_heads=1, batch_first=True)

    def forward(self, x):
        """
        前向传播。
        :param x: 输入特征图
        :return: 输出特征图
        """
        b, c, h, w = x.shape
        x = self.group_norm(x)
        x_flat = x.view(b, c, -1).permute(0, 2, 1)
        out, _ = self.attention(x_flat, x_flat, x_flat)
        # 确保恢复的形状和输入 x 一致
        out = out.permute(0, 2, 1).view(b, c, h, w)
        return out + x

class UNet(nn.Module):
    def __init__(self, num_classes=10):
        """
        初始化 UNet 模型。
        :param num_classes: 类别数量
        """
        super().__init__()
        self.time_emb_dim = 32
        self.time_mlp = nn.Sequential(
            nn.Linear(1, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim)
        )
        self.class_embedding = nn.Embedding(num_classes, self.time_emb_dim)

        self.down1 = ResidualBlock(1, 64, self.time_emb_dim)
        self.down2 = ResidualBlock(64, 128, self.time_emb_dim)
        self.middle_res = ResidualBlock(128, 128, self.time_emb_dim)
        self.middle_attn = AttentionBlock(128)
        self.up2 = ResidualBlock(128 + 128, 64, self.time_emb_dim)
        self.up1 = ResidualBlock(64 + 64, 1, self.time_emb_dim)
        self.avgpool = nn.AvgPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x, t, y=None):
        """
        前向传播。
        :param x: 输入图像
        :param t: 时间步
        :param y: 类别标签
        :return: 预测的噪声
        """
        t = self.time_mlp(t.unsqueeze(-1).float())
        if y is not None:
            class_emb = self.class_embedding(y)
            t = t + class_emb

        # 下采样
        x1 = self.down1(x, t)
        x = self.avgpool(x1)
        x2 = self.down2(x, t)
        x = self.avgpool(x2)

        # 中间层
        x = self.middle_res(x, t)
        x = self.middle_attn(x)

        # 上采样
        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x, t)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        return self.up1(x, t)

# 3. 数据准备,添加数据增强和数据集划分
def prepare_data():
    """
    准备 MNIST 数据集,包括数据增强和数据集划分。
    :return: 训练数据加载器、验证数据加载器
    """
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    train_size = int((1 - VALIDATION_RATIO) * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_dataloader, val_dataloader

# 4. 初始化模型和优化器,添加学习率调度器
def initialize_model():
    """
    初始化 UNet 模型、优化器、噪声调度器和学习率调度器。
    :return: 模型、优化器、噪声调度器、学习率调度器
    """
    model = UNet().to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = NoiseScheduler()
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    return model, optimizer, scheduler, lr_scheduler

# 5. 训练循环,添加验证过程
def train_model(model, optimizer, scheduler, lr_scheduler, train_dataloader, val_dataloader):
    """
    训练模型,包括训练过程和验证过程。
    :param model: UNet 模型
    :param optimizer: 优化器
    :param scheduler: 噪声调度器
    :param lr_scheduler: 学习率调度器
    :param train_dataloader: 训练数据加载器
    :param val_dataloader: 验证数据加载器
    """
    criterion = nn.MSELoss()
    best_val_loss = float('inf')
    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0
        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} (Training)") as pbar:
            for i, (images, labels) in enumerate(train_dataloader):
                # 将图像数据转换为 float32 类型
                images = images.to(DEVICE).float()
                labels = labels.to(DEVICE)

                # 随机采样时间步
                t = torch.randint(0, TIMESTEPS, (images.shape[0],)).to(DEVICE)

                # 添加噪声
                noisy_images, noise = scheduler.add_noise(images, t)

                # 预测噪声
                pred_noise = model(noisy_images, t, labels)

                # 计算损失
                loss = criterion(pred_noise, noise)

                # 反向传播
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                pbar.set_postfix({
                    'Train Loss': f'{train_loss / (i + 1):.4f}',
                })
                pbar.update(1)

        # 验证过程
        model.eval()
        val_loss = 0
        with torch.no_grad():
            with tqdm(total=len(val_dataloader), desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} (Validation)") as pbar:
                for i, (images, labels) in enumerate(val_dataloader):
                    # 将图像数据转换为 float32 类型
                    images = images.to(DEVICE).float()
                    labels = labels.to(DEVICE)
                    t = torch.randint(0, TIMESTEPS, (images.shape[0],)).to(DEVICE)
                    noisy_images, noise = scheduler.add_noise(images, t)
                    pred_noise = model(noisy_images, t, labels)
                    loss = criterion(pred_noise, noise)
                    val_loss += loss.item()
                    pbar.set_postfix({
                        'Val Loss': f'{val_loss / (i + 1):.4f}',
                    })
                    pbar.update(1)

        # 更新学习率
        lr_scheduler.step()

        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

# 6. 采样生成图像,采用 DDIM 采样策略,添加分类引导
@torch.no_grad()
def sample(model, scheduler, num_images=16, eta=0.0, class_labels=None):
    """
    使用 DDIM 采样策略生成图像,支持分类引导。
    :param model: UNet 模型
    :param scheduler: 噪声调度器
    :param num_images: 生成图像的数量
    :param eta: DDIM 采样的超参数
    :param class_labels: 类别标签
    :return: 生成的图像
    """
    model.eval()
    x = torch.randn((num_images, 1, 28, 28)).to(DEVICE)
    # 使用 torch.flip 实现逆序排列
    timesteps = torch.flip(torch.arange(0, TIMESTEPS), dims=[0]).to(DEVICE)
    for i, t in enumerate(timesteps):
        t_tensor = torch.full((num_images,), t, device=DEVICE)

        if class_labels is not None:
            # 无引导的预测
            unguided_pred_noise = model(x, t_tensor, y=None)
            # 有引导的预测
            guided_pred_noise = model(x, t_tensor, y=class_labels)
            # 分类引导
            pred_noise = unguided_pred_noise + CLASSIFIER_GUIDANCE_SCALE * (guided_pred_noise - unguided_pred_noise)
        else:
            pred_noise = model(x, t_tensor)

        alpha = scheduler.alphas[t].to(DEVICE)
        alpha_cumprod = scheduler.alpha_cumprod[t].to(DEVICE)
        alpha_cumprod_prev = scheduler.alpha_cumprod[timesteps[i + 1]] if i < len(timesteps) - 1 else torch.tensor(1.0).to(DEVICE)

        sigma = eta * torch.sqrt((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) * torch.sqrt(1 - alpha / alpha)
        pred_x0 = (x - torch.sqrt(1 - alpha_cumprod) * pred_noise) / torch.sqrt(alpha_cumprod)
        dir_xt = torch.sqrt(1 - alpha_cumprod_prev - sigma ** 2) * pred_noise
        noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
        x = torch.sqrt(alpha_cumprod_prev) * pred_x0 + dir_xt + sigma * noise

    # 将生成图像转换到 [0, 1] 范围
    x = (x.clamp(-1, 1) + 1) / 2
    return x

# 7. 主函数
def main():
    # 准备数据
    train_dataloader, val_dataloader = prepare_data()

    # 初始化模型和优化器
    model, optimizer, scheduler, lr_scheduler = initialize_model()
    model.load_state_dict(torch.load('best_model.pth'))

    # 训练模型
    train_model(model, optimizer, scheduler, lr_scheduler, train_dataloader, val_dataloader)

    # 加载最佳模型
    model.load_state_dict(torch.load('best_model.pth'))

    # 选择要生成的类别
    class_labels = torch.tensor([i % 10 for i in range(16)], device=DEVICE)

    # 生成并显示图像
    generated = sample(model, scheduler, class_labels=class_labels)

    # 获取原始图像
    original_images, original_labels = next(iter(val_dataloader))
    selected_indices = []
    for label in class_labels.cpu().tolist():
        index = (original_labels == label).nonzero(as_tuple=True)[0][0].item()
        selected_indices.append(index)
    original_images = original_images[selected_indices].to(DEVICE).float()
    original_images = (original_images.clamp(-1, 1) + 1) / 2

    # 创建一个大的图像来显示原始图像和生成图像
    fig, axes = plt.subplots(8, 4, figsize=(8, 16))

    for i in range(16):
        # 显示原始图像
        axes[i // 2, i % 2 * 2].imshow(original_images[i].squeeze().cpu().numpy(), cmap='gray')
        axes[i // 2, i % 2 * 2].axis('off')
        if i % 2 == 0:
            axes[i // 2, i % 2 * 2].set_title('Original')

        # 显示生成图像
        axes[i // 2, i % 2 * 2 + 1].imshow(generated[i].squeeze().cpu().numpy(), cmap='gray')
        axes[i // 2, i % 2 * 2 + 1].axis('off')
        if i % 2 == 0:
            axes[i // 2, i % 2 * 2 + 1].set_title('Generated')

总结

扩散模型通过前向过程破坏数据,反向过程学习去噪,最终从噪声生成数据。其数学核心在于:

  1. 前向闭合解的高效计算,
  2. 反向过程的参数化与损失简化,
  3. 采样过程的迭代去噪。

代码实现时,需注意噪声调度的设计、时间步的嵌入方式,以及采样时的数值稳定性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值