首先需要导入库文件
from torch.utils.data import Dataset, DataLoader
1、读取数据和标签
将数据和标签分别进行读入。
2、转换成 tensor
pytorch处理tensor,对于其他类型数据需要进行转换,如将numpy.ndarray转换为tensor类型可用函数:torch.from_numpy函数
3、定义init()和getitem() 和len() 函数
定义类CreateDataset,需要继承Dataset类。
class CreateDataset(Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
4、初始化类
实例化类CreateDataset
train_data = CreateDataset()
5、将类实例传给 DataLoader
trainloader = DataLoader(train_data, batch_size = 10, shuffle = False)
参考链接:
https://www.pytorchtutorial.com/pytorch-note4-input-data-pipeline/
自定义数据集例程