之前写了一篇通过读取保存了 (img_path, label) 的 csv 文件来读取数据集,但是个人感觉这种方式一点麻烦
之前阅读了大佬的博客,读取自己的数据集需要继承 Dataset
重写一个类:
import torch
import numpy as np
# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label):
self.data = data_root
self.label = data_label
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
def __len__(self):
return len(self.data)
# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)
其中 data_root, data_label 分别是原始的特征数据集合和label集合,__getitem__
函数方便 Dataset 通过 index 获取一个样本,__len__
返回样本的数目
如果我们想要读取 MSTAR 数据集中的图片,SARImage 包括 TEST 和 TRAIN 两个子目录,TRAIN 中有 10 个目录,对应 10 个类别,每个目录内存有原始的 300 张对应类别 SAR 图像
我们可以在 __init__
中传入数据的根目录(10类的上一级),需要的类别,类别的标签,每个类别多少张,transform:
class ReadDataFromDir(Dataset):
def __init__(self, root_dir, classes_name, classes_label, samples_per_cls, transform):
# Transforms
self.transform = transform
self.label_dict = dict(list(zip(classes_name, classes_label)))
self.img_label_pairs = []
for class_name in classes_name:
class_dir = root_dir + '/' + class_name # 类别路径
img_names = os.listdir(class_dir) # 类内图片名
random.shuffle(img_names) # 随机打乱选取 samples_per_cls
img_paths = [class_dir + '/' + name for name in img_names][:samples_per_cls]
labels = [self.label_dict[class_name] for i in range(samples_per_cls)]
self.img_label_pairs = self.img_label_pairs + list(zip(img_paths, labels))
self.data_len = len(self.img_label_pairs)
def __getitem__(self, index):
# get image
img_path, label = self.img_label_pairs[index]
img = Image.open(img_path)
img_to_tensor = self.transform(img)
return (img_to_tensor, label)
def __len__(self):
return self.data_len
其中,通过 os.listdir(class_dir)
读取了每个类别目录内的所有图片文件名,将图片位置和对应标签存储到一个元组构成的列表中。
测试一下:
if __name__ == '__main__':
cls = ['2S1', 'BMP2', 'BRDM_2', 'BTR70', 'BTR_60', 'D7', 'T62', 'T72', 'ZIL131', 'ZSU_23_4']
import torchvision.transforms as transforms
TargetDataset = ReadDataFromDir(
root_dir='../SARImages/TRAIN',
classes_name=['2S1', 'BRDM_2', 'BTR_60', 'D7', 'T72'],
classes_label=[i for i in range(5)],
samples_per_cls=5,
transform=transforms.Compose([
transforms.Resize(5),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]
),
)
print(TargetDataset.data_len)
print(TargetDataset.label_dict)
for item in TargetDataset.img_label_pairs:
print(item[0], '\t', item[1])
print(TargetDataset.__getitem__(0)[0].shape)
print(len(TargetDataset))
print(TargetDataset[0])
classes_name=['2S1', 'BRDM_2', 'BTR_60', 'D7', 'T72']
for (idx, (img, label)) in enumerate(TargetDataset):
print(idx, img.shape, classes_name[label])
完结 ~