消息传递图神经网络
GNN
GNN是一个邻居聚合策略,一个节点的表示向量,由它的邻居节点通过循环的聚合和转移表示向量计算得来。我们来想象人类学习知识的过程,在自身具有一定知识的基础上,我们会想要从周围的伙伴那里学习到更多的知识,然后将伙伴给予的信息与自身已有的知识组合起来,更新并获得更高阶的知识,这个过程就是一个消息传递过程。
MPNN
- 消息传递阶段 :共运行T个时间步,并包含以下两个子函数:
Aggregation Function:也称消息函数,作用是聚合邻居节点的特征,形成一个消息向量,准备传递给中心节点。
Combination Function:也称节点更新函数,作用是更新当前时刻的节点表示,组合当前时刻节点的表示以及从Aggregation Function中获得的消息。
2. 读出阶段:针对于图级别的任务(如图分类)仅仅获得节点级别的表示是不够的,需要通过读出函数(Readout Function)聚合节点级别的表示,计算获得图级别的表示。
因此,不同种类的GNN就是通过设计不同的Aggregation Function,Combination Function以及Readout Function实现的。也就是说,从邻居聚合的信息不同,邻居信息与自身信息的组合方式不同,或者节点的读出方式不同,最终都会导致我们最终获得的表示向量不同。不同的任务为了获得更准确的预测效果,会着重于抓取不同的信息
MessagePassing基类的运行流程。
基类MessagePassing提供参数agg来实现聚合函数,基类的两个成员函数message和update分别对应
ϕ
(
k
)
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
,
e
j
,
i
)
和
γ
(
k
)
(
x
i
(
k
−
1
)
,
…
)
\phi^{(k)}\left(\mathbf{x}{i}^{(k-1)}, \mathbf{x}{j}^{(k-1)}, \mathbf{e}{j, i}\right) 和 \gamma^{(k)}\left(\mathbf{x}{i}^{(k-1)},…\right)
ϕ(k)(xi(k−1),xj(k−1),ej,i)和γ(k)(xi(k−1),…)的实现,用户继承MessagePassing并实现具体的message和update函数以实现自定义的GNN,最后通过在forward方法中调用propagate方法驱动整个计算流程的进行