TFrecords 制作与读取

本文详细介绍了如何使用TensorFlow库制作TFrecords文件,并演示了如何从TFrecords文件中读取数据。通过实例代码,展示了如何将图像数据转换为TFrecords格式,以及如何解析这些数据用于深度学习模型训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TFrecords 制作与读取

  • 制作tfrecords
import tensorflow as tf
import os
from PIL import Image

# add为XM、MX文件夹的上层目录
add = "/home/"
# 设定图片分类,与要制作的图片所属文件名相同
classes = {'0','1','2','3','4','5','6','7','8','9'}
# 给出输出的位置以及输出的文件名
writer = tf.python_io.TFRecordWriter("/home/crxm/test.tfrecords")
for index, name in enumerate(classes):
    # 写入图片文件路径
    class_path = add + name + '/'
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name
        Img = Image.open(img_path)
        # 将图片转化为固定大小
        img = Img.resize((28, 28))
        # 将图片转化为二进制格式
        img_raw = img.tobytes()
        # 创建样本
        example = tf.train.Example(
            features = tf.train.Features(
                feature = {
                # 设定标签和图像数据的属性
                    "label": tf.train.Feature(int64_list = tf.train.Int64List(value = [index])),
                    "img_raw": tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
                }))
        writer.write(example.SerializeToString())
writer.close()
print('TFRecords File Generation Finished!')
  • tf.data读取tfrecords文件
import tensorflow as tf
import matplotlib.pyplot as plt

def _parse_function(example_proto):
    features = {'label':tf.FixedLenFeature([], tf.int64),
              'img_raw':tf.FixedLenFeature([], tf.string)}
    parsed_features = tf.parse_single_example(example_proto, features)
    img = tf.decode_raw(parsed_features['img_raw'], tf.uint8)
    img = tf.reshape(img, [28, 28, 3])
    # 在流中抛出img张量和label张量
    img = tf.cast(img, tf.float32) / 255
    label = tf.cast(parsed_features['label'], tf.int32)
    return img, label

filenames = ["/home/test.tfrecords"]  #tfrecords文件
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
dataset = tf.dataset.shuffle(buffer_size=100).batch(32).repeat(10)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    try:
        while True:
            image, label = sess.run(next_element)
            plt.imshow(image)
            plt.show()
    except tf.errors.OutOfRangeError:
        print("end")
        pass

参考:袁明奇老先生

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值