tensorflow使用range_input_producer多线程读取数据

本文介绍如何使用TensorFlow中的range_input_producer方法创建队列,并通过调整参数num_epochs来控制队列元素的数量,实现数据集的批次划分。具体展示了不同参数设置下,如何从文本文件中读取数据并进行批次划分。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

转载:http://blog.youkuaiyun.com/lyg5623/article/details/69387917
有少许代码修改

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()  
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])  

原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2…
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

代码:

# encoding: UTF-8
import tensorflow as tf
import codecs

BATCH_SIZE = 6
NUM_EXPOCHES = 5


def input_producer():
    array = codecs.open("test.txt").readlines()
    print(array)
    array = list(map(lambda line: line.strip('\n'), array))
    print(array)
    i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
    inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
    return inputs


class Inputs(object):
    def __init__(self):
        self.inputs = input_producer()


def main(*args, **kwargs):
    inputs = Inputs()
    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    sess = tf.Session()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    sess.run(init)
    try:
        index = 0
        while not coord.should_stop() and index<10:
            datalines = sess.run(inputs.inputs)
            index += 1
            print("step: %d, batch data: %s" % (index, str(datalines)))
    except tf.errors.OutOfRangeError:
        print("Done traing:-------Epoch limit reached")
    except KeyboardInterrupt:
        print("keyboard interrput detected, stop training")
    finally:
        coord.request_stop()
    coord.join(threads)
    sess.close()
    del sess

if __name__ == "__main__":
    main()

输出:

step: 1, batch data: [b'1' b'2' b'3' b'4' b'5' b'6']
step: 2, batch data: [b'7' b'8' b'9' b'10' b'11' b'12']
step: 3, batch data: [b'13' b'14' b'15' b'16' b'17' b'18']
step: 4, batch data: [b'19' b'20' b'21' b'22' b'23' b'24']
step: 5, batch data: [b'25' b'26' b'27' b'28' b'29' b'30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: [b'1' b'2' b'3' b'4' b'5' b'6']
step: 2, batch data: [b'7' b'8' b'9' b'10' b'11' b'12']
step: 3, batch data: [b'13' b'14' b'15' b'16' b'17' b'18']
step: 4, batch data: [b'19' b'20' b'21' b'22' b'23' b'24']
step: 5, batch data: [b'25' b'26' b'27' b'28' b'29' b'30']
step: 6, batch data: [b'1' b'2' b'3' b'4' b'5' b'6']
step: 7, batch data: [b'7' b'8' b'9' b'10' b'11' b'12']
step: 8, batch data: [b'13' b'14' b'15' b'16' b'17' b'18']
step: 9, batch data: [b'19' b'20' b'21' b'22' b'23' b'24']
step: 10, batch data: [b'25' b'26' b'27' b'28' b'29' b'30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6  
     [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]  

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值