本文来源公众号“天才程序员周弈帆”,仅用于学术分享,侵权删,干货满满。
原文链接:Stable Diffusion 解读(三):原版实现源码解读
天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作-优快云博客
天才程序员周弈帆 | Stable Diffusion 解读(二):论文精读-优快云博客
看完了Stable Diffusion的论文,在最后这几篇文章里,我们来学习Stable Diffusion的代码实现。具体来说,我们会学习Stable Diffusion官方仓库及Diffusers开源库中有关采样算法和U-Net的代码,而不会学习有关训练、VAE、text encoder (CLIP) 的代码。如今大多数工作都只会用到预训练的Stable Diffusion,只学采样算法和U-Net代码就能理解大多数工作了。
受字数限制,Diffusers的介绍会放到下一篇文章里。
建议读者在阅读本文之前了解DDPM、ResNet、U-Net、Transformer。
本文用到的Stable Diffusion版本是v1.5。Diffusers版本是0.25.0。为了提升可读性,本文对源代码做了一定的精简,部分不会运行到的分支会被略过。
1 算法梳理
在正式读代码之前,我们先用伪代码梳理一下Stable Diffusion的采样过程,并回顾一下U-Net架构的组成。实现Stable Diffusion的代码库有很多,各个库之间的API差异很大。但是,它们实际上都是在描述同一个算法,同一个模型。如果我们理解了算法和模型本身,就可以在学习时主动去找一个算法对应哪一段代码,而不是被动地去理解每一行代码在干什么。
1.1 LDM 采样算法
让我们从最早的DDPM开始,一步一步还原Latent Diffusion Model (LDM)的采样算法。DDPM的采样算法如下所示:
def ddpm_sample(image_shape):
ddpm_scheduler = DDPMScheduler()
unet = UNet()
xt = randn(image_shape)
T = 1000
for t in T ... 1:
eps = unet(xt, t)
std = ddpm_scheduler.get_std(t)
xt = ddpm_scheduler.get_xt_prev(xt, t, eps, std)
return xt
在DDPM的实现中,一般会有一个类专门维护扩散模型的alpha,beta等变量。我们这里把这个类称为DDPMScheduler
。此外,DDPM会用到一个U-Net神经网络unet
,用于计算去噪过程中图像应该去除的噪声eps
。准备好这两个变量后,就可以用randn()
从标准正态分布中采样一个纯噪声图像xt
。它会被逐渐去噪,最终变成一幅图片。去噪过程中,时刻t
会从总时刻T
遍历至1
(总时刻T
一般取1000
)。在每一轮去噪步骤中,U-Net会根据这一时刻的图像xt
和当前时间戳t
估计出此刻应去除的噪声eps
,根据xt
和eps
就能知道下一步图像的均值。除了均值,我们还要获取下一步图像的方差,这一般可以从DDPM调度类中直接获取。有了下一步图像的均值和方差,我们根据DDPM的公式,就能采样出下一步的图像。反复执行去噪循环,xt
会从纯噪声图像变成一幅有意义的图像。
DDIM对DDPM的采样过程做了两点改进:1) 去噪的有效步数可以少于T
步,由另一个变量ddim_steps
决定;2) 采样的方差大小可以由eta
决定。
因此,改进后的DDIM算法可以写成这样:
def ddim_sample(image_shape, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
unet = UNet()
xt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
for t in timesteps:
eps = unet(xt, t)
std = ddim_scheduler.get_std(t, eta)
xt = ddim_scheduler.get_xt_prev(xt, t, eps, std)
return xt
其中,ddim_steps
是去噪循环的执行次数。根据ddim_steps
,DDIM调度器可以生成所有被使用到的t
。比如对于T=1000, ddim_steps=20
,被使用到的就只有[1000, 950, 900, ..., 50]
这20个时间戳,其他时间戳就可以跳过不算了。eta
会被用来计算方差,一般这个值都会设成0
。
DDIM是早期的加速扩散模型采样的算法。如今有许多比DDIM更好的采样方法,但它们多数都保留了
steps
和eta
这两个参数。因此,在使用所有采样方法时,我们可以不用关心实现细节,只关注多出来的这两个参数。
在DDIM的基础上,LDM从生成像素空间上的图像变为生成隐空间上的图像。隐空间图像需要再做一次解码才能变回真实图像。从代码上来看,使用LDM后,只需要多准备一个VAE,并对最后的隐空间图像zt
解码。
def ldm_ddim_sample(image_shape, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
for t in timesteps:
eps = unet(zt, t)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt
而想用LDM实现文生图,则需要给一个额外的文本输入text
。文本编码器会把文本编码成张量c
,输入进unet
。其他地方的实现都和之前的LDM一样。
def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0):
ddim_scheduler = DDIMScheduler()
vae = VAE()
unet = UNet()
zt = randn(image_shape)
T = 1000
timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
text_encoder = CLIP()
c = text_encoder.encode(text)
for t = timesteps:
eps = unet(zt, t, c)
std = ddim_scheduler.get_std(t, eta)
zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
xt = vae.decoder.decode(zt)
return xt
最后这个能实现文生图的LDM就是我们熟悉的Stable Diffusion。Stable Diffusion的采样算法看上去比较复杂,但如果能够从DDPM开始把各个功能都拆开来看,理解起来就不是那么困难了。
1.2 U-Net 结构组成
Stable Diffusion代码实现中的另一个重点是去噪网络U-Net的实现。仿照上一节的学习方法,我们来逐步学习Stable Diffusion中的U-Net是怎么从最经典的纯卷积U-Net逐渐发展而来的。
最早的U-Net的结构如下图所示:
可以看出,U-Net的结构有以下特点:
-
整体上看,U-Net由若干个大层组成。特征在每一大层会被下采样成尺寸更小的特征,再被上采样回原尺寸的特征。整个网络构成一个U形结构。
-
下采样后,特征的通道数会变多。一般情况下,每次下采样后图像尺寸减半,通道数翻倍。上采样过程则反之。
-
为了防止信息在下采样的过程中丢失,U-Net每一大层在下采样前的输出会作为额外输入拼接到每一大层上采样前的输入上。这种数据连接方式类似于ResNet中的「短路连接」。
DDPM则使用了一种改进版的U-Net。改进主要有两点:
-
原来的卷积层被替换成了ResNet中的残差卷积模块。每一大层有若干个这样的子模块。对于较深的大层,残差卷积模块后面还会接一个自注意力模块。
-
原来模型每一大层只有一个短路连接。现在每个大层下采样部分的每个子模块的输出都会额外输入到其对称的上采样部分的子模块上。直观上来看,就是短路连接更多了一点,输入信息更不容易在下采样过程中丢失。
最后,LDM提出了一种给U-Net添加额外约束信息的方法:把U-Net中的自注意力模块换成交叉注意力模块。具体来说,DDPM的U-Net的自注意力模块被换成了标准的Transformer模块。约束信息可以作为Cross Attention的K, V输入进模块中。
Stable Diffusion的U-Net还在结构上有少许修改,该U-Net的每一大层都有Transformer块,而不是只有较深的大层有。
至此,我们已经学完了Stable Diffusion的采样原理和U-Net结构。接下来我们来看一看它们在不同框架下的代码实现。
2 Stable Diffusion 官方 GitHub 仓库
2.1 安装
克隆仓库后,照着官方Markdown文档安装即可。
git clone git@github.com:CompVis/stable-diffusion.git
先用下面的命令创建conda环境,此后ldm
环境就是运行Stable Diffusiion的conda环境。
conda env create -f environment.yaml
conda activate ldm
之后去网上下一个Stable Diffusion的模型文件。比较常见一个版本是v1.5,该模型在Hugging Face上:https://huggingface.co/runwayml/stable-diffusion-v1-5 (推荐下载v1-5-pruned.ckpt
)。下载完毕后,把模型软链接到指定位置。
mkdir -p models/ldm/stable-diffusion-v1/
ln -s <path/to/model.ckpt> models/ldm/stable-diffusion-v1/model.ckpt
准备完毕后,只要输入下面的命令,就可以生成实现文生图了。
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse"
在默认的参数下,“一幅骑着马的飞行员的照片”的绘制结果会被保存在outputs/txt2img-samples
中。你也可以通过--outdir <dir>
参数来指定输出到的文件夹。我得到的一些绘制结果为:
【说明】如果你在安装时碰到了错误,可以在搜索引擎上或者GitHub的issue里搜索,一般都能搜到其他人遇到的相同错误。
2.2 主函数
接下来,我们来探究一下scripts/txt2img.py
的执行过程。为了方便阅读,我们可以简化代码中的命令行处理,得到下面这份精简代码。(你可以把这份代码复制到仓库根目录下的一个新Python脚本里并直接运行。别忘了修改代码中的模型路径)
import os
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from pytorch_lightning import seed_everything
from torch import autocast
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
def main():
seed = 42
config = 'configs/stable-diffusion/v1-inference.yaml'
ckpt = 'ckpt/v1-5-pruned.ckpt'
outdir = 'tmp'
n_samples = batch_size = 3
n_rows = batch_size
n_iter = 2
prompt = 'a photograph of an astronaut riding a horse'
data = [batch_size * [prompt]]
scale = 7.5
C = 4
f = 8
H = W = 512
ddim_steps &