transH算法实现知识图谱补全实验

本文介绍TransH算法在知识图谱补全任务中的应用,并详细解释其实现细节。实验使用PyTorch工具,基于FB15k数据集进行训练,通过引入关系映射超平面解决多对多关系问题。最后评估模型效果,给出Meanrank与hit@10值。

transH算法实现知识图谱补全实验

1. 目的

使用transH算法进行知识图谱补全实验

2. 数据集

本次实验采用freebase数据集的FB15k, 该数据集共有entity2id.txt,relation2id.txt,test.txt,train.txt和valid.txt共五个文件。实验过程中,训练时主要采用entity2id.txt,relation2id.txt,train.txt三个文件,测试集使用test.txt。

3. 方法

本次实验主要采用transH模型进行知识图谱补全实验,使用pytorch工具辅助算法实现。

  1. transH算法原理

TransH 模型在 TransE 的基础上为每个关系多学一个映射向量, 具体思路是将三元组中的关系(relation),抽象成一个向量空间中的超平面(Hyperplane),每次都是将头结点或者尾节点映射到这个超平面上,再通过超平面上的平移向量计算头尾节点的差值。

这样做一定程度上缓解了transE模型不能很好地处理一对多,多对一等关系属性的问题.

  1. 具体算法实现
  2. 将头节点h和尾节点t映射到超平面上,计算三元组的差值

Wr是超平面的法向量,dr是超平面上的平移向量

  1. 将头节点h和尾节点t映射到超平面上,计算三元组的差值
  2. 计算损失函数

其中[ x ]+ 看做 max(0, x),y为margin值用于区分正例与负例。

  1. 损失函数通过随机梯度下降法进行训练

  2. 代码实现过程:(见代码)

数据集训练:

  1. 数据集加载,得到实体集,关系集和三元组集

  2. 数据预处理,将实体集,关系集和三元组集初始化为向量,计算每个关系中每个头结点平均对应的尾节点数,以及每个尾结点平均对应的头节点数

  3. 初始化transH所需参数,包括向量维度,以及损失函数所需的各种参数

  4. 使用torch.Tensor()方法初始化实体向量,关系向量dr和关系超平面法向量Wr

  5. 开始分批训练,分成100批数据

  6. 计算tph/(tph+hpt),决定负例随机替换掉头节点还是尾节点,由此获得负例集

  7. 计算损失函数,使用随机梯度下降法调整向量来最小化损失函数

  8. 更新实体集和关系集中的向量

  9. 从第5步开始,重复训练100次,不断降低损失函数的值

  10. 得到归一化向量的实体集和关系集

4. 指标

测试验证:

  1. 求Mean rank值

​ 将每个testing_triple的尾节点用实体集中每一个实体代替,计算f函数,将得到的结果升序排列,将所有testing triple的排名做平均得到Mean rank

  1. 求hit@10值

按照上述进行f函数值排列,计算测试集中testing triple正确答案排在序列的前十的个数,除以总个数得到hit@10的值

5. 结论

Trans模型是是知识图谱补全算法中经典的算法,TransE模型最为经典但是无法很好解决一对多,多对一的问题,TransH和TransE算法类似,但增加了关系映射超平面,一定程度上缓解了不能很好地处理多映射属性关系的问题。

代码:

transH_torch.py

import torch
import torch.optim as optim
import torch.nn.functional as F

import codecs
import numpy as np
import copy
import time
import random

entity2id = {
   
   }
relation2id = {
   
   }
relation_tph = {
   
   }   #关系每个头结点平均对应的尾节点数
relation_hpt = {
   
   }   #关系每个尾结点平均对应的头节点数

'''
数据加载
entity2id: {entity1:id1,entity2:id2}
relation2id: {relation1:id1,relation2:id2}
'''

def data_loader(file):
    print("load file...")
    file1 = file + "train.txt"
    file2 = file + "entity2id.txt"
    file3 = file + "relation2id.txt"

    with open(file2, 'r') as f1, open(file3, 'r') as f2:
        lines1 = f1.readlines()
        lines2 = f2.readlines()
        for line in lines1:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            entity2id[line[0]] = line[1]

        for line in lines2:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            relation2id[line[0]] = line[1]

    entity_set = set()      #训练集中的所有实体
    relation_set = set()    #训练集中的所有关系
    triple_list = []        #训练集中的所有三元组
    relation_head = {
   
   }      #训练集中的关系的所有头部和头部数量,格式:{r_:{head1:count1,head2:count2}}
    relation_tail = {
   
   }      #训练集中的关系的所有尾部和尾部数量,格式:{r_:{tail1:count1,tail2:count2}}

    with codecs.open(file1, 'r') as f:
        content = f.readlines()
        for line in content:
            triple = line.strip().split("\t")
            if len(triple) != 3:
                continue

            h_ = entity2id[triple[0]]
            t_ = entity2id[triple[1]]
            r_ = relation2id[triple[2]]

            triple_list.append([h_, t_, r_])

            entity_set.add(h_)
            entity_set.add(t_)

            relation_set.add(r_)
            if r_ in relation_head:
                if h_ in relation_head[r_]:
                    relation_head[r_][h_] += 1
                else:
                    relation_head[r_][h_] = 1
            else:
                relation_head[r_] = {
   
   }
                relation_head[r_][h_] = 1

            if r_ in relation_tail:
                if t_ in relation_tail[r_]:
                    relation_tail[r_][t_] += 1
                else:
                    relation_tail[r_][t_] = 1
            else:
                relation_tail[r_] = {
   
   }
                relation_tail[r_][t_] = 1
#计算关系中个头结点平均对应的尾节点数
    for r_ in relation_head:
        sum1, sum2 = 0, 0
        for head in relation_head[r_]:
            sum1 += 1
            sum2 += relation_head[r_][head]
        tph = sum2/sum1
        relation_tph[r_] = tph
#计算关系每个尾结点平均对应的头节点数
    for r_ in relation_tail:
        sum1, sum2 = 0, 0
        for tail in relation_tail[r_]:
            sum1 += 1
            sum2 += relation_tail[r_][tail]
        hpt = sum2/sum1
        relation_hpt[r_] = hpt

    print("Complete load. entity : %d , relation : %d , triple : %d" % (
        len(entity_set), len(relation_set), len(triple_list)))

    return entity_set, relation_set, triple_list


class TransH:
    def __init__(self, entity_set, relation_set, triple_list, embedding_dim
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值