TFrecords 制作与读取
- 制作tfrecords
import tensorflow as tf
import os
from PIL import Image
# add为XM、MX文件夹的上层目录
add = "/home/"
# 设定图片分类,与要制作的图片所属文件名相同
classes = {'0','1','2','3','4','5','6','7','8','9'}
# 给出输出的位置以及输出的文件名
writer = tf.python_io.TFRecordWriter("/home/crxm/test.tfrecords")
for index, name in enumerate(classes):
# 写入图片文件路径
class_path = add + name + '/'
for img_name in os.listdir(class_path):
img_path = class_path + img_name
Img = Image.open(img_path)
# 将图片转化为固定大小
img = Img.resize((28, 28))
# 将图片转化为二进制格式
img_raw = img.tobytes()
# 创建样本
example = tf.train.Example(
features = tf.train.Features(
feature = {
# 设定标签和图像数据的属性
"label": tf.train.Feature(int64_list = tf.train.Int64List(value = [index])),
"img_raw": tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
print('TFRecords File Generation Finished!')
- tf.data读取tfrecords文件
import tensorflow as tf
import matplotlib.pyplot as plt
def _parse_function(example_proto):
features = {'label':tf.FixedLenFeature([], tf.int64),
'img_raw':tf.FixedLenFeature([], tf.string)}
parsed_features = tf.parse_single_example(example_proto, features)
img = tf.decode_raw(parsed_features['img_raw'], tf.uint8)
img = tf.reshape(img, [28, 28, 3])
# 在流中抛出img张量和label张量
img = tf.cast(img, tf.float32) / 255
label = tf.cast(parsed_features['label'], tf.int32)
return img, label
filenames = ["/home/test.tfrecords"] #tfrecords文件
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
dataset = tf.dataset.shuffle(buffer_size=100).batch(32).repeat(10)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
try:
while True:
image, label = sess.run(next_element)
plt.imshow(image)
plt.show()
except tf.errors.OutOfRangeError:
print("end")
pass
参考:袁明奇老先生