pytorch一致数据增强

本文讨论了在进行图像分割任务时,如何确保图像和标签的变换保持一致性,提出使用MultiCompose扩展torchvision.transforms.Compose,以支持同时处理多幅图像和标签,并重置随机种子以保证结果的一致性。作者还提到了torchvision.transforms中的ColorJitter等仅对图像操作的变换未被考虑的情况。
部署运行你感兴趣的模型镜像

(2024.1.14)本文的例程有个错误:segmentation label 的 resize interpolation 应只能用 nearest 模式,否则可能会出现错误 label(如 3、5 二类插出 4 类,但那一片本来没有 4 类的 pixels),而本文例程用了默认的 bilinear。改进版 MultiCompose 见后作 [5],通过支持各输入同用、独用 augmentation 修正此错误。


分割任务对 image 做(某些)transform 时,要对 label(segmentation mask)也做对应的 transform,如 Resize、RandomRotation 等。如果对 image、label 分别用 transform 处理一遍,则涉及随机操作的可能不一致,如 RandomRotation 将 image 转了 a 度、却将 label 转了 b 度。

MONAI 有个 ArrayDataset 实现了这功能,思路是每次 transform 前都重置一次 random seed 先。对 monai 订制 transform 的方法不熟,torchvision.transforms 的订制接口比较简单,考虑基于 pytorch 实现。要改两个东西:

  • 扩展 torchvison.transforms.Compose,使之支持多个输入(image、label);
  • 一个 wrapper,扩展 transform,使之支持多输入。

思路也是重置 random seed,参考 [1-4]。

Code

  • to_multi:将处理单幅图的 transform 扩展成可处理多幅;
  • MultiCompose:扩展 torchvision.transforms.Compose,可输入多幅图。内部调用 to_multi 扩展传入的 transforms。
import random, os
import numpy as np
import torch

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def to_multi(trfm):
    """wrap a transform to extend to multiple input with synchronised random seed
    Input:
        trfm: transformation function/object (custom or from torchvision.transforms)
    Output:
        _multi_transform: function
    """
    # numpy.random.seed range error:
    #   ValueError: Seed must be between 0 and 2**32 - 1
    min_seed = 0 # - 0x8000_0000_0000_0000
    max_seed = min(2**32 - 1, 0xffff_ffff_ffff_ffff)
    def _multi_transform(*images):
        """images: [C, H, W]"""
        if len(images) == 1:
            return trfm(images[0])
        _seed = random.randint(min_seed, max_seed)
        res = []
        for img in images:
            seed_everything(_seed)
            res.append(trfm(img))
        return tuple(res)

    return _multi_transform


class MultiCompose:
    """Extension of torchvision.transforms.Compose that accepts multiple input.
    Usage is the same as torchvision.transforms.Compose. This class will wrap input
    transforms with `to_multi` to support simultaneous multiple transformation.
    This can be useful when simultaneously transforming images & segmentation masks.
    """
    def __init__(self, transforms):
        """transforms should be wrapped by `to_multi`"""
        self.transforms = [to_multi(t) for t in transforms]

    def __call__(self, *images):
        for t in self.transforms:
            images = t(*images)
        return images

test

测试一致性,用到预处理过的 verse’19 数据集、一些工具函数、一个订制 transform:

  • verse’19 数据集及预处理见 iTomxy/data/verse
  • digit_sort_key:数据文件排序用;
  • get_palettecolor_segblend_seg:可视化用;
  • MyDataset:看其中 __getitem__ 的 transform 用法,即同时传入 image 和 label;
  • ResizeZoomPad:一个订制的 transform;
import os, os.path as osp, random
from glob import glob
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F


def digit_sort_key(s, num_pattern=re.compile('([0-9]+)')):
    """natural sort,数据排序用"""
    return [int(text) for text in num_pattern.split(s) if text.isdigit()]


def get_palette(n_classes, pil_format=True):
    """创建调色盘,可视化用"""
    n = n_classes
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3

    if pil_format:
        return palette

    res = []
    for i in range(0, len(palette), 3):
        res.append(tuple(palette[i: i+3]))
    return res


def color_seg(label, n_classes=0):
    """segmentation mask 上色,可视化用"""
    if n_classes < 1:
        n_classes = math.ceil(np.max(label)) + 1
    label_rgb = Image.fromarray(label.astype(np.int32)).convert("L")
    label_rgb.putpalette(get_palette(n_classes))
    return label_rgb.convert("RGB")


