版权声明:本文为博主原创文章,转载请注明出处:https://blog.youkuaiyun.com/ling620/article/details/97789853
文章目录
之前的文章介绍了如何使用Bert的extract_features.py
去提取特征向量,本文对源码进一步的分析。
BERT之提取特征向量 及 bert-as-server的使用
代码位于: bert/extract_features.py
本文主要包含两部分内容:
- 对源码进行分析
- 对源码进行简化
源码分析
1. 输入参数
必选参数,如下:
input_file
:数据存放路径vocab_file
:字典文件的地址bert_config_file
:配置文件init_checkpoint
:模型文件output_file
:输出文件
if __name__ == "__main__":
flags.mark_flag_as_required("input_file")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("init_checkpoint")
flags.mark_flag_as_required("output_file")
tf.app.run()
其他参数:
在文件最开始部分
layers
:获取的层数索引, 默认值是 [-1, -2, -3, -4] 即表示倒数第一层、倒数第二层、倒数第三层和倒数第四层max_seq_length
:输入序列的最大长度,大于此值则截断,小于此值则填充0batch_size
:预测的batch大小use_tpu
:是否使用TPU
use_one_hot_embeddings
:是否使用独热编码
flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
flags.DEFINE_bool(
"use_one_hot_embeddings", False,
"If True, tf.one_hot will be used for embedding lookups, otherwise "
"tf.nn.embedding_lookup will be used. On TPUs, this should be True "
"since it is much faster.")
2. 主流程
主要有以下几个步骤:
-
读取配置文件,构建
BertConfig
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
-
获取tokenization的对象
tokenization.py是对输入的句子处理,包含两个主要类:BasickTokenizer, FullTokenizer
BasickTokenizer
会对每个字做分割,会识别英文单词,对于数字会合并,例如:query: 'Jack,请回答1988, UNwant\u00E9d,running' token: ['jack', ',', '请', '回', '答', '1988', ',', 'unwanted', ',', 'running']
FullTokenizer
会对英文字符做n-gram匹配,会将英文单词拆分,例如running会拆分为run、##ing,主要是针对英文。query: 'UNwant\u00E9d,running' token: ["un", "##want", "##ed", ",", "runn", "##ing"]
-
获取RunConfig对象,作为TPUEstimator的输入参数
run_config = tf.contrib.tpu.RunConfig()
-
读取输入文件,处理为
InputExample
类型的列表examples.append(InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
-
将输入文件转化为
InputFeatures
类型的列表
features = convert_examples_to_features()
-
构造model
-
构造
Estimator
对象 -
构造输入input_fn
-
进行预测,获取结果,存入json文件中
results = estimator.predict(input_fn, yield_single_examples=True)
依次将结果读取
for result in estimator.predict(input_fn, yield_single_examples=True): for (i, token) in enumerate(feature.tokens): all_layers = [] for (j, layer_index) in enumerate(layer_indexes): layer_output = result["layer_output_%d" % j] layers = collections.OrderedDict() layers["index"] = layer_index layers["values"] = [ round(float(x), 6) for x in layer_output[i:(i + 1)].flat ] all_layers.append(layers)
上面这部分代码的意思是只取出输入 经过
tokenize
之后的长度 的向量。
即如果max_seq_lenght
设为128, 如果输入的句子为我爱你,则经过tokenize之后的输入tokens=[["CLS"], '我', '爱','你',["SEP"]]
,实际有效长度为5,而其余128-5位均填充0。
上面代码就是只取出有效长度的向量。
layer_output
的维度是(128, 768),layers["values"]
的维度是是(5,768)
这在文章BERT之提取特征向量 及 bert-as-server的使用中提到。
上述几个流程详细内容见下一小节。
源码及注释如下:
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
# 读取配置文件,构建BertConfig
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
# 对句子进行处理,拆分
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
# 获取RunConfig对象,作为TPUEstimator的输入参数
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
master=FLAGS.master,
tpu_config=tf.contrib.tpu.TPUConfig(
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
# 读取输入文件,处理为InputExample类型的列表
examples = read_examples(FLAGS.input_file)
# 将输入文件转化为InputFeatures类型的列表
features = convert_examples_to_features(
examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
# 构造id到特征的映射字典
unique_id_to_feature = {
}
for feature in features:
unique_id_to_feature[feature.unique_id] = feature
# 构造model
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=FLAGS.init_checkpoint,
layer_indexes=layer_indexes,