Denoising Diffusion Probabilistic Models

这篇文章就是所谓的DDPM

前向扩散过程看作是一阶马尔可夫链,只和前一步有关,所以写成条件概率的形式。扩散过程是图像和标准高斯噪声I的加权,认为方差全部来自I,并且多步可以通过连乘合并为一步:

反向的过程也是类似的形式:

并且由贝叶斯公式,并且贝叶斯中三个概率都是高斯分布,可以得到:

GaussianDiffusionTrainer

首先明确扩散时的一步转移公式。表现形式为信号以某一系数进行衰减,同时加一个高斯噪声(高斯噪声为加性信号无关的高斯噪声)。

因为本质就是信号与高斯噪声的alpha blending,所以就需要考虑权重的选择。特别的是前一个状态的信号 x_{t-1} 和噪声的权重之和不是1,而是他们两个平方和才是1。因为这里关心的不是像素值,而是方差,而方差的变化与系数是平方关系。

对不同时刻的转移,权重系数是不同的,但是所使用的高斯噪声是固定的。再加上alpha blending本质是线性操作,所以多步转移可以合并为一个:

简写为:

扩散过程可以压缩为一步,每步的衰减系数连乘。可以看到代码中的beta表示的是噪声部分的权重,是等差递增的。利用这个等差数组计算连乘,得到每个时刻的权重:

class GaussianDiffusionTrainer(nn.Module): 
    def __init__(self, model, beta_1, beta_T, T): 
    super().__init__() 
    self.model = model 
    self.T = T 
    self.register_buffer( 'betas', torch.linspace(beta_1, beta_T, T).double()) #beta_T取0.02,T取1000,噪声的去噪betas是递增等差数列 
    alphas = 1. - self.betas #意味着\alpha_t是递减的,这是信号的权重,加上噪声的权重方差和是1                 
    alphas_bar = torch.cumprod(alphas, dim=0) #计算累乘,得到\sqrt(\hat(\alpha_t)),从0直接到t的累积权重 
# calculations for diffusion q(x_t | x_{t-1}) and others 
self.register_buffer( 'sqrt_alphas_bar', torch.sqrt(alphas_bar)) self.register_buffer( 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) # 得到原始信号和噪声的权重,并且注册到内存中 
def forward(self, x_0): 
""" Algorithm 1. """ 
    t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # t的区间是    [0,T),x_0.shape[0]指的就是batchsize 
    noise = torch.randn_like(x_0) # x_0是原始信号,所以噪声也要是相同尺寸的高斯分布 
    x_t = ( extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +     extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise) #时刻t的信号 
    loss = F.mse_loss(self.model(x_t, t), noise, reduction='none') 计算模型的预测与使用的高斯噪声的mse loss 
    return loss

beta的范围是[0.0001,0.02],意味着信号部分的权重alpha是[0.9999,0.98]。虽然信号部分权重很接近于1,但你要知道指数的力量:

np.power(0.99,1000)=4.3e-5。这意味着beta的取值会使得1000步之后几乎完全是一个高斯噪声。事实上,1000步之后高斯噪声的权重已经达到0.9999.

解释一下forward函数的含义。x_0表示一个batch的图像送入,forward的时候会先随机生成长度为batch的t,表示这批batch图的不同样本会经历不同时长的扩散,这些时长是在(0,1000)中随机取的,这样就可以模拟不同程度的扩散。因为使用向量运算,可以同时得到一个batch中所有图的扩散结果x_t.

除了扩散步长是随机的,扩散中所使用的噪声在不同batch之间也是随机的。这意味着我们模拟了同一幅图在不同噪声水平下,不同扩散步长下的扩散结果。

loss的计算。GaussianDiffusionTrainer还有一个成员函数model。model通常是一个unet,它的输入是x_t和t,是为了估计扩散时所使用的noise。所以计算loss是在model的输出和扩散过程所使用的噪声之间计算mse,因为网络就是来估计这个噪声的,这个噪声直接决定了反向过程的计算,详细原因可以看下面小节。

GaussianDiffusionSampler

后验概率也是高斯分布

后向转移其实就是求后验概率,所以可以使用贝叶斯公式:

上式中每一项概率都可以用x_0及扩散时的系数表示出来,并且每一项都是高斯分布:

贝叶斯公式中的概率都是高斯分布,所以可以认为P(x_{t-1}|x_t,x_0)也是高斯分布

前一步均值是首尾的加权和

既然是高斯分布,把上面式子化简为高斯分布的格式,其实就是得到x_{t-1}均值和方差:

这是求前一步均值和方差的相关代码,其中均值的表达式是初始时刻和扩散结果x_0,x_t的加权和,所以需要先计算两个权重系数:

    self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0) # t时刻的累计乘
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]  #前面补1,相对于整体右移了,得到t-1时刻的累计乘


    # variance for posterior q(x_{t-1} | x_t, x_0)
    self.register_buffer(
            'posterior_var',
            self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
    # below: log calculation clipped because the posterior variance is 0 at
    # the beginning of the diffusion chain
    self.register_buffer(
            'posterior_log_var_clipped',
            torch.log(
                torch.cat([self.posterior_var[1:2], self.posterior_var[1:]])))
# 因为后验的方差涉及到alpha_bar_prev,做法是把alpha_bar右移一位,前面补0。
# 这样的话的方差就会是0,所以把所求出的方差构成的list的第一个元素使用第二个取代
    
    # mean for posterior q(x_{t-1} | x_t, x_0)
    self.register_buffer(
            'posterior_mean_coef1',
            torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar))  # x_0的系数
        self.register_buffer(
            'posterior_mean_coef2',
            torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar))  # x_t的系数
    
    

