Python 自定义 Iterator

import tensorflow as tf
import numpy as np

dim1 = 10
dim2 = 5

array = np.array(np.reshape([i for i in range(dim1 * dim2)], newshape=(dim1, dim2)))
print(array)

dataset = tf.data.Dataset.from_tensor_slices(array)
dataset = dataset.batch(5)

class Iterator(object):
    def __init__(self, dataset, stride):
        self.dataset = iter(dataset)
        self.stride = stride
        if (stride < 1):
            raise ValueError("stride must > 1, but got ", stride)
        self.begin = 0

    def __iter__(self):
        return self

    def __next__(self):
        try:
            for items in self.dataset:
                while True:
                    end = self.begin + self.stride
                    if (end >= items.shape[0]):
                        yield items[self.begin:]
                        self.begin = 0
                        break
                    else:
                        yield items[self.begin : end]
                        self.begin += self.stride
        except tf.errors.OutOfRangeError as error:
            raise error

class Dataset(object):
    def __init__(self, dataset, stride):
        self.Iterator = Iterator(dataset, stride)

    def __iter__(self):
        return next(self.Iterator)

if __name__ == "__main__":
    for step, items in enumerate(Dataset(dataset, stride=2)):
        print(items)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值