1.简介
首先我们要知道DataSet模块和DataLoader模块,这两个模块帮助你高效地加载和处理数据
DataSet模块和DataLoader模块都在torch.utils.data下面
2.DataSet介绍
DataSet是pytorch下的数据集抽象类。可以继承它下面的__len__和__getitem__这两个方法,来定义数据集
3. 自定义数据集详解
自定义 Dataset
的关键步骤:
- 继承
torch.utils.data.Dataset
:你的自定义数据集类必须继承这个类。 - 实现
__init__
:这个方法用来初始化数据集,通常用于数据路径的加载、文件名的存储或预处理方法的定义。 - 实现
__len__
:这个方法返回数据集中的样本数量。DataLoader
会使用这个方法来获取数据集的大小。 - 实现
__getitem__
:这个方法返回一个样本。DataLoader
会调用此方法来按批次加载数据。在这个方法中,你可以进行必要的数据预处理、转换或增强。
import os
from torch.utils.data import Dataset
import cv2
from PIL import Image
class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
自定义数据集初始化
:param root_dir: 数据集的根目录
:param transform: 数据增强或预处理操作
"""
self.root_dir = root_dir
self.transform = transform
self.samples = []
# 遍历根目录下的所有子文件夹
for class_idx, class_name in enumerate(os.listdir(root_dir)):
class_path = os.path.join(root_dir, class_name)
if os.path.isdir(class_path):
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
self.samples.append((img_path, class_idx))
def __len__(self):
# 返回数据集中的样本数量
return len(self.samples)
def __getitem__(self, idx):
# 返回单个样本和标签
img_path, label = self.samples[idx]
image = cv2.imread(img_path)
if image is None:
raise ValueError(f"Failed to load image at path: {img_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 灰度化处理
image = Image.fromarray(image) # 将 np 数组转换为 PIL 图片格式
if self.transform:
image = self.transform(image)
return image, label
在上面的例子中:
__init__
:负责初始化数据集,读取文件路径并存储每个图片路径与对应标签的元组。__len__
:返回数据集的大小,即samples
中样本的数量。__getitem__
:加载指定索引的图片数据,并进行必要的处理(如灰度化和转换为 PIL 图像)。
4.DataLoader介绍
DataLoader可以批量加载数据,可以控制每批次多少样本,可以打乱数据顺序,同时还支持多线程并行加载数据,并且会调用DataSet对象的__getitem__方法,来获取数据。
DataLoader
的常用参数:
dataset
:传入自定义的Dataset
实例。batch_size
:设置每个批次加载的样本数量。shuffle
:是否在每个 epoch 开始前打乱数据集。num_workers
:设置加载数据的子进程数量,通常为 CPU 核心数。
from torch.utils.data import DataLoader
from torchvision import transforms
# 数据预处理操作
transform = transforms.Compose([
transforms.Resize((28, 28)), # 调整图片大小
transforms.ToTensor(), # 转换为张量
])
# 创建自定义数据集
train_set = CustomDataset(root_dir="data_set/training_set", transform=transform)
# 使用 DataLoader 加载数据
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
# 迭代数据
for images, labels in train_loader:
print(images.shape, labels)