1.tfrecord 格式说明
1.1 tf.train.Example生成tfrecord格式
- 1.tf.train.Features : {“key”: tf.train.Feature}
- tf.train.Feature: tf.train.ByteList / FloatList / Int64List
(字符串类型 / 浮点数类型 / 整数类型)
- 2.利用Features建立Example
- 3.把Example存入文件中
- 新建文件夹
- 使用TFRecordWriter打开文件,写入Example
- 4.使用tf.data.TFRecordDataset读取tfrecord文件并解析
favorite_books = [name.encode('utf-8')
for name in ["machine learning", "cc150"]]
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist)
hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist)
age_int64list = tf.train.Int64List(value = [42])
print(age_int64list)
features = tf.train.Features(
feature = {
"favorite_books": tf.train.Feature(bytes_list = favorite_books_bytelist),
"hours": tf.train.Feature(float_list = hours_floatlist),
"age": tf.train.Feature(int64_list = age_int64list),
}
)
print(features)
example = tf.train.Example(features=features)
print(example)
serialized_example = example.SerializeToString()
print(serialized_example)
output_dir = 'tfrecord_basic'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
for i in range(3):
writer.write(serialized_example)
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
print(serialized_example_tensor)
expected_features = {
"favorite_books": tf.io.VarLenFeature(dtype = tf.string),
"hours": tf.io.VarLenFeature(dtype = tf.float32