pytorch TensorDataset与DataLoader类

读取数据 Dataset类

Dataset 是一个读取数据抽象类,所有自定义的数据集类需要继承该类

该类主要实现以下三个功能

①如何获取每一个数据及其label --> 抽象方法__getitem()__设置通过对象[索引]的方式获取每一个样本及其label

②告知一共有多少数据 --> 抽象方法__len__()方法是返回所有样本的数量

使用

1.继承torch.utils.data.Dataset抽象类,定义从自己的数据源中创建数据集。

2.__init__用于初始化数据集,这里可以加载数据文件(定义读取的数据集地址)、初始化变量等。

3.实现__len__(self)方法,返回数据集中的样本数量

4.实现__getitem__(self, idx)方法,通过索引返回一个样本及其label

Dataset类代码实践

任意下载一个数据集(密码5suq),我下载hymenoptera_data数据集,该数据集用于蚂蚁和蜜蜂的。打开我们之前pytorch环境的pycharm,在项目目录中放入数据集。

按照Datase方法的说明,我们先自定义一个数据集的类,该类需要继承torch.utils.data.Dataset,整体框架如下

from torch.utils.data import Dataset
class MyDataset(Dataset):
    # 在构造器中,获取数据集
    def __init__(self):
    # 获取数据集中的一个样本
    def __getitem__(self):
    # 获取数据集的样本数量
    def __len__(self):

__init__构造器:获取数据集所在的路径

我们把文件夹名字ants看作是蚂蚁的labelbees看作是蜜蜂的label,根地址为hymenoptera_data/train

-> 蚂蚁和蜜蜂数据集的地址分别为根地址+对应label,所以构造器方法需要根地址和label两个参数

# 需要引入 operating system 操作系统库
import os

class MyDataset(Dataset):
    # 在构造函数中,获取数据集
    def __init__(self,root_dir,lable_dir):
        self.root_dir = root_dir
        self.lable_dir = lable_dir
        # os.path.join()函数用于路径拼接文件路径
        self.path = os.path.join(self.root_dir,self.lable_dir)
        # os.listdir返回参数路径包含的文件或文件夹的名字的列表。这个列表以字母顺序
        self.images_path = os.listdir(self.path)

__getitem__方法:获取数据集中的一个样本

我们知道数据集中的路径以及数据集中所有样本名字组成的列表后,可以通过拼接路径+样本名字找到每个样本的路径,然后打开该路径获取样本。

# 需要引入Python Image Library 图像处理库
from PIL import Image 

    def __getitem__(self, index):
        # 样本的名字
        image_name = self.images_path[index]
        # 样本的路径
        img_item_path = os.path.join(self.path, image_name)
        # 读取一个样本
        img = Image.open(img_item_path)
        # 样本的label
        label = self.lable_dir
        return img, label

len()方法:返回数据集中样本的数量

    def __len__(self):
        return len(self.images_path)

完整代码

from torch.utils.data import Dataset
from PIL import Image #Python Image Library 图像处理库
import os #operating system 操作系统

class MyDataset(Dataset):
    # 在构造函数中,获取数据集
    def __init__(self,root_dir,lable_dir):
        self.root_dir = root_dir
        self.lable_dir = lable_dir
        # os.path.join()函数用于路径拼接文件路径
        self.path = os.path.join(self.root_dir,self.lable_dir)
        # os.listdir返回参数包含的文件或文件夹的名字的列表。这个列表以字母顺序
        self.images_path = os.listdir(self.path)
    def __getitem__(self, index):
        # 每一个图片的名字
        image_name = self.images_path[index]
        # 每一个图片的路径
        img_item_path = os.path.join(self.path, image_name)
        # 读取一个样本
        img = Image.open(img_item_path)
        label = self.lable_dir
        return img, label
    def __len__(self):
        return len(self.images_path)

root_dir="hymenoptera_data/train"
ants_label_dir ="ants"
bees_label_dir ="bees"
ants_dataset = MyDataset(root_dir,ants_label_dir)
bees_dataset = MyDataset(root_dir,bees_label_dir)
# 打开蚂蚁数据集中的第一张
first_img,first_label = ants_dataset[0] 
first_img.show()

TensorDataset类 - 打包数据

data.TensorDataset(*tensors) 是 PyTorch 中用于将多个张量(Tensor)打包成一个数据集(Dataset)的工具。<它提供了一种方便的方法将数据封装为适合DataLoade 处理的格式。

  • 数据封装:将数据的特征和标签封装到一个张量数据集中,每个元素都是一个样本。
  • 简化索引:允许通过索引直接访问数据集中的任何点,简化了数据的访问和处理。

说明

  1. *tensors:多个张量,所有张量的第一个维度(样本数)必须相同。
  2. 通过TensorDataset返回值<[索引],可以访问特定索引的样本,返回元组(比如(x[i],y[i]))
  3. TensorDatasetDataset 的子类,TensorDataset继承自 PyTorch 的抽象基类 torch.utils.data.Dataset,所以实例可以传入给DataLoader类。

