GNN中的Message Passing方法解析
一、GNN中是如何实现特征学习的?
深度学习方法的兴起是从计算图像处理(Computer Vision)领域开始的。以卷积神经网络(CNN)为代表的方法会从邻近的像素中获取信息。这种方式对于结构化数据(structured data)十分有效,例如,图像和体素数据。但是,CNN的处理方式对于类似图(graph)数据则并不适用。对于一个图而言,类似图像像素的邻近关系是不存在的。考虑图中任意一个节点,我们几乎不可能始终找到8个有序邻近节点,因此类CNN卷积核将不再适用非结构数据(unstructured data)特征提取,这种非结构数据包括图(graph)和点集(set),例如点云数据(point cloud)和场景图(scene graph)。
那么对于图来说,GNN是如何学习到有效的特征呢?这就是这篇博文将要讨论的十分重要的方法——Message Passing。如果知道Message Passing准确的中文翻译的小伙伴可以给我留言,这篇博文中我将用缩写MP代表Message Passing。
大致来说,GNN与CNN在思想上是相通的。GNN也希望能够从“邻接”的元素中获取信息,并学习到固定的且可以固化在向量中的特征。因此,在字面意义上来说MP的思想就是,节点与节点之间可以相互传播(pass)信息(message),通过对于邻接节点信息的处理和原本节点特征的融合,来更新每一个点的特征,最后得到节点级别(node-level)或者图级别(graph-level)的序列化模式表达,这种表达可以完成节点分类,图分类,图生成等十分基本的任务。
二、如何用数学定义Message Passing过程?
小伙伴们不用害怕复杂的数学公式,只需要跟着思路走就可以,我会使用可以直观理解的变量命名方式的。首先,先在脑海中构建一个抽象的图结构G\mathcal{G}G,这个图的节点集合为V\mathcal{V}V,边的集合为E\mathcal{E}E。重点讨论的变量我列在了下面:
u,u∈Vu, u\in \mathcal{V}u,u∈V:我们重点讨论节点,可以是图中的任意一个节点。
{v∣v∈N(u)}\{v| v \in N(u)\}{v∣v∈N(u)}:节点uuu的邻接节点集合,与uuu有有效连接的节点的集合。
xux_{u}xu:节点uuu的初始特征向量,如果是节点vvv的特征则标记为xvx_vxv。
huh_{u}hu:节点uuu的隐层特征向量,如果是节点vvv的隐层特征则标记为hvh_vhv。若GNN中包含许多中间层,那么使用hulh^{l}_{u}hul来表述第lll层的节点uuu的隐层特征。
MP的过程可以由下述一个公式概括,这里借用文献[1]中的表述:
hul=Fcombine(hul−1,Faggregate(hvl−1))h^{l}_{u}=F_{combine}(h^{l-1}_u, F_{aggregate}(h^{l-1}_{v}))hul=Fcombine(hul−1,Faggregate(hvl−1))
简单来说Faggregate(⋅)F_{aggregate}(·)Faggregate(⋅)的作用就是从uuu的邻接节点中汇集(aggregate)特征并加以变换,得到节点uuu在本次得到的信息,因此,也可以写作:
mul=Faggregate(hvl−1)m^l_u=F_{aggregate}(h^{l-1}_{v})mul=Faggregate(hvl−1)
在节点uuu获得到由周围节点提供的信息之后,Fcombine(⋅)F_{combine}(·)Fcombine(⋅)将信息mulm^l_umul和本节点上一层的信息整合(combine)到一起,更新在这次操作之后节点uuu新的特征。
三、Message Passing之后的输出过程(不感兴趣可以略过,并不是本篇详细讨论的问题)
经过多层的MP过程之后,对于任意一个节点,都有一系列的隐层特征{hu0,...,hul}\{h^0_u,...,h^l_u\}{hu0,...,hul}。一般情况下可以通过sum,mean,max或者weighted sum这四种方式进行pooling操作,得到节点级别的输出。文章[2] 中的消融实验(ablation study)证明了learning weighted sum比sum和mean的池化方法更加好。不过,仅在[2]的数据集上测试过,我并不把这个现象看做普遍规律,还是需要平衡参数量和精度需求选择池化方式。对于图级别的输出,由于个人研究方向的限制,并没有很大的兴趣深究,以后再说啦。
四、常用的Message Passing方法
想要快速看懂大多数的GNN文章,掌握在第二节里介绍MP数学形式就基本足够了。之后就可以整理现不同GNN方法的优劣。更重要的是,也可以按照以下步骤来构建自己的MP算法。我自己使用的步骤如下:
- 注意图节点的连接方式,包括单向图、无向图(双向图)、匀质图(图中节点的性质是一致的)、异质图、稀疏图、稠密图等。构建图的邻接矩阵。
- FaggregateF_{aggregate}Faggregate:获取信息的方法主要以邻近节点的特征为输入,通过sum,concatenate,attention等机制融合邻接节点的特征。通过线性(linear)或者非线性(一般使用mlp)变换得到message特征。
- FcombineF_{combine}Fcombine:将当前节点的特征与message融合,完成当前轮的MP迭代。简单的方式可以使用MLP来实现融合。另外,可以使用LSTM和GRU门控的方式保留部分之前节点的信息来更新当前节点的特征。
- 节点输出或者图输出。
五、Pytorch Geometric示例代码,理论和实践结合
Pytorch Geometric库中已经写好了MP算法的蓝图,而且成熟的GNN的卷积层都已经写在torch_geometric.nn.conv头文件里,可以随时查阅或扩展。
下面展示了Pytorch Geometric中的GCN[3]论文中MP算法的实现过程。先说明一下能看懂这段代码所需要的的基本知识:
- GCNConv这个类别是继承了MessagePassing父类的,其中的4个函数 __ init__, forward, message和update都是MessagePassing类中定义好且需要我们自己扩展的函数。先放一下完整的代码。
- __ init__函数用来声明并初始化变量的,可以通过aggr变量控制特征的融合方式,通过flow变量来控制信息的流动方向,包括“source_to_target”和“target_to_source”。
- forward函数就是函数的运行的主体,使用pytorch的小伙伴一定不会陌生。这里提一下的是,这个函数最后需要调用self.propagate(edge_index, x, args)函数来进行MP算法。其中edge_index是邻接矩阵的稀疏表达,x是节点的特征矩阵,args就是你想在MP算法中调用的其他变量。
- message函数,顾名思义,就是用来计算message的函数。它的输入是self.propagate函数传入的变量们,基本上这个函数就是可以自由发挥的地方了。
- update函数是用来更新节点特征的函数,也就是combine的过程了。由于我平时喜欢额外调用GRUcell来计算,因此这个函数没有深入研究过。
完整代码如下:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # "Add" aggregation.
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3-5: Start propagating messages.
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_j, edge_index, size):
# x_j has shape [E, out_channels]
# Step 3: Normalize node features.
row, col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
# aggr_out has shape [N, out_channels]
# Step 5: Return new node embeddings.
return aggr_out
参考文献
以下参考文献非正规格式,仅提供能查阅到文献的必要信息,刚写论文的小伙伴不要学习哦。
[1] V. Thost, J. Chen. Directed Acyclic Graph Neural Networks. ICLR 2021.
[2] D. Xu, Y. Zhu, et al. Scene Graph Generation by Iterative Message Passing. CVPR 2017.
[3] Kipf, Welling. Semi-Supervised Classification With Graph Convolutional Networks. ICLR 2017.