TensorFlow入门(十-II)tfrecord 可变长度的序列数据

本例代码:https://github.com/yongyehuang/Tensorflow-Tutorial/tree/master/python/the_use_of_tfrecord

关于 tfrecord 的使用,分别介绍 tfrecord 进行三种不同类型数据的处理方法。
- 维度固定的 numpy 矩阵
- 可变长度的 序列 数据
- 图片数据

在 tf1.3 及以后版本中,推出了新的 Dataset API, 之前赶实验还没研究,可能以后都不太会用下面的方式写了。这些代码都是之前写好的,因为注释中都写得比较清楚了,所以直接上代码。

tfrecord_2_sequence_writer.py

# -*- coding:utf-8 -*- 

import tensorflow as tf
import numpy as np
from tqdm import tqdm

'''tfrecord 写入序列数据,每个样本的长度不固定。
和固定 shape 的数据处理方式类似,前者使用 tf.train.Example() 方式,而对于变长序列数据,需要使用 
tf.train.SequenceExample()。 在 tf.train.SequenceExample() 中,又包括了两部分:
context 来放置非序列化部分;
feature_lists 放置变长序列。

refer: 
https://github.com/tensorflow/magenta/blob/master/magenta/common/sequence_example_lib.py
https://github.com/dennybritz/tf-rnn
http://leix.me/2017/01/09/tensorflow-practical-guides/
https://github.com/siavash9000/im2txt_demo/blob/master/im2txt/im2txt/ops/inputs.py
'''

# **1.创建文件
writer1 = tf.python_io.TFRecordWriter('../../data/seq_test1.tfrecord')
writer2 = tf.python_io.TFRecordWriter('../../data/seq_test2.tfrecord')

# 非序列数据
labels = [1, 2, 3, 4, 5, 1, 2, 3, 4]
# 长度不固定的序列
frames = [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]


writer = writer1
for i in tqdm(xrange(len(labels))):  # **2.对于每个样本
    if i == len(labels) / 2:
        writer = writer2
        print('\nThere are %d sample writen into writer1' % i)
    label = labels[i]
    frame = frames[i]
    # 非序列化
    label_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    # 序列化
    frame_feature = [
        tf.train.Feature(int64_list=tf.train.Int64List(value=[frame_])) for frame_ in frame
    ]

    seq_example = tf.train.SequenceExample(
        # context 来放置非序列化部分
        context=tf.train.Features(feature={
            "label": label_feature
        }),
        # feature_lists 放置变长序列
        feature_lists=tf.train.FeatureLists(feature_list={
            "frame": tf.train.FeatureList(feature=frame_feature),
        })
    )

    serialized = seq_example.SerializeToString()
    writer.write(serialized)  # **4.写入文件中

print('Finished.')
writer1.close()
writer2.close()

tfrecord_2_sequence_reader.p

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值