一、数据准备
总结:
RandomDataset :用于验证 (val)
BatchDataset:用于训练 (train)
BalancedBatchSampler:决定如何采样样本,不是简单的在Dataloader中设置一个batch_size了
1)导入的包类:
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler#此处的BatchSampler相当于在Dataloader中设置的batch_size
2)读取图片函数
学习点:考虑读取图片失败情况,使用try…except结构,并将读取失败的图路径存储下来,并返回一个全为白色的同尺度新图。
def default_loader(path):
'''
打开图片并转化为RGB,打开失败,则记录下来,并返回一个新图
'''
try:
img = Image.open(path).convert('RGB')
except:
with open('read_error.txt', 'a') as fid:
fid.write(path+'\n')
return Image.new('RGB', (224,224), 'white')
return img
3)RandomDataset类
此类用于选取指定index的样本,返回的是一张图片以及对应的标签。
批量获取是在Dataloader中设置的。
class RandomDataset(Dataset):
def __init__(self, transform=None, dataloader=default_loader):#此处dataloader用不上
self.transform = transform
self.dataloader = dataloader
with open('val.txt', 'r') as fid:#将图片路径以及标签读取出来
self.imglist = fid.readlines()
def __getitem__(self, inde