【代码详解】用 Pytorch 实现DDPM(1. Unet 类)

前言

本博客采用代码:https://github.com/lucidrains/denoising-diffusion-pytorch
本博客将分成多节,详细阐述DDPM in Pytorch 的代码细节,如有缺漏、错误,请多指教

根据Github中的第二个实现,即:

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    flash_attn = True
)

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000,           # number of steps
    sampling_timesteps = 250    # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)

trainer = Trainer(
    diffusion,
    'path/to/your/images',
    train_batch_size = 32,
    train_lr = 8e-5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True,                       # turn on mixed precision
    calculate_fid = True              # whether to calculate fid during training
)

trainer.train()

其中的Unet类为自定义类,继承自 PyTorch 的 Module 类。有如下特点:

  1. 多尺度特征提取:通过下采样和上采样路径,提取不同尺度的特征。
  2. 残差连接(ResNet Blocks):使用残差块来加深网络,同时缓解梯度消失问题。
  3. 注意力机制:在某些层引入注意力机制,提升模型的表达能力。
  4. 时间嵌入(Time Embedding):处理时间步的信息,可能用于扩散模型或时序数据。
  5. 自条件(Self-Conditioning):在输入中包含模型自身的输出,以增强模型性能。

下面将对代码进行详细解析,包括各个组件和前向传播过程。

1. 类的定义和初始化

class Unet(Module):
    def __init__(self, ...):
        super().__init__()
        ...

dim:基础维度,决定了网络中各层的通道数基数。
init_dim:初始卷积层的输出维度,默认为 dim。
out_dim:输出层的维度,默认为 channels。
dim_mults:一个元组,指定每个阶段的通道数倍增系数。
channels:输入图像的通道数,默认为 3(RGB 图像)。
self_condition:是否使用自条件,将模型的预测作为下一次输入的一部分。
learned_variance:是否学习方差,可能用于扩散模型的反向过程。
learned_sinusoidal_cond 和 random_fourier_features:用于时间嵌入的选项,决定是否使用学习的正弦嵌入或随机傅里叶特征。
learned_sinusoidal_dim:学习的正弦嵌入的维度。
sinusoidal_pos_emb_theta:正弦位置嵌入的尺度参数。
dropout:在残差块中的 dropout 率。
attn_dim_head:注意力机制中每个头的维度。
attn_heads:注意力机制中的头数。
full_attn:指定在哪些层使用全局注意力,默认为最内层。
flash_attn:是否使用 Flash Attention,提高注意力计算的效率。

2. 模型组件的初始化

2.1 输入卷积层

self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding=3)

输入通道数:input_channels,如果使用自条件,则为 channels * 2,否则为 channels。
卷积核大小:7x7,大感受野,有助于捕获全局信息。
填充:padding=3,保持输出尺寸与输入一致。

2.2 通道数设置

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

dims:每个阶段的通道数,初始为 init_dim,之后按 dim_mults 倍增。
in_out:每个阶段的输入和输出通道数对。

2.3 时间嵌入

time_dim = dim * 4

时间嵌入维度:将时间步编码为高维向量,方便模型利用时间信息。

2.3.1 时间嵌入方式

if self.random_or_learned_sinusoidal_cond:
    sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
    fourier_dim = learned_sinusoidal_dim + 1
else:
    sinu_pos_emb = SinusoidalPosEmb(dim, theta=sinusoidal_pos_emb_theta)
    fourier_dim = dim

选择嵌入方式:
随机或学习的正弦嵌入:如果 learned_sinusoidal_cond 或 random_fourier_features 为真。
固定的正弦嵌入:否则,使用固定的正弦位置嵌入。

2.3.2 时间嵌入的 MLP

self.time_mlp = nn.Sequential(
    sinu_pos_emb,
    nn.Linear(fourier_dim, time_dim),
    nn.GELU(),
    nn.Linear(time_dim, time_dim)
)

嵌入层:将时间步编码为向量。
MLP:两层全连接网络,激活函数为 GELU,进一步处理时间嵌入。

2.4 注意力机制的设置

if not full_attn:
    full_attn = (*((False,) * (len(dim_mults) - 1)), True)

