BERT 提取特征 (extract_features.py) 源码分析 代码简化

版权声明:本文为博主原创文章,转载请注明出处:https://blog.youkuaiyun.com/ling620/article/details/97789853

之前的文章介绍了如何使用Bert的extract_features.py去提取特征向量,本文对源码进一步的分析。
BERT之提取特征向量 及 bert-as-server的使用

代码位于: bert/extract_features.py

本文主要包含两部分内容:

  1. 对源码进行分析
  2. 对源码进行简化

源码分析

1. 输入参数

必选参数,如下:

  • input_file:数据存放路径
  • vocab_file:字典文件的地址
  • bert_config_file:配置文件
  • init_checkpoint:模型文件
  • output_file:输出文件
if __name__ == "__main__":
    flags.mark_flag_as_required("input_file")
    flags.mark_flag_as_required("vocab_file")
    flags.mark_flag_as_required("bert_config_file")
    flags.mark_flag_as_required("init_checkpoint")
    flags.mark_flag_as_required("output_file")
    tf.app.run()

其他参数:

在文件最开始部分

  • layers:获取的层数索引, 默认值是 [-1, -2, -3, -4] 即表示倒数第一层、倒数第二层、倒数第三层和倒数第四层
  • max_seq_length:输入序列的最大长度,大于此值则截断,小于此值则填充0
  • batch_size:预测的batch大小
  • use_tpu:是否使用TPU
  • use_one_hot_embeddings:是否使用独热编码
flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
flags.DEFINE_integer(
    "max_seq_length", 128,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")
    
flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")

flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")

flags.DEFINE_bool(
    "use_one_hot_embeddings", False,
    "If True, tf.one_hot will be used for embedding lookups, otherwise "
    "tf.nn.embedding_lookup will be used. On TPUs, this should be True "
    "since it is much faster.")

2. 主流程

主要有以下几个步骤:

  1. 读取配置文件,构建BertConfig
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  2. 获取tokenization的对象
    tokenization.py是对输入的句子处理,包含两个主要类:BasickTokenizer, FullTokenizer

    BasickTokenizer会对每个字做分割,会识别英文单词,对于数字会合并,例如:

    query: 'Jack,请回答1988, UNwant\u00E9d,running'
    token: ['jack', ',', '请', '回', '答', '1988', ',', 'unwanted', ',', 'running']
    

    FullTokenizer会对英文字符做n-gram匹配,会将英文单词拆分,例如running会拆分为run、##ing,主要是针对英文。

    query: 'UNwant\u00E9d,running'
    token: ["un", "##want", "##ed", ",", "runn", "##ing"]
    
  3. 获取RunConfig对象,作为TPUEstimator的输入参数
    run_config = tf.contrib.tpu.RunConfig()

  4. 读取输入文件,处理为InputExample类型的列表

    examples.append(InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
    
  5. 将输入文件转化为InputFeatures类型的列表
    features = convert_examples_to_features()

  6. 构造model

  7. 构造Estimator对象

  8. 构造输入input_fn

  9. 进行预测,获取结果,存入json文件中

    results =  estimator.predict(input_fn, yield_single_examples=True)
    

    依次将结果读取

    for result in estimator.predict(input_fn, yield_single_examples=True):
    	for (i, token) in enumerate(feature.tokens):
            all_layers = []
            for (j, layer_index) in enumerate(layer_indexes):
                layer_output = result["layer_output_%d" % j]
                layers = collections.OrderedDict()
                layers["index"] = layer_index
                layers["values"] = [
                        round(float(x), 6) for x in layer_output[i:(i + 1)].flat
                    ]
                all_layers.append(layers)
    

上面这部分代码的意思是只取出输入 经过tokenize之后的长度 的向量。
即如果max_seq_lenght设为128, 如果输入的句子为我爱你,则经过tokenize之后的输入tokens=[["CLS"], '我', '爱','你',["SEP"]],实际有效长度为5,而其余128-5位均填充0。
上面代码就是只取出有效长度的向量。
layer_output的维度是(128, 768), layers["values"]的维度是是(5,768)

这在文章BERT之提取特征向量 及 bert-as-server的使用中提到。

上述几个流程详细内容见下一小节。

源码及注释如下:

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
	# 读取配置文件,构建BertConfig
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
	
	# 对句子进行处理,拆分
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
	
	# 获取RunConfig对象,作为TPUEstimator的输入参数
    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        master=FLAGS.master,
        tpu_config=tf.contrib.tpu.TPUConfig(
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))
	# 读取输入文件,处理为InputExample类型的列表
    examples = read_examples(FLAGS.input_file)
	# 将输入文件转化为InputFeatures类型的列表
    features = convert_examples_to_features(
        examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
	# 构造id到特征的映射字典
    unique_id_to_feature = {
   }
    for feature in features:
        unique_id_to_feature[feature.unique_id] = feature
	# 构造model
    model_fn = model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        layer_indexes=layer_indexes,
       
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值