from os import listdir
from os.path import isfile,join
import tensorflow as tf
import matplotlib.pyplot as plt
dataset_path = 'I:/raincode/TensorFlow/CVPR17_training_code/TrainData/input'#路径
with tf.Session() as sess:
filenames = [join(dataset_path,f) for f in listdir(dataset_path) if isfile(join(dataset_path,f))]
print('number of images:',len(filenames))
filename_queue = tf.train.string_input_producer(filenames,shuffle=False,num_epochs=1)
reader = tf.WholeFileReader()
name,img_bytes = reader.read(filename_queue)
image = tf.image.decode_jpeg(img_bytes,channels=3)
dataname = tf.train.batch([name],2,dynamic_pad=True)
init = tf.global_variables_initializer()
inits = tf.local_variables_initializer()
sess.run([init,inits])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
try:
while not coord.should_stop():
print(sess.run(dataname))
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
结果:
首先产生一个存储路径下所有文件名的List,然后用它作为输入产生字符串队列,这是pipeline的前端。得到了出队的字符串序列,我们可以使用tensorflow中的filereader,将文件内容读取出来,reader.read函数返回{文件名,文件内容}键值对,文件内容即是我们需要处理的对象(这里为了直观,我们使用了文件名作为输出),这个阶段是pipeline的中端。最后得到了文件名序列,再用batch函数提取一个batch,这个作为pipeline的末端。这样一个程序就完成了。