文章目录
读取数据 Dataset类
Dataset 是一个读取数据抽象类,所有自定义的数据集类需要继承该类。
该类主要实现以下三个功能
①如何获取每一个数据及其label --> 抽象方法__getitem()__设置通过对象[索引]的方式获取每一个样本及其label
②告知一共有多少数据 --> 抽象方法__len__()方法是返回所有样本的数量
使用
1.继承torch.utils.data.Dataset抽象类,定义从自己的数据源中创建数据集。
2.__init__用于初始化数据集,这里可以加载数据文件(定义读取的数据集地址)、初始化变量等。
3.实现__len__(self)方法,返回数据集中的样本数量
4.实现__getitem__(self, idx)方法,通过索引返回一个样本及其label
Dataset类代码实践
任意下载一个数据集(密码5suq),我下载hymenoptera_data数据集,该数据集用于蚂蚁和蜜蜂的。打开我们之前pytorch环境的pycharm,在项目目录中放入数据集。

按照Datase方法的说明,我们先自定义一个数据集的类,该类需要继承torch.utils.data.Dataset,整体框架如下
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 在构造器中,获取数据集
def __init__(self):
# 获取数据集中的一个样本
def __getitem__(self):
# 获取数据集的样本数量
def __len__(self):
__init__构造器:获取数据集所在的路径
我们把文件夹名字ants看作是蚂蚁的label,bees看作是蜜蜂的label,根地址为hymenoptera_data/train
-> 蚂蚁和蜜蜂数据集的地址分别为根地址+对应label,所以构造器方法需要根地址和label两个参数
# 需要引入 operating system 操作系统库
import os
class MyDataset(Dataset):
# 在构造函数中,获取数据集
def __init__(self,root_dir,lable_dir):
self.root_dir = root_dir
self.lable_dir = lable_dir
# os.path.join()函数用于路径拼接文件路径
self.path = os.path.join(self.root_dir,self.lable_dir)
# os.listdir返回参数路径包含的文件或文件夹的名字的列表。这个列表以字母顺序
self.images_path = os.listdir(self.path)
__getitem__方法:获取数据集中的一个样本
我们知道数据集中的路径以及数据集中所有样本名字组成的列表后,可以通过拼接路径+样本名字找到每个样本的路径,然后打开该路径获取样本。
# 需要引入Python Image Library 图像处理库
from PIL import Image
def __getitem__(self, index):
# 样本的名字
image_name = self.images_path[index]
# 样本的路径
img_item_path = os.path.join(self.path, image_name)
# 读取一个样本
img = Image.open(img_item_path)
# 样本的label
label = self.lable_dir
return img, label
len()方法:返回数据集中样本的数量
def __len__(self):
return len(self.images_path)
完整代码
from torch.utils.data import Dataset
from PIL import Image #Python Image Library 图像处理库
import os #operating system 操作系统
class MyDataset(Dataset):
# 在构造函数中,获取数据集
def __init__(self,root_dir,lable_dir):
self.root_dir = root_dir
self.lable_dir = lable_dir

最低0.47元/天 解锁文章
1557

被折叠的 条评论
为什么被折叠?



