✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。
我是Srlua小谢,在这里我会分享我的知识和经验。🎥
希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮
记得先点赞👍后阅读哦~ 👏👏
📘📚 所属专栏:传知代码论文复现
欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙
目录
本文所有资源均可在该地址处获取。
1. 论文概述
预训练语言模型(PLMs)在概念提取中往往依赖于文本中的共现关联,而不是实际的因果关系,导致提取结果存在偏差和低精确度。为了解决这个问题,本文提出了通过知识引导提示来干预PLM的概念提取过程。这个提示利用现有知识图谱中的知识,帮助PLM聚焦于相关概念,减少对虚假共现的依赖,从而提高提取精度。实验结果表明,这种方法有效减少了偏差,显著提升了概念提取的性能。
论文链接:链接
2. 论文方法
本文提出了一个名为KPCE的概念提取(CE)框架,并讨论了如何通过提示来缓解概念偏差。KPCE框架包括两个主要模块:提示构造器和概念提取器。下面详细介绍这两个模块。
本文提出了一个名为KPCE的概念提取(CE)框架,并讨论了如何通过提示来缓解概念偏差。KPCE框架包括两个主要模块:提示构造器和概念提取器。下面详细介绍这两个模块。
2.1 提示构造器
提示构造器使用知识图谱(KGs)中的实体主题作为知识引导提示,旨在减少概念偏差。以下是提示构造的过程:
2.1.1 获取典型概念集:
从CN-DBpedia中随机抽取一百万个实体及其现有概念,选择拥有最多实体的前100个概念构成典型概念集。
2.1.2 聚类典型概念:
使用自适应K-means算法和谱聚类方法将这些典型概念聚成多个组,每组对应一个主题。通过计算Silhouette Coefficient(SC)和Calinski Harabaz Index(CHI)确定最佳聚类数目为17。
2.1.3 构建训练数据:
随机抽取40,000个实体及其摘要文本和现有概念,根据概念聚类结果为每个实体分配主题。
2.1.4 训练主题分类器:
采用Transformer编码器和两层感知器(MLP)组成的主题分类器来预测输入文本的主题提示,分类准确率超过97.8%。
2.2 概念提取器
概念提取器是一个基于BERT的模型,结合了构造的提示,通过指针网络提取多层次概念。以下是提取过程:
2.2.1 输入构造:
将提示和输入文本序列拼接,并通过多头自注意力机制处理。
2.2.2 指针网络:
使用指针网络预测每个token作为概念起始位置和结束位置的概率。通过softmax操作,得到每个token的起始和结束位置的概率向量。
2.2.3 置信度评分:
根据起始和结束位置的概率计算候选概念的置信度评分,并保留置信度大于阈值的概念。
2.2.4 训练损失:
采用交叉熵损失函数,通过Adam优化器进行模型训练,最终优化整体损失函数。
3. 实验部分
3.1 数据集
3.1.1CN-DBpedia:
从最新版本的中文知识图谱CN-DBpedia(Xu等,2017)和维基百科中获取样本池。每个样本由一个实体及其概念和摘要文本组成。然后,从样本池中随机抽取500个样本作为测试集,并按照9:1的比例将其余样本划分为训练集和验证集。
3.1.2 Probase:
从Probase和维基百科中获取英语样本池,共包含50,000个样本。训练集、验证集和测试集的构建方式与中文数据集相同。
3.2 实验步骤
3.2.1 配置环境
Requirements
torch == 1.4.0
transformers == 4.2.0
3.2.1 cd prompt_code
进入到此文件夹
3.2.2 python main.py
进行训练
4.核心代码
class Prompt_MRCModel(nn.Module):
def __init__(self, model_name, tokenizer, max_topic_len=35, max_seq_len = 256):
super(Prompt_MRCModel, self).__init__()
self.model = BertForQuestionAnswering.from_pretrained(model_name)
self.topic_model = BertForQuestionAnswering.from_pretrained(model_name)
self.tokenizer = tokenizer
self.max_topic_len = max_topic_len
def generate_default_inputs(self, batch, topic_embed, device):
input_ids = batch['input_ids']
bz = batch['input_ids'].shape[0]
block_flag = 1
raw_embeds = self.model.bert.embeddings.word_embeddings(input_ids.to(device)).squeeze(1)
topic_embeds = self.topic_model.bert.embeddings.word_embeddings(topic_embed.to(device)).squeeze(1)
input_embeds = torch.cat((topic_embeds,raw_embeds),1)
inputs = {'inputs_embeds': raw_embeds.to(device), 'attention_mask': batch['attention_mask'].squeeze(1).to(device)}
inputs['token_type_ids'] = batch['token_type_ids'].squeeze(1).to(device)
return inputs
def forward(self, inputs_embeds=None, attention_mask=None, token_type_ids=None, labels=None):
return self.model(inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
token_type_ids=token_type_ids)
def mlm_train_step(self, batch, topic_embed, start_positions, end_positions, device):
inputs_prompt = self.generate_default_inputs(batch, topic_embed, device)
bert_out = self.model(**inputs_prompt, start_positions=start_positions, end_positions=end_positions)
return bert_out
希望对你有帮助!加油!
若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!