TensorFlow程序读取数据一共有3种方法:
- 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
- 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
- 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
目录
数据读取
一文件读取流程
1.构造文件名队列
将文件名列表交给tf.train.string_input_producer
函数.string_input_producer
来生成一个先入先出的队列, 文件阅读器会需要它来读取数据
tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, name=None)
string_tensor:含有文件名+路径的1阶张量
num_epochs:过几遍数据,默认过无限编
Returns: 文件队列
2.读取与解码
根据你的文件格式, 选择对应的文件阅读器, 然后将文件名队列提供给阅读器的read
方法。阅读器的read
方法会输出一个key来表征输入的文件和其中的纪录(对于调试非常有用),同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。
key, value =阅读器.read(queue, name=None)
key
: 文件名value
: 样本
文本格式/CSV 文件
读取 tf.TextLineReader()
解码 tf.decode_csv()
图片格式
读取 tf.WholeFileReader()
解码 不同的图片格式有不同的解码方式
tf.image.decode_png()
tf.image.decode_jpeg()
二进制
读取 tf.FixedLengthRecordReader()
TFRecords文件
3.批处理
在数据输入管线的末端, 我们需要有另一个队列来执行输入样本的训练,评价和推理。因此我们使用f.train.batch、
tf.train.shuffle_batch
函数来对队列中的样本进行乱序处理
tf.train.batch(tensor_list, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, name=None)
功能:读取指定大小(个数)的张量
tensor_list
: The list of tensors to enqueue.batch_size
: 从列表中读取的批处理器大小num_threads
: 进入队列的线程数capacity
: An integer. The maximum number of elements in the queue.enqueue_many
: Whether each tensor intensor_list
is a single example.shapes
: (Optional) The shapes for each example. Defaults to the inferred shapes fortensor_list
.name
: (Optional) A name for the operations.
A list of tensors with the same number and types as tensor_list
.
4.手动开启线程
tf.train.start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection='queue_runners')
收集图中所有队列线程,同时默认启动线程
Args:
sess
:Session
used to run the queue ops. Defaults to the default session.coord
: OptionalCoordinator
for coordinating the started threads.daemon
: Whether the threads should be marked asdaemons
, meaning they don't block program exit.start
: Set toFalse
to only create the threads, not start them.collection
: AGraphKey
specifying the graph collection to get the queue runners from. Defaults toGraphKeys.QUEUE_RUNNERS
.
Returns:
A list of threads.
class tf.train.Coordinator()
线程协调员,对线程进行管理和协调
request_stop() 请求停止
should_stop() 询问是否结束
join(threads, stop_grace_period_secs=120)回收线程
Returns: 线程协调器实例
http://www.tensorfly.cn/tfdoc/images/AnimatedFileQueues.gif