TensorFlow官方文档标准TensorFlow格式

本文介绍如何使用TensorFlow的TFRecords格式存储和读取数据,实现高效的数据流水线处理。包括数据转换、预处理、批处理等关键步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

标准TensorFlow格式

另一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriterclass写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。

艾伯特(http://www.aibbt.com/)国内第一家人工智能门户

从TFRecords文件中读取数据, 可以使用tf.TFRecordReadertf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocol buffer)解析为张量。 MNIST的例子就使用了convert_to_records 所构建的数据。 请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py, 您也可以将这个例子跟fully_connected_feed的版本加以比较。

预处理

你可以对输入的样本进行任意的预处理, 这些预处理不依赖于训练参数, 你可以在tensorflow/models/image/cifar10/cifar10.py找到数据归一化, 提取随机数据片,增加噪声或失真等等预处理的例子。

批处理

在数据输入管线的末端, 我们需要有另一个队列来执行输入样本的训练,评价和推理。因此我们使用tf.train.shuffle_batch 函数来对队列中的样本进行乱序处理

示例:

def read_my_file_format(filename_queue):
  reader = tf.SomeReader()
  key, record_string = reader.read(filename_queue)
  example, label = tf.some_decoder(record_string)
  processed_example = some_processing(example)
  return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)
  example, label = read_my_file_format(filename_queue)
  # min_after_dequeue defines how big a buffer we will randomly sample
  #   from -- bigger means better shuffling but slower start up and more
  #   memory used.
  # capacity must be larger than min_after_dequeue and the amount larger
  #   determines the maximum we will prefetch.  Recommendation:
  #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch

如果你需要对不同文件中的样子有更强的乱序和并行处理,可以使用tf.train.shuffle_batch_join 函数. 示例:

def read_my_file_format(filename_queue):
  # Same as above

def input_pipeline(filenames, batch_size, read_threads, num_epochs=None):
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)
  example_list = [read_my_file_format(filename_queue)
                  for _ in range(read_threads)]
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch_join(
      example_list, batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch

在这个例子中, 你虽然只使用了一个文件名队列, 但是TensorFlow依然能保证多个文件阅读器从同一次迭代(epoch)的不同文件中读取数据,知道这次迭代的所有文件都被开始读取为止。(通常来说一个线程来对文件名队列进行填充的效率是足够的)

另一种替代方案是: 使用tf.train.shuffle_batch 函数,设置num_threads的值大于1。 这种方案可以保证同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件。这种方案的优点是:

  • 避免了两个不同的线程从同一个文件中读取同一个样本。
  • 避免了过多的磁盘搜索操作。

你一共需要多少个读取线程呢? 函数tf.train.shuffle_batch*为TensorFlow图提供了获取文件名队列中的元素个数之和的方法。 如果你有足够多的读取线程, 文件名队列中的元素个数之和应该一直是一个略高于0的数。具体可以参考TensorBoard:可视化学习.

创建线程并使用QueueRunner对象来预取

简单来说:使用上面列出的许多tf.train函数添加QueueRunner到你的数据流图中。在你运行任何训练步骤之前,需要调用tf.train.start_queue_runners函数,否则数据流图将一直挂起。tf.train.start_queue_runners 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个tf.train.Coordinator,这样可以在发生错误的情况下正确地关闭这些线程。如果你对训练迭代数做了限制,那么需要使用一个训练迭代数计数器,并且需要被初始化。推荐的代码模板如下:

# Create the graph, etc.
init_op = tf.initialize_all_variables()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_op)

except tf.errors.OutOfRangeError:
    print 'Done training -- epoch limit reached'
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值