前言
本博客采用代码:https://github.com/lucidrains/denoising-diffusion-pytorch
本博客将分成多节,详细阐述DDPM in Pytorch 的代码细节,如有缺漏、错误,请多指教
根据Github中的第二个实现,即:
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
flash_attn = True
)
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000, # number of steps
sampling_timesteps = 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)
trainer = Trainer(
diffusion,
'path/to/your/images',
train_batch_size = 32,
train_lr = 8e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = True, # turn on mixed precision
calculate_fid = True # whether to calculate fid during training
)
trainer.train()
其中的Unet类为自定义类,继承自 PyTorch 的 Module 类。有如下特点:
- 多尺度特征提取:通过下采样和上采样路径,提取不同尺度的特征。
- 残差连接(ResNet Blocks):使用残差块来加深网络,同时缓解梯度消失问题。
- 注意力机制:在某些层引入注意力机制,提升模型的表达能力。
- 时间嵌入(Time Embedding):处理时间步的信息,可能用于扩散模型或时序数据。
- 自条件(Self-Conditioning):在输入中包含模型自身的输出,以增强模型性能。
下面将对代码进行详细解析,包括各个组件和前向传播过程。
1. 类的定义和初始化
class Unet(Module):
def __init__(self, ...):
super().__init__()
...
dim:基础维度,决定了网络中各层的通道数基数。
init_dim:初始卷积层的输出维度,默认为 dim。
out_dim:输出层的维度,默认为 channels。
dim_mults:一个元组,指定每个阶段的通道数倍增系数。
channels:输入图像的通道数,默认为 3(RGB 图像)。
self_condition:是否使用自条件,将模型的预测作为下一次输入的一部分。
learned_variance:是否学习方差,可能用于扩散模型的反向过程。
learned_sinusoidal_cond 和 random_fourier_features:用于时间嵌入的选项,决定是否使用学习的正弦嵌入或随机傅里叶特征。
learned_sinusoidal_dim:学习的正弦嵌入的维度。
sinusoidal_pos_emb_theta:正弦位置嵌入的尺度参数。
dropout:在残差块中的 dropout 率。
attn_dim_head:注意力机制中每个头的维度。
attn_heads:注意力机制中的头数。
full_attn:指定在哪些层使用全局注意力,默认为最内层。
flash_attn:是否使用 Flash Attention,提高注意力计算的效率。
2. 模型组件的初始化
2.1 输入卷积层
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding=3)
输入通道数:input_channels,如果使用自条件,则为 channels * 2,否则为 channels。
卷积核大小:7x7,大感受野,有助于捕获全局信息。
填充:padding=3,保持输出尺寸与输入一致。
2.2 通道数设置
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
dims:每个阶段的通道数,初始为 init_dim,之后按 dim_mults 倍增。
in_out:每个阶段的输入和输出通道数对。
2.3 时间嵌入
time_dim = dim * 4
时间嵌入维度:将时间步编码为高维向量,方便模型利用时间信息。
2.3.1 时间嵌入方式
if self.random_or_learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim, theta=sinusoidal_pos_emb_theta)
fourier_dim = dim
选择嵌入方式:
随机或学习的正弦嵌入:如果 learned_sinusoidal_cond 或 random_fourier_features 为真。
固定的正弦嵌入:否则,使用固定的正弦位置嵌入。
2.3.2 时间嵌入的 MLP
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
嵌入层:将时间步编码为向量。
MLP:两层全连接网络,激活函数为 GELU,进一步处理时间嵌入。
2.4 注意力机制的设置
if not full_attn:
full_attn = (*((False,) * (len(dim_mults) - 1)), True)
num_stages = len(dim_mults)
full_attn = cast_tuple(full_attn, num_stages)
attn_heads = cast_tuple(attn_heads, num_stages)
attn_dim_head = cast_tuple(attn_dim_head, num_stages)
assert len(full_attn) == len(dim_mults)
full_attn:指定哪些阶段使用全局注意力,默认只有最内层(最后一个)使用。
cast_tuple:将参数转换为与阶段数匹配的元组。
2.5 定义残差块和注意力类型
FullAttention = partial(Attention, flash=flash_attn)
resnet_block = partial(ResnetBlock, time_emb_dim=time_dim, dropout=dropout)
FullAttention:注意力机制类,可能使用 Flash Attention。
resnet_block:残差块的部分应用,固定了时间嵌入维度和 dropout 率。
3. 构建网络层次
3.1 下采样路径(Encoder)
self.downs = ModuleList([])
...
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
is_last = ind >= (num_resolutions - 1)
attn_klass = FullAttention if layer_full_attn else LinearAttention
self.downs.append(ModuleList([
resnet_block(dim_in, dim_in),
resnet_block(dim_in, dim_in),
attn_klass(dim_in, dim_head=layer_attn_dim_head, heads=layer_attn_heads),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1)
]))
self.downs:保存下采样路径的模块列表,每个阶段包含以下模块:
残差块1:输入维度 dim_in。
残差块2:输入维度 dim_in。
注意力层:根据 layer_full_attn 选择全局注意力或线性注意力。
下采样层:使用 Downsample 模块或卷积实现,输出维度 dim_out。
3.2 中间层(Bottleneck)
mid_dim = dims[-1]
self.mid_block1 = resnet_block(mid_dim, mid_dim)
self.mid_attn = FullAttention(mid_dim, heads=attn_heads[-1], dim_head=attn_dim_head[-1])
self.mid_block2 = resnet_block(mid_dim, mid_dim)
中间残差块和注意力层:连接编码器和解码器,处理最深层的特征。
3.3 上采样路径(Decoder)
self.ups = ModuleList([])
...
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
is_last = ind == (len(in_out) - 1)
attn_klass = FullAttention if layer_full_attn else LinearAttention
self.ups.append(ModuleList([
resnet_block(dim_out + dim_in, dim_out),
resnet_block(dim_out + dim_in, dim_out),
attn_klass(dim_out, dim_head=layer_attn_dim_head, heads=layer_attn_heads),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1)
]))
self.ups:保存上采样路径的模块列表,每个阶段包含:
残差块1:输入维度为 dim_out + dim_in(因为有跳跃连接)。
残差块2:输入维度为 dim_out + dim_in。
注意力层:同样根据 layer_full_attn 选择。
上采样层:使用 Upsample 模块或卷积实现,输出维度 dim_in。
3.4 输出层
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = resnet_block(init_dim * 2, init_dim)
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)
default_out_dim:如果不学习方差,输出维度为 channels,否则为 channels * 2。
self.final_res_block:最终的残差块,输入维度为 init_dim * 2(因为有跳跃连接)。
self.final_conv:1x1 卷积,将通道数调整为 self.out_dim。
4. 前向传播过程
def forward(self, x, time, x_self_cond=None):
...
x:输入图像张量。
time:时间步,用于时间嵌入。
x_self_cond:自条件输入,默认为 None。
4.1 输入处理
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim=1)
自条件处理:如果启用自条件,将 x_self_cond 与 x 在通道维度上拼接。
4.2 初始卷积
x = self.init_conv(x)
r = x.clone()
r:保存初始特征,用于后面的跳跃连接。
4.3 时间嵌入
t = self.time_mlp(time)
t:时间步的嵌入向量。
4.4 下采样路径
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x) + x
h.append(x)
x = downsample(x)
遍历每个下采样阶段:
block1:残差块1,输入 x 和时间嵌入 t。
h.append(x):将输出保存到列表 h,用于跳跃连接。
block2:残差块2,同样输入 x 和 t。
attn:注意力层,处理 x,并与输入 x 相加(残差连接)。
再次保存 x 到 h。
downsample:下采样层,减小空间尺寸,增加通道数。
4.5 中间层
x = self.mid_block1(x, t)
x = self.mid_attn(x) + x
x = self.mid_block2(x, t)
mid_block1:残差块。
mid_attn:全局注意力层,处理 x,并残差连接。
mid_block2:残差块。
4.6 上采样路径
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim=1)
x = block2(x, t)
x = attn(x) + x
x = upsample(x)
遍历每个上采样阶段:
拼接跳跃连接:从 h 中弹出对应的特征,与 x 在通道维度上拼接。
block1:残差块,输入拼接后的特征和时间嵌入 t。
重复拼接和残差块:再次拼接特征,经过 block2。
attn:注意力层,处理 x,并残差连接。
upsample:上采样层,增加空间尺寸,减小通道数。
4.7 最终输出
x = torch.cat((x, r), dim=1)
x = self.final_res_block(x, t)
return self.final_conv(x)
与初始特征拼接:将 x 与初始的 r 拼接,形成完整的特征。
self.final_res_block:最终的残差块,处理拼接后的特征。
self.final_conv:通过 1x1 卷积,调整通道数,得到输出。
5. 模型实例化
model = Unet(
dim=64,
dim_mults=(1, 2, 4, 8),
flash_attn=False
)
dim=64:基础维度为 64。
dim_mults=(1, 2, 4, 8):通道数倍增系数,表示每个阶段的通道数为 64, 128, 256, 512。
flash_attn=False:不使用 Flash Attention。
后记
此次代码解析也只是略略而言,博主正在做相关的扩散模型代码解析视频,有需要请在评论区留言,如果留言人数较多,博主将加快更新,感谢大家支持!