Day 38 Dataset和Dataloader类

今日任务:
  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. mnist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

数据集的规模较大时,显存通常无法一次性存储所有数据,因此需要使用分批训练的方法。这个时候就可以使用在pytorch中两个与数据加载有关的核心类,Dataset和Dataloader。

(1)Dataset:负责封装数据及其标签,提供按索引访问样本的能力,

(2)Dataloader:提供批量(batch)、打乱(shuffle)、多进程加载等功能,数据管道的构建

Dataset类

完成”是什么“——定义单个样本如何获取和总数量

Dataset类的核心思想:负责定义数据的来源以及如何从数据集中获取单个样本。不负责一次性加载所有数据,而是提供一种标准化的接口,让数据加载器Dataloader能够获取数据。

在Dataset中有两个核心方法__getitem__ __len__ ,任何继承torch.utils.data.Datasets(自定义) 的子类都需要实现这两个方法。这是为了统一接口(使用模式相同),以确保被Dataloader兼容,否则pytorch无法识别。

这两个方法与之前的__call__方法类似,也是魔术方法:不是被直接调用的,而是在特定语法或操作触发时由 Python 解释器自动调用,用于定义类的行为。 比如,在这里,__getitem__和__len__就实现了内置类型(list)的功能

核心方法

__getitem__ 方法

让实例对象支持索引操作,当使用“[]”访问对象元素时,会自动调用该方法,最终实现与内置的列表一样使用索引获取元素的效果。

class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    def __getitem__(self, idx):
        return self.data[idx]
# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30

__len__ 方法

返回实例对象中元素的数量,当len()函数作用于对象时,自动调用,最终实现像普通列表一样被len()函数调用获取长度

class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    def __len__(self):
        return len(self.data)
# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5

自定义Dataset

主要有三个部分:

  • 初始化:加载文件列表、标签、设置 transforms 等
  • __len__:返回数据集总样本数
  • __getitem__:根据 idx 加载并返回一个样本
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(annotations_file) # CSV文件路径,包含图片文件名和标签
        self.img_dir = img_dir # 包含所有图片的目录路径
        self.transform = transform # 变换函数,应用于样本(如图像预处理、数据增强)。

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx): # 获取指定索引的样本
        # 获取指定索引的图像和标签
        img, target = self.data[idx], self.targets[idx]
        
        # 应用图像预处理(如ToTensor、Normalize)
        if self.transform is not None: # 如果有预处理操作
            img = self.transform(img) # 转换图像格式
        # 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化
            
        return img, target  # 返回处理后的图像和标签,是一个元组 (img, target)
    # python中逗号表示元组,所以返回的 img, target 是一个元组

数据预处理时的写法类似前面的pipeline管道,这里使用的是torchvision库中transforms模块,它提供了常用的图像预处理操作

注意:先归一化再标准化。因为标准化使用的统计量基于归一化后的数据计算得到的。如果先标准化再归一化,可能会出现极端数值,出现训练稳定性差等问题。

#调用
# 定义变换管道,先归一化,再标准化
transform = transforms.Compose([
    transforms.ToTensor(), # 将PIL图像或numpy.ndarray转换为Tensor,并自动缩放到[0.0, 1.0]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet统计信息
])
# 实例化dataset
dataset = CustomImageDataset(
    annotations_file='labels.csv',
    img_dir='images/',
    transform=transform
)

sample_image, sample_label = dataset[0] # 获取第一个样本(每次返回元组)

Dataloader类

完成”怎么用“——定义如何批量、并行、高效地提供这些样本

Dataloader的主要功能包括:

  • 批量加载(Batching):将多个样本组合成一个 batch,batch_size。
  • 打乱顺序(Shuffling):在每个 epoch 开始时随机打乱数据顺序(训练时常用),shuffle。
  • 多进程并行加载(Multiprocessing):利用多个子进程(num_workers)预加载数据,避免 I/O 成为训练瓶颈。
  • 自动转换为 Tensor:配合 Dataset 返回的数据,自动将其转为 torch.Tensor。
  • 自定义采样(Sampler):如类别平衡采样、分布式采样等。
