TFRecord tf.train.Feature

本文详细介绍了TFRecord文件格式及其在TensorFlow中的应用。包括定义、特点、如何存储和读取复杂数据,以及通过实例演示如何持久化MNIST数据集并解析TFRecord文件。
部署运行你感兴趣的模型镜像

一、定义
在这里插入图片描述
事先将数据编码为二进制的TFRecord文件,配合TF自带的多线程API,读取效率最高,且跨平台,适合规范化存储复杂的数据。上图为TFRecord的pb格式定义,可发现每个TFRecord由许多Example组成。

Example官方定义:An Example is a mostly-normalized data format for storing data for training and inference.
一个Example代表一个封装的数据输入,比如包含一张图片、图片的宽高、图片的label等信息。而每个信息用键值对的方式存储。因此一个Example包含了一个Features(Features 包含多个 feature)。

这种约定好的TFRecord格式,可以应用于所有数据集的制作。

二、Feature
官方定义:

// A Feature contains Lists which may hold zero or more values. These
// lists are the base values BytesList, FloatList, Int64List.
//
// Features are organized into categories by name. The Features message
// contains the mapping from name to Feature.、

eatures是Feature的字典合集,key为String,而value为tf.train.Feature(),value必须符合特定的三种格式之一:字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。

tf.train.Feature(**options) 
options可以选择如下三种数据格式:
bytes_list = tf.train.BytesList(value = 输入)#输入的元素的数据类型为string
int64_list = tf.train.Int64List(value = 输入)#输入的元素的数据类型为int(int32,int64)
float_list = tf.trian.FloatList(value = 输入)#输入的元素的数据类型为float(float32,float64)
注:value必须是list(向量)

原始数据为矩阵或张量(比如图片格式)不管哪种方式存储都会使数据丢失形状信息,所以在向该样本中写入feature时应该额外加入shape信息作为额外feature。shape信息是int类型,建议采用原feature名字+’_shape’来指定shape信息的feature名。这样读取操作可获取到shape信息进行还原。

以下是两种存储矩阵的方式,都需要额外存储shape信息以便还原:(第二种更方便)

将矩阵或张量fatten成list(向量),再根据元素的数据类型选择使用哪个数据格式存储。
将矩阵或张量用.tostring()转换成string类型,再用tf.train.Feature(bytes_list=tf.train.BytesList(value=[input.tostring()]))来存储。

# 定义函数转化变量类型。
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 将每一个数据转化为tf.train.Example格式。
def _make_example(pixels, label, image):
    image_raw = image.tostring()  # np.array ---> String byte
    example = tf.train.Example(features=tf.train.Features(feature={
        'pixels': _int64_feature(pixels),
        'label': _int64_feature(np.argmax(label)),
        'image_raw': _bytes_feature(image_raw)
    }))
    return example

三、完整的持久化mnist数据为TFRecord

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

# 定义函数转化变量类型。
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 将数据转化为tf.train.Example格式。
def _make_example(pixels, label, image):
    image_raw = image.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'pixels': _int64_feature(pixels),
        'label': _int64_feature(np.argmax(label)),
        'image_raw': _bytes_feature(image_raw)
    }))
    return example

def save_tfrecords():
    # 读取mnist训练数据。
    mnist = input_data.read_data_sets("../../datasets/MNIST_data",dtype=tf.uint8, one_hot=True)
    images = mnist.train.images  # (55000, 784) <class 'numpy.ndarray'>
    labels = mnist.train.labels  # (55000, 10) <class 'numpy.ndarray'>
    pixels = images.shape[1]  # 784 = 28 * 28
    num_examples = mnist.train.num_examples

    # 输出包含训练数据的TFRecord文件。
    with tf.python_io.TFRecordWriter("output.tfrecords") as writer:
        for index in range(num_examples):
            # 生成一个Example并序列化后写入pb
            example = _make_example(pixels, labels[index], images[index])
            writer.write(example.SerializeToString())
    print("TFRecord训练文件已保存。")

四、读取解析TFRecord
读取解析的步骤中,需要根据编码时候的定义,来指定解码时候的规则和还原的dtype,如image需要指定tf.string格式,之后再去解析成uint8。注意,这里的parse等op操作都是在graph中定义一些运算op,并没有运行。sess.run()的时候才会真正多线程开始读取解析。这种读取二进制了流文件的速度,多线程加持下远远超过读取硬盘中的原生图片。

def test_tfrecords():
    # 读取文件。
    print(len(tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)))  # 0
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer(["output.tfrecords"])  # 队列默认自动添加进collection
    print(len(tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)))   # 1
    _, serialized_example = reader.read(filename_queue)

    # 解析读取的样例。
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'pixels': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        })

    images = tf.decode_raw(features['image_raw'], tf.uint8)
    labels = tf.cast(features['label'], tf.int32)
    pixels = tf.cast(features['pixels'], tf.int32)

    sess = tf.Session()

    # 启动多线程处理输入数据。
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(5):
        image, label, pixel = sess.run([images, labels, pixels])
        print(label)

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

