Pytorch 读取目录中的数据

这篇博客介绍了如何使用PyTorch自定义数据加载器`GetLoader`和`ReadDataFromDir`来处理图像数据集。`GetLoader`类用于随机生成数据,而`ReadDataFromDir`类则针对MSTAR数据集,读取指定目录下的图像并按类别随机选取样本。文章提供了详细代码示例,展示了如何处理图像路径、标签、数据增强以及数据预处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

之前写了一篇通过读取保存了 (img_path, label) 的 csv 文件来读取数据集,但是个人感觉这种方式一点麻烦

之前阅读了大佬的博客,读取自己的数据集需要继承 Dataset 重写一个类:

import torch
import numpy as np


# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
	# 初始化函数,得到数据
    def __init__(self, data_root, data_label):
        self.data = data_root
        self.label = data_label
    # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels
    # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)

其中 data_root, data_label 分别是原始的特征数据集合和label集合,__getitem__ 函数方便 Dataset 通过 index 获取一个样本,__len__ 返回样本的数目


如果我们想要读取 MSTAR 数据集中的图片,SARImage 包括 TEST 和 TRAIN 两个子目录,TRAIN 中有 10 个目录,对应 10 个类别,每个目录内存有原始的 300 张对应类别 SAR 图像

我们可以在 __init__ 中传入数据的根目录(10类的上一级),需要的类别,类别的标签,每个类别多少张,transform:

class ReadDataFromDir(Dataset):
    def __init__(self, root_dir, classes_name, classes_label, samples_per_cls, transform):
        # Transforms
        self.transform = transform
        self.label_dict = dict(list(zip(classes_name, classes_label)))

        self.img_label_pairs = []
        for class_name in classes_name:
            class_dir = root_dir + '/' + class_name     # 类别路径
            img_names = os.listdir(class_dir)           # 类内图片名
            random.shuffle(img_names)                   # 随机打乱选取 samples_per_cls
            img_paths = [class_dir + '/' + name for name in img_names][:samples_per_cls]
            labels = [self.label_dict[class_name] for i in range(samples_per_cls)]
            self.img_label_pairs = self.img_label_pairs + list(zip(img_paths, labels))

        self.data_len = len(self.img_label_pairs)
        
    def __getitem__(self, index):
        # get image
        img_path, label = self.img_label_pairs[index]
        img = Image.open(img_path)
        img_to_tensor = self.transform(img)

        return (img_to_tensor, label)


    def __len__(self):
        return self.data_len

其中,通过 os.listdir(class_dir) 读取了每个类别目录内的所有图片文件名,将图片位置和对应标签存储到一个元组构成的列表中。

测试一下:

if __name__ == '__main__':
    
    cls = ['2S1', 'BMP2', 'BRDM_2', 'BTR70', 'BTR_60', 'D7', 'T62', 'T72', 'ZIL131', 'ZSU_23_4']

    import torchvision.transforms as transforms

    TargetDataset = ReadDataFromDir(
        root_dir='../SARImages/TRAIN', 
        classes_name=['2S1', 'BRDM_2', 'BTR_60', 'D7', 'T72'], 
        classes_label=[i for i in range(5)], 
        samples_per_cls=5, 
        transform=transforms.Compose([
            transforms.Resize(5), 
            transforms.ToTensor(), 
            transforms.Normalize([0.5], [0.5])]
        ),
    )

    print(TargetDataset.data_len)
    print(TargetDataset.label_dict)
    for item in TargetDataset.img_label_pairs:
        print(item[0], '\t', item[1])

    print(TargetDataset.__getitem__(0)[0].shape)

    print(len(TargetDataset))
    print(TargetDataset[0])

    classes_name=['2S1', 'BRDM_2', 'BTR_60', 'D7', 'T72']
    for (idx, (img, label)) in enumerate(TargetDataset):
        print(idx, img.shape, classes_name[label])

完结 ~

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值