TensorFlow——TFRecords文件

一、什么是TFRecords文件

TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件

使用步骤:

1)获取数据

2)将数据填入到Example协议内存块(protocol buffer)

3)将协议内存块序列化为字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

  • 文件格式 *.tfrecords

二、Example结构解析

  • tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)
  • Features包含了一个Feature字段
  • Feature中包含要写入的数据、并指明数据类型。
    • 这是一个样本的结构,批数据需要循环存入这样的结构
 example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            }))
  • tf.train.Example(features=None)
    • 写入tfrecords文件
    • features:tf.train.Features类型的特征实例
    • return:example格式协议块
  • tf.train.Features(feature=None)
    • 构建每个样本的信息键值对
    • feature:字典数据,key为要保存的名字
    • value为tf.train.Feature实例
    • return:Features类型
  • tf.train.Feature(options)
    • options:例如
      • bytes_list=tf.train. BytesList(value=[Bytes])
      • int64_list=tf.train. Int64List(value=[Value])
    • 支持存入的类型如下
    • tf.train.Int64List(value=[Value])
    • tf.train.BytesList(value=[Bytes])
    • tf.train.FloatList(value=[value])

这种结构很好地实现了数据和标签(训练的类别标签)或者其他属性数据存储在同一个文件中

三、案例:CIFAR10数据存入TFRecords文件

1.分析

  • 构造存储实例,tf.python_io.TFRecordWriter(path)

    • 写入tfrecords文件
    • path:TFRecords文件的路径
    • return:写文件
      • method方法
        • write(record):向文件中写入一个example
        • close():关闭文件写入器
  • 循环将数据填入到Example协议内存块(protocol buffer)

2.代码

对于每一个图片样本数据,都需要写入到example当中,所以这里需要取出每一样本进行构造存入

def write_to_tfrecords(self, image_batch, label_batch):
    """
        将数据存进tfrecords,方便管理每个样本的属性
        :param image_batch: 特征值
        :param label_batch: 目标值
        :return: None
        """
    # 1、构造tfrecords的存储实例
    writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir)

    # 2、循环将每个样本写入到文件当中
    for i in range(10):

        # 一个样本一个样本的处理写入
        # 准备特征值,特征值必须是bytes类型 调用tostring()函数
        # [10, 32, 32, 3] ,在这里避免tensorflow的坑,取出来的不是真正的值,而是类型,所以要运行结果才能存入
        # 出现了eval,那就要在会话当中去运行该行数
        image = image_batch[i].eval().tostring()

        # 准备目标值,目标值是一个Int类型
        # eval()-->[6]--->6
        label = label_batch[i].eval()[0]

        # 绑定每个样本的属性
        example = tf.train.Example(features=tf.train.Features(feature={
            "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
        }))

        # 写入每个样本的example
        writer.write(example.SerializeToString())

        # 文件需要关闭
        writer.close()
        return None

    # 开启会话打印内容
    with tf.Session() as sess:
        # 创建线程协调器
        coord = tf.train.Coordinator()

        # 开启子线程去读取数据
        # 返回子线程实例
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 获取样本数据去训练
        print(sess.run([image_batch, label_batch]))

        # 存入数据
        cr.write_to_tfrecords(image_batch, label_batch )

        # 关闭子线程,回收
        coord.request_stop()

        coord.join(threads)

四、读取TFRecords文件API

读取这种文件整个过程与其他文件一样,只不过需要有个解析Example的步骤。从TFRecords文件中读取数据, 可以使用tf.TFRecordReadertf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

# 多了解析example的一个步骤
feature = tf.parse_single_example(values, features={
    "image": tf.FixedLenFeature([], tf.string),
    "label": tf.FixedLenFeature([], tf.int64)
})
  • tf.parse_single_example(serialized, features=None, name=None)

    • 解析一个单一的Example原型
    • serialized:标量字符串Tensor,一个序列化的Example
    • features:dict字典数据,键为读取的名字,值为FixedLenFeature
    • return:一个键值对组成的字典,键为读取的名字
  • tf.FixedLenFeature(shape, dtype)

    • shape:输入数据的形状,一般不指定,为空列表
    • dtype:输入数据类型,与存储进文件的类型要一致
    • 类型只能是float32, int64, string

五、案例:读取CIFAR的TFRecords文件

1.分析

  • 使用tf.train.string_input_producer构造文件队列
  • tf.TFRecordReader 读取TFRecords数据并进行解析
    • tf.parse_single_example进行解析
  • tf.decode_raw解码
    • 类型是bytes类型需要解码
    • 其他类型不需要
  • 处理图片数据形状以及数据类型,加入批处理队列
  • 开启会话线程运行

2.代码

