diffusion model基础学习——DDPM代码解析

Denoising Diffusion Probabilistic Models (DDPM)代码解析

Paper:地址
代码:https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-

前言

本文主要梳理以训练CIFAR-10数据集【包含10个类别的60000个32x32彩色图像】为例,DDPM与DDPM+condition的pytorch代码训练流程,模型搭建过程。细节方面解释会在代码中标注释,建议先自行查看一下代码流程。

理论大概

在这里插入图片描述
diffusion model是一个参数化的马尔可夫链,它使用变分推理进行训练,在有限时间后产生与数据匹配的样本。如上图所示,真实数据 x 0 x_0 x0经过T步正态分布的加噪过程得到 x T x_T xT x T x_T xT服从标准正态分布 N ( 0 , I ) N(0,I) N(0,I)。而模型学习逆向的转移概率 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt),从标准正态分布中采样的 x T x_T xT开始,逐步生成真实数据。训练目标则是生成的数据分布和真实数据分布尽可能相似。

这里不描述过多的公式推导过程,若对推导过程感兴趣,可查看 https://kxz18.github.io/2022/06/19/Diffusion/

前向过程

目标:通过不断加噪音将复杂的真实数据分布转换为简单的易处理的分布。
根据方差表 β 1 , . . . , β t \beta_1, ... ,\beta_t β1,...,βt逐渐向数据添加高斯噪声

q ( x 1 : T ∣ x 0 ) : = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(x_{1:T}|x_0) :=\prod_{t=1}^{T}q(x_t|x_{t-1}) q(x1:Tx0):=t=1Tq(xtxt1)

q ( x t ∣ x t − 1 ) : = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t|x_{t-1}) :=N(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_tI ) q(xtxt1):=N(xt;1βt xt1,βtI)

逆向过程

扩散模型是形式为 p θ ( x 0 ) : = ∫ p θ ( x 0 : T ) d x 1 : T p_\theta(x_0) :=\int p_\theta(x_{0:T})dx_{1:T} pθ(x0):=pθ(x0:T)dx1:T的潜变量模型,其中 x 1 , . . . , x T x_1, ... ,x_T x1,...,xT是与数据 x 0   q ( x 0 ) x_0 ~ q(x_0) x0 q(x0)具有相同维数的隐变量。联合分布 p θ ( x 0 : T ) p_\theta(x_{0:T}) pθ(x0:T)称为反向过程,它被定义为一个具有从 p ( x T ) = N ( x T ; 0 , I ) p(x_T) = N(x_T;0,I) p(xT)=N(xT;0,I)高斯分布可学习的马尔可夫链。

