·mxnet.io.ImageRecordIter接口是MXNet框架中用于图像算法相关的数据读取接口,只需将准备好的record文件(后缀是.rec)作为该接口的输入即可训练模型。mxnet.io.ImageRecordIter接口的源码是mxnet.io.MXDataIter类,该类同样继承自MXNet框架下的数据读取基础类mxnet.io.DataIter并重写了其中的一些方法。在了解了比较抽象和基础的mxnet.io.DataIter类后(参考博客:MXNet源码解读:数据读取基础类—mxnet.io.DataIter),再来看实际构造数据迭代器时候使用的mxnet.io.MXDataIter类就会更加清晰。总的来说了解mxnet.io.MXDataIter类的实现细节对于了解MXNet框架的图像数据读取以及灵活封装更high level的API非常有帮助。
mxnet.io.MXDataIter类的源码地址
源码如下。从注释可以看出该接口是C++底层数据迭代器的python封装,继承自mxnet.io.DataIter基础类。当你初始化一个mxnet.io.ImageRecordIter类时会得到一个MXDataIter实例,然后当你调用该实例的时候就会调用MXDataIter类的底层C++数据迭代器读取数据(后面会介绍是通过next方法实现的)。MXDataIter类是个非常基础的类,许多high level的数据读取接口都是调用该类来实现的,比如用于图像分类的mxnet.io.ImageRecordIter接口,用于目标检测的mxnet.io.ImageDetRecordIter接口以及CSVIter、MNISTIter等接口。
class MXDataIter(DataIter):
"""A python wrapper a C++ data iterator.
This iterator is the Python wrapper to all native C++ data iterators, such
as `CSVIter`, `ImageRecordIter`, `MNISTIter`, etc. When initializing
`CSVIter` for example, you will get an `MXDataIter` instance to use in your
Python code. Calls to `next`, `reset`, etc will be delegated to the
underlying C++ data iterators.
Usually you don't need to interact with `MXDataIter` directly unless you are
implementing your own data iterators in C++. To do that, please refer to
examples under the `src/io` folder.
Parameters
----------
handle : DataIterHandle, required
The handle to the underlying C++ Data Iterator.
data_name : str, optional
Data name. Default to "data".
label_name : str, optional
Label name. Default to "softmax_label".
See Also
--------
src/io : The underlying C++ data iterator implementation, e.g., `CSVIter`.
"""
# init方法中self.first_batch = None,self.first_batch = self.next()这两行是
# 对第一个batch数据的初始化,返回的是DataBatch类数据,后面会介绍。然后
# self.provide_data = [DataDesc(data_name, data.shape, data.dtype)],
# self.provide_label = [DataDesc(label_name, label.shape, label.dtype)]通
# 过DataDesc类来保存数据相关的信息,具体而言是通过调用mxnet.io.DataDesc类的
# __new__方法来保存数据的name、shape、type、layout等信息,该类在
# mxnet.io.DataIter博客中已经介绍了。需要注意的是self.provide_data和
# self.provide_label都是列表形式,因此假设你新增一个层需要额外的label信息时,可
# 以直接在self.provide_label列表中添加。因为init方法在数据集初始化的时候就会调
# 用,得到的self.provide_data和self.provide_label在网络结构bind的时候会用到,
# 因此如果提供的数据信息不对,会在bind的时候报错。
def __init__(self, handle, data_name='data', label_name='softmax_label', **_):
super(MXDataIter, self).__init__()
self.handle = handle
# debug option, used to test the speed with io effect eliminated
self._debug_skip_load = False
# load the first batch to get shape information
self.first_batch = None
self.first_batch = self.next()
data = self.first_batch.data[0]
label = self.first_batch.label[0]
# properties
self.provide_data = [DataDesc(data_name, data.shape, data.dtype)]
self.provide_label = [DataDesc(label_name, label.shape, label.dtype)]
self.batch_size = data.shape[0]