Mxnet 实现自己的dataiter

本文介绍如何在MXNet中实现自定义的数据迭代器,以满足特定深度学习模型的需求。通过继承MXNet内置的DataIter类并进行扩展,可以灵活地调整数据加载方式和预处理步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

实现深度学习模型的时候,有时候dataiter会不能满足自己的需求,所以需要继承下来,自己写一下。


mxnet写customed dataiter

代码也比较简单,直接上代码了

import mxnet as mx

class custom_iter(mx.io.DataIter):
    def __init__(self, data_iter):
        super(custom_iter,self).__init__()
        self.data_iter = data_iter
        self.batch_size = self.data_iter.batch_size

    @property
    def provide_data(self):
        return self.data_iter.provide_data

    @property
    def provide_label(self):
        provide_label = self.data_iter.provide_label[0]

        #return [('softmax_label', provide_label[1]), \
                # ('other_loss_label', provide_label[1])]

        return [('softmax_label', provide_label[1])]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch = self.data_iter.next()
        label = batch.label[0]

        return mx.io.DataBatch(data=batch.data, label=[label,label], \
                pad=batch.pad, index=batch.index)



import numpy as np
eigval = np.array([55.46, 4.794, 1.148])
eigvec = np.array([[-0.5675, 0.7192, 0.4009],
                [-0.5808, -0.0045, -0.8140],
                [-0.5836, -0.6948, 0.4203]])
shape_=112
shape=(3,shape_,shape_)


aug_list_test=[mx.image.ForceResizeAug(size=(shape_,shape_)),
                #mx.image.ResizeAug(size=shape_+32),
                mx.image.CenterCropAug((shape_,shape_)),
          ]
aug_list_train=[
                #mx.image.ResizeAug(size=shape_+32),
                mx.image.ForceResizeAug(size=(shape_,shape_)),
                mx.image.RandomCropAug((shape_,shape_)),
                mx.image.HorizontalFlipAug(0.5),
                mx.image.CastAug(),
                mx.image.ColorJitterAug(0.0, 0.1, 0.1),
                mx.image.HueJitterAug(0.5),
                mx.image.LightingAug(0.1, eigval, eigvec),
          ]

def get_iterator(batch_size):
    """return train and val iterators for training"""
    
    train_iter = mx.image.ImageIter(batch_size=batch_size,
                                    data_shape=shape,
                                    label_width=1,
                                    aug_list=aug_list_train,
                                    shuffle=True,
                                    path_root='',
                                    path_imglist='/you/path/train.lst'
                                    )
    val_iter = mx.image.ImageIter(batch_size=batch_size,
                                  data_shape=shape,
                                  label_width=1,
                                  shuffle=False,
                                  aug_list=aug_list_test,
                                  path_root='',
                                  path_imglist='/you/path/val.lst'
                                 )

    return (custom_iter(train_iter), custom_iter(val_iter))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值