【多线程管理】TensorFlow线程管理器——Coordinator

TensorFlow线程管理实战
本文介绍TensorFlow中的线程管理器Coordinator,并通过实例演示其使用方法。此外,还介绍了如何利用QueueRunner和RandomShuffleQueue改进多线程管理,确保程序在遇到异常时能正确关闭所有线程。
部署运行你感兴趣的模型镜像

TensorFlow线程管理器——Coordinator

  简介

       再现是过程中,有效的运行多个线程可能会更加复杂。线程应该能够正常停止(避免“僵尸”线程,或当一个线程出故障时,要一起关闭所有线程),停止后需要关闭线程队列,还有很多的问题需要解决。
       
        TensorFlow 配备了工具帮我们完成这个过程。 CoordinatorTensorFlow 的一个多线程管理器,用来管理在 Session 中的多个线程,可以用来同时停止多个工作线程。除此之外,它还能向等待所有工作线程终止的程序报告异常,该线程捕获到异常后就会终止所有线程。
       

  代码示例

       可以通过下面的代码学习 Coordinator 的使用方法:

import tensorflow as tf
import threading
import time

gen_random_normal = tf.random_normal(shape=())
queue = tf.FIFOQueue(capacity=100, dtypes=[tf.float32], shapes=())
enque = queue.enqueue(gen_random_normal)

with tf.Session() as sess:
    def add(coord, i):
        try:
            while not coord.should_stop():
                sess.run(enque)
                if i == 9:
                    # 抛出异常做测试
                    raise Exception("Invalid id: !", i)
        except Exception:
            print("Invalid id: !", i)
        finally:
            # 请求关闭所有线程
            coord.request_stop()


    coord = tf.train.Coordinator()
    threads = [threading.Thread(target=add, args=(coord, i)) for i in range(10)]
    coord.join(threads)

    for t in threads:
        t.start()

    print(sess.run(queue.size()))
    time.sleep(0.003)
    print(sess.run(queue.size()))
    time.sleep(0.01)
    print(sess.run(queue.size()))

    #把开启的线程加入主线程,等待threads结束
    coord.join(threads)
    print('所有进程终止!')

  测试结果

	0
	Invalid id: ! 9
	30
	33
	所有进程终止!

       

  使用QueueRunner,RandomShuffleQueue改进

       虽然我们可以创建多个线程重复运行入队操作,但最好使用内置的 tf.train.QueueRunner ,这会完成我们需要偶的动作,且遇到异常会关闭队列
       
       改进代码如下:

import tensorflow as tf

gen_random_normal = tf.random_normal(shape=())
queue = tf.RandomShuffleQueue(capacity=100, dtypes=[tf.float32], min_after_dequeue=1)
enqueue = queue.enqueue(gen_random_normal)

qr = tf.train.QueueRunner(queue, [enqueue] * 4)
coord = tf.train.Coordinator()

with tf.Session() as sess:

    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)

    for i in range(5):
        my_dequeue = sess.run(queue.dequeue())
        print(my_dequeue)

    coord.request_stop()
    coord.join(enqueue_threads)

       在这个例子中,我们使用 tf.RandomShuffleQuene 而不是FIFO队列。RandomShuffleQuene 就是一个带有出队操作的队列,它一随机的顺序弹出项。当使用随机梯度下降优化来训练深度神经网络使,这是非常有用的,当然我们需要对数据进行随机排序。
       
       min_after_dequeue参数指定在调用出队操作之后将保留在队列中的最小数目——更大的数字需要更好的混合(随机采样),也需要更多的内存。
       
       输出结果:

	-0.5699217
	0.23737773
	0.018851127
	0.8081313
	1.6825532
	

       

本文示例参考《TensorFlow学习指南——深度学习系统构建详解》第八章第三节

       
       欢迎各位大佬交流讨论!

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值