# 实例化,创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000 # 每个批次1000张图片
    # shuffle=False # 测试时不需要打乱数据
)

Mnist手写数据集

MNIST数据集主要用于手写数字识别任务,它的训练集包含60000张图像,测试集包含10000张图像。图像的格式是灰度图,尺寸为28*28像素。总共有0到9这个10个类别。

这里直接使用的是torchvision库进行加载数据和处理,这里是直接调用的,实际上内部的逻辑与上面的应该是类似的。

from torch.utils.data import Dataset,DataLoader
from torchvision import datasets,transforms
import torch
import matplotlib.pyplot as plt

# 设置随机数种子,保证重现性
torch.manual_seed(42)

# 1-预处理,创建变换器
transform = transforms.Compose([
    transforms.ToTensor(), # 归一化
    transforms.Normalize(mean=(0.1307,), std=(0.3081,)) # 标准化
])

# 2-dataset
# 训练集的读取与预处理
train_dataset = datasets.MNIST(
    root='./data', # 数据存储的根目录,如果没有则会自动下载
    train=True,
    download=True, # 如果数据集不存在则下载
    transform=transform
)
# 测试集的读取与预处理
test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

# 获取,随机选择一张图片,可以重复运行,每次都会随机选择
idx = torch.randint(low=0,high=len(train_dataset),size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image,label = train_dataset[idx] # 获取图片和标签

# 可视化原始图像(需要反归一化)
def imshow(img):
    img = img * 0.3081 + 0.1307  # 反标准化
    npimg = img.numpy() # (通道数, 高度, 宽度) 
    # imshow希望输入:(高度, 宽度, 通道数)(彩色),黑白取(高度, 宽度)
    #从单通道的三维数据 (1,28,28) 中提取出二维的像素数据 (28,28),cmap指定颜色映射为灰度
    plt.imshow(npimg[0], cmap='gray') # 显示灰度图像,索引0是因为MNIST是单通道图像
    plt.show()

print(f"Label: {label}")
imshow(image)

# 3-dataloader
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True, # 随机打乱数据
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=1000, # 每个批次1000张图片
    shuffle=False # 测试时不需要打乱数据
)
# 可视化原始图像(需要反归一化)
def imshow(img):
    img = img * 0.3081 + 0.1307  # 反标准化
    npimg = img.numpy()
    plt.imshow(npimg[0], cmap='gray') # 显示灰度图像,索引0是因为MNIST是单通道图像
    plt.show()

print(f"Label: {label}") # 有随机种子,标签结果为6
imshow(image)

总结

如果将Dataset类比作厨师的话,那么DataLoader就是服务员。厨师的工作是制作菜品(包括切菜、烹饪、调味等),对应预处理步骤;而服务员主要是将菜品组合并上桌,对应batchmultiprocessing等步骤。

作业——cifar-10图片获取

下面是摘自官网的介绍:

CIFAR-10 数据集由 10 类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。

注意:直接使用download=True有点慢,要么直接在官网下载后导入,更快的话是复制下载链接到迅雷下载

根据上面的知识,尝试获取其中一张图片(预处理仅采用了归一化操作),并且不包含Dataloader操作,代码如下:

from torch.utils.data import Dataset,DataLoader
from torchvision import datasets,transforms
import torch
import matplotlib.pyplot as plt
import numpy as np

# 设置随机种子,确保结果可复现
torch.manual_seed(42)

# 预处理transform
transform = transforms.Compose([
    transforms.ToTensor(), # 归一化
    #transforms.Normalize(mean=(0.1307,), std=(0.3081,)) # 标准化
])

# dataset创建
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataloader = datasets.CIFAR10(
    root='./data',
    train=False,
    transform=transform
)

# 获取标签
sample_idx = torch.randint(0,len(train_dataset),size=(1,)).item()
image,label = train_dataset[sample_idx]

# 可视化
print('Label:{}'.format(label))
#img = img * 0.3081 + 0.1307  # 反标准化
#npimg = img.numpy() # (通道数, 高度, 宽度) 
plt.imshow(np.transpose(image.numpy(),(1,2,0))) # (1,2,0)为轴顺序参数
plt.show()

设置了随机种子,结果固定为标签6的图像。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值