def blend_seg(image, label, n_classes=0, alpha=0.7, rescale=False, transparent_bg=True, save_file=""):
    """融合 image 和其 segmentation mask,可视化用"""
    if rescale:
        denom = image.max() - image.min()
        if 0 != denom:
            image = (image - image.min()) / denom * 255
        image = np.clip(image, 0, 255).astype(np.uint8)
    img_pil = Image.fromarray(image).convert("RGB")
    lab_pil = color_seg(label, n_classes)
    blended_image = Image.blend(img_pil, lab_pil, alpha)
    if transparent_bg:
        blended_image = Image.fromarray(np.where(
            (0 == label)[:, :, np.newaxis],
            np.asarray(img_pil),
            np.asarray(blended_image)
        ))
    if save_file:
        blended_image.save(save_file)
    return blended_image


class MyDataset(torch.utils.data.Dataset):
    """订制 dataset,看 __getitem__ 处 transform 的调法"""
    def __init__(self, image_list, label_list, transform=None):
        assert len(image_list) == len(label_list)
        self.image_list = image_list
        self.label_list = label_list
        self.transform = transform
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, index):
        img = np.load(self.image_list[index]) # [h, w]
        lab = np.load(self.label_list[index])
        img = torch.from_numpy(img).unsqueeze(0).float() # -> [c=1, h, w]
        lab = torch.from_numpy(lab).unsqueeze(0).int()
        if self.transform is not None:
            img, lab = self.transform(img, lab) # 同时传入 image、label
        return img, lab


class ResizeZoomPad:
    """订制 resize"""
    def __init__(self, size, interpolation="bilinear"):
        if isinstance(size, int):
            assert size > 0
            self.size = [size, size]
        elif isinstance(size, (tuple, list)):
            assert len(size) == 2 and size[0] > 0 and size[1] > 0
            self.size = size

        if isinstance(interpolation, str):
            assert interpolation.lower() in {"nearest", "bilinear", "bicubic", "box", "hamming", "lanczos"}
            interpolation = {
                "nearest": F.InterpolationMode.NEAREST,
                "bilinear": F.InterpolationMode.BILINEAR,
                "bicubic": F.InterpolationMode.BICUBIC,
                "box": F.InterpolationMode.BOX,
                "hamming": F.InterpolationMode.HAMMING,
                "lanczos": F.InterpolationMode.LANCZOS
            }[interpolation.lower()]
        self.interpolation = interpolation

    def __call__(self, image):
        """image: [C, H, W]"""
        scale_h, scale_w = float(self.size[0]) / image.size(1), float(self.size[1]) / image.size(2)
        scale = min(scale_h, scale_w)
        tmp_size = [ # clipping to ensure size
            min(int(image.size(1) * scale), self.size[0]),
            min(int(image.size(2) * scale), self.size[1])
        ]
        image = F.resize(image, tmp_size, self.interpolation)
        assert image.size(1) <= self.size[0] and image.size(2) <= self.size[1]
        pad_h, pad_w = self.size[0] - image.size(1), self.size[1] - image.size(2)
        if pad_h > 0 or pad_w > 0:
            pad_left, pad_right = pad_w // 2, (pad_w + 1) // 2
            pad_top, pad_bottom = pad_h // 2, (pad_h + 1) // 2
            image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom))
        return image


# 读数据文件
data_path = os.path.expanduser("~/data/verse/processed-verse19-npy-horizontal")
train_images, train_labels, val_images, val_labels = [], [], [], []
for d in os.listdir(osp.join(data_path, "training")):
    if d.endswith("_ct"):
        img_p = osp.join(data_path, "training", d)
        lab_p = osp.join(data_path, "training", d[:-3]+"_seg-vert_msk")
        assert osp.isdir(lab_p)
        train_labels.extend(glob(os.path.join(lab_p, "*.npy")))
        train_images.extend(glob(os.path.join(img_p, "*.npy")))
for d in os.listdir(osp.join(data_path, "validation")):
    if d.endswith("_ct"):
        img_p = osp.join(data_path, "validation", d)
        lab_p = osp.join(data_path, "validation", d[:-3]+"_seg-vert_msk")
        assert osp.isdir(lab_p)
        val_labels.extend(glob(os.path.join(lab_p, "*.npy")))
        val_images.extend(glob(os.path.join(img_p, "*.npy")))

