TFrecord:write&read

本文介绍如何利用TensorFlow的TFRecord格式处理图像数据,以加速卷积神经网络(CNN)的训练过程。文中详细展示了创建TFRecord文件的方法及从TFRecord中读取数据的具体实现,包括使用多线程提高读取效率。

参考了这位仁兄的博客

概述

在训练卷积神经网络时,将图片提前处理好并缓存在磁盘上,通过中间文件随机调用访问可以明显提高训练速度,并且可以减少重复处理图片的工作。

write

通过tf.train.Example Protocol Buffer
下面代码源于本人写的一个函数

def create_tfrecord(result, sess):
    """
    create tfrecord files for train,validation,test
    Args:
        result: the dictionary of images
        sess: the session

    """
    path = FLAGS.tfrecord_dir
    if not tf.gfile.Exists(path):
        tf.gfile.MakeDirs(path)
    tf_filename = os.path.join(path,'validation.tfrecord')

    jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding()

    writer = tf.python_io.TFRecordWriter(tf_filename)
    #print(len(result['validation']))   
    for index_val,file in enumerate(result['validation']):
        tf.logging.info("write the %d in validation"%index_val)
        name,_ = os.path.splitext(file)
        label= get_labels_array(name + '.txt')
        input_image_array = create_input_tensor(file, sess, jpeg_data_tensor, decoded_image_tensor)
        input_image_string = input_image_array.tostring()
        label_string = label.tostring()
        example = tf.train.Example(features = tf.train.Features(
                feature = {
                        'label': tf.train.Feature(bytes_list = tf.train.BytesList(value = [label_string])),
                        'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [input_image_string]))
                        }))
        writer.write(example.SerializeToString())
    writer.close()

read

读比较麻烦,还要建立线程什么的
注意函数使用了多线程

def read_tfrecord(file_name,batch):
    filename_queue = tf.train.string_input_producer([file_name],)
    reader = tf.TFRecordReader()
    _, serialize_example = reader.read(filename_queue)
    feature = tf.parse_single_example(serialize_example,
                                       features = {
                                               'label': tf.FixedLenFeature([], tf.string),
                                               'image': tf.FixedLenFeature([], tf.string),
                                               })
    labels = tf.decode_raw(feature['label'],tf.int64)
    labels = tf.reshape(labels, [26])
    images = tf.decode_raw(feature['image'],tf.float32)
    images = tf.reshape(images, [1080, 1440, 3])
    #coord = tf.train.Coordinator()
    #threads = tf.train.start_queue_runners(sess = sess,coord = coord)
    #images = tf.squeeze(images)
    images = tf.image.convert_image_dtype(images,tf.int8)
    if batch > 1:
        images, labels = tf.train.shuffle_batch([images,labels],
                                                batch_size=batch,
                                                capacity=500,
                                                num_threads=2,
                                                min_after_dequeue=10)

    return images,labels

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    """
    result = create_image_lists(FLAGS.image_dir,FLAGS.test_dir,30)
    label = get_labels_path(result['testing'])
    """
    #label = get_labels_array(r'G:\GraduateStudy\Smoke Recognition\Newdata\Train\10830004.txt')
    #result = create_image_lists(FLAGS.image_dir, FLAGS.test_dir, 10)
    file_name = r'G:\GraduateStudy\Smoke Recognition\Newdata\Tfrecord\validation.tfrecord'
    image,label = read_tfrecord(file_name,8)
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        #create_tfrecord(result,sess)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord = coord)
        try:
            for i in range(2):
                img,labe = sess.run([image,label])
                #cv2.imwrite('image' + str(i) + '.jpg',img)             
                print(img.shape, labe.shape)
        except tf.errors.OutOfRangeError:
            print('Done reading')
        finally:
            coord.request_stop()

        coord.join(threads)
### TFRecord 的使用指南 #### 什么是 TFRecordTFRecord 是 TensorFlow 提供的一种高效存储和读取大数据集的二进制文件格式。相比其他常见的数据格式(如 CSV 或 JSON),TFRecord 更加紧凑,能够显著提升 I/O 性能,尤其适用于大规模机器学习任务中的数据处理。 --- #### 数据写入 TFRecord 文件 要将原始数据保存为 TFRecord 格式的文件,通常需要以下几个步骤: 1. **定义协议缓冲区消息 (ProtoBuf)** 使用 `tf.train.Example` 来封装每条记录的数据。 2. **创建 TFRecordWriter 对象** 创建一个用于写入 TFRecord 文件的对象。 3. **逐条写入数据** 将每一条数据编码成 Example 协议缓冲区消息,并将其写入到 TFRecord 文件中。 以下是完整的代码示例: ```python import tensorflow as tf def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns a float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / integer.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 构造单个样例 def create_example(image_string, label): feature = { 'image': _bytes_feature(image_string), 'label': _int64_feature(label) } return tf.train.Example(features=tf.train.Features(feature=feature)) # 写入 TFRecord 文件 def write_tfrecord(output_path, image_strings, labels): with tf.io.TFRecordWriter(output_path) as writer: for img_str, lbl in zip(image_strings, labels): example = create_example(img_str, lbl) writer.write(example.SerializeToString()) # 示例调用 images = [open('path/to/image1.jpg', 'rb').read(), open('path/to/image2.jpg', 'rb').read()] labels = [0, 1] write_tfrecord('output.tfrecord', images, labels) ``` [^1] --- #### 数据读取 TFRecord 文件 为了从 TFRecord 文件中读取数据,可以按照以下流程操作: 1. **解析函数定义** 定义一个解析器来解码 TFRecord 中的每一项数据。 2. **创建 Dataset** 利用 TensorFlow 的 `Dataset` API 加载并解析 TFRecord 文件。 下面是具体的实现代码: ```python def parse_function(proto): # Define the features to be parsed feature_description = { 'image': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.int64), } # Parse the input `tf.Example` proto using the dictionary above. example = tf.io.parse_single_example(proto, feature_description) # Decode and preprocess data image = tf.image.decode_jpeg(example['image'], channels=3) image = tf.cast(image, tf.float32) / 255.0 # Normalize pixel values label = example['label'] return image, label # Load TFRecord file into dataset dataset = tf.data.TFRecordDataset(['output.tfrecord']) parsed_dataset = dataset.map(parse_function) for image, label in parsed_dataset.take(1): # Print first sample print(f'Image shape: {image.shape}, Label: {label}') ``` [^2] --- #### 应用场景与优势 - **大规模数据处理**:对于包含数百万甚至数十亿样本的大规模数据集,TFRecord 能够有效减少磁盘占用空间并加速 IO 操作。 - **分布式训练支持**:在分布式环境中,TFRecord 可以轻松被多个工作节点共享访问。 - **灵活性高**:除了图像外,还可以用来存储音频、文本等多种类型的数据。 [^3] --- #### 常见问题排查 如果遇到无法正常读取或写入 TFRecord 文件的情况,请检查以下几点: 1. 是否正确设置了特征描述符 (`feature_description`)。 2. 图像或其他复杂数据类型的编码/解码过程是否有误。 3. 确保输入路径和输出路径均无错误。 [^4] ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值