总结如何自定义数据集

1.简介

首先我们要知道DataSet模块和DataLoader模块,这两个模块帮助你高效地加载和处理数据

DataSet模块和DataLoader模块都在torch.utils.data下面

2.DataSet介绍

DataSet是pytorch下的数据集抽象类。可以继承它下面的__len__和__getitem__这两个方法,来定义数据集

3. 自定义数据集详解

自定义 Dataset 的关键步骤
  1. 继承 torch.utils.data.Dataset:你的自定义数据集类必须继承这个类。
  2. 实现 __init__:这个方法用来初始化数据集,通常用于数据路径的加载、文件名的存储或预处理方法的定义。
  3. 实现 __len__:这个方法返回数据集中的样本数量。DataLoader 会使用这个方法来获取数据集的大小。
  4. 实现 __getitem__:这个方法返回一个样本。DataLoader 会调用此方法来按批次加载数据。在这个方法中,你可以进行必要的数据预处理、转换或增强。
import os
from torch.utils.data import Dataset
import cv2
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        自定义数据集初始化
        :param root_dir: 数据集的根目录
        :param transform: 数据增强或预处理操作
        """
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # 遍历根目录下的所有子文件夹
        for class_idx, class_name in enumerate(os.listdir(root_dir)):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                for img_name in os.listdir(class_path):
                    img_path = os.path.join(class_path, img_name)
                    self.samples.append((img_path, class_idx))

    def __len__(self):
        # 返回数据集中的样本数量
        return len(self.samples)

    def __getitem__(self, idx):
        # 返回单个样本和标签
        img_path, label = self.samples[idx]
        image = cv2.imread(img_path)
        if image is None:
            raise ValueError(f"Failed to load image at path: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # 灰度化处理
        image = Image.fromarray(image)  # 将 np 数组转换为 PIL 图片格式
        if self.transform:
            image = self.transform(image)
        return image, label

在上面的例子中:

  • __init__:负责初始化数据集,读取文件路径并存储每个图片路径与对应标签的元组。
  • __len__:返回数据集的大小,即 samples 中样本的数量。
  • __getitem__:加载指定索引的图片数据,并进行必要的处理(如灰度化和转换为 PIL 图像)。

4.DataLoader介绍

DataLoader可以批量加载数据,可以控制每批次多少样本,可以打乱数据顺序,同时还支持多线程并行加载数据,并且会调用DataSet对象的__getitem__方法,来获取数据。

DataLoader 的常用参数
  • dataset:传入自定义的 Dataset 实例。
  • batch_size:设置每个批次加载的样本数量。
  • shuffle:是否在每个 epoch 开始前打乱数据集。
  • num_workers:设置加载数据的子进程数量,通常为 CPU 核心数。
from torch.utils.data import DataLoader
from torchvision import transforms

# 数据预处理操作
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # 调整图片大小
    transforms.ToTensor(),  # 转换为张量
])

# 创建自定义数据集
train_set = CustomDataset(root_dir="data_set/training_set", transform=transform)

# 使用 DataLoader 加载数据
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)

# 迭代数据
for images, labels in train_loader:
    print(images.shape, labels)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

serenity宁静

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值