Tensorflow中的tf.train.batch函数的使用

本文探讨了TensorFlow中tf.train.batch函数的应用,通过实例演示如何处理数据队列,确保多线程环境下数据读取的有序性。文章对比了tf.train.batch与tf.train.shuffle_batch的区别,并介绍了slice_input_producer参数的作用。

tensorflow中的读取数据的队列,不论是对正喜欢Python开发的人来说,还是对已经走入工作中的Python开发工程师来说都是比较难懂的。原因可能之大家对这方面的经验不足吧。下面扣丁学堂Python培训小编就和大家分享一下关于Tensorflow中的tf.train.batch函数的使用。

tensorflow中的读取数据的队列,简单的说,就是计算图是从一个管道中读取数据的,录入管道是用的现成的方法,读取也是。为了保证多线程的时候从一个管道读取数据不会乱吧,所以这种时候读取的时候需要线程管理的相关操作。今天给大家分享一个简单的操作,就是给一个有序的数据,看看读出来是不是有序的,结果发现是有序的,所以直接给代码:

import tensorflow as tf
import numpy as np

def generate_data():
    num = 25
    label = np.asarray(range(0, num))
    images = np.random.random([num, 5, 5, 3])
    print('label size :{}, image size {}'.format(label.shape, images.shape))
    return label, images

def get_batch_data():
    label, images = generate_data()
    images = tf.cast(images, tf.float32)
    label = tf.cast(label, tf.int32)
    input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
    image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
    return image_batch, label_batch

image_batch, label_batch = get_batch_data()
with tf.Session() as sess:
    sess.run(tf.initialize_local_variables())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)
    i = 0
    try:
        while not coord.should_stop():
            image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
            i += 1
            for j in range(10):
                print(image_batch_v.shape, label_batch_v[j])
    except tf.errors.OutOfRangeError:
        print("done")
    finally:
        coord.request_stop()
    coord.join(threads)

记得那个slice_input_producer方法,默认是要shuffle的哈。

此外,我想评论一下这段代码。

1:slice_input_producer中有一个参数“ num_epochs”,它控制slice_input_producer方法可以工作多少个epochs。 当此方法运行指定的epochs时,它将报告OutOfRangeRrror。 我认为这对我们控制训练时期很有用。
2:此方法的输出是一个单一图像,我们可以使用tensorflow API操作该单一图像,例如归一化,作物等,然后将此单一图像馈入批处理方法,将一批图像进行训练或测试 将会收到。

tf.train.batch和tf.train.shuffle_batch的区别用法

tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)

[example, label]表示样本和样本标签,这个可以是一个样本和一个样本标签,batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量。这主要是按顺序组合成一个batch

tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue)

这里面的参数和上面的一样的意思。不一样的是这个参数min_after_dequeue,一定要保证这参数小于capacity参数的值,否则会出错。这个代表队列中的元素大于它的时候就输出乱的顺序的batch。也就是说这个函数的输出结果是一个乱序的样本排列的batch,不是按照顺序排列的。

上面的函数返回值都是一个batch的样本和样本标签,只是一个是按照顺序,另外一个是随机的。

注意:tf.train.batch这个函数的实现是使用queue,queue的QueueRunner被添加到当前计算图的"QUEUE_RUNNER"集合中,所在使用初始化器的时候,需要使用tf.initialize_local_variables(),如果使用tf.global_varialbes_initialize()时,会报: Attempting to use uninitialized value 

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值