tensorflow中的读取数据的队列,不论是对正喜欢Python开发的人来说,还是对已经走入工作中的Python开发工程师来说都是比较难懂的。原因可能之大家对这方面的经验不足吧。下面扣丁学堂Python培训小编就和大家分享一下关于Tensorflow中的tf.train.batch函数的使用。

tensorflow中的读取数据的队列,简单的说,就是计算图是从一个管道中读取数据的,录入管道是用的现成的方法,读取也是。为了保证多线程的时候从一个管道读取数据不会乱吧,所以这种时候读取的时候需要线程管理的相关操作。今天给大家分享一个简单的操作,就是给一个有序的数据,看看读出来是不是有序的,结果发现是有序的,所以直接给代码:
import tensorflow as tf
import numpy as np
def generate_data():
num = 25
label = np.asarray(range(0, num))
images = np.random.random([num, 5, 5, 3])
print('label size :{}, image size {}'.format(label.shape, images.shape))
return label, images
def get_batch_data():
label, images = generate_data()
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
return image_batch, label_batch
image_batch, label_batch = get_batch_data()
with tf.Session() as sess:
sess.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
i = 0
try:
while not coord.should_stop():
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
i += 1
for j in range(10):
print(image_batch_v.shape, label_batch_v[j])
except tf.errors.OutOfRangeError:
print("done")
finally:
coord.request_stop()
coord.join(threads)
记得那个slice_input_producer方法,默认是要shuffle的哈。
此外,我想评论一下这段代码。
1:slice_input_producer中有一个参数“ num_epochs”,它控制slice_input_producer方法可以工作多少个epochs。 当此方法运行指定的epochs时,它将报告OutOfRangeRrror。 我认为这对我们控制训练时期很有用。
2:此方法的输出是一个单一图像,我们可以使用tensorflow API操作该单一图像,例如归一化,作物等,然后将此单一图像馈入批处理方法,将一批图像进行训练或测试 将会收到。
tf.train.batch和tf.train.shuffle_batch的区别用法
tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)
[example, label]表示样本和样本标签,这个可以是一个样本和一个样本标签,batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量。这主要是按顺序组合成一个batch
tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue)
这里面的参数和上面的一样的意思。不一样的是这个参数min_after_dequeue,一定要保证这参数小于capacity参数的值,否则会出错。这个代表队列中的元素大于它的时候就输出乱的顺序的batch。也就是说这个函数的输出结果是一个乱序的样本排列的batch,不是按照顺序排列的。
上面的函数返回值都是一个batch的样本和样本标签,只是一个是按照顺序,另外一个是随机的。
注意:tf.train.batch这个函数的实现是使用queue,queue的QueueRunner被添加到当前计算图的"QUEUE_RUNNER"集合中,所在使用初始化器的时候,需要使用tf.initialize_local_variables(),如果使用tf.global_varialbes_initialize()时,会报: Attempting to use uninitialized value
本文探讨了TensorFlow中tf.train.batch函数的应用,通过实例演示如何处理数据队列,确保多线程环境下数据读取的有序性。文章对比了tf.train.batch与tf.train.shuffle_batch的区别,并介绍了slice_input_producer参数的作用。
740

被折叠的 条评论
为什么被折叠?



