LEARNING ON LARGE-SCALE TEXT-ATTRIBUTED GRAPHS VIA VARIATIONAL INFERENCE
ICLR23
推荐指数:#paper/⭐⭐#
作者的写作手法感觉就是有点把模型往数学化的方式去写,内容其实就相当于一个LM与GCN互相提供伪标签监督。利用KL散度来优化有点意思
动机
关键挑战
- 传统方法同时训练LLMs和GNNs时,计算复杂度高,难以扩展到大规模图数据。
方法创新:GLEM框架
-
交替训练机制:
- E步:固定GNN参数,GNN预测的伪标签与观察到的标签一起用于LM训练,使其学习与图拓扑一致的文本表示。
- M步:固定LLM参数,LM提供文本嵌入和伪标签给GNN,提升结构建模能力。
-
模块解耦:避免同时训练两个大模型,显著降低计算开销。
-
互增强机制:LLM和GNN通过交替更新互相提供监督信号(伪标签),逐步提升整体性能。
模型细节
伪似然变分框架
这是一个基于伪似然变分框架的方法,该方法用于模型设计,提供了一种原则性和灵活性的公式化方式。具体来说,这个方法的目标是最大化观测到的节点标签的对数似然函数,即 p(yL∣sV,A)p(y_L | s_V, A)p(yL∣sV,A),其中 yLy_LyL 是已标记节点的标签集合,sVs_VsV 是所有节点的文本特征,AAA 是图的邻接矩阵。
直接优化这个函数通常是困难的,因为存在未观测到的节点标签 yUy_UyU。为了解决这个问题,该框架不直接优化对数似然函数,而是优化一个称为证据下界(Evidence Lower Bound,ELBO)的量。ELBO 的表达式如下:
logp(yL∣sV,A)≥Eq(yU∣sU)[logp(yL,yU∣sV,A)−logq(yU∣sU)]\\log p(y_L | s_V, A) \geq \mathbb{E}_{q(y_U | s_U)}[\log p(y_L, y_U | s_V, A) - \log q(y_U | s_U)]logp(yL∣sV,A)≥Eq(yU∣sU)[logp(yL,yU∣sV,A)−logq(yU∣sU)]
这里,q(yU∣sU)q(y_U | s_U)q(yU∣sU)是一个变分分布,上述不等式对任何 qqq都成立。ELBO 可以通过交替优化分布 qqq(即 E-step)和分布 ppp(即 M-step)来优化。
- E-step(期望步) :优化变分分布 qqq,目的是最小化 qqq 和后验分布 p(yU∣sV,A,yL)p(y_U | s_V, A, y_L)p(yU∣sV,A,yL) 之间的 Kullback-Leibler (KL) 散度,从而收紧上述下界。
- M-step(最大化步) :优化目标分布 ( p ),以最大化伪似然函数:
Eq(yU∣sU)[logp(yL,yU∣sV,A)]≈Eq(yU∣sU)[∑n∈Vlogp(yn∣sV,A,yV∖n)]\mathbb{E}_{q(y_U | s_U)}[\log p(y_{L}, y_{U} | s_{V}, A)] \approx \mathbb{E}_{q(y_U | s_U)}\left[\sum_{n \in V}\log p(y_{n} | s_{V}, A, y_{V \setminus n})\right]Eq(yU∣sU)[logp(yL,yU∣sV,A)]≈Eq(yU∣sU)[∑n∈Vlogp(yn∣sV,A,yV∖n)]
这个过程通过交替执行 E-step 和 M-step 来实现,直到收敛。这种方法允许模型在不需要直接处理未观测标签的情况下,有效地学习节点表示。简而言之,这种方法通过优化一个下界来间接优化对数似然函数,这个下界可以通过交替优化两个分布来逐步提高,从而使得模型能够更好地处理未观测数据,并提高学习效果。
具体的两个模型的介绍
在这部分内容中,文章详细阐述了GLEM方法中使用的两种分布——qqq 和ppp的参数化过程,以及它们如何用于节点标签分布的建模和优化。
分布 qqq的参数化
分布 qqq的目标是利用文本信息 sUs_UsU 来定义节点标签分布,这相当于一个语言模型(LM)。在 GLEM 中,采用均场形式(mean-field form),假设不同节点的标签是独立的,每个节点的标签只依赖于它自己的文本信息。这导致了以下形式的分解:
qθ(yU∣sU)=∏n∈Uqθ(yn∣sn).q_{\theta}(y_{U}|s_{U}) = \prod_{n \in U}q_{\theta}(y_{n}|s_{n}).qθ(yU∣sU)=∏n∈Uqθ(yn∣sn).这里,θ\thetaθ是模型参数**UUU是未标记节点的集合。在这里,每个项qθ(yn∣sn)q_{\theta}(y_{n}|s_{n})qθ(yn∣sn)可以通过基于 Transformer 的语言模型qθq_{\theta}qθ来建模,该模型通过注意力机制有效地模拟细粒度的标记交互。**
分布ppp的参数化
分布 ppp定义了一个条件分布 pϕ(yn∣sV,A,yV\n)p_{\phi}(y_{n}|s_{V}, A, y_{V\backslash n})pϕ(yn∣sV,A,yV\n),旨在利用节点特征sVs_{V}sV、图结构AAA和其他节点标签yV\ny_{V\backslash n}yV\n来表征每个节点的标签分布。因此,pϕ(yn∣sV,A,yV\n)p_{\phi}(y_{n}|s_{V}, A, y_{V\backslash n})pϕ(yn∣sV,A,yV\n)被建模为一个由 ϕ\phiϕ参数化的 GNNpϕGNN p_{\phi}GNNpϕ,以有效地模拟节点间的结构交互。
由于节点文本sVs_VsV是离散变量,不能直接被 GNN 使用,因此在实践中,我们首先使用 LMqθLM q_{\theta}LMqθ 对节点文本进行编码,然后使用获得的嵌入作为GNNpϕGNN p_{}\phiGNNpϕ的节点文本的替代。
E-STEP: LM OPTIMIZATION
- 目标:在E-step中,固定GNN,更新LM以最大化证据下界。这样做的目的是将不同节点之间的全局语义相关性提取到LM中。最大化关于 LM 的证据下限等同于最小化后验分布和变分分布之间的 KL 散度
- 优化方法:直接优化KL散度是困难的,因为KL散度依赖于难以处理的变分分布的熵。为了克服这个挑战,作者采用了wake-sleep算法来最小化反向KL散度,从而得到一个更易于处理的目标函数。
- 目标函数:目标函数是关于LMqθq_\thetaqθ的,目的是最大化这个函数。这个函数的形式是:(KL的前面,第一项相当于GNN,第二项相当于LM)
−KL(pϕ(yU∣sV,A,yL)∣∣qθ(yU∣sU))=∑n∈UEpϕ(yn∣sV,A,yL)[logqθ(yn∣sn)]+const-KL(p_\phi(y_U|s_V, A, y_L)||q_\theta(y_U|s_U)) = \sum_{n \in U} \mathbb{E}_{p_\phi(y_n|s_V, A, y_L)}[\log q_\theta(y_n|s_n)] + \text{const}−KL(pϕ(yU∣sV,A,yL)∣∣qθ(yU∣sU))=∑n∈UEpϕ(yn∣sV,A,yL)[logqθ(yn∣sn)]+const
这个目标函数更容易处理,因为我们不再需要考虑qθ(yU∣sU)q_\theta(y_U|s_U)qθ(yU∣sU)的熵。 - 分布计算:唯一的困难在于计算分布pϕ(yn∣sV,A,yL)p_\phi(y_n|s_V, A, y_L)pϕ(yn∣sV,A,yL)。在原始GNN中,我们基于周围节点标签 yV\ny_{V\backslash n}yV\n来预测节点 ( n ) 的标签分布。然而,在pϕ(yn∣sV,A,yL)p_\phi(y_n|s_V, A, y_L)pϕ(yn∣sV,A,yL)中,我们只基于观察到的节点标签 yLy_LyL,其他节点的标签是未指定的,因此我们不能直接用GNN计算这个分布。
- 解决方案:为了解决这个问题,作者提出用LM预测的伪标签来标注图中所有未标记的节点,从而可以近似分布:
pϕ(yn∣sV,A,yL)≈pϕ(yn∣sV,A,yL,y^U\n)p_\phi(y_n|s_V, A, y_L) \approx p_\phi(y_n|s_V, A, y_L, \hat{y}_{U\backslash n})pϕ(yn∣sV,A,yL)≈pϕ(yn∣sV,A,yL,y^U\n)
其中 y^U\n\hat{y}_{U\backslash n}y^U\n是未标记节点的伪标签集合。 - 最终目标函数:结合上述目标函数和标记节点,得到训练LM的最终目标函数:
O(q)=α∑n∈UEp(yn∣sV,A,yL,y^U\n)[logq(yn∣sn)]+(1−α)∑n∈Llogq(yn∣sn)\mathcal{O}(q) = \alpha \sum_{n \in U} \mathbb{E}_{p(y_n|s_V, A, y_L, \hat{y}_{U\backslash n})}[\log q(y_n|s_n)] + (1-\alpha) \sum_{n \in L} \log q(y_n|s_n)O(q)=α∑n∈UEp(yn∣sV,A,yL,y^U\n)[logq(yn∣sn)]+(1−α)∑n∈Llogq(yn∣sn)
其中α\alphaα是一个超参数。直观上,第二项是一个监督目标,使用给定的标记节点进行训练。同时,第一项可以看作是一个知识蒸馏过程,通过强制LM基于邻域文本信息预测标签分布来训练LM
M-STEP: GNN OPTIMIZATION
目标:在GNN阶段,目标是固定语言模型 qθq_\thetaqθ并优化图神经网络pϕp_\phipϕ以最大化伪似然(pseudo-likelihood)。
-
方法:
- 使用语言模型为所有节点生成节点表示 hVh_VhV,并将这些表示作为文本特征输入到图神经网络中进行消息传递。
- 利用语言模型qθq_\thetaqθ为每个未标记节点n∈Un \in Un∈U预测一个伪标签y^n\hat{y}_ny^n并将所有伪标签{y^n}n∈U\{\hat{y}_n\}_{n \in U}{y^n}n∈U组合成y^U\hat{y}_Uy^U。
-
伪似然重写:结合节点表示和LMqθq_\thetaqθ的伪标签,伪似然可以重写为:
O(ϕ)=β∑n∈Ulogpϕ(y^n∣sV,A,yL,y^U∖n)+(1−β)∑n∈Llogpϕ(yn∣sV,A,yL∖n,y^U)\mathcal{O}(\phi) = \beta \sum_{n \in U}\log p_{\phi}(\hat{y}_{n} | s_{V}, A, y_{L} , \hat{y}_{U \setminus n}) + (1 - \beta) \sum_{n \in L}\log p_{\phi}(y_{n} | s_{V} , A, y_{L \setminus n}, \hat{y}_{U})O(ϕ)=β∑n∈Ulogpϕ(y^n∣sV,A,yL,y^U∖n)+(1−β)∑n∈Llogpϕ(yn∣sV,A,yL∖n,y^U)
其中,β\betaβ是一个超参数,用于平衡两个项的权重。第一项:可以看作是一个知识蒸馏过程,通过所有伪标签将LM捕获的知识注入到GNN中。第二项:是一个监督损失,使用观察到的节点标签进行模型训练。
一旦训练完成,E-step中的LM(记为GLEM-LM)和M-step中的GNN(记为GLEM-GNN)都可以用来进行节点标签预测。
实验
-
结果:
- GLEM-GNN:在所有三个数据集上取得了新的最佳性能,证明了其在节点分类任务中的有效性。
- GLEM-LM:通过结合图结构信息,显著提升了语言模型的性能。
- 可扩展性:GLEM 能够适应大型语言模型(如 DeBERTa-large),并且在效率和性能之间取得了良好的平衡。