使用案例

import torch
from torch.utils.data import DataLoader, TensorDataset

# 假设我们有一些输入数据 X 和标签 Y
X = torch.randn(100, 3)  # 100个样本,每个样本3个特征
Y = torch.randn(100, 1)  # 100个样本的标签

# 创建 TensorDataset
dataset = TensorDataset(X, Y)
# 读取第一个样本
x, y = dataset[0]  # tensor([-1.1590, -0.3770,  0.3824]) tensor([-1.7849])
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# 迭代 DataLoader
for i, (x, y) in enumerate(dataloader):
    print(f"Batch {i}:")
    print(f"Features: {x.size()}, Labels: {y.size()}")
    # 在这里,x 和 y 将是批次的特征和标签
# Batch 0:
# Features: torch.Size([10, 3]), Labels: torch.Size([10, 1])
# Batch 1:
# Features: torch.Size([10, 3]), Labels: torch.Size([10, 1])
# Batch 2:
# Features: torch.Size([10, 3]), Labels: torch.Size([10, 1])
# .....

torch.utils.data.DataLoader类 - 加载数据

函数作用: PyTorch 中用于高效加载数据的核心工具,支持批量处理、随机打乱、多进程加速等功能,适用于训练神经网络时的数据流水线构建。

参数说明

  1. dataset (必需): 用于加载数据的数据集,通常是torch.utils.data.Dataset的子类实例。
  2. batch_size (可选): 每个批次的数据样本数。默认值为1。
  3. shuffle (可选): 是否在每个周期开始时打乱数据,默认为False。训练时建议启用,验证/测试时禁用
  4. sampler (可选): 定义从数据集中抽取样本的策略。如果指定,则忽略shuffle参数。
  5. batch_sampler (可选): 与sampler类似,但一次返回一个批次的索引。不能与batch_size、shuffle和sampler同时使用。
  6. num_workers (可选): 用于数据加载的子进程数量。默认为0,意味着数据将在主进程中加载。
  7. collate_fn (可选): 如何将多个数据样本整合成一个批次,常用于处理变长数据(如文本)。通常不需要指定。
  8. drop_last (可选): 如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次,默认为False。

返回值:可迭代的数据装载器(DataLoader)对象

from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,          # 数据集对象(必须继承自 torch.utils.data.Dataset)
    batch_size=1,     # 每个批次的样本数
    shuffle=False,    # 是否打乱数据顺序(通常训练时设为 True,验证/测试时设为 False)
    num_workers=0,    # 加载数据的子进程数(建议设为 CPU 核心数)
    drop_last=False,  # 是否丢弃最后一个不足 batch_size 的批次
    collate_fn=None   # 自定义批次数据的合并逻辑(例如填充序列)
)

数据加载流程

  1. DataLoader通过调用 dataset__len__() 获取总样本数。
  2. 根据 batch_sizeshuffle 参数生成批次索引。
  3. 对每个批次索引,调用 dataset__getitem__() 方法获取数据。
  4. 将多个样本的元组合并为一个批次的张量(例如,将多个 (X[i], Y[i]) 合并为 (batch_X, batch_Y)) => 每次返回一个批次的数据

案例

*在函数调用中的解包作用,可以解包一个可迭代对象(如列表、元组),将其元素作为位置参数传递。*data_arrays这里表示将传入的元组(features, labels)解包为features、labels两个参数。

def load_array(data_arrays, batch_size, is_train=True):  #@save
    """构造一个PyTorch数据迭代器"""
    # 得到pytorch的dataseat
    dataset = data.TensorDataset(*data_arrays)
    # 每次随机挑选batch_size个样本出来
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
# (features, labels) 组合成一个list传递
data_iter = load_array((features, labels), batch_size) 
print(data_iter) #<torch.utils.data.dataloader.DataLoader object at >
print(next(iter(data_iter))) 

获取可迭代的数据装载器(DataLoader)对象里的数据

  1. 使用for循环遍历
    训练/评估时逐批次处理数据。
for X, y in train_iter:
    # X: 输入数据, y: 标签 
    print("Batch shape:", X.shape, y.shape)
    break  # 仅演示第一个批次
  1. 手动调用iter()和next()
    手动获取单个批次(如调试或抽样检查)。
# 获取迭代器对象
data_iter = iter(train_iter)

# 手动获取一个批次
try:
    X, y = next(data_iter)
    print("Batch shape:", X.shape, y.shape)
except StopIteration:
    print("数据遍历完毕")
  1. 直接访问原始数据集
    获取未分批的原始数据
# 获取原始数据集(非批次数据)
dataset = train_iter.dataset

# 访问第0个样本
X, y = dataset[0]
print("Sample shape:", X.shape, y)
  1. 从批次中拆解,提取单个样本
    可视化或调试单个样本
for X, y in train_iter:
    # 从第一个批次提取第0个样本
    single_X = X[0]  # 形状 [1, 28, 28]
    single_y = y[0]  # 标量值
    print("Sample shape:", single_X.shape, "Label:", single_y)
    break
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值