num_stages = len(dim_mults)
full_attn  = cast_tuple(full_attn, num_stages)
attn_heads = cast_tuple(attn_heads, num_stages)
attn_dim_head = cast_tuple(attn_dim_head, num_stages)

assert len(full_attn) == len(dim_mults)

full_attn:指定哪些阶段使用全局注意力,默认只有最内层(最后一个)使用。
cast_tuple:将参数转换为与阶段数匹配的元组。

2.5 定义残差块和注意力类型

FullAttention = partial(Attention, flash=flash_attn)
resnet_block = partial(ResnetBlock, time_emb_dim=time_dim, dropout=dropout)

FullAttention:注意力机制类,可能使用 Flash Attention。
resnet_block:残差块的部分应用,固定了时间嵌入维度和 dropout 率。

3. 构建网络层次

3.1 下采样路径(Encoder)

self.downs = ModuleList([])
...
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
    is_last = ind >= (num_resolutions - 1)
    attn_klass = FullAttention if layer_full_attn else LinearAttention
    self.downs.append(ModuleList([
        resnet_block(dim_in, dim_in),
        resnet_block(dim_in, dim_in),
        attn_klass(dim_in, dim_head=layer_attn_dim_head, heads=layer_attn_heads),
        Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1)
    ]))

self.downs:保存下采样路径的模块列表,每个阶段包含以下模块:
残差块1:输入维度 dim_in。
残差块2:输入维度 dim_in。
注意力层:根据 layer_full_attn 选择全局注意力或线性注意力。
下采样层:使用 Downsample 模块或卷积实现,输出维度 dim_out。

3.2 中间层(Bottleneck)

mid_dim = dims[-1]
self.mid_block1 = resnet_block(mid_dim, mid_dim)
self.mid_attn = FullAttention(mid_dim, heads=attn_heads[-1], dim_head=attn_dim_head[-1])
self.mid_block2 = resnet_block(mid_dim, mid_dim)

中间残差块和注意力层:连接编码器和解码器,处理最深层的特征。

3.3 上采样路径(Decoder)

self.ups = ModuleList([])
...
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
    is_last = ind == (len(in_out) - 1)
    attn_klass = FullAttention if layer_full_attn else LinearAttention
    self.ups.append(ModuleList([
        resnet_block(dim_out + dim_in, dim_out),
        resnet_block(dim_out + dim_in, dim_out),
        attn_klass(dim_out, dim_head=layer_attn_dim_head, heads=layer_attn_heads),
        Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1)
    ]))

self.ups:保存上采样路径的模块列表,每个阶段包含:
残差块1:输入维度为 dim_out + dim_in(因为有跳跃连接)。
残差块2:输入维度为 dim_out + dim_in。
注意力层:同样根据 layer_full_attn 选择。
上采样层:使用 Upsample 模块或卷积实现,输出维度 dim_in。

3.4 输出层

default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = resnet_block(init_dim * 2, init_dim)
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)

default_out_dim:如果不学习方差,输出维度为 channels,否则为 channels * 2。
self.final_res_block:最终的残差块,输入维度为 init_dim * 2(因为有跳跃连接)。
self.final_conv:1x1 卷积,将通道数调整为 self.out_dim。

4. 前向传播过程

def forward(self, x, time, x_self_cond=None):
    ...

x:输入图像张量。
time:时间步,用于时间嵌入。
x_self_cond:自条件输入,默认为 None。

4.1 输入处理

if self.self_condition:
    x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
    x = torch.cat((x_self_cond, x), dim=1)

自条件处理:如果启用自条件,将 x_self_cond 与 x 在通道维度上拼接。

4.2 初始卷积

x = self.init_conv(x)
r = x.clone()

r:保存初始特征,用于后面的跳跃连接。

4.3 时间嵌入

t = self.time_mlp(time)

t:时间步的嵌入向量。

4.4 下采样路径

h = []

for block1, block2, attn, downsample in self.downs:
    x = block1(x, t)
    h.append(x)

    x = block2(x, t)
    x = attn(x) + x
    h.append(x)

    x = downsample(x)

