TFRecord数据生成和解析
配置文件
json格式的配置文件,用于tfrecord文件的生成和读取
{
"name": "config",
"features": [
{
"feature_name": "user_id",
"feature_type": "id",
"value_type": "string",
"feature_size": 100,
"feature_length": 1
},
{
"feature_name": "item_id",
"feature_type": "id",
"value_type": "string",
"feature_size": 1000,
"feature_length": 1
},
{
"feature_name": "recall_item_id",
"feature_type": "discrete",
"value_type": "int64",
"feature_size": 1000,
"feature_length": 1
},
{
"feature_name": "is_play",
"feature_type": "label",
"value_type": "int64",
"feature_size": 2,
"feature_length": 1
},
{
"feature_name": "play_duration",
"feature_type": "label",
"value_type": "float32",
"feature_size": 1,
"feature_length": 1
},
{
"feature_name": "age",
"feature_type": "discrete",
"value_type": "int64",
"feature_size": 10,
"feature_length": 1
},
{
"feature_name": "gender",
"feature_type": "discrete",
"value_type": "int64",
"feature_size": 3,
"feature_length": 1
}
]
}
解析config文件,并生成tfrecord文件
- 根据config文件,生成feature_list,字典格式
- 根据feature_list,生成tf.train.Example,根据generate_fake_data(data_dir, feature_list, data_count),生成假数据,其中data_dir为具体路径+文件名,feature_list,字典格式,data_count表示数据量,压缩格式为GZIP
- tfrecord_reader_dataset(filenames, feature_list, n_reader=5, batch_size=32, n_parse_threads=5,shuffle_buffer_size=10000,compression_type=“GZIP”),读取tfrecord文件,filenames为list格式,生成dataset,首先经过shuffle,batch,最后repeat,注意注释
import json
import os
import numpy as np
import tensorflow as tf
import collections
def _bytes_feature(value):
"""Returns a bytes_list from string / byte."""
if not isinstance(value, collections.Iterable):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _float_feature(value):
"""Returns a float_list from float / double."""
if not isinstance(value, collections.Iterable):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
if not isinstance(value, collections.Iterable):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def generate_fake_data(data_dir, feature_list, data_count):
with tf.io.TFRecordWriter(data_dir, options="GZIP") as writer:
for _ in range(int(data_count)):
feature_dict = collections.OrderedDict()
for feature in feature_list:
feature_name = feature["feature_name"]
feature_type = feature["feature_type"]
value_type = feature["value_type"]
feature_size = feature["feature_size"]
feature_length = feature["feature_length"]
if value_type == 'string':
value = np.random.randint(low=0, high=feature_size, size=feature_length).astype('bytes')
feature_dict[feature_name] = _bytes_feature(value)
elif value_type == 'int64':
value = np.random.randint(low=0, high=feature_size, size=feature_length)
feature_dict[feature_name] = _int64_feature(value)
else:
value = np.random.random(size=feature_length)
feature_dict[feature_name] = _float_feature(value)
# print(feature_dict[feature_name])
example_proto = tf.train.Example(features=tf.train.Features(feature=feature_dict)).SerializeToString()
writer.write(example_proto)
def generate_feature_description(feature_list):
feature_description = collections.OrderedDict()
for feature in feature_list:
feature_name = feature["feature_name"]
value_type = feature["value_type"]
feature_length = feature["feature_length"]
if value_type in 'int64':
feature_description[feature_name] = tf.io.FixedLenFeature([feature_length, ], tf.int64)
elif value_type in 'float32':
feature_description[feature_name] = tf.io.FixedLenFeature([feature_length, ], tf.float32)
elif value_type in 'string':
feature_description[feature_name] = tf.io.FixedLenFeature([feature_length, ], tf.string)
return feature_description
def _parse_example_function(example_proto, feature_description):
# Parse the input tf.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
# n_reader : 并行读取文件数
# n_parse_threads : 解析文件时并行数
# shuffle_buffer_size : 混排buffe的大小
def tfrecord_reader_dataset(filenames, feature_list, n_reader=5, batch_size=32, n_parse_threads=5,
shuffle_buffer_size=10000,
compression_type="GZIP"):
feature_description = generate_feature_description(feature_list)
dataset = tf.data.Dataset.list_files(filenames)
# interleave() : 读取数据形成一个dataset
dataset = dataset.interleave(
lambda filename: tf.data.TFRecordDataset(filename, compression_type=compression_type),
cycle_length=n_reader
)
# 训练过程最常用的顺序,先shuffle,在batch,然后repeat
dataset.shuffle(shuffle_buffer_size)
dataset = dataset.map(lambda x: _parse_example_function(x, feature_description),
num_parallel_calls=n_parse_threads)
dataset = dataset.batch(batch_size)
# repeat(): 无参数表示重复无数次
# 作用:在训练模型时我们不止一次使用数据,要多次使用训练集数据,通过epoch来终止,当repeat=1时,大小不变,当repeat=2时,数量翻倍
dataset = dataset.repeat(1)
return dataset
if __name__ == '__main__':
# 读取model_config.json
with open('D:\pythonProject\study0614\config\model_config.json', 'r') as fb:
model_config = json.load(fb)
# # 保存model_config1.json
# with open('D:\pythonProject\study0614\config\model_config.json', 'w') as fb:
# json.dump(feature_dict,fb)
feature_list = model_config['features']
data_count = 100
train_count = int(data_count * 0.8)
val_count = data_count - train_count
# 保存tfrecords
generate_fake_data("D://pythonProject//study0614//data//train//train.tfrecords", feature_list, train_count)
generate_fake_data("D://pythonProject//study0614//data//val//val.tfrecords", feature_list, val_count)
# 读取tfrecords
path = "D:/pythonProject/study0614/data/train"
filenames = [os.path.join(path, sub_path) for sub_path in os.listdir(path)]
# dataset = tf.data.TFRecordDataset(filenames, compression_type='GZIP')
dataset = tfrecord_reader_dataset(filenames=filenames, feature_list=feature_list, batch_size=32)
for data in dataset:
print(data)