<think>嗯,用户想了解TensorFlow中tf.train.Feature的使用方法和示例。首先,我需要回顾一下tf.train.Feature的相关知识。记得tf.train.Feature是用于构建Example协议缓冲区的,这在处理TFRecord格式数据时非常重要。用户可能正在尝试将数据转换为TFRecord格式,所以需要知道如何正确使用Feature来序列化数据。 首先,我应该解释tf.train.Feature的基本概念,它是Example中的一个字段,用来存储不同类型的数据。然后,需要提到它支持的三种数据类型:BytesList、FloatList、Int64List。每个Feature对应一种类型,用户需要根据数据内容选择合适的类型。 接下来,应该给出具体的步骤,比如如何将数据转换为Feature,然后构建Example,最后写入TFRecord文件。示例代码会很有帮助,但要注意代码的正确性和兼容性,比如使用兼容TensorFlow 2.x的方法,因为现在大多数用户可能在使用较新版本。 还要注意用户可能的误区,比如数据类型的正确转换,比如图像数据通常保存为字节,标签为整数,可能需要用tf.io.serialize_tensor来处理张量。另外,提醒用户写入和读取时的流程,比如使用TFRecordWriter和TFRecordDataset,以及解析函数中的parse_single_example。 用户可能还需要知道如何在实际训练中使用这些TFRecord文件,比如在Dataset.map中使用解析函数。此外,可能需要提到性能优化的建议,如并行处理和预取,引用中的内容提到了使用tf.data.Dataset的管道机制,这可以提升性能,确保CPU处理数据的同时GPU进行训练。 还需要检查是否有相关的引用资料可以支持回答,比如引用3提到了从TFRecord文件中提取Example,并解析特征,这可以引用作为示例。引用4提到在数据预处理中使用TensorFlow指令,可能相关,但用户的问题更侧重于Feature的使用,所以重点在引用3的内容。 最后,生成相关问题时要围绕TFRecordFeature的使用,比如不同数据类型的处理、性能优化、与其他序列化方法的比较等。确保问题对用户深入学习有帮助。</think>### TensorFlow tf.train.Feature 使用方法及示例 `tf.train.Feature` 是 TensorFlow 中用于构建 `tf.train.Example` 协议缓冲区的核心组件,主要用于将结构化数据序列化为 TFRecord 格式[^3]。以下是其使用方法和示例: --- #### 1. **基本概念** - **功能**:`tf.train.Feature` 用于定义单个数据字段的存储格式,支持三种数据类型: - `bytes_list`:存储字节类型数据(如序列化后的图像、文本)。 - `float_list`:存储浮点数列表(如特征向量)。 - `int64_list`:存储整数列表(如标签、ID)。 - **层级关系**: ``` Example → Features → {key: Feature} → Feature → {bytes_list/float_list/int64_list} ``` --- #### 2. **使用步骤** ##### (1) 将数据转换为 `tf.train.Feature` ```python import tensorflow as tf # 示例数据:图像字节和标签 image_bytes = b"raw_image_data" # 实际应从文件读取 label = 5 # 转换为 Feature def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 构建 Features 字典 feature_dict = { "image": _bytes_feature(image_bytes), "label": _int64_feature(label) } ``` ##### (2) 构建 `tf.train.Example` ```python example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) ``` ##### (3) 写入 TFRecord 文件 ```python with tf.io.TFRecordWriter("data.tfrecord") as writer: writer.write(example.SerializeToString()) ``` --- #### 3. **从 TFRecord 读取数据** ```python # 定义解析函数 def _parse_function(example_proto): feature_description = { "image": tf.io.FixedLenFeature([], tf.string), "label": tf.io.FixedLenFeature([], tf.int64) } parsed_features = tf.io.parse_single_example(example_proto, feature_description) return parsed_features["image"], parsed_features["label"] # 创建 Dataset 管道 dataset = tf.data.TFRecordDataset(["data.tfrecord"]) dataset = dataset.map(_parse_function) # 解析 Example ``` --- #### 4. **复杂数据类型处理** - **张量序列化**:使用 `tf.io.serialize_tensor` 存储多维数据。 ```python tensor_data = tf.constant([[1.0, 2.0], [3.0, 4.0]]) serialized_tensor = tf.io.serialize_tensor(tensor_data) feature = _bytes_feature(serialized_tensor.numpy()) ``` --- #### 5. **最佳实践** - **性能优化**:结合 `tf.data.Dataset` 的 `shuffle`、`prefetch` 和并行处理提升效率[^2]。 ```python dataset = dataset.shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE) ``` - **类型匹配**:确保写入和解析时的数据类型一致(如 `tf.string` 对应 `bytes_list`)。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值