dataset与dataloder
1.dataset:
torch.utils.data.Dataset()
Dataset抽象类,所有自定义的dataset需要继承它,
getitem:接受一个索引,返回一个样本
class Dataset(object):
def __init__(self, ):
def __len__(self):
def __getitem__(self, ):
例如:
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index): #根据index索引返回img,label
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
2.dataloder
torch.utils.data.DataLoder(
dataset, dataset类,决定数据从哪读取以及如何读取
batch_size=1, 批大小
shuffle=False, 每个epoch是否乱序
sampler,
batch_sampler,
num_workers=0,是否多进程读取数据
collate_fn,
pin_memory
drop_last=False,当样本数不能被batchsize整除时,是否舍弃最后一批数据
timeout,
worker_init_fn
multiprocessing_context
)
所有训练样本都已经输入到模型中,称为一个epoch
一批样本输入到模型中,称之为一个iteration
批大小,决定一个epoch有多少个iteration
如:样本总数:80,batchsize:8
1 epoch=10 iteration
这两个我自己看的似懂非懂的样子,以后再补充