- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- mnist手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
当数据集的规模较大时,显存通常无法一次性存储所有数据,因此需要使用分批训练的方法。这个时候就可以使用在pytorch中两个与数据加载有关的核心类,Dataset和Dataloader。
(1)Dataset:负责封装数据及其标签,提供按索引访问样本的能力,
(2)Dataloader:提供批量(batch)、打乱(shuffle)、多进程加载等功能,数据管道的构建
Dataset类
完成”是什么“——定义单个样本如何获取和总数量
Dataset类的核心思想:负责定义数据的来源以及如何从数据集中获取单个样本。不负责一次性加载所有数据,而是提供一种标准化的接口,让数据加载器Dataloader能够获取数据。
在Dataset中有两个核心方法:__getitem__ 和 __len__ ,任何继承torch.utils.data.Datasets(自定义) 的子类都需要实现这两个方法。这是为了统一接口(使用模式相同),以确保被Dataloader兼容,否则pytorch无法识别。
这两个方法与之前的__call__方法类似,也是魔术方法:不是被直接调用的,而是在特定语法或操作触发时由 Python 解释器自动调用,用于定义类的行为。 比如,在这里,__getitem__和__len__就实现了内置类型(list)的功能。
核心方法
__getitem__ 方法
让实例对象支持索引操作,当使用“[]”访问对象元素时,会自动调用该方法,最终实现与内置的列表一样使用索引获取元素的效果。
class MyList:
def __init__(self):
self.data = [10, 20, 30, 40, 50]
def __getitem__(self, idx):
return self.data[idx]
# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2]) # 输出:30
__len__ 方法
返回实例对象中元素的数量,当len()函数作用于对象时,自动调用,最终实现像普通列表一样被len()函数调用获取长度。
class MyList:
def __init__(self):
self.data = [10, 20, 30, 40, 50]
def __len__(self):
return len(self.data)
# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj)) # 输出:5
自定义Dataset
主要有三个部分:
- 初始化:加载文件列表、标签、设置 transforms 等
- __len__:返回数据集总样本数
- __getitem__:根据 idx 加载并返回一个样本
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None):
self.img_labels = pd.read_csv(annotations_file) # CSV文件路径,包含图片文件名和标签
self.img_dir = img_dir # 包含所有图片的目录路径
self.transform = transform # 变换函数,应用于样本(如图像预处理、数据增强)。
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx): # 获取指定索引的样本
# 获取指定索引的图像和标签
img, target = self.data[idx], self.targets[idx]
# 应用图像预处理(如ToTensor、Normalize)
if self.transform is not None: # 如果有预处理操作
img = self.transform(img) # 转换图像格式
# 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化
return img, target # 返回处理后的图像和标签,是一个元组 (img, target)
# python中逗号表示元组,所以返回的 img, target 是一个元组
数据预处理时的写法类似前面的pipeline管道,这里使用的是torchvision库中transforms模块,它提供了常用的图像预处理操作。
注意:先归一化再标准化。因为标准化使用的统计量是基于归一化后的数据计算得到的。如果先标准化再归一化,可能会出现极端数值,出现训练稳定性差等问题。
#调用
# 定义变换管道,先归一化,再标准化
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL图像或numpy.ndarray转换为Tensor,并自动缩放到[0.0, 1.0]
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet统计信息
])
# 实例化dataset
dataset = CustomImageDataset(
annotations_file='labels.csv',
img_dir='images/',
transform=transform
)
sample_image, sample_label = dataset[0] # 获取第一个样本(每次返回元组)
Dataloader类
完成”怎么用“——定义如何批量、并行、高效地提供这些样本
Dataloader的主要功能包括:
- 批量加载(Batching):将多个样本组合成一个 batch,batch_size。
- 打乱顺序(Shuffling):在每个 epoch 开始时随机打乱数据顺序(训练时常用),shuffle。
- 多进程并行加载(Multiprocessing):利用多个子进程(num_workers)预加载数据,避免 I/O 成为训练瓶颈。
- 自动转换为 Tensor:配合 Dataset 返回的数据,自动将其转为 torch.Tensor。
- 自定义采样(Sampler):如类别平衡采样、分布式采样等。
# 实例化,创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
shuffle=True # 随机打乱数据
)
test_loader = DataLoader(
test_dataset,
batch_size=1000 # 每个批次1000张图片
# shuffle=False # 测试时不需要打乱数据
)
Mnist手写数据集
MNIST数据集主要用于手写数字识别任务,它的训练集包含60000张图像,测试集包含10000张图像。图像的格式是灰度图,尺寸为28*28像素。总共有0到9这个10个类别。
这里直接使用的是torchvision库进行加载数据和处理,这里是直接调用的,实际上内部的逻辑与上面的应该是类似的。
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets,transforms
import torch
import matplotlib.pyplot as plt
# 设置随机数种子,保证重现性
torch.manual_seed(42)
# 1-预处理,创建变换器
transform = transforms.Compose([
transforms.ToTensor(), # 归一化
transforms.Normalize(mean=(0.1307,), std=(0.3081,)) # 标准化
])
# 2-dataset
# 训练集的读取与预处理
train_dataset = datasets.MNIST(
root='./data', # 数据存储的根目录,如果没有则会自动下载
train=True,
download=True, # 如果数据集不存在则下载
transform=transform
)
# 测试集的读取与预处理
test_dataset = datasets.MNIST(
root='./data',
train=False,
transform=transform
)
# 获取,随机选择一张图片,可以重复运行,每次都会随机选择
idx = torch.randint(low=0,high=len(train_dataset),size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image,label = train_dataset[idx] # 获取图片和标签
# 可视化原始图像(需要反归一化)
def imshow(img):
img = img * 0.3081 + 0.1307 # 反标准化
npimg = img.numpy() # (通道数, 高度, 宽度)
# imshow希望输入:(高度, 宽度, 通道数)(彩色),黑白取(高度, 宽度)
#从单通道的三维数据 (1,28,28) 中提取出二维的像素数据 (28,28),cmap指定颜色映射为灰度
plt.imshow(npimg[0], cmap='gray') # 显示灰度图像,索引0是因为MNIST是单通道图像
plt.show()
print(f"Label: {label}")
imshow(image)
# 3-dataloader
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
shuffle=True, # 随机打乱数据
)
test_dataloader = DataLoader(
dataset=test_dataset,
batch_size=1000, # 每个批次1000张图片
shuffle=False # 测试时不需要打乱数据
)
# 可视化原始图像(需要反归一化)
def imshow(img):
img = img * 0.3081 + 0.1307 # 反标准化
npimg = img.numpy()
plt.imshow(npimg[0], cmap='gray') # 显示灰度图像,索引0是因为MNIST是单通道图像
plt.show()
print(f"Label: {label}") # 有随机种子,标签结果为6
imshow(image)

