定义自己的dataloader
在使用自己数据集训练网络时,往往需要定义自己的dataloader。
1 定义datalaoder
一般将dataloader封装为一个类,这个类继承自 torch.utils.data.dataset
from torch.utils.data import dataset
class LoadData(Dataset): # 注意父类的名称,不能写dataset
pass
需要注意的是dataset是模块名,而Dataset是类名,在python中模块名和类名是完全独立的命名空间,因此这里的父类需要写成 dataset.Dataset。
在我们定义的LoadData中,至少需要有三个方法:
- __init__方法,主要用来定义数据的预处理
- __getitem__方法,返回数据的item和label
- __len__方法,返回数据个数
整体大致架构:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class LoadData(dDataset):
def __init__(self):
pass
def __getitem__(self,index):
pass
def __len__(self):
pass
dataset = Loaddata()
train_loader = DataLoader(dataset = dataset,