tensorflow编程一些需要知道的 - 3

本文介绍了TensorFlow中实现异步数据读取的队列机制,包括FIFOQueue和RandomShuffleQueue等组件的使用方法。并通过实例展示了如何利用tf.Coordinator和tf.QueueRunner来管理多线程操作,提高训练效率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


做训练时,我们往往要处理大批量的数据,这时如果有个可以异步读取的方式,那么处理程序会更加灵活和高效。FIFOQueue 、RandomShuffleQueue 便是tensorflow提供的一些通过队列来做异步数据存取的方法,并且它是多线程的(tf.Session对象就是多线程的)。这个框架如下图




由于整个程序是多线程的,因此我们可以在同一个session里并行跑多个ops。为此,tf一共了tf.Coordinator 和 tf.QueueRunner来帮助我们编写这样的程序,使得我们可以方便地处理多个线程的启动、停止、异常捕获等。Coordinator类让我们的多个线程同时停止、把异常抛给调用的地方。QueueRunner类针对enqueue ops创建多个线程,而这些线程可以通过Coordinator类来同时停止,并当有异常发生时通过一个关闭线程来自动关闭这些线程。下面是一个对上述的示例

import tensorflow as tf

example = ...ops to create one example...
#Step 1. 为输入构建一个queue,并将这些输入enqueue进去
queue = tf.RandomShuffleQueue(...)
enqueue_op = queue.enqueue(example)

#Step 2. 从queue里dequeue样本进行训练 
inputs = queue.dequeue_many(batch_size)
train_op = ...use 'inputs' to build the training part of the graph...

#Note: 以上可以通过tf.train.string_input_producer来完成

#Step 3. 通过4个线程来并发enqueue_op样本
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)

# Launch the graph.
sess = tf.Session()
#构建Coordinator, 启动QueueRunner
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
# 跑训练迭代。这些线程除了跑enqueue_op, 而且捕获并处理异常, 通过coordinator来在主循环中处理。
try:
    for step in xrange(1000000):
        if coord.should_stop():
            break
        sess.run(train_op)
except Exception, e:
    # 将异常抛给coordinator,通知线程停止
    coord.request_stop(e)
finally:
    coord.request_stop()
    coord.join(enqueue_threads)





评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

mao_feng

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值