1、重写Dataset类:
#源码
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
#这个函数就是根据索引,迭代的读取路径和标签。因此我们需要有一个路径和标签的 ‘容器’供我们读
def __getitem__(self, index):
raise NotImplementedError
#返回数据的长度
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
想制作自己的图像数据集供DataLoader拿取,首先要重写Datasets类,主要用来完成从哪里读取数据和标签的功能。主要是__getitem()__(返回数据集和标签)和__len__(返回数据的长度)这两个方法。
完成Datasets类的这两个主要功能后,训练的时候可以把数据集传送给DataLoader就可以获取自己想要的batch数据。
例1:通过包含 数据路径 和 标签 的TXT文件读取
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
#集成Dataset类
clas