前言
一致性模型的Key idea:
Learn a model that maps any arbitrary point in the latent space to the initial data point, i.e: if points lie on the same probability flow trajectory they are mapped to the same initial data point.
一、EMA
为什么EMA在测试过程中使用通常能提升模型表现?
指数滑动平均(exponentially weighted moving average)可以使模型在测试数据上更健壮(robust)。“采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。”
对神经网络边的权重 weights 使用滑动平均,得到对应的影子变量shadow_weights。在训练过程仍然使用原来不带滑动平均的权重 weights,以得到 weights 下一步更新的值,进而求下一步 weights 的影子变量 shadow_weights。之后在测试过程中使用shadow_weights 来代替 weights 作为神经网络边的权重,这样在测试数据上效果更好。因为 shadow_weights 的更新更加平滑,对于:
- 随机梯度下降,更平滑的更新说明不会偏离最优点很远;
- 梯度下降 batch gradient decent,影子变量作用可能不大,因为梯度下降的方向已经是最优的了,loss 一定减小;
- mini-batch gradient decent,可以尝试滑动平均,因为mini-batch gradient decent 对参数的更新也存在抖动。
举例来说,设decay=0.999decay=0.999,直观理解,在最后的1000次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的1000次抖动进行了平均,这样得到的权重会更加robust。
二、Consistency Models[1]
Definition
Given a diffusion trajectory x t ∈ [ ϵ , T ] x_{t \in \left[\epsilon, T\right]} xt∈[ϵ,T], we define a consistency function f : ( x t , t ) → x ϵ f : \left(x_t, t\right) \rightarrow x_{\epsilon} f:(xt,t)→xϵ.
We can then train a consistency model f θ ( . , . ) f_{\theta}\left(., . \right) fθ(.,.) to approximate the consistency function. A property of the consistency function is that f : ( x ϵ , ϵ ) → x ϵ f : \left(x_{\epsilon}, \epsilon \right) \rightarrow x_{\epsilon} f:(xϵ,ϵ)→xϵ. To achieve this we parameterize the consistency model using skip connections as in [2]
f θ ( x t , t ) = c s k i p ( t ) x t + c o u t ( t ) F θ ( x t , t ) f_{\theta}\left(x_{t}, t \right) = c_{skip}\left(t \right)x_{t} + c_{out}\left(t \right)F_{\theta}\left(x_{t}, t \right) fθ(xt,t)=cskip(t)xt+cout(t)Fθ(xt,t)
where c s k i p ( ϵ ) = 1 c_{skip}\left(\epsilon \right) = 1 cskip(ϵ)=1 and c o u t ( ϵ ) = 0 c_{out}\left(\epsilon \right) = 0 cout(ϵ)=0 and F θ ( . , . ) F_{\theta}\left(.,.\right) Fθ(.,.) is the neural network.
1. Sampling
Starting from an initial random noise x ^ T ∼ N ( 0 , T 2 I ) \hat{x}_{T} \sim \mathcal{N}(0, T^2I) x^T∼N(0,T2I), the consistency model can be used to sample a point in a single step: x ^ ϵ = f θ ( x T , T ) \hat{x}_{\epsilon} = f_{\theta}(x_{T}, T) x^ϵ=fθ(xT,T). For iterative refinement, the following algorithm can be used:
# Generate an initial sample from the initial random noise
sample = consistency_model(x_T, T)
sample = clamp?(sample)
for t in timesteps:
noise = standard_gaussian_noise()
noisy_sample = sample + square_root(square(t) - square(ϵ)) * noise
sample = consistency_model(noisy_sample, t)
sample = clamp?(sample)
where consistency_model
=
f
θ
(
.
,
.
)
= f_{\theta}\left(.,.\right)
=fθ(.,.),
clamp?
is a function that optionally clips values to a given range and timesteps
=
[
N
−
1
,
…
,
ϵ
]
= \left[N-1, \dots, \epsilon \right]
=[N−1,…,ϵ]
2. Training
To train the model follow the following algorithm:
for step in range(total_steps):
data = data_distribution()
noise = standard_gaussian_noise()
timestep = uniform_distribution(start=1, end=timestep_schedule(step)-1)
current_noisy_data = data + timestep * noise
next_noisy_data = data + (timestep + 1) * noise
loss = distance_metric(consistency_model(next_noisy_data, timestep + 1), ema_consistency_model(current_noisy_data, timestep))
loss.backward()
with no_grad():
ema_consistency_model_params = ema_decay_schedule(step) * ema_consistency_model_params + (1 - ema_decay_schedule(step)) * consistency_model_params
对于距离度量的损失函数可以用感知相似度LPIPS,L1 loss,MSE loss
实验结果
Single-Step Generation:
ImageNet 64 x 64 (FID:6.37)
LSUN Bedroom 256 x 256 (FID:8.26)
一致性模型作为扩散模型新成员,从本质上解决了ddpm的实时性难题,生成质量上优于单步非对抗性生成模型。虽然生成质量不如GAN,但是好在可以权衡质量和时间以达到满意的效果。
总结
这周主要是在搞懂Consistency Model原理的基础上,复现图像生成任务的结果,动手拆解各个模块的代码,研究各个超参数在论文里提到的CD和CT两种训练方法下对模型性能的影响。想法是尝试应用到图像生成之外的领域,看看能否应用到text2image或image2image等下游任务。
参考论文
[1] Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. arXiv preprint arXiv:2303.01469.
[2] Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the design space of diffusion-based generative models. arXiv preprint arXiv:2206.00364.