# 数据文件名排序
train_images = sorted(train_images, key=lambda f: digit_sort_key(os.path.basename(f)))
train_labels = sorted(train_labels, key=lambda f: digit_sort_key(os.path.basename(f)))
val_images = sorted(val_images, key=lambda f: digit_sort_key(os.path.basename(f)))
val_labels = sorted(val_labels, key=lambda f: digit_sort_key(os.path.basename(f)))

# transform
# 用 MultiCompose,其内部调用 to_multi 将 transforms wrap 成支持多输入的
train_trans = MultiCompose([
    ResizeZoomPad((224, 256)),
    transforms.RandomRotation(30),
])

# 测试:读数据,可试化 image 和 label
check_ds = MyDataset(train_images, train_labels, train_trans)
check_loader = torch.utils.data.DataLoader(check_ds, batch_size=10, shuffle=True)
for images, labels in check_loader:
    print(images.size(), labels.size())
    for i in range(images.size(0)):
        # print(i, end='\r')
        img = images[i][0].numpy()
        lab = labels[i][0].numpy()
        print(np.unique(lab))
        img = (255 * (img - img.min()) / (img.max() - img.min())).astype(np.uint8)
        seg_img = blend_seg(img, lab)
        img = np.asarray(Image.fromarray(img).convert("RGB"))
        lab = np.asarray(color_seg(lab))
        comb = np.concatenate([img, lab, seg_img], axis=1)
        Image.fromarray(comb).save(f"test-dataset-{i}.png")
    break

效果:
test-dataset-7.png
可见,image 和 label 转了同一个随机角度。

Limits

有些 augmentations 是只对 image 做而不对 label 做的,如 ColorJitter,这里没有考虑怎么处理。

References

  1. How to Set Random Seeds in PyTorch and Tensorflow
  2. ihoromi4/seed_everything.py
  3. Reproducibility
  4. What is the max seed you can set up?
  5. pytorch一致数据增强—独用增强

您可能感兴趣的与本文相关的镜像

PyTorch 2.9

PyTorch 2.9

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

