2021SC@SDUSC
前篇背景介绍:前几周的源代码分析中,我们已经了解了drfact是如何对语料库进行预处理的,也了解了drfact模型算法的前几步都做了什么事情。但这一周的源代码分析我不会对具体的源代码进行分析,原因在于我在本周进行源代码分析,并回顾了过往的源代码分析内容时,注意到drfact模型对其他模型进行了一定程度的借鉴,这一点尤其体现在其核心源代码之中——调用了其他模型中编写好的函数。因此,体现在源代码之中的内容也就不再是仅仅只要关注到drfact这一个项目包即可,而是需要对整个OpenCSR项目的其他源代码也进行审视。
承接上文,在上一周的源代码分析中,我主要描述了DrFact模型与DrKit模型之间的勾连借鉴。这一点主要体现在了DrFact模型的各个模块对于DrKit模型中某些模块的调用,其中尤其以input_fns.py和model.fns.py源文件为典型,在这两个模型中,都出现了名字如上述一般的源文件,由此可见二者在功能定位上应该有相近之处。此外,在对DrFact模型中模块的分析中我们可以认识到,有些来自于DrKit模型的函数在DrFact模型中被反复用到,且是跨越多个模块被调用,这让我们可以确定这些函数具有重要的分析价值。因此,在本周的这篇源代码分析中,我将主要阐述这些DrKit模型中的函数定义,希望通过更加细微的刻画来展现这些函数的具体用途。
二、DrKit模型的已定义函数
2.1 BERT与bert_utils_v2.py模块
要提及DrKit模型中对DrFact模型中产生贡献的模块,首先要谈到BERT。BERT在前面的源代码分析中有提及过,它的全称为Bidirectional Encoder Representation from Transformers,是一个预训练的语言表征模型。它强调了不再像以往一样采用传统的单向语言模型或者把两个单向语言模型进行浅层拼接的方法进行预训练,而是采用新的masked language model(MLM),以致能生成深度的双向语言表征。
有了上文做一些铺垫,我们可以转头去看有关DrKit模型中定义的有关BERT的模块。有上一篇的源码分析内容可以很容易知道,与BERT最直接也是唯一相关的一个模块是bert_utils_v2.py这个模块,在这个模块中定义有一个BERTPredictor类。根据其注释易知,这个类就是一个封装了BERT模型的编码器。
BERTPredictor类的构造函数如下所示,可以很清晰的知道,这个类共有七个成员变量,其中的序列最大长度max_seq_length,查询最大长度max_qry_length,实体最大长度max_seq_length,词向量大小emb_dim和批大小(这里的batch即是神经网络模型中的那个)batch_size都是通过源代码中的flags定义的参数传进来的,而分词器tokenizer则是通过构造函数的形参传进来的。
此外,在构造函数的形参中还有一个缺省值为None的形参estimator,当它是None时,即没有一个明确的estimator时,构造函数则会根据flags定义的bert_config_file参数指向的文件为bert模型配置参数,而运行配置run_config则采用tensorflow中的estimator配置,至于QA系统的配置则使用DrKit中写好的run_dualencoder_qa模块中的QAConfig()函数,传入flags定义好的参数,从而形成这里的QA系统配置(由于继续深挖下去就没完没了了,也不属于核心代码范畴,因此对于run_dualencoder_qa模块就不细究了,不过望文生义来说,这个模块应该是用来运行QA系统的双工编码器的)。model_fn亦同理,采用了DrKit中定义好的模块内容进行初始化,最后根据上述model_fn和配置内容,通过tensorflow中TUREstimator构造函数来构建estimator,再将这个estimator作为FastPredictor的参数,构建出一个BERTPredictor类的最后一个成员变量fast_predictor,由此,这样一个BERT模型就构建完成了。
class BERTPredictor:
"""Wrapper around a BERT model to encode text."""
def __init__(self, tokenizer, init_checkpoint, estimator=None):
"""Setup BERT model."""
self.max_seq_length = FLAGS.max_seq_length
self.max_qry_length = FLAGS.max_query_length
self.max_ent_length = FLAGS.max_entity_length
self.batch_size = FLAGS.predict_batch_size
self.tokenizer = tokenizer
if estimator is None:
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
run_config = tf.estimator.tpu.RunConfig()
qa_config = run_dualencoder_qa.QAConfig(
doc_layers_to_use=FLAGS.doc_layers_to_use,
doc_aggregation_fn=FLAGS.doc_aggregation_fn,
qry_layers_to_use=FLAGS.qry_layers_to_use,
qry_aggregation_fn=FLAGS.qry_aggregation_fn,
projection_dim=FLAGS.projection_dim,
normalize_emb=FLAGS.normalize_emb,
share_bert=True,
exclude_scopes=None)
mod