统一数据格式TFRecord

本文详细介绍如何使用TensorFlow的TFRecord格式来保存和读取数据集,包括图像和标签数据的序列化与反序列化过程。TFRecord是一种高效的文件格式,适用于存储大量训练数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

背景

机器学习的数据可以通过很多种方式进行存储,例如csv文件,excel文件,txt文件等。为了能够将各种种类的数据进行统一,我们将采用TRRECORD的格式进行统一,这种格式比较便于对训练数据的属性进行管理,同时也便于进行数据的多线程输入。

介绍

TFRecord数据通过tf.train.Example Protocol Buffer格式存储的。Protocol Buffer可以上网查询。可以把Protocol buffer看作是一些message的类。

//一个数据样本
message Example{
    Features feature =1;
};
//一个特征
message Features{
    map<string,Feature> feature =1;
};
//对应的特征值类型
message Features{
    oneof kind{
        BytesList bytes_list=1;
        FloatList float_list=2;
        Int64List int64_list=3;
    }
};

保存数据集

保存一个TFRecord的数据集大致的步骤如下:

  • 解析出一个训练集的图像部分和标签部分
  • 新建一个TFRecord文件
  • 循环写入样本
    • 图像数据字符串化
    • 生成example实例,图像像素、标签、图像数据
    • example序列化后写入文件

下面是保存Mnist数据的代码。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

#定义两个初始化函数
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets("dataset/", dtype=tf.uint8, one_hot=True)
#解析出图片
images = mnist.train.images
#解析出标签
labels = mnist.train.labels
#解析出像素
pixels = images.shape[1]
num_examples = mnist.train.num_examples 
#保存的文件格式是*.tfrecords
filename = "./tfrecord/output.tfrecords" 
#定义写入对象writer
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
    image_raw = images[index].tostring()
    #定义example实例,根据类型选定函数进行初始化
    example = tf.train.Example(features=tf.train.Features(feature={
        'pixels': _int64_feature(pixels),
        'label': _int64_feature(np.argmax(labels[index])),
        'image_raw' : _bytes_feature(image_raw)}))
    writer.write(example.SerializeToString())
writer.close()

写入数据集

读取数据

import tensorflow as tf
import scipy.misc 
#定义读取实例
reader = tf.TFRecordReader()
#定义文件队列
files = ['./tfrecord/output.tfrecords']
file_queue = tf.train.string_input_producer(files)

#定义读取一个样本的操作
_,one_example = reader.read(file_queue)
features = tf.parse_single_example(one_example,features={
    'image_raw':tf.FixedLenFeature([],tf.string),
    'pixels':tf.FixedLenFeature([],tf.int64),
    'label':tf.FixedLenFeature([],tf.int64),
    })
#解析图像的操作
images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(10):
        image,label,pixel = sess.run([images,labels,pixels])
        #注意这里一定要reshape
        test = tf.reshape(image,[28,28])
        scipy.misc.imsave('./pics/'+str(i)+'.png',sess.run(test))
    coord.request_stop()
    coord.join(threads)

这里写图片描述

### TFRecord 数据格式及其使用方法 #### 什么是 TFRecordTFRecord 是 TensorFlow 提供的一种用于高效存储和读取数据的二进制文件格式[^3]。它能够统一存储多种形式的数据,包括但不限于图片、文本以及对应的标签信息。这种格式的主要优势在于提高 I/O 效率、减少磁盘占用并简化数据预处理流程。 #### 创建 TFRecord 文件 创建 TFRecord 文件通常涉及以下几个方面: 1. **导入必要的库** 需要先引入 `tensorflow` 库,并定义写入器对象。 ```python import tensorflow as tf writer = tf.io.TFRecordWriter("example.tfrecord") ``` 2. **构建 Example 协议缓冲区 (Protocol Buffer)** 将数据转换为 `tf.train.Example` 对象以便于序列化。以下是具体操作方式: ```python feature = { 'feature_name': tf.train.Feature(float_list=tf.train.FloatList(value=[value])) } example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) serialized_example = example_proto.SerializeToString() ``` 上述代码片段展示了如何将浮点数列表封装成一个特征字段,并将其进一步打包至协议缓冲区内存表示形式中[^1]。 3. **写入数据到 TFRecord 文件** 利用之前实例化的写入器完成实际记录过程。 ```python writer.write(serialized_example) writer.close() # 关闭资源释放句柄 ``` #### 读取 TFRecord 文件 对于已存在的 TFRecord 文件,可以通过如下手段解析其中的内容: 1. **初始化 Dataset 实例** 借助 `tf.data.TFRecordDataset` 类加载目标路径下的单个或者多个 TFRecord 文件集合。 ```python dataset = tf.data.TFRecordDataset(["example.tfrecord"]) ``` 2. **自定义解析函数** 设计适合特定需求的解码逻辑,比如恢复原始数值型数组或者其他复杂结构体。 ```python def _parse_function(example_proto): keys_to_features = {'feature_name': tf.io.FixedLenFeature([], dtype=tf.float32)} parsed_features = tf.io.parse_single_example(example_proto, keys_to_features) return parsed_features['feature_name'] ``` 3. **应用映射变换** 把上述定制好的处理器绑定给整个批次流水线执行链路末端位置上。 ```python mapped_dataset = dataset.map(_parse_function) iterator = iter(mapped_dataset) next_element = iterator.get_next() ``` 以上步骤实现了从本地磁盘加载预先加工完毕后的样本队列直至最终可供模型消费的状态转变全过程描述[^2]。 #### 存储与分发优化建议 当面对大规模分布式计算场景时,可考虑借助 Hadoop 分布式文件系统(HDFS)来管理海量 TFRecord 资料集。例如,在阿里云 EMR 平台上部署相关作业流,则能充分利用集群内部节点间通信特性加速整体运算效率[^4]。 此外值得注意的一点是在不同操作系统环境下运行程序可能会遇到版本兼容性问题。如果希望在 Windows 或 Linux 下稳定支持 GPU 加速功能的话,可能还需要额外关注对应驱动安装状况以及手动编译指定分支源码等工作细节[^5]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值