前言
tf.data.Dataset
曾是放在tf.contrib
包中的,从tensorflow 1.4
开始转正。主要作用是数据读取,相比原来的读取数据方法(placeholde读内存数据和queue读硬盘数据)更加简洁,且是TensorFlow最新的Eager模式必须的读取数据方法。
1. 使用流程
1.1 实例化Dataset
dataset = tf.data.Dataset.from_tensor_slice((A, B,,,N))
- 该方法的作用是切分传入Tensor的第一个维度
假设A是包含5000张图片的数据集,shape =(5000, 32, 32, 3),经过切分后就变成5000张(32, 32, 3)的独立图片,并且和B切分后的数据也是对应的,不难想象,在此场景下B就是图片标签。 - 除了这种方法还有另外三种创建Dataset的方法,比如
tf.data.TFRecordDataset()
可以读取tfrecords文件,此处不展开
1.2 按需对dataset进行变换(transformation)
最常用的transformation非batch
莫属,batch
的作用是将数据集分成小批,后面训练的时候就可以方便的一次取一个batch出来了。
dataset = dataset.batch(128)
除此之外常用的还有:
- repeat: 重复数据集,相当于epoch
- map: 接收一个函数,对原始数据做一些简单处理
- shuffle: 打乱数据
1.3 创建Iterator,取出数据
创建Iterator的方法也有几种,这里我们选择可以使用placeholder的initializable iterator:
iterator = dataset.make_initializable_iterator()
创建迭代器
A, B = iterator.get_next()
获得一个批次的数据
注意后面需要在会话中初始化,并放入placeholder的实际数据:
sess.run(iterator.initializer, feed_dict={A: x_train, B: y_train})
1.4 实例
# 创建dataset对象
dataset = tf.data.Dataset.from_tensor_slices((image_holder, label_holder))
# 训练周期
dataset = dataset.repeat(1)
# 批次大小
dataset = dataset.batch(128bat )
# 初始化迭代器
iterator = dataset.make_initializable_iterator()
# 获得一个批次数据和标签
data_batch, label_batch = iterator.get_next()
# 此时的data_batch和label_batch就可以输入到模型中了
,
,
,
# 开启会话训练时即可从迭代器中拿数据
sess.run(iterator.initializer, feed_dict={image_holder: x_train, label_holder: y_train})