def read_tfrecords(self):
    """
        读取tfrecords的数据
        :return: None
        """
    # 1、构造文件队列
    file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"])

    # 2、构造tfrecords读取器,读取队列
    reader = tf.TFRecordReader()

    # 默认也是只读取一个样本
    key, values = reader.read(file_queue)

    # tfrecords
    # 多了解析example的一个步骤
    feature = tf.parse_single_example(values, features={
        "image": tf.FixedLenFeature([], tf.string),
        "label": tf.FixedLenFeature([], tf.int64)
    })

    # 取出feature里面的特征值和目标值
    # 通过键值对获取
    image = feature["image"]

    label = feature["label"]

    # 3、解码操作
    # 对于image是一个bytes类型,所以需要decode_raw去解码成uint8张量
    # 对于Label:本身是一个int类型,不需要去解码
    image = tf.decode_raw(image, tf.uint8)

    print(image, label)

    # # 从原来的[32,32,3]的bytes形式直接变成[32,32,3]
    # 不存在一开始我们的读取RGB的问题
    # 处理image的形状和类型
    image_reshape = tf.reshape(image, [self.height, self.width, self.channel])

    # 处理label的形状和类型
    label_cast = tf.cast(label, tf.int32)

    print(image_reshape, label_cast)

    # 4、批处理操作
    image_batch, label_batch = tf.train.batch([image_reshape, label_cast], batch_size=10, num_threads=1, capacity=10)

    print(image_batch, label_batch)
    return image_batch, label_batch

# 从tfrecords文件读取数据
image_batch, label_batch = cr.read_tfrecords()

# 开启会话打印内容
with tf.Session() as sess:
    # 创建线程协调器
    coord = tf.train.Coordinator()

完整代码:

import tensorflow as tf
import os


class Cifar(object):

    # 初始化
    def __init__(self):
        # 图像的大小
        self.height = 32
        self.width = 32
        self.channels = 3

        # 图像的字节数
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channels
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_decode(self, file_list):
        # 读取二进制文件
        # print("read_and_decode:\n", file_list)
        # 1、构造文件名队列
        file_queue = tf.train.string_input_producer(file_list)

        # 2、构造二进制文件阅读器
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(file_queue)

        print("key:\n", key)
        print("value:\n", value)
        # 3、解码
        decoded = tf.decode_raw(value, tf.uint8)
        print("decoded:\n", decoded)

        # 4、基本的数据处理
        # 切片处理,把标签值和特征值分开
        label = tf.slice(decoded, [0], [self.label_bytes])
        image = tf.slice(decoded, [self.label_bytes], [self.image_bytes])

        print("label:\n", label)
        print("image:\n", image)
        # 改变图像的形状
        image_reshaped = tf.reshape(image, [self.channels, self.height, self.width])
        # 转置
        image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
        print("image_transposed:\n", image_transposed)

        # 类型转换
        label_cast = tf.cast(label, tf.float32)
        image_cast = tf.cast(image_transposed, tf.float32)

        # 5、批处理
        label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)
        return label_batch, image_batch


    def write_to_tfrecords(self, label_batch, image_batch):
        # 进行类型转换,转成tf.uint8
        # 为了节省空间
        label_batch = tf.cast(label_batch, tf.uint8)
        image_batch = tf.cast(image_batch, tf.uint8)
        # 构造tfrecords存储器
        with tf.python_io.TFRecordWriter("./cifar.tfrecords") as writer:
            for i in range(10):
                label = label_batch[i].eval()[0]
                image = image_batch[i].eval().tostring()
                print("tfrecords_label:\n", label)
                print("tfrecords_image:\n", image, type(image))
                # 构造example协议块
                example = tf.train.Example(features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train. Int64List(value=[label])),
                    "image": tf.train.Feature(bytes_list=tf.train. BytesList(value=[image]))
                }))
                # 写入序列化后的example
                writer.write(example.SerializeToString())


    def read_tfrecords(self):
        # 读取tfrecords文件
        # 1、构造文件名队列
        file_queue = tf.train.string_input_producer(["cifar.tfrecords"])

        # 2、构造tfrecords阅读器
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 3、解析example协议块
        example = tf.parse_single_example(value, features={
            "label": tf.FixedLenFeature(shape=[], dtype=tf.int64),
            "image": tf.FixedLenFeature(shape=[], dtype=tf.string)
        })
        label = example["label"]
        image = example["image"]
        print("read_tfrecords_label:\n", label)
        print("read_tfrecords_image:\n", image)

        # 4、解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("read_tfrecords_image_decoded:\n", image_decoded)

        # 5、基本的数据处理
        # 调整图像形状
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
        # 转换类型
        image_cast = tf.cast(image_reshaped, tf.float32)
        label_cast = tf.cast(label, tf.float32)
        print("read_records_image_cast:\n", image_cast)
        print("read_records_label_cast:\n", label_cast)

        # 6、批处理
        label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)

        return label_batch, image_batch

if __name__ == "__main__":
    # 构造文件名列表
    file_name = os.listdir("./cifar-10-batches-bin")
    print("file_name:\n", file_name)
    file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in file_name if file[-3:] == "bin"]
    print("file_list:\n", file_list)

    # 调用读取二进制文件的方法
    cf = Cifar()
    # label, image = cf.read_and_decode(file_list)
    label, image = cf.read_tfrecords()

    # 开启会话
    with tf.Session() as sess:
        # 创建线程协调器
        coord = tf.train.Coordinator()
        # 创建线程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 打印结果
        print("label:\n", sess.run(label))
        print("image:\n", sess.run(image))

        # cf.write_to_tfrecords(label, image)
        # 回收资源
        coord.request_stop()
        coord.join(threads)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值