总结
如果将Dataset类比作厨师的话,那么DataLoader就是服务员。厨师的工作是制作菜品(包括切菜、烹饪、调味等),对应预处理步骤;而服务员主要是将菜品组合并上桌,对应batch、multiprocessing等步骤。


作业——cifar-10图片获取
下面是摘自官网的介绍:
CIFAR-10 数据集由 10 类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。

注意:直接使用download=True有点慢,要么直接在官网下载后导入,更快的话是复制下载链接到迅雷下载。
根据上面的知识,尝试获取其中一张图片(预处理仅采用了归一化操作),并且不包含Dataloader操作,代码如下:
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets,transforms
import torch
import matplotlib.pyplot as plt
import numpy as np
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 预处理transform
transform = transforms.Compose([
transforms.ToTensor(), # 归一化
#transforms.Normalize(mean=(0.1307,), std=(0.3081,)) # 标准化
])
# dataset创建
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataloader = datasets.CIFAR10(
root='./data',
train=False,
transform=transform
)
# 获取标签
sample_idx = torch.randint(0,len(train_dataset),size=(1,)).item()
image,label = train_dataset[sample_idx]
# 可视化
print('Label:{}'.format(label))
#img = img * 0.3081 + 0.1307 # 反标准化
#npimg = img.numpy() # (通道数, 高度, 宽度)
plt.imshow(np.transpose(image.numpy(),(1,2,0))) # (1,2,0)为轴顺序参数
plt.show()
设置了随机种子,结果固定为标签6的图像。

1213

被折叠的 条评论
为什么被折叠?



