任务二:
-
打卡内容:
-
学习笔记
-
作业:
-
基础:fine-tune 一个 fashion-mnist 类别引导的图像生成模型,并生成对应的图像
-
进阶:使用 upscaler 等超分模块高清化生成的图像
-
-
引导与微调
总结: 主要讨论了如何调整和引导现有的扩散模型。其中包括两种主要方法:
- 微调 (Fine-Tuning): 重新训练现有模型以更改其输出。
- 引导 (Guidance): 在推理时引导现有模型的生成过程,以获得更多的控制权。
能力总结:
- 创建一个采样循环并使用新的调度器更快地生成样本。
- 在新数据上微调现有的扩散模型。
- 使用其他损失函数指导采样过程,以增加对现有模型的控制。
学习目的:
- 理解微调和引导的基本概念。
- 学习如何使用新的调度器更快地生成样本。
- 了解如何在新数据上微调现有的扩散模型。
- 学习如何使用额外的损失函数来指导采样过程。
学习总结:
- 微调是重新训练模型的方法,以使其产生不同的输出。
- 引导允许我们在推理时对模型的生成过程进行更精确的控制。
- 采样循环可以被优化,以更快地生成样本。
- 还可以使用Weights and Biases工具来监控训练过程。
DDIM
[2010.02502] Denoising Diffusion Implicit Models (arxiv.org)
在生成图像的每一步中,模型都会接收一个带有噪声的输入,并且需要预测这个噪声,以此来估计没有噪声的完整图像是什么。这个过程被称为采样过程,在Diffusers库中,采样通过调度器控制的,之前的文章中介绍过DDPMScheduler调度器,本文介绍的DDIMScheduler可以通过更少的迭代周期来产生很好的采样样本(1000多步采样不是必须的)。
# 创建一个新的调度器并设置推理迭代次数
scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(num_inference_steps=40)
scheduler.timesteps
下面使用4幅随机噪声图像进行循环采样,并观察每一步的输入与输出的”去噪“图像,代码如下:
# 从随机噪声开始
x = torch.randn(4, 3, 256, 256).to(device)
# batch size为4,三通道,长、宽均为256像素的一组图像
# 循环一整套时间步
for i, t in tqdm(enumerate(scheduler.timesteps)):
# 准备模型输入:给“带躁”图像加上时间步信息
model_input = scheduler.scale_model_input(x, t)
# 预测噪声
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
# 使用调度器计算更新后的样本应该是什么样子
scheduler_output = scheduler.step(noise_pred, t, x)
# 更新输入图像
x = scheduler_output.prev_sample
# 时不时看一下输入图像和预测的“去噪”图像
if i % 10 == 0 or i == len(scheduler.timesteps) - 1:
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
grid = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0)
axs[0].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)
axs[0].set_title(f"Current x (step {i})")
pred_x0 = (
scheduler_output.pred_original_sample
)
grid = torchvision.utils.make_grid(pred_x0, nrow=4).
permute(1, 2, 0)
axs[1].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)
axs[1].set_title(f"Predicted denoised images (step {i})")
plt.show()
通过step不同的对比可以看到,经过逐步优化后的模型输出已经在较好的预测噪声,从而能够使得最终的x0变得清晰。
第二步生成图像的采样器是DDPMScheduler,我们可以使用新的DDIMScheduler来代替DDPMScheduler看看image_pipe生成的效果是否有提升,代码如下:
image_pipe.scheduler = scheduler
images = image_pipe(num_inference_steps=40).images
images[0]
Fine-Tuning
当你