【Diffusers】三、扩散模型的训练流程

三、训练扩散模型

训练的基础

不能直接训练pipeline,得单独训练组件。

Train a diffusion model

通常,最好的结果是通过在特定数据集上微调预训练模型来获得的。可以在Hub上找到许多这样的检查点. 本教程将教如何在Smithsonian Butterflies数据集的子集上从头开始训练UNet2DModel,以生成自己的 🦋 蝴蝶 🦋。

训练前设置

在开始之前,请确保安装数据集来加载和预处理图像数据集,并安装了加速,以简化在任意数量的 GPU 上的训练。以下命令还将安装**TensorBoard以可视化训练指标(您还可以使用权重和偏差**来跟踪您的训练)。

pip install differs[training]

上传模型:从笔记本登录并在出现提示时输入您的令牌。确保您的令牌具有写入角色。

from huggingface_hub import notebook_login
notebook_login()
# 或者从终端登录:
huggingface-cli login

由于模型检查点非常大,因此安装**Git-LFS**来对这些大文件进行版本控制:

!sudo apt -qq install git-lfs
!git config --global credential.helper store

训练配置

为了方便起见,创建一个TrainingConfig包含训练超参数的类(随意调整它们):

from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub

    push_to_hub = True  # whether to upload the saved model to the HF Hub
    hub_model_id = "<your-username>/<my-awesome-model>"  # the name of the repository to create on the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()

加载数据集

使用 🤗 数据集库轻松加载**史密森尼蝴蝶数据集**

from datasets import load_dataset

config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")

**可以从HugGan 社区活动中找到其他数据集,也可以通过[ImageFolder](https://huggingface.co/docs/datasets/image_dataset#imagefolder)**创建本地config.dataset_name 如果数据集来自 HugGan 社区活动,或者imagefolder您使用自己的图像,则设置为数据集的存储库 ID。

快速可视化

Datasets 使用**Image功能自动解码图像数据并将其加载为[PIL.Image](https://pillow.readthedocs.io/en/stable/reference/Image.html)**我们可以可视化的图像:

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]["image"]):
    axs[i].imshow(image)
    axs[i].set_axis_off()
fig.show()

预处理

不过,这些图像的大小不同,因此您需要先对它们进行预处理:

  • Resizeconfig.image_size将图像大小更改为 中定义的大小。
  • RandomHorizontalFlip通过随机镜像图像来增强数据集。
  • Normalize将像素值重新缩放到 [-1, 1] 范围内非常重要,这正是模型所期望的。
from torchvision import transforms

preprocess = transforms.Compose(
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值