在前面的文章中,我们简单介绍了在 MegEngine imperative 中的各模块以及它们的作用。对于新用户而言可能不太了解各个模块的使用方法,对于模块的结构和原理也是一头雾水。Python 作为现在深度学习领域的主流编程语言,其相关的模块自然也是深度学习框架的重中之重。
模块串讲将对 MegEngine 的 Python 层相关模块分别进行更加深入的介绍,会涉及到一些原理的解释和代码解读。Python 层模块串讲共分为上、中、下三个部分,本文介绍 Python 层的 data 模块。读者将通过本文了解到要构建数据 pipeline 所需要的对象,以及如何高效地构建 pipeline。
构建数据处理 pipeline —— data 模块
神经网络需要数据才可以训练,数据源文件可能是各种格式,读取数据需要定义采样规则、数据变换规则、数据合并策略等,这些和数据相关的模块都封装在 data 下。
在 MegEngine 中训练模型读取数据的 pipeline 一般是:
- 创建一个 Dataset 对象;
- 按照训练场景的需求可能需要对数据做一些变换或合并的处理,这里可能需要创建 Sampler、Transform、Collator 等对象来完成相应操作;
- 创建一个 DataLoader 对象;
- 将数据分批加载到
DataLoader里,迭代DataLoader对象进行训练。
下面我们看一下这几个对象的实现。
Dataset
在 MegEngine 中,数据集是一个可迭代的对象,所有的 Dataset 对象都继承自 class Dataset,都需要实现自己的 __getitem__() 方法和 __len__() 方法,这两个方法分别是用来获取给定索引的对应的数据样本和返回数据集的大小。
根据对数据集访问方式的区别,MegEngine 中的数据集类型主要分为两种:ArrayDataset 和 StreamDataset,前者支持随机访问数据样本,而后者只可以顺序访问。二者的主要区别见下表:

Dataset
Dataset 支持对数据集的随机访问,访问类型是 Map-style 的,也就是可以从索引映射到数据样本,使用时需要实现 __getitem__() 方法和 __len__() 方法。
下面是一个使用 Dataset 生成一个由 0 到 5 的数组成的数据集的例子:
from megengine.data.dataset import Dataset
class CustomMapDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
使用起来如下:
>>> data = list(range(0, 5))
>>> map_dataset = CustomMapDataset(data)
>>> print(len(map_dataset))
5
>>> print(map_dataset[2])
2
可以发现 Dataset 最大的特点就是可以根据给定索引随机访问数据集中对应下标的数据。
ArrayDataset
对于 Numpy ndarray 类型的数据,MegEngine 中对 Dataset 进一步封装实现了 ArrayDataset,使用 ArrayDataset 无需实现 __getitem__() 方法和 __len__() 方法。
下面的例子随机生成了一个具有 100 个样本、每张样本为 32 × 32 像素的 RGB 图片的数据集:
import numpy as np
from megengine.data.dataset import ArrayDataset
data = np.random.random((100, 3, 32, 32))
target = np.random.random(

最低0.47元/天 解锁文章





