2021SC@SDUSC
前几周的源代码分析中,我们已经了解了drfact是如何对语料库进行预处理的,也了解了drfact模型算法的前几步都做了什么事情。但这一周的源代码分析我不会对具体的源代码进行分析,原因在于我在本周进行源代码分析,并回顾了过往的源代码分析内容时,注意到drfact模型对其他模型进行了一定程度的借鉴,这一点尤其体现在其核心源代码之中——调用了其他模型中编写好的函数。因此,体现在源代码之中的内容也就不再是仅仅只要关注到drfact这一个项目包即可,而是需要对整个OpenCSR项目的其他源代码也进行审视。
某种意义上,我对OpenCSR这个项目源代码的核心产生了一定程度错误的评判,这也意味着我需要花上更多的力气对这个项目进行更深层次的理解。因此,在我对源代码的整体结构以及各模型之间的相互勾连达成宏观层面的理解之前,我暂且不会对drfact模型中具体实现算法以及其它进行数据处理的代码中具体细微的详细描述进行进一步的分析与探究,而是转头分析各个模型之间可以互通的模块以及函数特点。这些可以互通的模块以及函数特点会一并体现在这篇博客之中。
一、DrFact模型与其对DrKit模型函数的调用情况
要分析诸多模型之间的关联,尤其是这些关联对我们的主要研究对象——DrFact模型究竟产生了怎样的影响,有着怎样的作用,首先便是从DrFact模型本身入手。而从DrFact模型本身的源代码中不难看出,与DrFact模型有着直接关联的便是DrKit模型,在DrFact模型的许多源文件中,均可见来自于DrKit模型的python源文件模块被导入,源代码中也调用了许多来源于DrKit模型的函数。其中包含有以下的源文件。
1.1 convert_dpr_index.py源文件
在convert_dpr_index.py源文件中,导入了来自于DrKit模型的search_utils.py,并且调用了其中的write_to_checkpoint()函数。这个源文件正是在上一周的代码分析中分析到的源文件,正是在这次源码分析的过程中,让我意识到了我对于这些源文件之间的关联并不是特别明晰,为此我中断了之前的单篇源文件分析的工作,转头去研究这些源文件之间的关联程度。
具体调用代码如下所示:
from language.labs.drkit import search_utils
with tf.device("/cpu:0"):
search_utils.write_to_checkpoint("fact_db_emb", fact_emb, tf.float32, output_index_path)
1.2 fact2fact_index.py源文件
在fact2fact_index.py源文件中,同样也导入了来自DrKit模型的search_utils.py,并且调用了其中的write_ragged_to_checkpoint()函数。从名字上来看,这个函数的功能应该是比较类似于在之前的convert_dpr_index.py源文件调用到的write_to_checkpoint()函数。
具体调用代码如下所示:
from language.labs.drkit import search_utils
search_utils.write_ragged_to_checkpoint(
"fact2fact", sp_fact2fact,
os.path.join(FLAGS.fact2fact_index_dir, "fact2fact.npz"))
1.3 index_corpus.py源文件
在index_corpus.py源文件中,同样导入了来自DrKit模型的search_utils.py,并且多次调用了其中的write_to_checkpoint()函数和write_ragged_to_checkpoint()函数。这两个函数都是我们在上述两个源文件中使用到的函数,在这个源文件中竟然又被用到了,可想而知,这两个函数有着举足轻重的重要意义,值得我们后续对其进行深挖。
此外,在该源文件中,还导入了bert_utils_v2.py和hotpotqa.index.py,分别调用了BERTPredictor()函数与get_sub_paras()函数。
具体调用代码如下所示:
from language.labs.drkit import bert_utils_v2
from language.labs.drkit import search_utils
from language.labs.drkit.hotpotqa import index as index_util
# 以下是导入search_utils后调用的实例
search_utils.write_to_checkpoint(
"coref", np.array([m[0] for m in mentions], dtype=np.int32), tf.int32,
os.path.join(FLAGS.index_result_path, "coref.npz"))
search_utils.write_ragged_to_checkpoint(
"ent2ment", sp_entity2mention,
os.path.join(FLAGS.index_result_path, "ent2ment.npz"))
search_utils.write_ragged_to_checkpoint(
"ent2fact", sp_entity2fact,
os.path.join(FLAGS.index_result_path,
"ent2fact_%d.npz" % FLAGS.max_facts_per_entity))
search_utils.write_ragged_to_checkpoint(
"fact2ent", sp_fact2entity,
os.path.join(FLAGS.index_result_path, "fact_coref.npz"))
search_utils.write_to_checkpoint(
"entity_ids", entity_ids, tf.int32,
os.path.join(FLAGS.index_result_path, "entity_ids"))
search_utils.write_to_checkpoint(
"entity_mask", entity_mask, tf.float32,
os.path.join(FLAGS.index_result_path, "entity_mask"))
with tf.device("/cpu:0"):
search_utils.write_to_checkpoint(
"db_emb_%d" % ns, mention_emb, tf.float32,
os.path.join(FLAGS.index_result_path,
"%s_mention_feats_%d" % (embed_prefix, ns)))
tf_db = search_utils.load_database(
db_emb_str + "_%d" % i, var_to_shape_map[db_emb_str + "_%d" % i],
ckpt_path)
search_utils.write_to_checkpoint(
db_emb_str, np_db, tf.float32,
os.path.join(FLAGS.index_result_path, embed_feats_str))
with tf.device("/cpu:0"):
search_utils.write_to_checkpoint(
"fact_db_emb_%d" % ns, fact_emb, tf.float32,
os.path.join(FLAGS.index_result_path,
"%s_fact_feats_%d" % (embed_prefix, ns)))
# 以下是导入bert_utils_v2后调用的实例
bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt)
# 以下是导入hotpotqa.index后调用的实例
sub_para_objs = index_util.get_sub_paras(orig_para, tokenizer,
FLAGS.max_seq_length,
FLAGS.doc_stride, total_sub_paras)
1.4 input_fns.py源文件
在input_fns.py源文件中,则导入了来自DrKit模型的input_fns.py,并且调用了其中的get_tokens_and_mask()函数。从两个源文件的名字完全一致可以看出,这两个源文件的功能意义应该是相近的。而我们根据当前DrFact模型的input_fns.py源文件的注释可以得知,这个源文件的主要功能在于将不同的数据集处理成一种通用格式的类,因此可以在此做出一定的猜测,即DrKit模型中的input_fns.py源文件的功能也基本相同,是一个特殊数据集转通用的类的定义文件。
具体调用代码如下所示:
from language.labs.drkit import input_fns as input_utils
(qry_input_ids, qry_input_mask,
qry_tokens) = input_utils.get_tokens_and_mask(example.question_text,
tokenizer, max_query_length)
1.5 model_fns.py源文件
在model_fns.py源文件中,同样导入了来自DrKit模型的search_utils.py,并且多次调用了其中的create_mips_searcher()函数。
此外,在该源文件中,还导入了来自DrKit模型的model_fns.py,调用了许多函数,如下所示:
- entity_emb()函数,
- sparse_ragged_mul()函数,
- ensure_values_in_mat()函数,
- convert_search_to_vector()函数,
- sp_sp_matmul()函数,
- rescore_sparse()函数,
- aggregate_sparse_indices()函数,
- shared_qry_encoder_v2()函数,
- layer_qry_encoder()函数,
- remove_from_sparse()函数,
- batch_multiply()函数,
- compute_loss_from_sptensors()函数。
类似于input_fns.py文件,在model_fns.py文件中调用的来自于DrKit模型的模块也具有相同的名字model_fns。因此,我们在此同样做出假设,DrKit模型中的model_fns.py源文件的功能也与DrFact模型中的基本相同,是一个实现不同多跳变量的模型函数集合的定义文件。
具体调用代码如下所示:
from language.labs.drkit import model_fns as model_utils
from language.labs.drkit import search_utils
# 以下是导入search_utils调用的函数实例
with tf.device("/cpu:0"):
tf_fact_db, fact_mips_search_fn = search_utils.create_mips_searcher(
fact_mips_config.ckpt_var_name,
# [fact_mips_config.num_facts, fact_mips_config.emb_size],
fact_mips_config.ckpt_path,
fact_mips_config.num_neighbors,
local_var_name="scam_init_barrier_fact")
# 以下是导入model_fns调用的部分函数实例
batch_entity_emb = model_utils.entity_emb(entity_ind, entity_word_ids,
entity_word_masks, word_emb_table,
word_weights)
sp_mention_vec = model_utils.sparse_ragged_mul(
batch_entities,
ent2ment_ind,
ent2ment_val,
batch_size,
mips_config.num_mentions,
qa_config.sparse_reduce_fn, # max or sum
threshold=qa_config.entity_score_threshold,
fix_values_to_one=qa_config.fix_sparse_to_one)
if is_training and qa_config.ensure_answer_dense:
ret_mention_ids = model_utils.ensure_values_in_mat(
ret_mention_ids, ensure_index, tf.int32)
dense_mention_vec = model_utils.convert_search_to_vector(
ret_mention_scs, ret_mention_ids, tf.cast(batch_size, tf.int32),
mips_config.num_neighbors, mips_config.num_mentions)
if qa_config.sparse_strategy == "dense_first":
ret_mention_vec = model_utils.sp_sp_matmul(dense_mention_vec,
sp_mention_vec)
with tf.device("/cpu:0"):
ret_mention_vec = model_utils.rescore_sparse(sp_mention_vec, tf_db,
scam_qrys)
uniq_fact_ids, uniq_fact_scs = model_utils.aggregate_sparse_indices(
sp_fact_vec.indices, sp_fact_vec.values, sp_fact_vec.dense_shape,
"max")
qry_seq_emb, word_emb_table, qry_hidden_size = model_utils.shared_qry_encoder_v2(
qry_input_ids, qry_input_mask, is_training, use_one_hot_embeddings,
bert_config, qa_config)
综上所述,在本文中我们主要介绍了DrFact模型中的各个模块是如何与DrKit模型产生勾连的,并描述了它们是如何调用这些函数,这些函数调用的频率高低,以及这些函数功能的粗浅定义。在下一篇中,我将对这些被调用的函数进行更加详尽的分析。