遍历每个下采样阶段:
block1:残差块1,输入 x 和时间嵌入 t。
h.append(x):将输出保存到列表 h,用于跳跃连接。
block2:残差块2,同样输入 x 和 t。
attn:注意力层,处理 x,并与输入 x 相加(残差连接)。
再次保存 x 到 h。
downsample:下采样层,减小空间尺寸,增加通道数。

4.5 中间层

x = self.mid_block1(x, t)
x = self.mid_attn(x) + x
x = self.mid_block2(x, t)

mid_block1:残差块。
mid_attn:全局注意力层,处理 x,并残差连接。
mid_block2:残差块。

4.6 上采样路径

for block1, block2, attn, upsample in self.ups:
    x = torch.cat((x, h.pop()), dim=1)
    x = block1(x, t)

    x = torch.cat((x, h.pop()), dim=1)
    x = block2(x, t)
    x = attn(x) + x

    x = upsample(x)

遍历每个上采样阶段:
拼接跳跃连接:从 h 中弹出对应的特征,与 x 在通道维度上拼接。
block1:残差块,输入拼接后的特征和时间嵌入 t。
重复拼接和残差块:再次拼接特征,经过 block2。
attn:注意力层,处理 x,并残差连接。
upsample:上采样层,增加空间尺寸,减小通道数。

4.7 最终输出

x = torch.cat((x, r), dim=1)
x = self.final_res_block(x, t)
return self.final_conv(x)

与初始特征拼接:将 x 与初始的 r 拼接,形成完整的特征。
self.final_res_block:最终的残差块,处理拼接后的特征。
self.final_conv:通过 1x1 卷积,调整通道数,得到输出。

5. 模型实例化

model = Unet(
    dim=64,
    dim_mults=(1, 2, 4, 8),
    flash_attn=False
)

dim=64:基础维度为 64。
dim_mults=(1, 2, 4, 8):通道数倍增系数,表示每个阶段的通道数为 64, 128, 256, 512。
flash_attn=False:不使用 Flash Attention。

后记

此次代码解析也只是略略而言,博主正在做相关的扩散模型代码解析视频,有需要请在评论区留言,如果留言人数较多,博主将加快更新,感谢大家支持!

### 关于 DDPM实现代码与教程 对于希望深入了解并实践 Denoising Diffusion Probabilistic Models (DDPM),存在多种资源可以帮助理解和复现该模型。这些资源不仅提供了理论上的解释,还包含了实际操作所需的代码。 #### GitHub 项目推荐 一个广泛认可且详尽的 PyTorch 版本实现可以在 `pytorch-stable-diffusion` 中找到[^2]。此仓库提供了一个易于理解的框架用于构建和训练基于扩散的概率模型,并附带详细的文档说明如何配置环境以及运行实验。 此外,在 优快云 博客上有一篇针对初学者的文章介绍了整个过程及其论文要点[^1]。这篇文章除了给出概念性的介绍外,也分享了一些实用技巧帮助读者更好地掌握这一领域内的技术细节。 #### 官方及社区支持材料 为了更深入地研究 DDPM 及其变体,可以参考原始作者发布的官方资料或者活跃的研究者们维护的相关开源项目。例如: - **Ho et al.** 提出的经典 DDPM 模型在其原版 TensorFlow 实现在 Google Research GitHub 上公开可用。 - 社区贡献方面,则有许多个人开发者创建了自己的版本,其中不乏高质量的作品如上述提到的 pytorch-stable-diffusion 库。 通过以上途径获取的信息能够满足不同层次的学习需求——无论是想要快速搭建原型还是细致探究算法内部机制都能得到相应的指导和支持。 ```python import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader from model import UNet # 假设这是定义好的UNet架构文件路径 device = 'cuda' if torch.cuda.is_available() else 'cpu' transform = transforms.Compose([ transforms.ToTensor(), ]) dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=64, shuffle=True) model = UNet().to(device) for epoch in range(num_epochs): for i, data in enumerate(dataloader): inputs, _ = data ... ``` 这段简单的 Python 代码展示了加载 MNIST 数据集并与自定义 U-Net 架构配合使用的流程片段。这只是一个起点;完整的 DDPM 训练脚本会更加复杂,涉及更多组件如噪声调度器、损失函数计算等部分。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值