BERT 两种输入数据处理方式

TFRecord格式

TFRecord内部使用了“Protocol Buffer” 二进制数据编码 方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

写文件

使用TFRecord生成器以及样本Example模块。

writer = tf.python_io.TFRecordWriter(output_file)
tf_example = tf.train.Example(
            features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()

上述writer是TFrecord生成器,通过writer.write(tf_example.SerializeToString())来生成tfrecord文件。
tf_example.SerializeToString()是将Example中的map压缩为二进制文件,更好的节省空间。

Example协议如下:

message Example {
   
  Features features = 1;
};

message Features {
   
  map<string, Feature> feature = 1;
};

tf.train.Features(feature = None)这里的feature是以 字典 的形式存在。
key:要保存数据的名字,value:要保存的数据,格式必须符合tf.train.Feature实例要求。

读取
  1. tfrecord文件创建TFRecordDataset
  2. 通过解析器tf.parse_single_example将的example解析出来,即序列化后的tf.train.Example,输入参数是
    name_to_features = {
         
            "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
            "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
            "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
            "label_ids": tf.FixedLenFeature([], tf.int64),
            "is_real_example": tf.FixedLenFeature([], tf.int64),
        }
        
    d = tf.data.TFRecordDataset(input_file)
    
    example = tf.parse_single_example(record, name_to_features)
    

第一种:TFRecord类型

该种方法在训练模型文件中使用run_classifier.py

将数据文件,保存为TFRecord类型的文件,使用时再从TFRecord文件中读取/解码出来。

  1. 将输入文本处理为InputExample类的形式
    调用:

    predict_examples = get_test_examples(test_file)
    

    函数实现:

    def get_test_examples(data_file):
        """See base class."""
        # file_path = os.path.join(data_dir, 'test_1.csv')
        examples = []
        with open(data_file, encoding='utf-8') as f:
            reader = f.readlines()
        for i, line in enumerate(reader):
            guid = "train-%d" % (i)
            split_line = line.strip().split(",")
            text_a = tokenization.convert_to_unicode(split_line[1])
            text_b = None
            # text_b = tokenization.convert_to_unicode(split_line[2])
            # label = tokenization.convert_to_unicode(line[2])
            label = str(split_line[0])
            examples
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值