<think>首先,用户的问题是基于提供的引用内容:"请问如何在PyTorch中实现自定义数据集加载类"。引用内容讨论了在PyTorch中创建自定义数据集和加载器,特别提到了继承`torch.utils.data.Dataset`类。 关键点从引用中: - 引用[1]:PyTorch 提供了`torch.utils.data.DataLoader`和`torch.utils.data.Dataset`来使用自定义数据。 - 引用[2]:自定义数据集通过继承`Dataset`类实现,需要定义三个魔法方法:`__init__`、`__len__`、`__getitem__`。 - 引用[3]:详细说明了如何实现子类,包括初始化、长度和获取项目。 - 引用[4]:提到了创建自定义数据集类并生成标签文件,提供了一个数据集下载链接。 系统级指令: - 所有行内数学表达式用`$...$`格式。 - 独立公式用`$$...$$`格式并单独成段。 - LaTeX语法正确。 - 使用中文回答。 - 生成相关问题。 - 回答中引用的段落末尾自然地添加引用标识,例如[^1]。 用户可见层: - 回答结构清晰,帮助用户逐步解决问题。 - 保证回答真实可靠。 用户的问题:如何在PyTorch中实现自定义数据集加载类?这应该包括创建自定义数据集类和使用数据加载器。 结构回答: 1. **介绍**:简要解释PyTorch中自定义数据集的重要性。 2. **步骤**:逐步说明如何实现自定义数据集类。 - 继承`Dataset`类。 - 实现`__init__`方法:初始化数据路径、标签等。 - 实现`__len__`方法:返回数据集大小。 - 实现`__getitem__`方法:加载和预处理单个样本。 3. **使用数据加载器**:展示如何将自定义数据集传递给`DataLoader`。 4. **示例代码**:提供一个简单的代码示例。 5. **引用**:在相关部分末尾添加引用标识。 6. **生成相关问题**:在回答末尾添加相关问题。 确保使用中文回答。 生成相关问题: - 基于主题,生成3-4个相关问题,如优化、性能分析等。 示例代码:基于引用,提供一个自定义数据集类的示例。例如,使用一个图像数据集,如蜜蜂蚂蚁数据集(从引用[4])。 完整回答: - 开头:直接回答问题。 - 主体:逐步解释。 - 结尾:总结并添加相关问题。 引用标识:在回答中,当提到引用内容时,在段落末尾添加[^n]。例如,引用[2]讨论了三个魔法方法。 最终输出格式: - 回答部分。 - §§相关问题§§:列出相关问题。</think>### 在PyTorch中实现自定义数据集加载类 在PyTorch中,自定义数据集加载类允许用户处理非标准数据(如图像、文本或时间序列),通过继承`torch.utils.data.Dataset`类并实现关键方法来实现。这确保了数据能被高效加载和预处理,再通过`DataLoader`进行批处理和迭代。以下是基于PyTorch最佳实践的逐步实现指南,结合了引用内容的核心要点[^1][^2][^3]。 #### 步骤1: 理解自定义数据集类的结构 自定义数据集类必须继承`torch.utils.data.Dataset`,并覆盖三个核心方法: - `__init__(self, ...)`:初始化数据集参数(如数据路径、标签文件或预处理函数)。 - `__len__(self)`:返回数据集样本总数。 - `__getitem__(self, index)`:根据索引加载单个样本,并进行预处理(如归一化或数据增强)。 这些方法确保了数据集的兼容性,使其能被`DataLoader`无缝集成[^2][^3]。 #### 步骤2: 实现自定义数据集类 以下是一个完整示例,使用图像数据集(如蜜蜂蚂蚁数据集[^4])作为场景。假设数据集包含图像文件和对应的标签CSV文件。 ```python import torch from torch.utils.data import Dataset from PIL import Image import pandas as pd import os class CustomImageDataset(Dataset): def __init__(self, img_dir, label_file, transform=None): """ 初始化数据集。 :param img_dir: 图像文件夹路径 :param label_file: 标签文件路径(CSV格式) :param transform: 数据预处理函数(如torchvision.transforms) """ super(CustomImageDataset, self).__init__() self.img_dir = img_dir self.labels = pd.read_csv(label_file) # 读取标签文件 self.transform = transform # 可选的数据增强 def __len__(self): """返回数据集大小。""" return len(self.labels) # 确保标签文件行数等于样本数[^3] def __getitem__(self, idx): """加载并预处理单个样本。""" img_path = os.path.join(self.img_dir, self.labels.iloc[idx, 0]) # 第一列为图像文件名 image = Image.open(img_path).convert('RGB') # 加载图像 label = self.labels.iloc[idx, 1] # 第二列为标签值 if self.transform: image = self.transform(image) # 应用预处理(如Resize, ToTensor) return image, label # 返回样本和标签 ``` **关键点说明**: - **初始化 (`__init__`)**:设置数据路径和预处理函数。标签文件通常为CSV,格式如`image_name, label`[^4]。 - **长度 (`__len__`)**:返回样本数,确保索引范围有效。 - **获取项目 (`__getitem__`)**:加载图像并应用预处理(如使用`torchvision.transforms`),输出张量格式的样本和标签[^2][^3]。 - **预处理函数**:可使用PyTorch的`transforms`模块,例如: ```python from torchvision import transforms transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), # 转换为张量并归一化到[0,1] transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化 ]) ``` #### 步骤3: 使用DataLoader加载数据集 创建自定义数据集实例后,传递给`DataLoader`以实现批处理、洗牌和多线程加载。 ```python from torch.utils.data import DataLoader # 创建数据集实例 dataset = CustomImageDataset( img_dir='path/to/images', label_file='path/to/labels.csv', transform=transform ) # 创建DataLoader dataloader = DataLoader( dataset, batch_size=32, # 批大小 shuffle=True, # 训练时洗牌数据 num_workers=4 # 多线程加载,加速数据读取 ) # 示例:迭代数据集 for images, labels in dataloader: # 在此处进行模型训练 pass ``` **优点**: - **高效加载**:`DataLoader`支持并行数据加载,减少I/O瓶颈。 - **灵活性**:可处理各种数据类型,包括非图像数据(如修改`__getitem__`加载文本或音频)。 - **兼容性**:与PyTorch模型无缝集成[^1][^3]。 #### 常见问题与优化 - **标签文件生成**:如果无现成标签,可编写脚本生成CSV文件(如遍历文件夹并分配标签)[^4]。 - **性能优化**:使用`num_workers`加速加载,或使用内存映射技术处理大文件。 - **错误处理**:在`__getitem__`中添加异常捕获(如文件缺失时跳过样本)。 通过以上步骤,您可以在PyTorch中高效实现自定义数据集加载类,适用于各种机器学习任务。自定义数据集的核心是确保数据格式一致性和预处理标准化,这能提升模型训练效果[^1][^2]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值