TFRecord数据生成和解析

配置文件

json格式的配置文件,用于tfrecord文件的生成和读取

{
  "name": "config",
  "features": [
    {
      "feature_name": "user_id",
      "feature_type": "id",
      "value_type": "string",
      "feature_size": 100,
      "feature_length": 1
    },
    {
      "feature_name": "item_id",
      "feature_type": "id",
      "value_type": "string",
      "feature_size": 1000,
      "feature_length": 1
    },
    {
      "feature_name": "recall_item_id",
      "feature_type": "discrete",
      "value_type": "int64",
      "feature_size": 1000,
      "feature_length": 1
    },
    {
      "feature_name": "is_play",
      "feature_type": "label",
      "value_type": "int64",
      "feature_size": 2,
      "feature_length": 1
    },
    {
      "feature_name": "play_duration",
      "feature_type": "label",
      "value_type": "float32",
      "feature_size": 1,
      "feature_length": 1
    },
    {
      "feature_name": "age",
      "feature_type": "discrete",
      "value_type": "int64",
      "feature_size": 10,
      "feature_length": 1
    },
    {
      "feature_name": "gender",
      "feature_type": "discrete",
      "value_type": "int64",
      "feature_size": 3,
      "feature_length": 1
    }
  ]
}

解析config文件,并生成tfrecord文件

  1. 根据config文件,生成feature_list,字典格式
  2. 根据feature_list,生成tf.train.Example,根据generate_fake_data(data_dir, feature_list, data_count),生成假数据,其中data_dir为具体路径+文件名,feature_list,字典格式,data_count表示数据量,压缩格式为GZIP
  3. tfrecord_reader_dataset(filenames, feature_list, n_reader=5, batch_size=32, n_parse_threads=5,shuffle_buffer_size=10000,compression_type=“GZIP”),读取tfrecord文件,filenames为list格式,生成dataset,首先经过shuffle,batch,最后repeat,注意注释
import json
import os
import numpy as np
import tensorflow as tf
import collections


def _bytes_feature(value):
    """Returns a bytes_list from  string / byte."""
    if not isinstance(value, collections.Iterable):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def _float_feature(value):
    """Returns a float_list from  float / double."""
    if not isinstance(value, collections.Iterable):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    if not isinstance(value, collections.Iterable):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def generate_fake_data(data_dir, feature_list, data_count):
    with tf.io.TFRecordWriter(data_dir, options="GZIP") as writer:
        for _ in range(int(data_count)):
            feature_dict = collections.OrderedDict()
            for feature in feature_list:
                feature_name = feature["feature_name"]
                feature_type = feature["feature_type"]
                value_type = feature["value_type"]
                feature_size = feature["feature_size"]
                feature_length = feature["feature_length"]
                if value_type == 'string':
                    value = np.random.randint(low=0, high=feature_size, size=feature_length).astype('bytes')
                    feature_dict[feature_name] = _bytes_feature(value)
                elif value_type == 'int64':
                    value = np.random.randint(low=0, high=feature_size, size=feature_length)
                    feature_dict[feature_name] = _int64_feature(value)
                else:
                    value = np.random.random(size=feature_length)
                    feature_dict[feature_name] = _float_feature(value)
                # print(feature_dict[feature_name])
            example_proto = tf.train.Example(features=tf.train.Features(feature=feature_dict)).SerializeToString()
            writer.write(example_proto)


def generate_feature_description(feature_list):
    feature_description = collections.OrderedDict()
    for feature in feature_list:
        feature_name = feature["feature_name"]
        value_type = feature["value_type"]
        feature_length = feature["feature_length"]
        if value_type in 'int64':
            feature_description[feature_name] = tf.io.FixedLenFeature([feature_length, ], tf.int64)
        elif value_type in 'float32':
            feature_description[feature_name] = tf.io.FixedLenFeature([feature_length, ], tf.float32)
        elif value_type in 'string':
            feature_description[feature_name] = tf.io.FixedLenFeature([feature_length, ], tf.string)
    return feature_description


def _parse_example_function(example_proto, feature_description):
    # Parse the input tf.Example proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, feature_description)


# n_reader : 并行读取文件数
# n_parse_threads : 解析文件时并行数
# shuffle_buffer_size : 混排buffe的大小
def tfrecord_reader_dataset(filenames, feature_list, n_reader=5, batch_size=32, n_parse_threads=5,
                            shuffle_buffer_size=10000,
                            compression_type="GZIP"):
    feature_description = generate_feature_description(feature_list)
    dataset = tf.data.Dataset.list_files(filenames)
    # interleave() : 读取数据形成一个dataset
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(filename, compression_type=compression_type),
        cycle_length=n_reader
    )
    # 训练过程最常用的顺序,先shuffle,在batch,然后repeat
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(lambda x: _parse_example_function(x, feature_description),
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    # repeat(): 无参数表示重复无数次
    # 作用:在训练模型时我们不止一次使用数据,要多次使用训练集数据,通过epoch来终止,当repeat=1时,大小不变,当repeat=2时,数量翻倍
    dataset = dataset.repeat(1)

    return dataset


if __name__ == '__main__':
    # 读取model_config.json
    with open('D:\pythonProject\study0614\config\model_config.json', 'r') as fb:
        model_config = json.load(fb)

    # # 保存model_config1.json
    # with open('D:\pythonProject\study0614\config\model_config.json', 'w') as fb:
    #     json.dump(feature_dict,fb)

    feature_list = model_config['features']

    data_count = 100
    train_count = int(data_count * 0.8)
    val_count = data_count - train_count

    # 保存tfrecords
    generate_fake_data("D://pythonProject//study0614//data//train//train.tfrecords", feature_list, train_count)
    generate_fake_data("D://pythonProject//study0614//data//val//val.tfrecords", feature_list, val_count)

    # 读取tfrecords
    path = "D:/pythonProject/study0614/data/train"
    filenames = [os.path.join(path, sub_path) for sub_path in os.listdir(path)]
    # dataset = tf.data.TFRecordDataset(filenames, compression_type='GZIP')
    dataset = tfrecord_reader_dataset(filenames=filenames, feature_list=feature_list, batch_size=32)

    for data in dataset:
        print(data)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值