tensorflow笔记 协调器tf.train.Coordinator

本文介绍了TensorFlow中的Coordinator和QueueRunner在多线程管理中的作用。Coordinator负责同步终止线程和处理异常,而QueueRunner用于启动tensor入队线程,确保数据被正确推入内存供计算。只有调用tf.train.start_queue_runners,数据流图才会开始运行。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原理

TensorFlow中有两个函数管理Session中的多线程:Coordinator和 QueueRunner。

同一个Session中可以创建多个线程,但所有线程必须能被同步终止,异常必须能被正确捕获并报告。当会话终止的时候, 队列必须能被正确地关闭。

Coordinator用来管理在Session中的多个线程,可以用来同时停止多个工作线程,同时报告异常,当程序捕捉到这个异常后之后就会终止所有线程。

QueueRunner用来启动tensor的入队线程,可以启动多个工作线程将多个tensor推送入文件名称队列中。 只有运行 tf.train.start_queue_runners 后,才会真正把tensor推入内存序列中,供计算单元调用,否则数据流图会处于一直等待状态。

在这里插入图片描述
图片来自Reference[2]

示例

产生测试数据的方法请参考 tensorflow笔记 tfrecord创建及读取

import tensorflow as tf

# k = 1时可以看到Coordinator在队列清空时抛出的异常
k = 0
# 偷懒用了之前的数据,格式: feature是一个1x5向量,label是0或1.

tfrecord_path = 'data.record' 

def _parse_function(example_proto): # 解析函数
    dics = {  
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是标量,一定要在这里说明数组的长度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    parsed_example = tf.parse_single_example(example_proto, dics)
    parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
    parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)

    return parsed_example

def read_dataset(tfrecord_path = tfrecord_path):
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    new_dataset = dataset.map(_parse_function)
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000) # 打乱顺序
    batch_dataset = shuffle_dataset.batch(2) # 按batch数量输出
    prefetch_dataset = batch_dataset.prefetch(2000) # 数据提前进入队列,速度会快很多
    iterator = prefetch_dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    return next_element

next_element = read_dataset()
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

with tf.Session() as sess:
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord) # 文件名开始进入文件名队列和内存
    
    if k == 1:
        while not coord.should_stop():
            print('dataset:', sess.run([next_element['sample'], next_element['label']]))
    else:
        try:
            while not coord.should_stop():
                print('dataset:', sess.run([next_element['sample'], next_element['label']]))
        except tf.errors.OutOfRangeError:
            print("queue is empty")
    
    coord.request_stop()
    coord.join(threads)

'''
k = 0
dataset: [array([[4., 8., 8., 3., 8.],
       [2., 0., 0., 5., 8.]], dtype=float32), array([1., 1.], dtype=float32)]
dataset: [array([[5., 0., 2., 4., 9.],
       [0., 1., 0., 3., 0.]], dtype=float32), array([0., 0.], dtype=float32)]
dataset: [array([[0., 3., 4., 2., 5.],
       [0., 4., 7., 7., 3.]], dtype=float32), array([0., 1.], dtype=float32)]
dataset: [array([[3., 5., 7., 8., 7.],
       [5., 2., 7., 9., 9.]], dtype=float32), array([1., 0.], dtype=float32)]
dataset: [array([[8., 3., 7., 5., 1.],
       [3., 5., 2., 7., 7.]], dtype=float32), array([1., 0.], dtype=float32)]
queue is empty
'''

'''
k = 1
OutOfRangeError (see above for traceback): End of sequence
'''

Reference:

[1] tf.train.input_producer
[2] tensorflow中协调器 tf.train.Coordinator 和入队线程启动器 tf.train.start_queue_runners

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值