1、概述
在tensorflow中的输入数据会有很多形式一般有一下几种形式
- 数据以tf.constant的实行直接嵌入到graph中。在这种情况下一般数据量不会很大,应用场景也比较单一
- 以tf.placeholder与feed_dic的形式存在
在这种场景下,往往也需要将数据全部读入到内存,转换成tf的张量集合然后再进行处理。在进行大量数据处理时显得的力不从心。
- 以pipeline的方式从文件中直接读取数据并且采用多线程异步形式来解决IO的瓶颈。
本章内容主要讨论第三种方式。
2、tf.data API
tensorflow的tf.data类可以让我们使用简单的可重用的代码来构建复杂的输入管道。并将数据构建、打乱数据、生成批量数据的功能,整合其中。同时tf.data提供了文本文件输入模型与图像输入模型用于处理不同形式的输入数据。
tf.data引入了2个新的概念
1、tf.data.Dataset
tf.data.Dataset表示一个元素序列,其中每个元素包含一个或多个张量,例如,在图像流水线中,元素可以是单个训练示例,其中一对张量表示图像数据和标签。创建数据集有两种不同的方法:
- 从数据集中直接创建一个dataset对象
- 从一个已有的dataset对象转换一个新的dataset对象
2、tf.data.Iterator
使用该api可以构造一个迭代器从Dataset中提取数据。我们可以使用Iterator.get_next()产生Dataset执行时的下一个元素,并且通常充当输入管道代码和模型之间的接口。最简单的迭代器是一个“一次性迭代器”,它与一个特定的Dataset迭代器相关联并迭代一次。如果需要构造更为复杂的迭代器可以使用Iterator.initializer传递不同的参数来进行构建。
3、基本原理
- 要启动输入管道,您必须定义源。
- 一旦有了Dataset对象,就可以通过对对象进行链接方法调用将其转换为新对象。
- 创建一个迭代器从Dataset对象中获取数据
4、Dataset的基本结构
Dataset必须是具有相同结构的元素集合。每个元素至少包含一个Tensor对象,每一个元素被称之为“组件”。
每一个组件都包含下面两个非常重要的属性:
- tf.Dtype:用来表示组件中每一个元素的数据类型
- tf.TensorShape:用来表示每个元素的静态形状
而就数据集本身来讲也有两个非常重要的属性,我们更多的情况下是关注整个数据集的情况,而不是单个组件的情况:
- Dataset.output_types:整体数据数据集中所包含的数据类型
- Dataset.output_shapes:数据集的整体形状综述
请参照如下代码:
input_data = tf.random_uniform([4,10])
dataset1 = tf.data.Dataset.from_tensor_slices(input_data)
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
with tf.Session() as sess:
print(sess.run(input_data))
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
也可以对dataset进行任意组合:
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
5、创建一个迭代器
要从数据集里读取数据,就要使用迭代器,tensorflow提供了下面几种迭代器:
- one-shot
- initializable
- reinitializable
- feedable
one-shot迭代器是所有的迭代器中最简单的一个。他只支持一次性迭代,而且无需初始化操作。one-shot迭代器几乎处理现有基于队列的输入管道支持的所有情况,但它们不支持参数化。请参照下面代码:
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(100):
value = sess.run(next_element)
print(value)
initializable迭代器可以使用placeholder对迭代器进行初始化,请参照如下代码:
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
reinitializable迭代器可以使用不同的已经被dataset来初始化迭代器。如果我们需要通过在训练的时候,同时进行交叉验证,那么此时我们就会用到此类的迭代器,请参考一下代码:
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
sess = tf.Session()
for _ in range(20):
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
sess.run(validation_init_op)
for _ in range(50):
print(sess.run(next_element))
feedable迭代器采用feed_dic机制在session.run时使用类似place_holder的机制来初始化不同的迭代器。他提供的功能与reinitializable迭代器类似,但并不需要在使用数据集之前就初始化迭代器。我们可以使用feedable迭代器实现上面类似的功能。
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
sess = tf.Session()
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})