最近没有更新博客了,可能是变懒了吧...持续学习的确是一件比较辛苦的事情,废话好像又多了那么一点点。
今天主要记录一下如何将TFRecord文件转化成Dataset数据集。
上代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ProjectName : 07_create_tfrecord_dataset.py
# @DateTime : 2019-12-08 21:09
# @Author : 皮皮虾
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
print(tf.__version__)
def dataset(directory, size, batchsize):
"""parse dataset"""
def _parseone(example_proto):
"""reading and hadle image"""
# 定义解析的字典
dics = {}
dics["label"] = tf.FixedLenFeature(shape=[], dtype=tf.int64)
dics["img_raw"] = tf.FixedLenFeature(shape=[], dtype=tf.string)
# 解析一行样本
parsed_example = tf.parse_single_example(serialized=example_proto, features=dics)
image = tf.decode_raw(bytes=parsed_example["img_raw"], out_type=tf.uint8)
image = tf.reshape(tensor=image, shape=size)
# 类型转换
image = tf.cast(x=image, dtype=tf.float32) * (1. / 255) - 0.5
label = parsed_example["label"]
label = tf.cast(x=label, dtype=tf.int32)
# 转换为one-hot编码
label = tf.one_hot(indices=label, depth=2, on_value=1)
return image, label
dataset = tf.data.TFRecordDataset(directory)
dataset = dataset.map(_parseone)
dataset = dataset.batch(batch_size=batchsize)
dataset = dataset.prefetch(buffer_size=batchsize)
return dataset
def show_single_image(subplot, label, image):
plt.subplot(subplot)
plt.axis("off")
plt.imshow(image)
plt.title(label=label)
def show_batch_image(label, image, top):
plt.figure(figsize=(20, 10))
plt.axis("off")
top = min(top, 9)
for i in range(top):
show_single_image(subplot=100 + 10 * top + 1 + i, label=label[i], image=image[i])
plt.show()
def get_one(dataset):
# 生成一个迭代器
iterator = dataset.make_one_shot_iterator()
# 从迭代器中取出一个元素
one_element = iterator.get_next()
return one_element
if __name__ == '__main__':
sample_dir = ["my_tfrecord_data.tfrecords"]
size = [256, 256, 3]
batchsize = 10
tf_dataset = dataset(directory=sample_dir, size=size, batchsize=batchsize)
print(tf_dataset.output_types)
print(tf_dataset.output_shapes)
one_element = get_one(dataset=tf_dataset)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
for step in np.arange(1):
value = sess.run(one_element)
show_batch_image(label=value[1], image=np.asarray((value[0]+0.5)*255, np.uint8), top=10)
except tf.errors.OutOfRangeError:
print("finish!!!")