大家好,小鱼🐟最近学习了DDPM,将一些干货提供给大家。
干货:原理讲解+网络讲解+代码讲解,三者结合在一起的哦。原文和推荐代码也有提供哦,都是可直接打开和运行的,当然,跑代码,环境要先设置好哦。小鱼🐟后面也有提到哦!
如果,小鱼🐟有不对的和错误的地方,望大家评论区提出哦,共同进步!谢谢哦!
一、简介
1.1原文
DDPM2020这个是原文地址,可直接下载哦!
本文,结合原文重点以及推荐代码,给大家进行了最重要的总结与归纳。
1.2推荐代码
推荐代码在文章最后,引用源于github资源。
1.3原理
原理分为两个部分:
(1)Diffusion Phase:对应训练Train过程【q过程】
(2)Reverse Phase:对应评估Evaluation过程【p过程】
二、算法
2.0前提准备
首先,DDPM,需要手动设定两个参数,
;以及步长数T
其次,,
,因此,这两个参数我们是已知的。
接着,
最后,Net_Z
2.1Diffusion Phase
需要注意的是:Net_Z就是训练结果
Net_Z的工作目标是:
Net_Z的Loss是:与
越近似越好
Net_Z的保存:相关训练参数,最后,用于Reverse
2.2Reverse Phase
需要注意的是:Reverse Phase的目标是求,即采样,还原信息。
输入仅存在最早的:纯Noisy
Reserve_Net的工作目标是:
,
Reserve_Net的输出是:预测的
三、网络
3.1UNet
UNet网络是专为图像分割任务设计的CNN,可有效提取图像的全局和局部信息。
3.2Pytorch
所给推荐代码的网络,继承于torch.Model
3.3Train
基本过程:梯度清零->计算损失->损失反向传播->梯度裁剪->更新权重
3.4Reverse
基本过程:噪声设置->调用Reverse_Net->保存结果
3.5环境配置
pytorch, tqdm,torchvision,scipy,numpy。
四、推荐的DDPM代码(二维)
DDPM代码这是二维图像的,最好是GPU跑。
五、调试修正为一维信号的DDPM代码
可以私信小鱼哦!
这是小鱼,最后用DDPM跑出来的生成一维信号。
DDPM:输入的是一个受严重噪声干扰的调制余弦波。