【论文精读】RELIEF: Reinforcement Learning Empowered Graph Feature Prompt Tuning

前言

一篇图prompt的前沿工作,利用强化学习的方法来探索图中需要添加prompt的节点以及prompt的规模,实现了对必要节点添加轻量prompt的过程,在多个下游任务上取得了SOTA的效果。文章思路清晰,逻辑严谨,深入浅出,实验丰富,是不可多得的值得深入学习的工作。
Paper https://arxiv.org/pdf/2408.03195
Code https://github.com/JasonZhujp

Abstract

“pre-train, prompt”的范式最近在图表征领域展现了其泛化性和数据高效性。一开始的图prompt tuning方法为GNN特定的训练策略设定,限制了其应用性,因此,通用的prompt方法通过直接将prompt输入到图的表征空间,去除了对预训练策略的依赖从而受到欢迎。然而,如何加以及加多少prompt是当前领域所存在的问题,受到NLP中充分预训练的模型处理下游任务时需要更少条件信号的启发,本文主张将必要且轻量的prompt策略性地加入到某些节点中,以增强下游任务的性能。这涉及到一个组合优化的问题, 需要往哪个节点上加prompt,以及具体要加多少。为此作者提出了RELIEF,利用RL的方法来解决这些问题。在每一步中,RL代理选择一个节点并确定prompt,旨在最大化累计性能增益。在小样本场景中通过和各种预训练策略方法的实验表明,RELIEF在分类性能和数据效率方面优于微调和其它基于prompt的方法。

Motivation

GNNs在知识图谱,社交媒体,推荐系统都有广泛应用,为了增强模型的泛化能力,很多工作都投身于预训练的GNN模型。但是这种基于预训练和微调的范式有如下问题:

  1. 上下游任务不一致,导致负迁移。
  2. 小样本场景模型泛化能力不够。

借鉴NLP领域Prompt学习取得的巨大成功,现有工作将“pre-train, prompt”范式扩展到图领域。现有方法可以分为两类:

  • 依赖预训练策略。但是在多任务、多自监督技术的预训练下,Prompt方法可能会失败。
  • 预训练策略无关。兼容性强,通用且高效。

但是现有预训练无关的工作对Prompt learning没有深入思考,在NLP中,强大的模型只需要合适的条件信号就能够符合下游任务的要求,这对于遵循消息传递机制的GNN模型来说更是如此。

Solution

作者推测对于一个充分预训练的GNN模型,补充合适的条件信号就足够应用于下游任务了,对每个节点都加Prompt反而会导致过拟合。因此,策略性地将必要且轻量的prompt加入到原始图中,可以让GNN释放最大的预训练能力,从而泛化到各种下游任务。

选择什么节点、加入多少prompt是一个组合优化的问题,为此,作者采用可以高效搜索的RL方法,并提出基于RL增强的图表征prompt方法,名为RELIEF。作者将注入prompt的过程建模为序列决策问题,整个过程如下:

  1. RL代理选择需要进行prompt的节点。
  2. RL代理决定prompt的内容。
  3. prompt后的图输入预训练好的GNN中进行评估。
  4. 接着,RL代理生成的新prompt加入先前的图中,如此反复直到最大步数。

方法的目标是最大化下游任务预期的累积性能提升,此外还加入策略泛化的技术来保证训练的稳定性和高效性。作者还设计了两个metrics:prompt coverage ratio 和 average prompt magnitude,来量化prompt对原始输入的影响力。

RELIEF

Incorporating Feature Prompts as MDP

在强化学习领域中,环境通常用一个MDP建模。本方法将prompt构建成MDP,设计细节如下。

Action Space

给定具有n个节点的图 G \mathcal{G} G,一个离散的动作 a a a用于从节点集合 { v 1 , … , v n } \{v_1, \dots, v_n\} { v1,,vn}中挑选 v a v_a va,一个连续的动作 z ∈ R 1 × D z \in \mathbb{R}^{1 \times D} zR1×D用于决定赋予节点 v a v_a va的值向量。因此, t t t时刻的prompt(混合动作)可以表示为 ( a t , z t ) = p t a , z (a_t, z_t) = p^{a,z}_t (at,zt)=pta,z,进一步,prompt矩阵可以定义为 P = { p 1 , … , p n } ∈ R n × D \mathbf{P} = \{p_1, \dots, p_n\} \in \mathbb{R}^{n \times D} P={ p1,,pn}Rn×D,对于经过prompt后的图,其prompted特征 X ∗ \mathbf{X}^\ast X通过 X ∗ = X + P \mathbf{X}^\ast = \mathbf{X} + \mathbf{P} X=X+P更新。

State Transition

状态空间被定义为经过预训练GNN后图的节点表征, t t t时刻的状态被表示为:

s t : = f θ ( G t − 1 ∗ ) = f θ ( X t − 1 ∗ ,   A ) = f θ ( X + P t − 1 ,   A ) = { h 1 , t − 1 ∗ , … , h n , t − 1 ∗ } ∈ R n × d \begin{aligned} s_{t} & :=f_{\theta}\left(\mathcal{G}_{t-1}^{*}\right)=f_{\theta}\left(\mathrm{X}_{t-1}^{*}, \mathrm{~A}\right)=f_{\theta}\left(\mathrm{X}+\mathrm{P}_{t-1}, \mathrm{~A}\right) \\ & =\left\{h_{1, t-1}^{*}, \ldots, h_{n, t-1}^{*}\right\} \in \mathbb{R}^{n \times d} \end{aligned} st:=fθ(Gt1)=fθ(Xt1, A)=fθ(X+Pt1, A)={ h1,t1,,hn,t1}Rn×d

其中 h i , t − 1 ∗ h^*_{i,t-1} hi,t1是图 G t − 1 ∗ \mathcal{G}^*_{t-1} Gt1中节点 v i v_i vi的表征, d d d是表征的维度。当前的状态基于先前的步骤。为了解决不同图包含不同数量节点影响batch训练的情况,作者设置了一个最大节点数量 N N N,节点不足用零向量填充。因此, t t t时刻的状态可以表示为:

s t : = f θ ( X t − 1 ∗ ,   A ) ∥ 0 ( N − n ) × d = { h 1 , t − 1 ∗ , … , h n , t − 1 ∗ , 0 n + 1 , … , 0 N } ∈ R N × d \begin{aligned} s_t & :=f_\theta\left(\mathrm{X}_{t-1}^*, \mathrm{~A}\right) \| \mathbf{0}_{(N-n) \times d} \\ & =\left\{h_{1, t-1}^*, \ldots, h_{n, t-1}^*, 0_{n+1}, \ldots, 0_N\right\} \in \mathbb{R}^{N \times d} \end{aligned} st:=fθ(Xt1, A)0(N

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

HERODING77

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值