Tensorflow 提供队列读取数据构造 batch 的方法。队列可以加快数据度的速度,先进先出等等的基本特性就不赘述了,直接介绍tensorflow 是如何做的。
首先,用给的文件名列表生成一个文件名队列(file name queue)。
然后,从文件名队列中,按照指定的方式按个读取单个数据。
最后,再把读出的数据放入另外一个队列中,该队列 dequeue 用以构造 batch 。
可以参考下面的代码:
def read_my_file_format(filename_queue):
reader = tf.SomeReader()
key, record_string = reader.read(filename_queue)
example, label = tf.some_decoder(record_string)
processed_example = some_processing(example)
return processed_example, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
tf.train.string_input_producer(string_tensor, num_epochs=none, shuffle=true, seed=none, capacity=32, shared_name=none, name=none, cancel_op=none) 返回一个 queue
string_tensor: 存储的是要构造 queue 的一组文件名
num_epochs: 如果没有指定,则在这些文件名中一直循环不停。若指定,则在每一个 string 都被生成指定次数后产 生 out_of_range 错误
shuffle: 是否开启乱序,默认开启
tf.SomeReader() 有多种类型,下面只介绍一种。
tf.FixedLengthRecordReader(record_bytes, hearder_bytes, footer_bytes, name=none) 返回一个 reader
record_bytes: 实际有效数据长度,单位为 byte
hearder_bytes: 头
footer_bytes: 尾
tf.FixedLengthRecordReader.read(queue, name=none) 返回单个 (key, value)-- (名字,值)
queue: 输入 string 队列
tf.train.shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=none, enqueue_many=false, shapes=none, allow_smaller_final_batch=false, shared_name=none, name=none)
tensors: 可以输入多个 tensor ,若输入的形状是 [x, y, z] ,输出为 [batch_size, x, y, z]
batch_size: batch 大小
capacity: 控制队列的总大小
min_after_dequeue: 队列维持的最小长度
num_threads: 线程数,可设置多个
更多关于 reading data 和 queue 的介绍参看官方教程。