知识表示学习 (一) —— Point-Wise Space之2
文章目录
上一篇博客 知识表示学习 (一) —— Point-Wise Space之1中以TransE以及其变体和衍生为主的 点向量平移模型,这篇博客中对其中的一部分模型进行简单的实现。
由于在进行评估时,需要处理mean rank,这个参数需要对所有的实体数量进行筛查和排序。因此我在实验中使用的数据集是Freebase知识图谱的子集FB15k,包含14951个实体,1345种关系,总共483142个三元组。FB15k中所包含的实体数是常用的数据集中最少的。
一、数据处理
数据处理在pre_process_data.py中完成,其主要作用:是将公开数据集中的三元组进行整理和编号,每一个entity或者relation对应一个唯一的index;然后将
- 数据转化:由于公开数据集中的三元组均是字符串的形式,因此需要将数据集中的三元组进行整理和编号,每一个entity或者relation对应一个唯一的id,把对应的字符串映射到相应的数字index上;
- 负样本生成:在TransE等模型中,损失函数的计算中采用了负采样的方式构建负样例,因此需要对每一个三元组构建一个负样例
- DataSet生成:为后面的DataLoader做准备,提供每一个batch所得到的数据内容
- DataLoader生成:对之前处理好的数据进行分批,为后续的训练做准备
在load_data函数中完成数据转化的工作,将训练集的三元组转化为对应的id,并保存映射关系,为后面测试集的转化提供参照。并记录了entity、relation的数量等统计信息。
def load_data(self):
file_pathname = self.filepath + self.filename
train_df = pd.read_csv(filepath_or_buffer=file_pathname,
sep='\t',
header=None,
names=['head', 'relation', 'tail'],
keep_default_na=False,
encoding='utf-8')
train_df = train_df.applymap(lambda x: x.strip()) # 每一个单元格进行切分
# 统计每一类的数量 dict存储
head_count = Counter(train_df['head'])
tail_count = Counter(train_df['tail'])
relation_count = Counter(train_df['relation'])
# 记录entity和relation的key
entity_list = list((head_count + tail_count).keys())
relation_list = list(relation_count.keys())
# 构造entity和relation的(key,index)结构
entity_dict = dict([(word, idx) for idx, word in enumerate(entity_list)])
relation_dict = dict([(word, idx) for idx, word in enumerate(relation_list)])
# 将df中的key转化为index
train_df['head'] = train_df['head'].apply(lambda cell_key: entity_dict[cell_key])
train_df['tail'] = train_df['tail'].apply(lambda cell_key: entity_dict[cell_key])
train_df['relation'] = train_df['relation'].apply(lambda cell_key: relation_dict[cell_key])
return train_df.values, entity_dict, relation_dict
在generate_neg函数中生成负样本,对每一个三元组,随机对其头或尾结点进行修改,并且判断其合理性。
def generate_neg(self):
neg_candidates, i = [], 0
neg_data = []
population = list(range(self.entity_num))
for idx, triple in enumerate(self.pos_data):
while True:
if i == len(neg_candidates):
i = 0
neg_candidates = random.choices(population=population, k=int(1e4))
neg, i = neg_candidates[i], i + 1
if random.randint(0, 1) == 0:
# replace head
if neg not in self.related_dict[triple[2]]:
neg_data.append([neg, triple[1], triple[2]])
break