在项目中涉及迭代器相关概念,不是很懂,因此写了简单的测试代码帮助理解,代码如下:
import tensorflow as tf
t = tf.constant(
[0,1,2,3,4,4,3,2,1,0,0,0,0,0,0,1,1,1,1,1]
)
# print(t)
train_db = tf.data.Dataset.from_tensor_slices(t).batch(3)
train_iter = iter(train_db) #迭代器
sample = next(train_iter)
# print('batch:', sample[0].shape, sample[1].shape)
print(train_db)
print('\n')
print(train_iter)
print('\n')
print(sample)
print(sample)
print(sample)
print('\n')
print(next(train_iter))
print(next(train_iter))
print(next(train_iter))
可以看到,对于一个列表,可以首先通过slice函数进行切片,切片的单位取决于batch中的参数;
将列表切片后,使用iter函数生成迭代器;
使用next函数不断获得迭代器中的迭代参数;
代码运行后输出如下:
分析结果可以看到,前两个参数的打印输出为乱码,我也不知道为什么,可能是不能直接使用吧。
从之后的打印中可以分析出,不停的对next函数进行调用,就可以不断输出迭代器中的数据,直到所有数据输出完毕。