原理
TensorFlow中有两个函数管理Session中的多线程:Coordinator和 QueueRunner。
同一个Session中可以创建多个线程,但所有线程必须能被同步终止,异常必须能被正确捕获并报告。当会话终止的时候, 队列必须能被正确地关闭。
Coordinator用来管理在Session中的多个线程,可以用来同时停止多个工作线程,同时报告异常,当程序捕捉到这个异常后之后就会终止所有线程。
QueueRunner用来启动tensor的入队线程,可以启动多个工作线程将多个tensor推送入文件名称队列中。 只有运行 tf.train.start_queue_runners 后,才会真正把tensor推入内存序列中,供计算单元调用,否则数据流图会处于一直等待状态。
图片来自Reference[2]
示例
产生测试数据的方法请参考 tensorflow笔记 tfrecord创建及读取
import tensorflow as tf
# k = 1时可以看到Coordinator在队列清空时抛出的异常
k = 0
# 偷懒用了之前的数据,格式: feature是一个1x5向量,label是0或1.
tfrecord_path = 'data.record'
def _parse_function(example_proto): # 解析函数
dics = {
'sample': tf.FixedLenFeature([5], tf.int64), # 如果不是标量,一定要在这里说明数组的长度
'label': tf.FixedLenFeature([], tf.int64)
}
parsed_example = tf.parse_single_example(example_proto, dics)
parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)
return parsed_example
def read_dataset(tfrecord_path = tfrecord_path):
dataset = tf.data.TFRecordDataset(tfrecord_path)
new_dataset = dataset.map(_parse_function)
shuffle_dataset = new_dataset.shuffle(buffer_size=20000) # 打乱顺序
batch_dataset = shuffle_dataset.batch(2) # 按batch数量输出
prefetch_dataset = batch_dataset.prefetch(2000) # 数据提前进入队列,速度会快很多
iterator = prefetch_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
return next_element
next_element = read_dataset()
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord) # 文件名开始进入文件名队列和内存
if k == 1:
while not coord.should_stop():
print('dataset:', sess.run([next_element['sample'], next_element['label']]))
else:
try:
while not coord.should_stop():
print('dataset:', sess.run([next_element['sample'], next_element['label']]))
except tf.errors.OutOfRangeError:
print("queue is empty")
coord.request_stop()
coord.join(threads)
'''
k = 0
dataset: [array([[4., 8., 8., 3., 8.],
[2., 0., 0., 5., 8.]], dtype=float32), array([1., 1.], dtype=float32)]
dataset: [array([[5., 0., 2., 4., 9.],
[0., 1., 0., 3., 0.]], dtype=float32), array([0., 0.], dtype=float32)]
dataset: [array([[0., 3., 4., 2., 5.],
[0., 4., 7., 7., 3.]], dtype=float32), array([0., 1.], dtype=float32)]
dataset: [array([[3., 5., 7., 8., 7.],
[5., 2., 7., 9., 9.]], dtype=float32), array([1., 0.], dtype=float32)]
dataset: [array([[8., 3., 7., 5., 1.],
[3., 5., 2., 7., 7.]], dtype=float32), array([1., 0.], dtype=float32)]
queue is empty
'''
'''
k = 1
OutOfRangeError (see above for traceback): End of sequence
'''
Reference:
[1] tf.train.input_producer
[2] tensorflow中协调器 tf.train.Coordinator 和入队线程启动器 tf.train.start_queue_runners