tensorflow数据读取机制
tensorflow
中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算。
具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程负责计算任务,所需数据直接从内存队列中获取。
tf在内存队列之前,还设立了一个文件名队列,文件名队列存放的是参与训练的文件名,要训练 N个epoch,则文件名队列中就含有N个批次的所有文件名。 示例图如下:
在 N N N个epoch
的文件名最后是一个结束标志,当tf
读到这个结束标志的时候,会抛出一个 OutofRange
的异常,外部捕获到这个异常之后就可以结束程序了。而创建tf
的文件名队列就需要使用到tf.train.slice_input_producer
函数。
tf.train.slice_input_producer官方说明
tf.train.slice_input_producer
解释
tf.train.slice_input_producer
是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。
- 第一个参数
tensor_list
:包含一系列tensor的列表,表中tens