前面其实对输入tensorflow数据集的构造和输入那一块的认知比较模糊,所以抽了点时间解析了一下官方代码。
大概顺序如下:
1.输入所需图片的地址,然后放到tf.train.string_input_producer中进行管理,注意tf.train.string_input_producer中
只是图片的地址,不是图片的值。
2.然后用各种读取器读取地址中的数据(图片,标签),用的是
reader=tf.FixedLengthRecordReader(record_bytes=record_bytes),这是按字节来读的
还有
reader = tf.WholeFileReader()这是直接读完的
前者适用于一个文件中有很多组数据的,后者适用于一个文件一个数据组的
然后就是统一的流程:
result.key, value = reader.read(filename_queue)
从这里读出来的value其实数据形态是string类型的,因为trfrecord就是这么转换保存的,所以还要把读出来的数据进行一些处理(比如把label和image分开,然后转换成各自所需的数据类型)
record_bytes = tf.decode_raw(value, tf.uint8)
3.为了增强图像的稳健性,可以对图像进行一系列的操作,比如旋转,亮度,对比度,区域裁剪等操作。对了,一定要注意标准化,不然准确率上不去!!!!!
4.此时的图像都是tensor值得形式了,然后就是放入tf.train.batch或者tf.train.shuffle_batch(打乱)中进行管理,以便数据提取。也就是在这里才引入线程的
images, label_batch = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,#线程数
capacity=min_queue_examples + 3 * batch_size)