import tensorflow as tf def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) if __name__=="__main__": filename0="file0.tfrecords" writer=tf.python_io.TFRecordWriter(filename0) for index in range(10): example=tf.train.Example(features=tf.train.Features(feature={ 'v1':_int64_feature(index), 'v2':_int64_feature(index+1)})) writer.write(example.SerializeToString()) write.close() filename1="file.tfrecords" writer=tf.python_io.TFRecordWriter(filename1) for index in range(10,20): example=tf.train.Example(features=tf.train.Features(feature={ 'v1': _int64_feature(index), 'v2':_int64_feature(index+1)})) writer.write(example.SerializeToString()) writer.close() filename_queue =tf.train.string_input_producer(["file0.tfrecords", "file1.tfrecords"],shuffle=True,num_epochs=2) reader=tf.TFRecordReader() _,serialized_example=reader.read(filename_queue) features=tf.parse_single_example( serialized_example, features={'v1':tf.FixedLenFeature([],tf.int64),'v2':tf.FixedLenFeature([],tf.int64)} ) v1=tf.cast(features['v1'],tf.int32) v2=tf.cast(features['v2'],tf.int32) v_mul=tf.multiply(v1,v2) init_op=tf.global_variables_initializer() local_init_op=tf.local_variables_initializer() sess=tf.Session() sess.run(init_op) sess.run(local_init_op) coord=tf.train.Coordinator() threads=tf.train.start_queue_runners(sess=sess,coord=coord) try: while not coord.should_stop(): value1,value2,mul_result=sess.run([v1,v2,v_mul]) print("%f\t%f\t%f"%(value1,value2,mul_result)) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop() coord.join(threads) sess.close()
tensorflow 读取TFRecord格式数据并进行计算代码
最新推荐文章于 2019-11-22 16:14:24 发布