!!!在数据集框架中,每一个数据集代表一个数据来源,其数据来源有一下几种:张量,TFRecord文件,文本文件,sharding文件等等。。。
一.数据集Dataset的常用构造方法:
(1)从一个tensor中构造数据集:dataset=tf.data.Dataset.from_tensor_slices(tensor)
(2)利用硬盘上的文件构建数据集:dataset=tf.date.TextLineDataset(filename),#filename可以是多个文件路径组成的文件路径列 表
(3)在图像任务相关的任务重,输入数据通常以TFRecord形式存储,这时可以利用TFRecordDataset来读取数据:
dataset=tf.data.TFRecordDataset(filename)
***filename可以是多个文件路径组成的文件路径列
***tf.data.TFRecordDataset()读取到的每一个TFRecord格式的数据都有不同的feature格式,因此,在利 用tf.data.TFRecordDataset()读取文件以后,还要定义一个parse(record)函数来对读取到的数据进行解析。然后利用map()函数来完成文件解析。具体如下:
def parse(record):
......... #........内容为解析TFRecord格式文件的样式
dataset=dataset.map(parse)
二.读取数据集的三个基本步骤:
3.使用get_next()获取tensor
具体实例如下:
- 定义数据集的构造方法,根据具体上文中数情况,利用上文中对应构造方法来进行定义。
- 定义遍历器,常用的有一下两种:
-
(1)数据集的所有参数已经确定的情况下(文件的路径,文件内容已经确定),经常使用one_shot_iterator。
-
格式如下:iterator=dataset.make_one_shot_iterator()
-
(2)在需要使用placeholder来初始化数据集时,用initializable_iterator来遍历数据集。
-
格式如下:iterator=dataset.make_initializable_iterator()
-
#与one_shot_iterator不同的是,initializable_iterator()在会话中首先要对其进行初始化:sess.run(iterator.initializer,......)
-
import tensorflow as tf def parser(record): features = tf.parse_single_example( record, features={ 'feat1':tf.FixedLenFeature([],tf.int64), 'feat2':tf.FixedLenFeature([],tf.int64) } ) return features['feat1'],features['feat2'] #数据集可以是一个tensor,或者文本文件 #若是tensor,则使用tf.data.from_tensor_slices(input_data) #若是文本文件,则使用tf.data.TextLineDataset(input_files) input_files = ['file1','file2'] dataset = tf.data.TFRecordDataset(input_files) #由于tfrecords读取出来的是二进制数据,需要对每个数据进行解析,得到想要的格式 #这里使用映射函数对每个数据进行解析 dataset = dataset.map(parser) #通过一个迭代器获取数据 iterator = dataset.make_one_shot_iterator() feat1,feat2 = iterator.get_next() with tf.Session() as sess: for i in range(10): print(sess.run([feat1,feat2]))