import tensorflow as tf
def _int64_feature(value):
if type(value) != list:
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value = value)) def _bytes_feature(value):
if type(value) != list:
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
feature_config = {}
# open the TFRecords file
writer = tf.python_io.TFRecordWriter("./tfrecord.test")
# feature
feature_config['video_id'] = _int64_feature([3,2])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
feature_config['video_id'] = _int64_feature([3,2,1,2,3])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
feature_config['video_id'] = _int64_feature([3,2,9])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
with tf.Session() as sess:
# Create a list of filenames and pass it to a queue
filename_queue = tf.train.string_input_producer(["./tfrecord.test"], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
video_id = tf.feature_column.categorical_column_with_identity(
key='video_id', num_buckets=10, default_value=0)
video_id2 = tf.feature_column.indicator_column(video_id)
columns = [video_id2]
#columns = [video_id]
video_out = tf.train.shuffle_batch([serialized_example], batch_size=1, capacity=30, num_threads=1, min_after_dequeue=3)
features = tf.parse_example(video_out, features= tf.feature_column.make_parse_example_spec(columns))
#video = features['video_id']
#video = tf.cast(features['video_id'], tf.int32)
# indicator
heihei = tf.feature_column.input_layer(features,columns)
#video = features
# Initialize all global and local variables
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
# Create a coordinator and run all QueueRunner objects
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for batch_index in range(3):
#print video
x= sess.run([heihei])
print x
# Stop the threads
coord.request_stop()
# Wait for threads to stop
coord.join(threads)
sess.close()
tensorflow tfrecord save and read demo
最新推荐文章于 2022-04-17 16:13:19 发布
本文介绍如何使用TensorFlow中的TFRecord进行数据的序列化存储与读取,并演示了如何定义特征列以用于后续的数据处理流程。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
TensorFlow-v2.15
TensorFlow
TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

1158

被折叠的 条评论
为什么被折叠?



