读写tfrecord文件


在训练模型的时候,一般会将数据预处理转换成tfrecord格式,负责I/O操作的CPU和进行数值运行计算的GPU相互之间可以并行工作,保证GPU高的利用率。以下是对特征是定长和变长读写tfrecord方式。

1 写tfrecor方式

一般会将数据按照模型训练所需要的方式对输入x和label标签进行tfrecord格式转换。主要有定长和变长两种方式,根据实际应用和需求决定。若输入的每个example的input 是变长的,比如每个example的输入特征索引个数不是相同的,则可以按照变长的方式转换,否则按照定长的方式转换。

1.1 变长特征转tfrecord

import collections
writer = tf.python_io.TFRecordWriter('data.tfrecord')
def toTF(data):
	''' 
	data是一个dict,假设其中key有input_x和input_y,
	对应的value是索引list
	'''
	features = collections.OrderedDict()
	input_x = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x"])))
	features["input_x"] = tf.train.FeatureList(feature=input_x)
	input_y = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y"])))
	features["input_y"] = tf.train.FeatureList(feature=input_y)
	sequence_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list=features))
	writer.write(sequence_example.SerializeToString())

以下方式实现与上面方式等价:

def toTF_v2(data)
	sequence_example = tf.train.SequenceExample()
	input_x = sequence_example.feature_lists.feature_list["input_x"]
	input_y = sequence_example.feature_lists.feature_list["input_y"]
	for x in data["input_x"]:
		input_x.feature.add().int64_list.value.append(x)
	for y in data["input_y"]:
		input_y.feature.add().int64_list.value.append(y)
	writer.write(sequence_example.SerializeToString())

1.2 定长特征转tfrecord

def toTF_fixed(data):
	features = collections.OrderedDict()
	features["input_x"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x")))
	features["input_y"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y")))
	example = tf.train.Example(features=tf.train.Features(feature=features))
	write.write(example.SerializeToString())

2 读tfrecord

和写trrecord一样,也分定长和变长方式,如果写tfrecord是定长方式,则读tfrecord也需要定长方式。读写方式需要保持一致。

2.1 变长方式读tfrecord

需要定义特征的格式,如果是变长则定义tf.FixedLenSequenceFeature类型特征

import tensorflow as tf
features = {
			'input_x': tf.FixedLenSequenceFeature([], tf.int64)
			'input_y': tf.FixedLenSequenceFeature([], tf.int64)
			}

2.2 定长方式读tfrecord

定长方式用tf.FixedLenFeature类型

seq_length = 10
features = {
		'input_x': tf.FixedLenFeature([seq_length], tf.int64).
		'input_y': tf.FixedLenFeature([seq_length], tf.int64
		}

3 从hdfs中读取批量tfrecord文件

当训练数据量级很大时,一般转tfrecord试用分布式方式处理数据,提高效率。训练模型的时候,可以从远程,例如hdfs上读取批量文件。以下是从hdfs上批量读取tfrecord文件。

def input_fn_builder(file_path, num_cpu_threads, seq_length, num_class, batch_size):
	'''
	其中file_path是hdfs上文件的路径,比如data目录下的所有tfrecord文件
	读的是定长的feature
	'''
	features = {
			'input_x': tf.FixedLenFeature([seq_length], tf.int64),
			'input_y': tf.FixedLenFeature([seq_length], tf.int64),
	}
	def _decode_record(record):
		# 一个样本解析
		example = tf.io.parse_single_example(record, features)
		multi_label_enc = tf.one_hot(indices=example["input_y"], depth=num_class)
		example["input_y"] = tf.reduce_sum(multi_label_enc, axis=0)
		return example

	def _decode_batch_record(batch_record):
		# 一个batch样本解析
		batch_example = tf.io.parse_example(serialized=batch_record, features=features)
		multi_label_enc = tf.one_hot(indices=batch_example["input_y"], depth=num_class)
		batch_example["input_y"] = tf.reduce_sum(multi_label_enc, axis=1)
		return batch_example

	def input_fn(params):
		# d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
		d = tf.data.Dataset.list_files(file_path)
		d = d.repeat()
		d = d.shuffle(buffer_size=100)
		d = d.appley(
			tf.contrib.data.parallel_interleave(
				tf.data.TFRecordDataset,
				sloppy=True,
				cycle_length=num_cpu_threads))
		d = d.apply(
			tf.contrib.data.map_and_batch(
					lambda record: _decode_record(record),
					batch_size = batch_size,
					num_parallel_batches=num_cpu_threads,
					drop_remainder=True))
		return d

	def input_fn_v2(params):
		d = tf.data.Dataset.list_files(file_path)
		d = d.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=num_cpu_threads, block_length=128).\
		batch(batch_size).map(_decode_batch_record, num_parallel_calls=tf.data.experimental.AUTOTRUE).prefetch(
			tf.data.experimental.AUTOTUNE).repeat()
		return d
	return input_fn
	#return input_fn_v2

上面提供了两个解析函数,input_fn和input_fn_v2两种方式都可行,配合estimator方式训练,可以使得CPU读取数据与GPU训练数据之间可以并行处理,减少等待时间,提高GPU的利用率,加快训练速度。解析tfrecord文件时,有下面四种方式,根据自己具体的数据格式进行选择:

  • 解析单个样本,定长特征:tf.io.parse_single_example()
  • 解析单个样本,变长特征:tf.io.parse_single_sequence_example()
  • 解析批量样本,定长特征:tf.io.parse_example()
  • 解析批量样本,定长特征:tf.io.parse_sequence_example()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值