combines consecutive elements of this dataset into batches. 两种情况取决于是否要最后那个个数不足的batch
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
print(list(dataset.as_numpy_iterator()))
输出
[array([0, 1, 2], dtype=int64), array([3, 4, 5], dtype=int64), array([6, 7], dtype=int64)]
dataset2 = tf.data.Dataset.range(8)
dataset2 = dataset2.batch(3,drop_remainder=True)
print(list(dataset2.as_numpy_iterator()))
输出
[array([0, 1, 2], dtype=int64), array([3, 4, 5], dtype=int64)] 这里6,7没有了
本文详细介绍了TensorFlow中用于批量处理数据集的`tf.data.Dataset.batch`方法。通过使用该方法,可以将数据集连续的元素组合成批次,同时讨论了在处理最后可能不足批量大小的数据时的两种不同策略。
402

被折叠的 条评论
为什么被折叠?



