三、训练扩散模型
训练的基础
不能直接训练pipeline,得单独训练组件。
通常,最好的结果是通过在特定数据集上微调预训练模型来获得的。可以在Hub上找到许多这样的检查点. 本教程将教如何在Smithsonian Butterflies数据集的子集上从头开始训练UNet2DModel,以生成自己的 🦋 蝴蝶 🦋。
训练前设置
在开始之前,请确保安装数据集来加载和预处理图像数据集,并安装了加速,以简化在任意数量的 GPU 上的训练。以下命令还将安装**TensorBoard以可视化训练指标(您还可以使用权重和偏差**来跟踪您的训练)。
pip install differs[training]
上传模型:从笔记本登录并在出现提示时输入您的令牌。确保您的令牌具有写入角色。
from huggingface_hub import notebook_login
notebook_login()
# 或者从终端登录:
huggingface-cli login
由于模型检查点非常大,因此安装**Git-LFS**来对这些大文件进行版本控制:
!sudo apt -qq install git-lfs
!git config --global credential.helper store
训练配置
为了方便起见,创建一个TrainingConfig
包含训练超参数的类(随意调整它们):
from dataclasses import dataclass
@dataclass
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 16
eval_batch_size = 16 # how many images to sample during evaluation
num_epochs = 50
gradient_accumulation_steps = 1
learning_rate = 1e-4
lr_warmup_steps = 500
save_image_epochs = 10
save_model_epochs = 30
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub
push_to_hub = True # whether to upload the saved model to the HF Hub
hub_model_id = "<your-username>/<my-awesome-model>" # the name of the repository to create on the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0
config = TrainingConfig()
加载数据集
使用 🤗 数据集库轻松加载**史密森尼蝴蝶数据集**
from datasets import load_dataset
config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")
**可以从HugGan 社区活动中找到其他数据集,也可以通过[ImageFolder](https://huggingface.co/docs/datasets/image_dataset#imagefolder)
**创建本地config.dataset_name
如果数据集来自 HugGan 社区活动,或者imagefolder
您使用自己的图像,则设置为数据集的存储库 ID。
快速可视化
Datasets 使用**Image功能自动解码图像数据并将其加载为[PIL.Image](https://pillow.readthedocs.io/en/stable/reference/Image.html)
**我们可以可视化的图像:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]["image"]):
axs[i].imshow(image)
axs[i].set_axis_off()
fig.show()
预处理
不过,这些图像的大小不同,因此您需要先对它们进行预处理:
Resizeconfig.image_size
将图像大小更改为 中定义的大小。RandomHorizontalFlip
通过随机镜像图像来增强数据集。Normalize
将像素值重新缩放到 [-1, 1] 范围内非常重要,这正是模型所期望的。
from torchvision import transforms
preprocess = transforms.Compose(