MindSpore提供了大部分常用数据集和标准格式数据集的加载接口,可以直接使用mindspore.dataset中对应的数据集加载类进行数据加载,如MNIST、CIFAR-10、CIFAR-100、VOC、COCO、ImageNet、CelebA、CLUE等, 以及业界标准格式的数据集,包括MindRecord、TFRecord、Manifest等。
常用数据集加载以cifar10为例,首先将cifar10数据集下载并解压到本地。
1、加载cifar10数据集:
DATA_DIR = "./cifar-10-batches-bin/"
sampler = ds.SequentialSampler(num_samples=5)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)
用create_dict_iterator创建数据迭代器访问数据:
for data in dataset.create_dict_iterator():
print("Image shape: {}".format(data['image'].shape), ", Label: {}".format(data['label']))

2、加载自定义图像分类数据集
使用mindspore加载自定义图像分类数据,可以使用mindspore.dataset.ImageFolderDataset接口进行加载。将相同类别的图像放在同一文件夹下,不同类别以不同文件夹区分,将所有分类的上级目录传入ImageFolderDataset接口,mindspore会自动加载图像数据并根据不同文件夹分配对应标签。


这里以TinyImageNet为例进行数据加载。首先,使用imageFolderDataset接口传入数据路径,通过num_parallel_worker可设置数据加载并行线程数,shuffle参数设置是否打乱数据顺序。另外需要通过map接口进行图像数据预处理,图像预处理接口mindspore.dataset.vision.c_transforms,通过c_transforms可进行图像解码,缩放归一化,矩阵转置等操作。
import mindspore
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import dtype as mstype
def create_dataset(data_path, batch_size=24, c_transforms
repeat_num=1):
"""定义数据集"""
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode == context.ParallelMode.DATA_PARALLEL:
data_set = ds.ImageFolderDataset(data_path, num_parallel_def create_dataset(data_path, batch_size=24, repeat_num=1):
"""定义数据集"""
data_set = ds.ImageFolderDataset(data_path, num_parallel_workers=8, shuffle=True)
image_size = [100, 100]
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
trans = [
CV.Decode(),
CV.Resize(image_size),
CV.Normalize(mean=mean, std=std),
CV.HWC2CHW()
]
# 实现数据的map映射、批量处理和数据重复的操作
type_cast_op = C.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
data_set = data_set.batch(batch_size, drop_remainder=True)
data_set = data_set.repeat(repeat_num)
return data_set
数据迭代,ImageFolderDataset通过create_tuple_iterator()接口对数据集进行迭代,每次迭代一个batch的数据。
if __name__ == '__main__':
datapath = 'D:/Sources/Data/datasets/TinyImageNet/val'
ds = create_dataset(datapath, batch_size=8)
iterator = ds.create_tuple_iterator()
for item in iterator:
print(f'images:{mindspore.Tensor(item[0]).shape},labels:{item[1]}')

本文介绍了MindSpore如何使用内置数据集接口加载MNIST、CIFAR-10/CIFAR-100、VOC、COCO等标准数据集,以及如何处理自定义图像分类任务,如TinyImageNet,通过ImageFolderDataset进行数据预处理和加载。
486

被折叠的 条评论
为什么被折叠?



