tensorflow - 队列、多线程

本文介绍了在TensorFlow中如何使用队列和多线程,重点关注tf.Coordinator和tf.QueueRunner。tf.Coordinator负责协调多线程的停止,提供should_stop、request_stop和join方法。当should_stop返回True时,线程结束。tf.QueueRunner用于启动多个线程处理同一队列,可结合tf.Coordinator管理线程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、总览
  今天为小主们简单得介绍一下tensorflow中有关队列与多线程的基本使用。主要讲解tf.Coordinator 和 tf.QueueRunner.

二、具体介绍
  首先来说一说tf.Coordinator,它的作用是协同多个线程一起停止,并提供了should_stop、request_stop、join三个函数,当should_stop函数的返回值为True时,当前线程退出;每一个启动过的线程都可以调用request_stop函数通知其他所有线程退出,当调用该函数后,should_stop的返回值就便成了True,其余的所有线程都会退出。在启动线程之前,需要先声明一个tf.CoorDinator类,具体代码如下:

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
# author:Dr.Shang

import tensorflow as tf
import numpy as np
import threading
import time


def print_myid(coord, my_id):
    """
    线程中所运行的程序,每隔1秒判断should_stop是否为True,并且打印当前线程的编号
    :param coord: 传入的tf.Coordinator实例
    :param my_id: 每一个线程编号
    :return: none
    """
    # 判断是否需要停止当前线程
    while not coord.should_stop():
        # 随机停止所有线程
        if np.random.rand() < 0.5:
            print("Stoping from id: %d" %  my_id)
            coord.request_stop()
        else:
            # 打印当前线程的ID
            print("Working on id %d" % my_id)
        # 暂停一秒
        time.sleep(1)


# 声明一个tf.train.Coordinator类的实例来协同多个线程
coord = tf.train.Coordinator()
# 创建5个线程
threads = [
    threading.Thread(target=print_myid, args=(coord, i,)) for i in range (5)
]
# 启动线程
for i in threads: i.start()
# waiting for all threads to exit
coord.join(threads)

output may be like:
Working on id 0
Working on id 1
Working on id 2
Stoping from id: 3

  接着,来说一说“tf.QueueRunner”,它主要用于启动多个线程来操作同一个队列,启动的多个线程可以通过刚刚介绍的tf.Coordinator来管理。代码如下:

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
# author:Dr.Shang


import tensorflow as tf

# declare a FIFO queue which less than 100 float elements
queue = tf.FIFOQueue(100, "float")
# define the enqueue operation
enqueue_op = queue.enqueue([tf.random_normal([1])])

# create  multiple queues to run the enqueue op
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)

# add the defined qr into the calculate graph collection of TensorFlow
# if you don't set a collection, it will be added into the Default calculate graph collection
# - tf.GraphKeys.Queue_RUNNERS.
tf.train.add_queue_runner(qr)
# define the dequeue op
out_tensor = queue.dequeue()

with tf.Session() as sess:
    # use tf.train.Coordinator to coordinate started threads
    coord = tf.train.Coordinator()
    # use the tf.train.start_queue_runners to start threads(which are defined in qr) before use tf.train.QueueRunner
    threads = tf.train.start_queue_runners(sess, coord)
    for _ in range(5): print(sess.run(out_tensor))

    coord.request_stop()
    coord.join(threads)

output may be like:
[-1.4820653]
[0.71643615]
[0.7950299]
[-1.2422307]
[2.3382144]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值