Tensorflow使用TFRecord构建自己的数据集并读取
参考文章:
http://blog.youkuaiyun.com/freedom098/article/details/56011858
还有 优酷上kevin大神的视频
目标:1、将自己的数据集以TFRecord格式存储。
2、从TFRecord中读取数据,并使用画图工具,以图片形式展现。
以一个图片为例:
一、将图片存储TFRecod
# 生成整数型的属性
def _int64_feature(value):
if not isinstance(value,list):
value=[value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
#生成字符串型的属性
def _byte_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
#将图片存储到tfrecord中
def convert_to_tfrecord(images, labels, save_dir, name):#从图片路径读取图片编码成tfrecord
'''''convert all images and labels to one tfrecord file.
Args:
images: list of image directories, string type
labels: list of labels, int type
save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/'
name: the name of tfrecord file, string type, e.g.: 'train'
Return:
no return
Note:
converting needs some time, be patient...
'''
filename = (save_dir + name + '.tfrecords')
n_samples = len(labels)
#判断 image的样本数量和label是否相同
if np.shape(images)[0] != n_samples:
raise ValueError('Images size %d does not match label size %d.' %(images.shape[0], n_samples))
writer = tf.python_io.TFRecordWriter(filename)
print('\nTransform start......')
for i in range(len(images)):
try:
image_raw_data = tf.gfile.FastGFile(images[i],'r').read()
img_data = tf.image.decode_png(image_raw_data)
label = int(labels[i])
example = tf.train.Example(features=tf.train.Features(feature={
'label':int64_feature(label),
'image_raw': bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
except IOError as e:
print('Could not read:', images[i])
print('error: %s' %e)
print('Skip it!\n')
writer.close()
二、读取数据,并绘图
# read the data from tfrecoder
def read_and_decode(tfrecords_file):
'''''read and decode tfrecord file, generate (image, label) batches
Args:
tfrecords_file: the directory of tfrecord file
batch_size: number of images in each batch
Returns:
image: 4D tensor - [batch_size, width, height, channel]
label: 1D tensor - [batch_size]
'''
# make an input queue from the tfrecord file
filename_queue = tf.train.string_input_producer([tfrecords_file])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
#解析读入的样例
img_features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
})
#将字符串解析成相应的数组
image = tf.decode_raw(img_features['image_raw'], tf.uint8)
#转化成图片的格式
image = tf.reshape(image, [465, 315,3])
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
image , label = sess.run([image,label])
print image
plt.imshow(image)
plt.show()
sess.close()
read_and_decode('/home/tensor/Desktop/tia.tfrecords')
三、结果