import tensorflow as tf
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
if __name__=="__main__":
filename0="file0.tfrecords"
writer=tf.python_io.TFRecordWriter(filename0)
for index in range(10):
example=tf.train.Example(features=tf.train.Features(feature={
'v1':_int64_feature(index),
'v2':_int64_feature(index+1)}))
writer.write(example.SerializeToString())
write.close()
filename1="file.tfrecords"
writer=tf.python_io.TFRecordWriter(filename1)
for index in range(10,20):
example=tf.train.Example(features=tf.train.Features(feature={
'v1': _int64_feature(index),
'v2':_int64_feature(index+1)}))
writer.write(example.SerializeToString())
writer.close()
filename_queue =tf.train.string_input_producer(["file0.tfrecords",
"file1.tfrecords"],shuffle=True,num_epochs=2)
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
features=tf.parse_single_example(
serialized_example,
features={'v1':tf.FixedLenFeature([],tf.int64),'v2':tf.FixedLenFeature([],tf.int64)}
)
v1=tf.cast(features['v1'],tf.int32)
v2=tf.cast(features['v2'],tf.int32)
v_mul=tf.multiply(v1,v2)
init_op=tf.global_variables_initializer()
local_init_op=tf.local_variables_initializer()
sess=tf.Session()
sess.run(init_op)
sess.run(local_init_op)
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
try:
while not coord.should_stop():
value1,value2,mul_result=sess.run([v1,v2,v_mul])
print("%f\t%f\t%f"%(value1,value2,mul_result))
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
sess.close()
tensorflow 读取TFRecord格式数据并进行计算代码
最新推荐文章于 2019-11-22 16:14:24 发布
本文详细介绍使用TensorFlow进行TFRecords文件的创建、读取及数据解析过程,演示了如何将整数特征写入TFRecords,并从文件中读取这些数据进行乘法运算。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
TensorFlow-v2.15
TensorFlow
TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型
561

被折叠的 条评论
为什么被折叠?



