Tensorflow Dataset API介绍

前言

tf.data.Dataset曾是放在tf.contrib包中的,从tensorflow 1.4开始转正。主要作用是数据读取,相比原来的读取数据方法(placeholde读内存数据和queue读硬盘数据)更加简洁,且是TensorFlow最新的Eager模式必须的读取数据方法。

1. 使用流程

1.1 实例化Dataset

dataset = tf.data.Dataset.from_tensor_slice((A, B,,,N))

  • 该方法的作用是切分传入Tensor的第一个维度
    假设A是包含5000张图片的数据集,shape =(5000, 32, 32, 3),经过切分后就变成5000张(32, 32, 3)的独立图片,并且和B切分后的数据也是对应的,不难想象,在此场景下B就是图片标签。
  • 除了这种方法还有另外三种创建Dataset的方法,比如tf.data.TFRecordDataset()可以读取tfrecords文件,此处不展开

1.2 按需对dataset进行变换(transformation)

最常用的transformation非batch莫属,batch的作用是将数据集分成小批,后面训练的时候就可以方便的一次取一个batch出来了。
dataset = dataset.batch(128)
除此之外常用的还有:

  • repeat: 重复数据集,相当于epoch
  • map: 接收一个函数,对原始数据做一些简单处理
  • shuffle: 打乱数据

1.3 创建Iterator,取出数据

创建Iterator的方法也有几种,这里我们选择可以使用placeholder的initializable iterator:
iterator = dataset.make_initializable_iterator() 创建迭代器
A, B = iterator.get_next() 获得一个批次的数据
注意后面需要在会话中初始化,并放入placeholder的实际数据:
sess.run(iterator.initializer, feed_dict={A: x_train, B: y_train})

1.4 实例

# 创建dataset对象
dataset = tf.data.Dataset.from_tensor_slices((image_holder, label_holder))
# 训练周期
dataset = dataset.repeat(1)
# 批次大小
dataset = dataset.batch(128bat	)
# 初始化迭代器
iterator = dataset.make_initializable_iterator()
# 获得一个批次数据和标签
data_batch, label_batch = iterator.get_next()
# 此时的data_batch和label_batch就可以输入到模型中了
,
,
,
# 开启会话训练时即可从迭代器中拿数据
sess.run(iterator.initializer, feed_dict={image_holder: x_train, label_holder: y_train})

参考

TensorFlow全新的数据读取方式:Dataset API入门教程

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值