最近在学TensorFlow,在此把学到的东西记录一下。
在学习别人代码时,遇到多线程训练的问题,代码截取部分如下:
image, label = ReadMyOwnData.read_and_decode("dog_and_cat_train.tfrecords")
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
example = np.zeros((batch_size,128,128,3))
l = np.zeros((batch_size,1))
try:
for i in range(50):
for epoch in range(batch_size):
example[epoch], l[epoch] = sess.run([image,label])#在会话中取出image和label
train_step.run(feed_dict={x: example, y_: l, keep_prob: 0.5})
print(accuracy.eval(feed_dict={x: example, y_: l, keep_prob: 1.0})) #eval函数类似于重新run一遍,验证,同时修正
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
这边用到了多线程,之前没有接触过,于是查询了网上的一些教程,参考网址如下,第一个为原文网站,后极客学院做了翻译
http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/threading_and_queues.html
一个典型的输入结构:是使用一个RandomShuffleQueue
来作为模型训练的输入:
- 多个线程准备训练样本,并且把这些样本推入队列。
- 一个训练线程执行一个训练操作,此操作会从队列中移除最小批次的样本(mini-batches)。
TensorFlow的Session
对象是可以支持多线程的,因此多个线程可以很方便地使用同一个会话(Session)并且并行地执行操作。然而,在Python程序实现这样的并行运算却并不容易。所有线程都必须能被同步终止,异常必须能被正确捕获并报告,回话终止的时候, 队列必须能被正确地关闭。
所幸TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一起使用。Coordinator
类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常。QueueRunner
类用来协调多个工作线程同时将多个张量推入同一个队列中。
Coordinator
Coordinator类用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:
should_stop()
:如果线程应该停止则返回True。request_stop(<exception>)
: 请求该线程停止。join(<list of threads>)
:等待被指定的线程终止。
首先创建一个Coordinator
对象,然后建立一些使用Coordinator
对象的线程。这些线程通常一直循环运行,一直到should_stop()
返回True时停止。 任何线程都可以决定计算什么时候应该停止。它只需要调用request_stop()
,同时其他线程的should_stop()
将会返回True
,然后都停下来。Coordinator可以管理线程去做不同的事情,还支持捕捉和报告异常。
QueueRunner
QueueRunner
类会创建一组线程, 这些线程可以重复的执行Enquene操作, 他们使用同一个Coordinator来处理线程同步终止。此外,一个QueueRunner会运行一个closer thread,当Coordinator收到异常报告时,这个closer thread会自动关闭队列。
您可以使用一个queue runner,来实现上述结构。 首先建立一个TensorFlow图表,这个图表使用队列来输入样本。增加处理样本并将样本推入队列中的操作。增加training操作来移除队列中的样本。
在Python的训练程序中,创建一个QueueRunner
来运行几个线程, 这几个线程处理样本,并且将样本推入队列。创建一个Coordinator
,让queue runner使用Coordinator
来启动这些线程,创建一个训练的循环, 并且使用Coordinator
来控制QueueRunner
的线程们的终止。
异常处理
通过queue runners启动的线程不仅仅只处理推送样本到队列。他们还捕捉和处理由队列产生的异常,包括OutOfRangeError
异常,这个异常是用于报告队列被关闭。 使用Coordinator
的训练程序在主循环中必须同时捕捉和报告异常。 下面是对上面训练循环的改进版本。