可以看到上面求均值和方差时基本上都是和衰减系数相关的。比如连乘alphas_bar 和alphas_bar_prev,当然还有t时刻的信号权重alpha_t.

齿轮转动需要x_0

注意到上式在求均值时需要用到x_0,而x_0其实是我们最终要复原的。这无异于鸡生蛋蛋生鸡的问题。

其中一个解决办法是先随机选取一个点,然后不停地去迭代更新。比如牛顿迭代法,EM算法,K-means算法都是这个思想。不过当然初始值越准确越好,这里可以先将x_0表示为:

进一步简化:x_0=\sqrt(\frac{1}{\bar\alpha_t} )x_t-\sqrt(\frac{1}{\bar\alpha_t}-1) \varepsilon

发现x_0又是扩散终点状态x_t和噪声\varepsilon的加权和。从而有下面的代码,计算权重来估计x_0:

# calculations for diffusion q(x_t | x_{t-1}) and others
# x_t和eps的系数分别是sqrt(1/alphas_bar)和sqrt(1. / alphas_bar - 1))

self.register_buffer(
            'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar))
self.register_buffer(
            'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1))

def predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            extract(self.sqrt_recip_alphas_bar, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_bar, t, x_t.shape) * eps
        )

需要的其实是eps

把x_0的计算公式再代入上面的后向一步转移概率,得到从下面的式子。可以看出,后向转移概率的均值和方差都要知道扩散时的权重,而均值还需要知道diffusion过程中使用的高斯噪声eps。扩散时的权重是提前设定的,所以是已知的,x_t也是已知的,所以现在的关键就是求取噪声eps。

对于不同batch图像的diffusion,转移的权重是固定的list(可以认为是只和时间相关的),而高斯噪声eps是每次随机得到的。从这个角度说,噪声和图像又有某些抽象的联系,如何从寻找一个最优的标准高斯噪声,我们可以使用unet来学习得到eps。

恢复上一步信号

结合上面两个分别求x_0和求权重的代码块,可以得到:

def q_mean_variance(self, x_0, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior
        q(x_{t-1} | x_t, x_0)
        """
        assert x_0.shape == x_t.shape
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_log_var_clipped = extract(
            self.posterior_log_var_clipped, t, x_t.shape)
        return posterior_mean, posterior_log_var_clipped

得到均值和方差之后,知道了均值和方差,就可以构建上一时刻的信号。一步步迭代,就可以起到恢复图像的效果:

def forward(self, x_T):
        """
        Algorithm 2.
        """
        x_t = x_T
        for time_step in reversed(range(self.T)): # 注意这里的reversed
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step  # 得到一个batch的time_step
            mean, log_var = self.p_mean_variance(x_t=x_t, t=t)
            # no noise when t == 0
            if time_step > 0:
                noise = torch.randn_like(x_t) # 引入(0,1)的高斯噪声
            else:
                noise = 0
            x_t = mean + torch.exp(0.5 * log_var) * noise  # 噪声的权重为var的开根号
        x_0 = x_t
        return torch.clip(x_0, -1, 1)

注意均值其实就是恢复出的信号,再按照估计的方差大小,叠加对应的随机高斯噪声。这样做的好处是保持了生成的可能性和多样性。

估计完上一步之后,还要估计上上一步,仍然需要计算均值和方差,而这就要网络估计eps。这也就是为什么训练阶段就需要扩散到不同程度的原因,这样网络才可以从不同扩散时刻的信号中估计出噪声eps。

实验

010000140001800074000

疑问:

  1. 变分推理,求x_0需要先知道x_0,取代积分?和熵+KL优化的区别?
  2. 前向后向都是高斯的依据
  3. 训练求eps,也可以训练求x_t-1?
  4. 后验概率方差求cat和log?cat的原因是代码中的注释所写的,求log是为了避免溢出?
  5. eps可以认为是退化核?unet的作用是寻找最优的?最符合这个图的核?
  6. 渐进的有损解压progressive lossy decompression

    自回归解码的泛化generalization of autoregressive decoding

  7. 和传统去噪算法对比:

    f(原始信号,noise) GT:干净图像,估计的噪声和图像内容强相关。

    f(原始信号, t ,noise) GT:高斯噪声。

    残差的时候都可以看作是学习噪声分布。

    区别: 1. diffusion还有时间t的影响。

    2.diffusion的噪声分布是高斯的,信号无关的。

    3.去噪的时候可以直接拿到带噪声的信号,生成的时候输入是标准高斯,加入文本模型的指导也是高斯?但是扩散的时候1000步之后不一定是高斯吧

    4.去噪可以直接由残差得到干净图,生成因为是多步的,只能根据高斯噪声一步步转移回去。5.扩散的阶段是使用同一个高斯噪声,采样的阶段不是同一个。

reference:

1.pytorch-ddpm/diffusion.py at master · w86763777/pytorch-ddpm · GitHub

2.https://zhuanlan.zhihu.com/p/666552214

3.Diffusion Models:生成扩散模型

4.https://zhuanlan.zhihu.com/p/682840224

5.https://sailing-mbzuai.github.io/assets/pdf/Diffusion_Model_Slides.pdf

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值