上下文学习是指预训练模型在不优化任何参数的情况下,通过提示-prompt examples 来适应新的、多样化的下游任务。虽然LLM已经证明了这种能力,但如何在图上执行上下文学习尚未探索。在本文中,作者开发了预训练多元上下文图系统(PRODIGY,Pretraining Over Diverse In-Context Graph Systems),这是第一个能够在图上进行上下文学习的预训练框架。该框架的关键思想是用一种新颖的提示图表示来制定图的上下文学习,它将提示示例和查询联系起来。
并且,作者提出了一个基于提示图和相应上下文预训练的图神经网络架构。使用PRODIGY,预训练模型可以通过上下文学习直接在未见过的图上执行新的下游任务。在所有设置中,该方法比具有硬编码的对比预训练基线的上下文学习准确性平均高出18%。此外,在上下文学习中,它还比使用有限数据的标准微调平均高出33%。
来自:PRODIGY: Enabling In-context Learning Over Graphs, NeurIPS, 2023
背景概述
In-context learning是语言模型的一种新颖且最有趣的功能。它指的是预训练模型在预测时只需要几个提示示例就可以直接执行新颖和多样化任务,而无需更新模型权重。例如,一个人可以使用自然语言描述新任务,并通过几个提示示例向语言模型演示。然后,语言模型直接执行任务,而不需要任何模型训练或微调。
然而,如何在不同的图学习任务中实现上下文学习,例如识别社交网络中的虚假信息传播者和在线电子商务网站上的产品推荐,仍然没有被探索和挑战。图的上下文学习器应该能够解决新图上的新任务。例如,在训练了亚马逊购买图之后,能够在Spotify上给出音乐产品推荐。
这里的第一个挑战是如何使用统一的任务表示方式来描述图上的节点级、边缘级和图级任务,从而允许模型在不需要重新参数调整的情况下解决不同的任务。换句话说,关键的挑战是:图机器学习任务的自然语言提示的模拟是什么?第二个挑战是如何设计模型架构和预训练目标,使模型能够在统一的任务表示中实现跨不同任务和不同图的上下文学习能力。现有的图预训练方法只是为了学习一个好的图编码器,还需要进行微调以适应不同的任务,而现有的图元学习方法只是为了在同一图内的不同任务之间进行泛化。
在这里,作者提出了解决图上分类任务的这两个挑战的一般方法:(1)prompt graph,一个上下文中的图任务表示;(2)PRODIGY,一个在prompt graph上预训练的上下文学习框架。
作者提出提示图-prompt graph(图1),以提供统一的方式来表示不同的节点级、边缘级和图级任务。提示图首先将我们进行预测的输入节点/边(包括提示示例和查询)上下文化,然后将它们与额外的标签节点连接起来,这样提示示例与查询相互连接。这样一个统一的表示允许我们为同一个模型指定不同的图任务。
- 图1
- A.给定源图 G G G,我们提供提示示例 S S S,该示例由输入头-尾节点及其标签以及查询组成。
- B.对于提示示例和查询中的每个数据点,我们首先通过从源图 G G G中检索上下文来构建Data Graph G D G^D GD。
- C.然后我们创建一个Task Graph来捕获每个数据点和每个标签之间的连接,其中包括每个数据点的数据节点 v x v_x vx和 Y Y Y中的每个标签的标签节点 v y v_y vy。每对数据和标签节点都与对应于其二进制标签的边属性相连。
然后PRODIGY设计模型架构和预训练目标,使用提示图,模型就可以在广泛的任务和图中进行预训练,并且可以继续开箱即用。作者设计了一个图架构,利用图神经网络来学习节点/边缘表示,并设计了一个注意力机制来通过提示图进行通信。此外,作者提出了一组基于提示图的上下文预训练目标。特别地,这包括一种新的自监督预训练任务,邻居匹配----分类节点或边属于哪个邻域。
作者使用PRODIGY框架在引文网络(MAG240M)和知识图谱(Wiki)上进行预训练。然后,证明了这种模型(没有任何再训练)在上下文论文类别分类和知识图补全任务上提供了很强的性能,并且它从未在新图上训练过(arXiv, ConceptNet, FB15K-237, NELL)。
图上下文学习引理
在这项工作中,特别关注具有少量提示的图上节点和边缘分类任务的上下文学习,这是最标准和最重要的图任务形式。在本节中,将介绍图上的具体分类任务。
图上的分类
定义一个图 G = ( V , E , R ) G=(V,E,R) G=(V,E,R),其中 V , E , R V,E,R V,E,R分别代表节点集,边集和关系集。一条边 e = ( u , r , v ) ∈ E e=(u,r,v)\in E e=(u,r,v)∈E由 u ∈ V , r ∈ R , v ∈ V u\in V,r\in R,v\in V u∈V,r∈R,v∈V组成。给定一个类别集合 Y Y Y,一个标准的分类任务是预测每个输入 x ∈ X x\in X x∈X的标签 y ∈ Y y\in Y y∈Y。节点分类任务与之类似,但每个输入是图 G G G中的一个单独节点,即 X = V X=V X=V。类似地,边级别的分类任务是预测由任意一对节点形成的潜在边的最佳标签,输入空间为 X = V × V X=V\times V X=V×V。
在统一的表述中,输入应该由图组成: x i ∈ X x_{i}\in X xi∈X, x i = ( V i