一、总览
今天为小主们简单得介绍一下tensorflow中有关队列与多线程的基本使用。主要讲解tf.Coordinator 和 tf.QueueRunner.
二、具体介绍
首先来说一说tf.Coordinator,它的作用是协同多个线程一起停止,并提供了should_stop、request_stop、join三个函数,当should_stop函数的返回值为True时,当前线程退出;每一个启动过的线程都可以调用request_stop函数通知其他所有线程退出,当调用该函数后,should_stop的返回值就便成了True,其余的所有线程都会退出。在启动线程之前,需要先声明一个tf.CoorDinator类,具体代码如下:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# author:Dr.Shang
import tensorflow as tf
import numpy as np
import threading
import time
def print_myid(coord, my_id):
"""
线程中所运行的程序,每隔1秒判断should_stop是否为True,并且打印当前线程的编号
:param coord: 传入的tf.Coordinator实例
:param my_id: 每一个线程编号
:return: none
"""
# 判断是否需要停止当前线程
while not coord.should_stop():
# 随机停止所有线程
if np.random.rand() < 0.5:
print("Stoping from id: %d" % my_id)
coord.request_stop()
else:
# 打印当前线程的ID
print("Working on id %d" % my_id)
# 暂停一秒
time.sleep(1)
# 声明一个tf.train.Coordinator类的实例来协同多个线程
coord = tf.train.Coordinator()
# 创建5个线程
threads = [
threading.Thread(target=print_myid, args=(coord, i,)) for i in range (5)
]
# 启动线程
for i in threads: i.start()
# waiting for all threads to exit
coord.join(threads)
output may be like:
Working on id 0
Working on id 1
Working on id 2
Stoping from id: 3
接着,来说一说“tf.QueueRunner”,它主要用于启动多个线程来操作同一个队列,启动的多个线程可以通过刚刚介绍的tf.Coordinator来管理。代码如下:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# author:Dr.Shang
import tensorflow as tf
# declare a FIFO queue which less than 100 float elements
queue = tf.FIFOQueue(100, "float")
# define the enqueue operation
enqueue_op = queue.enqueue([tf.random_normal([1])])
# create multiple queues to run the enqueue op
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
# add the defined qr into the calculate graph collection of TensorFlow
# if you don't set a collection, it will be added into the Default calculate graph collection
# - tf.GraphKeys.Queue_RUNNERS.
tf.train.add_queue_runner(qr)
# define the dequeue op
out_tensor = queue.dequeue()
with tf.Session() as sess:
# use tf.train.Coordinator to coordinate started threads
coord = tf.train.Coordinator()
# use the tf.train.start_queue_runners to start threads(which are defined in qr) before use tf.train.QueueRunner
threads = tf.train.start_queue_runners(sess, coord)
for _ in range(5): print(sess.run(out_tensor))
coord.request_stop()
coord.join(threads)
output may be like:
[-1.4820653]
[0.71643615]
[0.7950299]
[-1.2422307]
[2.3382144]