TFRecord格式
TFRecord
内部使用了“Protocol Buffer”
二进制数据编码 方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord
文件,来提高处理效率。
写文件
使用TFRecord生成器以及样本Example模块。
writer = tf.python_io.TFRecordWriter(output_file)
tf_example = tf.train.Example(
features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
上述writer
是TFrecord生成器,通过writer.write(tf_example.SerializeToString())
来生成tfrecord文件。
tf_example.SerializeToString()
是将Example中的map压缩为二进制文件,更好的节省空间。
Example协议如下:
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
tf.train.Features(feature = None)
这里的feature是以 字典 的形式存在。
key:要保存数据的名字,value:要保存的数据,格式必须符合tf.train.Feature实例要求。
读取
- 从
tfrecord
文件创建TFRecordDataset - 通过解析器tf.parse_single_example将的example解析出来,即序列化后的
tf.train.Example
,输入参数是name_to_features = { "input_ids": tf.FixedLenFeature([seq_length], tf.int64), "input_mask": tf.FixedLenFeature([seq_length], tf.int64), "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), "label_ids": tf.FixedLenFeature([], tf.int64), "is_real_example": tf.FixedLenFeature([], tf.int64), } d = tf.data.TFRecordDataset(input_file) example = tf.parse_single_example(record, name_to_features)
第一种:TFRecord类型
该种方法在训练模型文件中使用run_classifier.py
将数据文件,保存为TFRecord类型的文件,使用时再从TFRecord文件中读取/解码出来。
-
将输入文本处理为
InputExample
类的形式
调用:predict_examples = get_test_examples(test_file)
函数实现:
def get_test_examples(data_file): """See base class.""" # file_path = os.path.join(data_dir, 'test_1.csv') examples = [] with open(data_file, encoding='utf-8') as f: reader = f.readlines() for i, line in enumerate(reader): guid = "train-%d" % (i) split_line = line.strip().split(",") text_a = tokenization.convert_to_unicode(split_line[1]) text_b = None # text_b = tokenization.convert_to_unicode(split_line[2]) # label = tokenization.convert_to_unicode(line[2]) label = str(split_line[0]) examples