1.用 tf.train.match_filenames_once() 获取存储训练数据的文件列表。(数据已转为 TFRecord 格式的多个文件)
2.用 tf.train.string_input_producer() 创建输入文件队列,可以将输入文件顺序随机打乱(shuffle = True)
3.用 tf.TFRecordReader() 读取TFrecords文件中的数据。
4.用 tf.parse_single_example() 解析数据
5.对数据进行解码及预处理(使用图像处理函数)
6.用 tf.train.shuffle_batch() 将数据组合成 batch
7.将 batch 用于训练。
在这一部分中,将自己的数据转为TFrecords文件还不是很会。下面贴一个书上的示例:
import tensorflow as tf
import numpy as np
files = tf.train.match_filenames_once("/path/to/data.tfrecords-*") # 此函数获取一个符合正则表达式的所有文件
file_queue = tf.train.string_input_producer(files, shuffle=False)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_queue)
# 解析数据
features = tf.parse_single_example(serialized_example, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channel': tf.FixedLenFeature([], tf.int64),
})
image, label = features['image'], features['label']
height, width = features['heiht'], features['width']
channel = features['channel']
image_decode = tf.decode_raw(image, tf.uint8)
image_decode.set_shape([height, width, channel])
# 假设神经网络输入层的图片大小为300
image_size = 300
distorted_image = preprocess_for_train(image_decode, image_size, image_size, None)
# 将处理过后的图像和标签数据通过tf.train.shuffle_batch 整理成神经网络训练时需要的batch
min_after_dequeue = 1000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size
# tf.train.shuffle_batch函数的入队操作就是数据处理以及预处理的过程
image_batch, label_batch = tf.train.shuffle_batch([distorted_image, label], batch_size, capacity, min_after_dequeue, num_threads=5)
# 定义神经网络的优化结构以及优化过程
logit = inference(image_batch)
loss = cal_loss(loss, label_batch)
train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
#启动线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
for i in range(TRAINING_EPOCHS):
_, loss = sess.run([train_op, loss])
# -------停止线程
coord.request_stop()
coord.join(threads)
本文介绍如何使用TensorFlow从TFRecord文件加载数据,并通过多个步骤处理这些数据以供神经网络训练使用。包括获取文件列表、创建文件队列、读取文件、解析数据、图像预处理、组织数据批(batch)等。
2539

被折叠的 条评论
为什么被折叠?



