combines consecutive elements of this dataset into batches. 两种情况取决于是否要最后那个个数不足的batch
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
print(list(dataset.as_numpy_iterator()))
输出
[array([0, 1, 2], dtype=int64), array([3, 4, 5], dtype=int64), array([6, 7], dtype=int64)]
dataset2 = tf.data.Dataset.range(8)
dataset2 = dataset2.batch(3,drop_remainder=True)
print(list(dataset2.as_numpy_iterator()))
输出
[array([0, 1, 2], dtype=int64), array([3, 4, 5], dtype=int64)] 这里6,7没有了