Denoising Diffusion Probabilistic Models (DDPM)代码解析
Paper:地址
代码:https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-
前言
本文主要梳理以训练CIFAR-10数据集【包含10个类别的60000个32x32彩色图像】为例,DDPM与DDPM+condition的pytorch代码训练流程,模型搭建过程。细节方面解释会在代码中标注释,建议先自行查看一下代码流程。
理论大概
diffusion model是一个参数化的马尔可夫链,它使用变分推理进行训练,在有限时间后产生与数据匹配的样本。如上图所示,真实数据
x
0
x_0
x0经过T步正态分布的加噪过程得到
x
T
x_T
xT,
x
T
x_T
xT服从标准正态分布
N
(
0
,
I
)
N(0,I)
N(0,I)。而模型学习逆向的转移概率
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(x_{t-1}|x_t)
pθ(xt−1∣xt),从标准正态分布中采样的
x
T
x_T
xT开始,逐步生成真实数据。训练目标则是生成的数据分布和真实数据分布尽可能相似。
这里不描述过多的公式推导过程,若对推导过程感兴趣,可查看 https://kxz18.github.io/2022/06/19/Diffusion/
前向过程
目标:通过不断加噪音将复杂的真实数据分布转换为简单的易处理的分布。
根据方差表
β
1
,
.
.
.
,
β
t
\beta_1, ... ,\beta_t
β1,...,βt逐渐向数据添加高斯噪声
q ( x 1 : T ∣ x 0 ) : = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(x_{1:T}|x_0) :=\prod_{t=1}^{T}q(x_t|x_{t-1}) q(x1:T∣x0):=∏t=1Tq(xt∣xt−1)
q ( x t ∣ x t − 1 ) : = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t|x_{t-1}) :=N(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_tI ) q(xt∣xt−1):=N(xt;1−βtxt−1,βtI)
逆向过程
扩散模型是形式为 p θ ( x 0 ) : = ∫ p θ ( x 0 : T ) d x 1 : T p_\theta(x_0) :=\int p_\theta(x_{0:T})dx_{1:T} pθ(x0):=∫pθ(x0:T)dx1:T的潜变量模型,其中 x 1 , . . . , x T x_1, ... ,x_T x1,...,xT是与数据 x 0 q ( x 0 ) x_0 ~ q(x_0) x0 q(x0)具有相同维数的隐变量。联合分布 p θ ( x 0 : T ) p_\theta(x_{0:T}) pθ(x0:T)称为反向过程,它被定义为一个具有从 p ( x T ) = N ( x T ; 0 , I ) p(x_T) = N(x_T;0,I) p(xT)=N(xT;0,I)高斯分布可学习的马尔可夫链。
p θ ( x 0 : T : = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p_\theta(x_{0:T} := p(x_T)\prod_{t=1}^{T}p_\theta(x_{t-1}|x_t) pθ(x0:T:=p(xT)∏t=1Tpθ(xt−1∣xt)
p θ ( x t − 1 ∣ x t ) : = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(x_{t-1}|x_t) :=N(x_{t-1};\mu_\theta(x_t,t),\sum_{\theta}(x_t,t)) pθ(xt−1∣xt):=N(xt−1;μθ(xt,t),∑θ(xt,t))
训练目标是通过模型生成数据分布 p θ ( x 0 ) p_\theta(x_0) pθ(x0)与真实数据分布 q θ ( x 0 ) q_\theta(x_0) qθ(x0)尽可能相近(训练unet模型使得每一步生成的噪声尽可能与加噪q相似)
训练流程
diffusion 代码
1、加载数据集
#加载pytorch中内置CIFAR-10数据集的类,指定root路径,加载train训练集
#对加载的图像首先进行随机水平翻转(RandomHorizontalFlip)
#然后将图像转换为张量(ToTensor),最后对图像进行标准化(Normalize)
dataset = CIFAR10(
root='./CIFAR10', train=True, download=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
#创建数据加载器的类,批量加载数据
dataloader = DataLoader(
dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
#batch_size: 指定每个小批量的样本数量。
#shuffle=True: 指定是否在每个epoch之前打乱数据顺序,以增加训练的随机性。
#num_workers: 指定用于数据加载的子进程数量。这可以加快数据加载速度,因为数据预处理和加载可以与模型训练并行进行。
#drop_last=True: 如果数据集的样本数量不能被批量大小整除,设置为True会丢弃最后一个不完整的批次。
#pin_memory=True: 如果可用,将数据加载到CUDA固定内存中,可以加速数据传输到GPU。
2、设置模型参数
#提前设置的modelConfig
modelConfig = {
"state": "train", # or eval
"epoch": 200,
"batch_size": 100,
"T": 1000,
"channel": 128,
"channel_mult": [1, 2, 3, 4],
"attn": [2],
"num_res_blocks": 2,
"dropout": 0.15,
"lr": 1e-4,
"multiplier": 2.,
"beta_1": 1e-4,
"beta_T": 0.02,
"img_size": 32,
"grad_clip": 1.,
"device": "cuda:1", ### MAKE SURE YOU HAVE A GPU !!!
"training_load_weight": None,
"save_weight_dir": "./Checkpoints/",
"test_load_weight": "ckpt_199_.pt",
"sampled_dir": "./SampledImgs/",
"sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
"sampledImgName": "SampledNoGuidenceImgs.png",
"nrow": 8
}
# unet,1000步,channel 128,channel层数[1,2,3,4],attn [2],res_blocks 2,根据这些设置搭建后面的unet模型
net_model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
if modelConfig["training_load_weight"] is not None:
net_model.load_state_dict(torch.load(os.path.join(
modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
#adamw 优化器,Adam优化器的变种,具有更好的收敛性能
optimizer = torch.optim.AdamW(net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
#PyTorch中余弦退火学习率调度器,在训练过程中动态地调整学习率
#T_max指定了余弦退火的周期(即一个epoch的数量),eta_min是学习率的下限,last_epoch指定了上一个epoch的索引
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
#GradualWarmupScheduler自定义的学习率调度器,在训练开始时逐渐增加学习率,以帮助模型更快地收敛到合适的区域。
#multiplier指定了初始学习率相对于设定初始学习率的倍数
#warm_epoch指定了预热阶段的epoch数量,after_scheduler指定了预热阶段结束后要使用的学习率调度器。
warmUpScheduler = GradualWarmupScheduler(
optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
#GaussianDiffusionTrainer,训练器,传入了模型net_model以及beta_1、beta_T和T
trainer = GaussianDiffusionTrainer(
net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
3、训练过程
# start training
for e in range(modelConfig["epoch"]):
with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
for images, labels in tqdmDataLoader:
# train
optimizer.zero_grad()
x_0 = images.to(device)
loss = trainer(x_0).sum() / 1000. #传入模型计算loss
loss.backward() #计算损失关于模型参数的梯度。
torch.nn.utils.clip_grad_norm_(
net_model.parameters(), modelConfig["grad_clip"]) #对梯度进行裁剪,以防止梯度爆炸。
optimizer.step() #计算出的梯度更新模型参数
tqdmDataLoader.set_postfix(ordered_dict={
"epoch": e,
"loss: ": loss.item(),
"img shape: ": x_0.shape,
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
}) #更新进度条的显示信息,包括当前epoch数、损失值、当前处理的图像形状以及当前的学习率。
warmUpScheduler.step() #更新学习率调度器,用于在预热阶段结束后切换到余弦退火调度器。
torch.save(net_model.state_dict(), os.path.join(
modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt"))
4、GaussianDiffusion模型
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
#创建一个名为betas的缓冲区,其中保存了从beta_1到beta_T等间距的参数
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
def forward(self, x_0):
"""
输入x_0 计算最终噪声x_T,将x_T和T步传入Unet
"""
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)#根据公式计算x_t
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
5、Unet模型搭建流程
class UNet(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
super().__init__()
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
tdim = ch * 4
#TimeEmbedding模块按照sin,cos的计算方式计算位置编码。
self.time_embedding = TimeEmbedding(T, ch, tdim)
#搭建下采样模块,torch.Size([80, 128, 32, 32])——>torch.Size([80, 512, 4, 4])
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
#中间 两个res模块
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])
#搭建上采样模块
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)
def forward(self, x, t):
# Timestep embedding
temb = self.time_embedding(t)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb)
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb)
h = self.tail(h)
assert len(hs) == 0
return h
diffusion condition代码
相比于diffusion代码,diffusion condition在训练时加入了label模块
1、训练过程
for e in range(modelConfig["epoch"]):
with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
for images, labels in tqdmDataLoader:
# train
b = images.shape[0]
optimizer.zero_grad()
x_0 = images.to(device)
#labels
labels = labels.to(device) + 1
if np.random.rand() < 0.1:
labels = torch.zeros_like(labels).to(device)
loss = trainer(x_0, labels).sum() / b ** 2.
loss.backward()
torch.nn.utils.clip_grad_norm_(
net_model.parameters(), modelConfig["grad_clip"])
optimizer.step()
2、相应模型搭建部分调整
对于 GaussianDiffusionTrainer部分,
x
0
x_0
x0正态分布计算为
x
T
x_T
xT部分不变,unet模块搭建相应改变
增加了一个条件编码
class UNet(nn.Module):
def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
super().__init__()
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)
self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
def forward(self, x, t, labels):
# Timestep embedding
temb = self.time_embedding(t)
cemb = self.cond_embedding(labels)#条件编码
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb, cemb)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb, cemb)
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb, cemb)
h = self.tail(h)
assert len(hs) == 0
return h