pytorch TensorDataset与DataLoader类

读取数据 Dataset类

Dataset 是一个读取数据抽象类,所有自定义的数据集类需要继承该类

该类主要实现以下三个功能

①如何获取每一个数据及其label --> 抽象方法__getitem()__设置通过对象[索引]的方式获取每一个样本及其label

②告知一共有多少数据 --> 抽象方法__len__()方法是返回所有样本的数量

使用

1.继承torch.utils.data.Dataset抽象类,定义从自己的数据源中创建数据集。

2.__init__用于初始化数据集,这里可以加载数据文件(定义读取的数据集地址)、初始化变量等。

3.实现__len__(self)方法,返回数据集中的样本数量

4.实现__getitem__(self, idx)方法,通过索引返回一个样本及其label

Dataset类代码实践

任意下载一个数据集(密码5suq),我下载hymenoptera_data数据集,该数据集用于蚂蚁和蜜蜂的。打开我们之前pytorch环境的pycharm,在项目目录中放入数据集。

按照Datase方法的说明,我们先自定义一个数据集的类,该类需要继承torch.utils.data.Dataset,整体框架如下

from torch.utils.data import Dataset
class MyDataset(Dataset):
    # 在构造器中,获取数据集
    def __init__(self):
    # 获取数据集中的一个样本
    def __getitem__(self):
    # 获取数据集的样本数量
    def __len__(self):

__init__构造器:获取数据集所在的路径

我们把文件夹名字ants看作是蚂蚁的labelbees看作是蜜蜂的label,根地址为hymenoptera_data/train

-> 蚂蚁和蜜蜂数据集的地址分别为根地址+对应label,所以构造器方法需要根地址和label两个参数

# 需要引入 operating system 操作系统库
import os

class MyDataset(Dataset):
    # 在构造函数中,获取数据集
    def __init__(self,root_dir,lable_dir):
        self.root_dir = root_dir
        self.lable_dir = lable_dir
        # os.path.join()函数用于路径拼接文件路径
        self.path = os.path.join(self.root_dir,self.lable_dir)
        # os.listdir返回参数路径包含的文件或文件夹的名字的列表。这个列表以字母顺序
        self.images_path = os.listdir(self.path)

__getitem__方法:获取数据集中的一个样本

我们知道数据集中的路径以及数据集中所有样本名字组成的列表后,可以通过拼接路径+样本名字找到每个样本的路径,然后打开该路径获取样本。

# 需要引入Python Image Library 图像处理库
from PIL import Image 

    def __getitem__(self, index):
        # 样本的名字
        image_name = self.images_path[index]
        # 样本的路径
        img_item_path = os.path.join(self.path, image_name)
        # 读取一个样本
        img = Image.open(img_item_path)
        # 样本的label
        label = self.lable_dir
        return img, label

len()方法:返回数据集中样本的数量

    def __len__(self):
        return len(self.images_path)

完整代码

from torch.utils.data import Dataset
from PIL import Image #Python Image Library 图像处理库
import os #operating system 操作系统

class MyDataset(Dataset):
    # 在构造函数中,获取数据集
    def __init__(self,root_dir,lable_dir):
        self.root_dir = root_dir
        self.lable_dir = lable_dir
        
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值