p θ ( x 0 : T : = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p_\theta(x_{0:T} := p(x_T)\prod_{t=1}^{T}p_\theta(x_{t-1}|x_t) pθ(x0:T:=p(xT)t=1Tpθ(xt1xt)

p θ ( x t − 1 ∣ x t ) : = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(x_{t-1}|x_t) :=N(x_{t-1};\mu_\theta(x_t,t),\sum_{\theta}(x_t,t)) pθ(xt1xt):=N(xt1;μθ(xt,t),θ(xt,t))

训练目标是通过模型生成数据分布 p θ ( x 0 ) p_\theta(x_0) pθ(x0)与真实数据分布 q θ ( x 0 ) q_\theta(x_0) qθ(x0)尽可能相近(训练unet模型使得每一步生成的噪声尽可能与加噪q相似)

训练流程

在这里插入图片描述

diffusion 代码

1、加载数据集

#加载pytorch中内置CIFAR-10数据集的类,指定root路径,加载train训练集
#对加载的图像首先进行随机水平翻转(RandomHorizontalFlip)
#然后将图像转换为张量(ToTensor),最后对图像进行标准化(Normalize)
    dataset = CIFAR10(
        root='./CIFAR10', train=True, download=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]))
        #创建数据加载器的类,批量加载数据
    dataloader = DataLoader(
        dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
#batch_size: 指定每个小批量的样本数量。
#shuffle=True: 指定是否在每个epoch之前打乱数据顺序,以增加训练的随机性。
#num_workers: 指定用于数据加载的子进程数量。这可以加快数据加载速度,因为数据预处理和加载可以与模型训练并行进行。
#drop_last=True: 如果数据集的样本数量不能被批量大小整除,设置为True会丢弃最后一个不完整的批次。
#pin_memory=True: 如果可用,将数据加载到CUDA固定内存中,可以加速数据传输到GPU。

2、设置模型参数

    #提前设置的modelConfig
    modelConfig = {
        "state": "train", # or eval
        "epoch": 200,
        "batch_size": 100,
        "T": 1000,
        "channel": 128,
        "channel_mult": [1, 2, 3, 4],
        "attn": [2],
        "num_res_blocks": 2,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.,
        "beta_1": 1e-4,
        "beta_T": 0.02,
        "img_size": 32,
        "grad_clip": 1.,
        "device": "cuda:1", ### MAKE SURE YOU HAVE A GPU !!!
        "training_load_weight": None,
        "save_weight_dir": "./Checkpoints/",
        "test_load_weight": "ckpt_199_.pt",
        "sampled_dir": "./SampledImgs/",
        "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
        "sampledImgName": "SampledNoGuidenceImgs.png",
        "nrow": 8
        }
  # unet,1000步,channel 128,channel层数[1,2,3,4],attn [2],res_blocks 2,根据这些设置搭建后面的unet模型
    net_model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    if modelConfig["training_load_weight"] is not None:
        net_model.load_state_dict(torch.load(os.path.join(
            modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
            
    #adamw 优化器,Adam优化器的变种,具有更好的收敛性能
    optimizer = torch.optim.AdamW(net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
    
    #PyTorch中余弦退火学习率调度器,在训练过程中动态地调整学习率
    #T_max指定了余弦退火的周期(即一个epoch的数量),eta_min是学习率的下限,last_epoch指定了上一个epoch的索引
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
  
    #GradualWarmupScheduler自定义的学习率调度器,在训练开始时逐渐增加学习率,以帮助模型更快地收敛到合适的区域。
    #multiplier指定了初始学习率相对于设定初始学习率的倍数
    #warm_epoch指定了预热阶段的epoch数量,after_scheduler指定了预热阶段结束后要使用的学习率调度器。
    warmUpScheduler = GradualWarmupScheduler(
        optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
    
    #GaussianDiffusionTrainer,训练器,传入了模型net_model以及beta_1、beta_T和T
    trainer = GaussianDiffusionTrainer(
        net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)

3、训练过程

    # start training
    for e in range(modelConfig["epoch"]):
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for images, labels in tqdmDataLoader:
                # train
                optimizer.zero_grad()
                x_0 = images.to(device)   
                loss = trainer(x_0).sum() / 1000.   #传入模型计算loss
                loss.backward()   #计算损失关于模型参数的梯度。
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"]) #对梯度进行裁剪,以防止梯度爆炸。
                optimizer.step()    #计算出的梯度更新模型参数
                tqdmDataLoader.set_postfix(ordered_dict={
                    "epoch": e,
                    "loss: ": loss.item(),
                    "img shape: ": x_0.shape,
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                })   #更新进度条的显示信息,包括当前epoch数、损失值、当前处理的图像形状以及当前的学习率。
        warmUpScheduler.step()   #更新学习率调度器,用于在预热阶段结束后切换到余弦退火调度器。
        torch.save(net_model.state_dict(), os.path.join(
            modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt"))

4、GaussianDiffusion模型

class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T
        
        #创建一个名为betas的缓冲区,其中保存了从beta_1到beta_T等间距的参数
        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

    def forward(self, x_0):
        """
        输入x_0 计算最终噪声x_T,将x_T和T步传入Unet
        """
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)#根据公式计算x_t
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
        return loss

5、Unet模型搭建流程

class UNet(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
         #TimeEmbedding模块按照sin,cos的计算方式计算位置编码。
        self.time_embedding = TimeEmbedding(T, ch, tdim)
       #搭建下采样模块,torch.Size([80, 128, 32, 32])——>torch.Size([80, 512, 4, 4])
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)
           
       #中间 两个res模块
        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
        
       #搭建上采样模块
        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

diffusion condition代码

相比于diffusion代码,diffusion condition在训练时加入了label模块

1、训练过程

    for e in range(modelConfig["epoch"]):
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for images, labels in tqdmDataLoader:
                # train
                b = images.shape[0]
                optimizer.zero_grad()
                x_0 = images.to(device)
                #labels
                labels = labels.to(device) + 1
                if np.random.rand() < 0.1:
                    labels = torch.zeros_like(labels).to(device)
                loss = trainer(x_0, labels).sum() / b ** 2.
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"])
                optimizer.step()

2、相应模型搭建部分调整

对于 GaussianDiffusionTrainer部分, x 0 x_0 x0正态分布计算为 x T x_T xT部分不变,unet模块搭建相应改变
增加了一个条件编码

class UNet(nn.Module):
    def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
        super().__init__()
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
 

    def forward(self, x, t, labels):
        # Timestep embedding
        temb = self.time_embedding(t)
        cemb = self.cond_embedding(labels)#条件编码
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb, cemb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb, cemb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb, cemb)
        h = self.tail(h)

        assert len(hs) == 0
        return h
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值