2021SC@SDUSC
在上一周中,我已经对DrFact模型算法的剩余部分做了一定的阐述,接下来需要做的是对最后一部分DrFact模型算法的阐述完成。
三、add_sup_facts.py脚本
在这次的源代码分析中,我将会把重心放在add_sup_facts.py脚本上。接下里就直接进入分析环节。
3.1 调用模块
首先,add_sup_facts.py文件调用了这些python模块。其中absl和tqdm已经是老熟人了,最后一个则是我们在上一周的分析过程中认真研究过的index_corpus模块。因为已经有了上一周的铺垫,我们将会在这一周的分析过程中少碰到许多困难。
import json
from absl import app
from absl import flags
from absl import logging
from tqdm import tqdm
from language.labs.drfact import index_corpus
3.2 flags参数与全局常量
老规矩,在介绍完模块之后,接下来开始介绍add_sup_facts.py文件中的flags参数。不过这次有所不同在于,除了介绍下述的六个flags参数,我还将会介绍将会在本次源文件中使用到的全局常量。
首先是这次定义的flags参数,这次六个参数中有两个参数的缺省值为非None——表示指向数据集文件路径的split的缺省值为"train",以及表示指向数据集文件路径的整型参数max_num_facts的缺省值为10。不过在我看来,这两个参数的用法显然并如同描述的那样。
此外,是这次将用到的全局常量COMMON_CONCEPTS,这是一个纯字符串列表,其中的内容分别是"make","cause","factor","person","need","use","people","part","system"。
FLAGS = flags.FLAGS
# flags.DEFINE_string("concept_vocab_file", "drfact_data/knowledge_corpus/gkb_best.vocab.txt", "Path to dataset file.")
flags.DEFINE_string("linked_qas_file", None, "Path to dataset file.")
flags.DEFINE_string("drfact_format_gkb_file", None, "Path to gkb corpus.")
flags.DEFINE_string("ret_result_file", None, "Path to dataset file.")
flags.DEFINE_string("output_file", None, "Path to dataset file.")
flags.DEFINE_string("split", "train", "Path to dataset file.")
flags.DEFINE_integer("max_num_facts", 10, "Path to dataset file.")
COMMON_CONCEPTS = ["make", "cause", "factor",
"person", "need", "use", "people", "part", "system"]
3.3 主函数main
在本源文件中,只有一个函数,即是主函数main(_)。接下来对主函数进行分析。
在主函数中,首先通过index_corpus中的load_concept_vocab函数将参数concept_vocab_file指向的概念词汇库读取出来,存放于conpet2id中。再打开参数drfact_format_gkb_file指向的gkb格式化后的文件,将其中内容按行拆分后遍历每行,将其内容解析并存于instance,再将instance的id值作键,当前遍历序号和它本身作为值,遍历构造出字典gkb_id_to_id和facts_dict。
def main(_):
"""Main funciton."""
logging.set_verbosity(logging.INFO)
concept2id = index_corpus.load_concept_vocab(FLAGS.concept_vocab_file)
with open(FLAGS.drfact_format_gkb_file) as f:
logging.info("Reading %s..." % f.name)
gkb_id_to_id = {}
facts_dict = {}
cur_fact_ind = 0
for line in f.read().split("\n"):
if line:
instance = json.loads(line)
gkb_id_to_id[instance["id"]] = cur_fact_ind
facts_dict[instance["id"]] = instance
cur_fact_ind += 1
接下来对参数ret_result_file和linked_qas_file指向的文件做拆解整合存储操作,存于ret_data和data。检测二者长度是否一致,不一致则报错。(这里的步骤和add_init_facts.py文件中的基本一样)
with open(FLAGS.ret_result_file) as f:
logging.info("Reading %s..." % f.name)
ret_data = [json.loads(line) for line in f.read().split("\n") if line]
logging.info("Reading QAS(-formatted) data...")
with open(FLAGS.linked_qas_file) as f:
jsonlines = f.read().split("\n")
data = [json.loads(jsonline) for jsonline in jsonlines if jsonline]
assert len(ret_data) == len(data)
然后对data进行遍历,拆解出其中的每个对象ins和它的索引ind,使用与add_init_facts.py文件的相同方法得到所有相关事实字典all_ret_facts,问题概念集合question_concepts和回答概念集合answer_concepts。
new_data = []
num_covered = 0
num_2nd = 0
for ind, ins in tqdm(
enumerate(data),
desc=FLAGS.linked_qas_file, total=len(data)):
all_ret_facts = ret_data[ind]["results"]["all_ret_facts"]
ins["sup_facts"] = []
question_concepts = set([c["kb_id"]
for c in ins["entities"]]) - set(COMMON_CONCEPTS)
# TODO: decomp?
answer_concepts = set([c["kb_id"] for c in ins["all_answer_concepts"]]
) - set(COMMON_CONCEPTS) - question_concepts
之后准备进行第一轮内循环,首先定义is_covered=F,original_rank=0,空集合concept_set和空列表question_only_facts,answer_only_facts。然后执行首轮内循环,对生成的all_ret_facts字典继续遍历,拆分出fid和s,将fid对应在facts_dict字典中的值的mentions值进行遍历,将其中的kb_id值存于fact_concepts集合中,然后将其与问题概念集合和回答概念集合取交,获得包含问题概念的集合contain_question和包含回答的集合contain_answer。
此时进行判断,如果contain_question非空而contain_answer为空,则将当前事实文本id映射,s,当前事实文本打包成为元组装入question_only_facts列表;如果contain_answer非空而contain_question为空则装入answer_only_facts列表;如果都不为空,则装入当前样例的sup_facts值对应的列表中。
如此循环多次之后,如果装满,即超出max_num_facts,则终止循环。
is_covered = False
concept_set = set()
original_rank = 0
question_only_facts = []
answer_only_facts = []
# first round
for fid, s in all_ret_facts:
original_rank += 1
fact = facts_dict[fid]
fact_concepts = set([m["kb_id"] for m in fact["mentions"]])
contain_question = fact_concepts & question_concepts
contain_answer = fact_concepts & answer_concepts
if contain_question and not contain_answer:
question_only_facts.append((gkb_id_to_id[fid], s, fact["context"]))
continue
if contain_answer and not contain_question:
answer_only_facts.append((gkb_id_to_id[fid], s, fact["context"]))
continue
if contain_answer and contain_question:
ins["sup_facts"].append((gkb_id_to_id[fid], s, fact["context"]))
if len(ins["sup_facts"]) >= FLAGS.max_num_facts:
break
concept_set.update(fact_concepts)
结束上述循环之后,开始下一个条件判定:即如果上述循环中获得的结果不超过10个的话,那么对all_ret_facts进行雷同上述的第二轮循环。
if len(answer_only_facts) + len(ins["sup_facts"]) <=10:
# second round
is_covered = False
concept_set = set()
original_rank = 0
question_only_facts = []
answer_only_facts = []
num_2nd += 1
answer_concepts = set([c["kb_id"] for c in ins["all_answer_concepts_decomp"]]
) - set(COMMON_CONCEPTS) - question_concepts
for fid, s in all_ret_facts:
original_rank += 1
fact = facts_dict[fid]
fact_concepts = set([m["kb_id"] for m in fact["mentions"]])
contain_question = fact_concepts & question_concepts
contain_answer = fact_concepts & answer_concepts
if contain_question and not contain_answer:
question_only_facts.append((gkb_id_to_id[fid], s, fact["context"]))
continue
if contain_answer and not contain_question:
answer_only_facts.append((gkb_id_to_id[fid], s, fact["context"]))
continue
if contain_answer and contain_question:
ins["sup_facts"].append((gkb_id_to_id[fid], s, fact["context"]))
if len(ins["sup_facts"]) >= FLAGS.max_num_facts:
break
if len(ins["sup_facts"]) > 1:
is_covered = True
concept_set.update(fact_concepts)
完成上述第二轮循环之后,进行一次条件判断:即如果根据上述二轮循环之后获得了不止一个支持的事实(当前样例的sup_facts值列表长度大于1),那么将is_covered设为T并为当前样例设置新的键值对——answer_only_facts值为刚刚循环获得的answer_only_facts列表,question_only_facts值为刚刚获得的question_only_facts列表,sup_facts_source的值即为参数ret_result_file代表的路径。
之后再根据is_covered判定知道本轮找到结果,num_covered+1,为当前样例添加num_mentioned_concepts值为concept_set长度的键值对。更新new_data。本轮循环结束。
if len(ins["sup_facts"]) > 1:
is_covered = True
ins["answer_only_facts"] = answer_only_facts[:FLAGS.max_num_facts]
ins["question_only_facts"] = question_only_facts[:FLAGS.max_num_facts]
ins["sup_facts_source"] = FLAGS.ret_result_file
# del ins["all_ret_facts"]
if is_covered:
num_covered += 1
ins["num_mentioned_concepts"] = len(concept_set)
new_data.append(ins)
在得到所有结果,之后,将结果new_data遍历存放写入参数output_file指向的输出文档中。
with open(FLAGS.output_file, "w") as f:
logging.info("num_2nd: %d", num_2nd)
logging.info("num_covered: %d", num_covered)
logging.info("len(new_data): %d", len(new_data))
logging.info("Coverage:%.2f", num_covered/len(new_data))
logging.info("Writing to %s", f.name)
f.write("\n".join([json.dumps(i) for i in new_data])+"\n")
logging.info("Done.")
综上,这就是对add_sup_facts.py文件的全部源代码分析了。