import tensorflow as tf
import numpy as np
dim1 = 10
dim2 = 5
array = np.array(np.reshape([i for i in range(dim1 * dim2)], newshape=(dim1, dim2)))
print(array)
dataset = tf.data.Dataset.from_tensor_slices(array)
dataset = dataset.batch(5)
class Iterator(object):
def __init__(self, dataset, stride):
self.dataset = iter(dataset)
self.stride = stride
if (stride < 1):
raise ValueError("stride must > 1, but got ", stride)
self.begin = 0
def __iter__(self):
return self
def __next__(self):
try:
for items in self.dataset:
while True:
end = self.begin + self.stride
if (end >= items.shape[0]):
yield items[self.begin:]
self.begin = 0
break
else:
yield items[self.begin : end]
self.begin += self.stride
except tf.errors.OutOfRangeError as error:
raise error
class Dataset(object):
def __init__(self, dataset, stride):
self.Iterator = Iterator(dataset, stride)
def __iter__(self):
return next(self.Iterator)
if __name__ == "__main__":
for step, items in enumerate(Dataset(dataset, stride=2)):
print(items)