diffusion model 简单demo

写在前面
强烈推荐这个:
https://huggingface.co/datasets/HuggingFace-CN-community diffusion 教程
扩散模型是如何工作的:从零开始的数学原理

以下内容参考自:

Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读
知乎:diffusion model 最近在图像生成领域大红大紫,如何看待它的风头开始超过 GAN ?
diffusion 简单demo
扩散模型之DDPM
Diffusion model 原理剖析
张振虎-扩散概率模型
生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼

核心公式和逻辑

在这里插入图片描述
优化目标:
在这里插入图片描述
在这里插入图片描述
然后:
在这里插入图片描述
然后:
在这里插入图片描述
在这里插入图片描述

核心公式:

在这里插入图片描述

训练阶段

在这里插入图片描述
实际上是根据加噪后的图 和时间步 t 去预测噪声
在这里插入图片描述

q_x 计算公式,后面会用到:
在这里插入图片描述

推理

在这里插入图片描述

代码

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve, make_swiss_roll
from PIL import Image
import torch
import io

# get data
# s_curve, _ = make_s_curve(10**4 , noise=0.1)
# s_curve = s_curve[:, [0, 2]] / 10.0

swiss_roll, _ = make_swiss_roll(10**4,noise=0.1)
s_curve = swiss_roll[:, [0, 2]]/10.0

print('shape of moons: ', np.shape(s_curve))

data = s_curve.T
fix, ax = plt.subplots()
ax.scatter(*data, color='red', edgecolors='white', alpha=0.5)

ax.axis('off')

# plt.show()
plt.savefig('./s_curve.png')

dataset = torch.Tensor(s_curve).float()

# set params
num_steps = 100

betas = torch.linspace(-6, 6, num_steps)    # # 逐渐递增
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5    # β0,β1,...,βt

print('beta: ', betas)

alphas = 1 - betas
alphas_pro = torch.cumprod(alphas, 0)   # αt^ = αt的累乘

# αt^往右平移一位, 原第t步的值维第t-1步的值, 第0步补1
alphas_pro_p = torch.cat([torch.tensor([1]).float(), alphas_pro[:-1]], 0)   # p表示previous, 即 αt-1^


alphas_bar_sqrt = torch.sqrt(alphas_pro)    # αt^ 开根号
one_minus_alphas_bar_log = torch.log(1 - alphas_pro)    # log (1 - αt^)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_pro)  # 根号下(1-αt^)

assert alphas.shape == alphas_pro.shape == alphas_pro_p.shape == alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shape

print('beta: shape ', betas
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值