pytorch的dataset和dataloader

本文详细介绍了PyTorch中的数据集(Dataset)和数据加载器(Dataloader)的使用。Dataset是数据集的抽象,用于存储和访问数据,可以根据需求自定义数据加载方式;Dataloader则定义了数据加载的策略,如批处理、洗牌和多线程加载。理解这两个组件对于高效地训练深度学习模型至关重要。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

简单说,dataset是数据集,dataloader是加载数据集的工具

dataset

pytorch提供了多样化的dataset方法。

  1. 如果你的数据集比较小, xxxyyy都可以load到内存里,可以直接使用pytorch的torch.utils.data.TensorDataset:
import torch
from torch.utils import data

# build a toy dataset, with a sequence of x and y using y = sin(x) + noise
T = 1000
x = torch.arange(1, T + 1, dtype=torch.float32)
y = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))

dataset = data.TensorDataset(x, y) # takes in x/y pairs as parameters
  1. 如果数据集存放于硬盘里(文件夹),或者有其他需求,比如在训练时需要从zip或者hdf5中提取数据,那么需要重载pytorch提供的dataset类(下面是pytorch的官方例子):
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

上面这个例子有3个函数:
__init__:参数可以自定义,主要目的是告诉这个类,到什么地方去找数据。上面的这个例子应该是一个图像方面的深度学习应用。了img_dir来指定图片存放的路径,annotations_file指定标注信息(y)的文件名。transform传入图片在fit之前所需进行的变换,比如augmentation(增强)、normalization(归一化)等。target_transform传入标注信息所需进行的变换,这个通常较少用到。注意,上面的这些参数并不是固定的,用户在定义Dataset类的时候,根据自己的需要进行设置。
__len__:返回数据集的长度(数据个数)
__getitem__:传入idx,返回对应的 x/y 对

这三个方法基本上可以满足不同数据集的需求。

dataloader

dataloader定义了如何加载dataset。函数定义原型如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

dataloader的参数虽然比较多,但理解和使用比较简单,可以直接参考官方的说明:
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader 这里重点说几个参数:
shuffle:是否将数据集顺序打乱(洗牌)。通常,如果应用在训练集training_set上,shuffle=True,否则,shuffle=False
num_workers:可以理解为处理数据所用的并行线程数量。实际可以尝试数值0~8。大于这个数字效果可能不理想
pin_memory:这个参数我一开始不太理解,查阅了不少资料。下面“参考”里面的最后一条SO上的回答写的比较详细。这个参数形象的理解,是(用图钉)“钉住内存”。简单的说,就是使用固定地址的内存,避免在CPU-CPU间拷贝来拷贝去,影响性能。此外,pin_memory还可以允许对CPU内存和GPU内存的操作变为“异步”:对CPU内存的处理和延时等不会影响GPU内存的处理,二者可以并行。

参考

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
https://deeplizard.com/learn/video/kWVgvsejXsE#:~:text=Different%20num_workers%20Values%3A%20Results%20%20%20%20run,%20%204%20%2014%20more%20rows%20
https://stackoverflow.com/questions/55563376/pytorch-how-does-pin-memory-work-in-dataloader

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值