Tensorflow数据读取方式
关于tensorflow(简称TF)数据读取方式,官方给出了三种:
供给数据(Feeding):在TF程序运行的每一步,让python代码来供给数据。
从文件读取数据:在TF图的起始,让每一个管线从文件中读取数据。
预加载数据 :在TF图中定义常量或者变量来保存数据(使用数据量较小的情况)。
的
一、供给数据
TF的数据供给机制允许在TF运算图中将数据注入发到任一张量(tensor)。
通过run(),eval()函数输入到feed_dict()中,如:
with tf.Session() as sess:
sess.run(init)
......
train_accuracy = accuracy.eval(feed_dict={x_input: batch_x, y_labels: batch_y, keep_prob: 1.0})
......
train_step.run(feed_dict={x_input: batch_x, y_labels: batch_y, keep_prob: 0.5})
......
上述代码中,x_input,y_labels为张量。虽然可以使用常量和变量来代替张量,但是在TF中,最好还是使用 op节点
x_input = tf.placeholder(tf.float32, [None, 32,32,3],name='Mul')
y_labels = tf.placeholder(tf.float32,[None,62])
上述代码声明了x_input,y_labels张量,但是张量未被初始化,也不包含数据。
二、从文件读取数据
Kaggle比赛中最常见的数据格式是CSV文件,以读取CSV文件为例进行说明。直接上代码:
(代码来至http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html#Feeding)
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.concat(0, [col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)
首先将文件名列表交给tf.train.string_input_producer函数来生成一个先进先出的队列。
文件列表的表达方式为[“file0.csv”, “file1.csv”],也可以用[(“file%d” % i) for i in range(2)] 或者使用tf.train.match_filenames_once函数来生成。
使用TF读取CSV文件需要使用tf.TextLineReader和tf.decode_csv两个函数,每次read操作都会从文件中读取一行内容,而decode_csv会将这一行内容转为张量列表,如果输入的参数有缺失,record_default参数会根据张量的类型来设置默认值。
在调用run和eval去执行read时,需要使用tf.train_start_queue_runners来将文件名填充到队列。否则,read操作会被阻塞,直到文件名队列中有值为止。(从二进制文件中读取固定长度记录,可以使用tf.FixedLengthRecordReader的tf.decode_raw操作,tf.decode_raw可以将一个字符串转换成Uint8的张量)
三、预加载数据
当数据量较小时,一般选择直接将数据加载到内存中,然后再分Batch的输入到网络中。
这边举个简单的例子:使用python读取图片数据
不同类别的数据存储在四个不同的文件夹下。使用如下代码进行读取。
def load_data(data_dir):
directories = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
labels = []
images = []
for d in directories:
label_dir = os.path.join(data_dir, d)
file_names = [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith(".jpg")]
for f in file_names:
img = skimage.data.imread(f)
img299 = skimage.transform.resize(img,(299,299))
images.append(img299)
labels.append(int(d))
return images, labels
调用上述代码得到的是list数据,需要调用如下代码转变成array类型。
images299 = [image for image in images]
images_x = np.array(images299)
labels_x= np.array(labels)
在数据量较大时,预加载数据就不现实了,因为太耗内存。所以这时就是使用上诉三种方法中的第二种:从文件读取数据。
如果要读取图片数据,可以将其转换成TF中的标准支持格式tfrecords,它是一种二进制文件,能够很好的利用内存,且方便复制和移动。
直接上代码(主要参考:http://blog.youkuaiyun.com/u012759136/article/details/52232266)
import os
import tensorflow as tf
from PIL import Image
cwd = os.getcwd()
writer = tf.python_io.TFRecordWriter("train.tfrecords")
classes = ['0','1','2','3']
for index, name in enumerate(classes):
class_path = cwd + '\\' + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((100, 100))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()