[Scene Graph] 图神经网络的核心方法——Message Passing

本文深入探讨了图神经网络(GNN)如何通过Message Passing方法学习图数据的特征。GNN借鉴了CNN的思想,但针对非结构化数据如图和点集进行信息提取。Message Passing涉及节点间信息的传递和融合,通过Faggregate和Fcombine函数实现邻接节点特征的聚合和组合。在数学上,Message Passing可以用一个公式概括,描述了节点特征的更新过程。此外,文章还介绍了PytorchGeometric库中GCNConv类的实现,展示了一个具体的Message Passing实例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

GNN中的Message Passing方法解析
一、GNN中是如何实现特征学习的?

深度学习方法的兴起是从计算图像处理(Computer Vision)领域开始的。以卷积神经网络(CNN)为代表的方法会从邻近的像素中获取信息。这种方式对于结构化数据(structured data)十分有效,例如,图像和体素数据。但是,CNN的处理方式对于类似图(graph)数据则并不适用。对于一个图而言,类似图像像素的邻近关系是不存在的。考虑图中任意一个节点,我们几乎不可能始终找到8个有序邻近节点,因此类CNN卷积核将不再适用非结构数据(unstructured data)特征提取,这种非结构数据包括图(graph)和点集(set),例如点云数据(point cloud)和场景图(scene graph)。

那么对于图来说,GNN是如何学习到有效的特征呢?这就是这篇博文将要讨论的十分重要的方法——Message Passing。如果知道Message Passing准确的中文翻译的小伙伴可以给我留言,这篇博文中我将用缩写MP代表Message Passing。

大致来说,GNN与CNN在思想上是相通的。GNN也希望能够从“邻接”的元素中获取信息,并学习到固定的且可以固化在向量中的特征。因此,在字面意义上来说MP的思想就是,节点与节点之间可以相互传播(pass)信息(message),通过对于邻接节点信息的处理和原本节点特征的融合,来更新每一个点的特征,最后得到节点级别(node-level)或者图级别(graph-level)的序列化模式表达,这种表达可以完成节点分类,图分类,图生成等十分基本的任务。

二、如何用数学定义Message Passing过程?

小伙伴们不用害怕复杂的数学公式,只需要跟着思路走就可以,我会使用可以直观理解的变量命名方式的。首先,先在脑海中构建一个抽象的图结构G\mathcal{G}G,这个图的节点集合为V\mathcal{V}V,边的集合为E\mathcal{E}E。重点讨论的变量我列在了下面:

u,u∈Vu, u\in \mathcal{V}u,uV:我们重点讨论节点,可以是图中的任意一个节点。
{v∣v∈N(u)}\{v| v \in N(u)\}{vvN(u)}:节点uuu的邻接节点集合,与uuu有有效连接的节点的集合。
xux_{u}xu:节点uuu的初始特征向量,如果是节点vvv的特征则标记为xvx_vxv
huh_{u}hu:节点uuu的隐层特征向量,如果是节点vvv的隐层特征则标记为hvh_vhv。若GNN中包含许多中间层,那么使用hulh^{l}_{u}hul来表述第lll层的节点uuu的隐层特征。

MP的过程可以由下述一个公式概括,这里借用文献[1]中的表述:
hul=Fcombine(hul−1,Faggregate(hvl−1))h^{l}_{u}=F_{combine}(h^{l-1}_u, F_{aggregate}(h^{l-1}_{v}))hul=Fcombine(hul1,Faggregate(hvl1))

简单来说Faggregate(⋅)F_{aggregate}(·)Faggregate()的作用就是从uuu的邻接节点中汇集(aggregate)特征并加以变换,得到节点uuu在本次得到的信息,因此,也可以写作:
mul=Faggregate(hvl−1)m^l_u=F_{aggregate}(h^{l-1}_{v})mul=Faggregate(hvl1)

在节点uuu获得到由周围节点提供的信息之后,Fcombine(⋅)F_{combine}(·)Fcombine()将信息mulm^l_umul和本节点上一层的信息整合(combine)到一起,更新在这次操作之后节点uuu新的特征。

三、Message Passing之后的输出过程(不感兴趣可以略过,并不是本篇详细讨论的问题)

经过多层的MP过程之后,对于任意一个节点,都有一系列的隐层特征{hu0,...,hul}\{h^0_u,...,h^l_u\}{hu0,...,hul}。一般情况下可以通过sum,mean,max或者weighted sum这四种方式进行pooling操作,得到节点级别的输出。文章[2] 中的消融实验(ablation study)证明了learning weighted sum比sum和mean的池化方法更加好。不过,仅在[2]的数据集上测试过,我并不把这个现象看做普遍规律,还是需要平衡参数量和精度需求选择池化方式。对于图级别的输出,由于个人研究方向的限制,并没有很大的兴趣深究,以后再说啦。

四、常用的Message Passing方法

想要快速看懂大多数的GNN文章,掌握在第二节里介绍MP数学形式就基本足够了。之后就可以整理现不同GNN方法的优劣。更重要的是,也可以按照以下步骤来构建自己的MP算法。我自己使用的步骤如下:

  1. 注意图节点的连接方式,包括单向图、无向图(双向图)、匀质图(图中节点的性质是一致的)、异质图、稀疏图、稠密图等。构建图的邻接矩阵。
  2. FaggregateF_{aggregate}Faggregate:获取信息的方法主要以邻近节点的特征为输入,通过sum,concatenate,attention等机制融合邻接节点的特征。通过线性(linear)或者非线性(一般使用mlp)变换得到message特征。
  3. FcombineF_{combine}Fcombine:将当前节点的特征与message融合,完成当前轮的MP迭代。简单的方式可以使用MLP来实现融合。另外,可以使用LSTM和GRU门控的方式保留部分之前节点的信息来更新当前节点的特征。
  4. 节点输出或者图输出。
五、Pytorch Geometric示例代码,理论和实践结合

Pytorch Geometric库中已经写好了MP算法的蓝图,而且成熟的GNN的卷积层都已经写在torch_geometric.nn.conv头文件里,可以随时查阅或扩展。

下面展示了Pytorch Geometric中的GCN[3]论文中MP算法的实现过程。先说明一下能看懂这段代码所需要的的基本知识:

  1. GCNConv这个类别是继承了MessagePassing父类的,其中的4个函数 __ init__, forward, message和update都是MessagePassing类中定义好且需要我们自己扩展的函数。先放一下完整的代码。
  2. __ init__函数用来声明并初始化变量的,可以通过aggr变量控制特征的融合方式,通过flow变量来控制信息的流动方向,包括“source_to_target”和“target_to_source”。
  3. forward函数就是函数的运行的主体,使用pytorch的小伙伴一定不会陌生。这里提一下的是,这个函数最后需要调用self.propagate(edge_index, x, args)函数来进行MP算法。其中edge_index是邻接矩阵的稀疏表达,x是节点的特征矩阵,args就是你想在MP算法中调用的其他变量。
  4. message函数,顾名思义,就是用来计算message的函数。它的输入是self.propagate函数传入的变量们,基本上这个函数就是可以自由发挥的地方了。
  5. update函数是用来更新节点特征的函数,也就是combine的过程了。由于我平时喜欢额外调用GRUcell来计算,因此这个函数没有深入研究过。

完整代码如下:

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
 
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
 
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
 
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
 
        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)
 
        # Step 3-5: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
 
    def message(self, x_j, edge_index, size):
        # x_j has shape [E, out_channels]
 
        # Step 3: Normalize node features.
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
 
        return norm.view(-1, 1) * x_j
 
    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
 
        # Step 5: Return new node embeddings.
        return aggr_out
参考文献

以下参考文献非正规格式,仅提供能查阅到文献的必要信息,刚写论文的小伙伴不要学习哦。
[1] V. Thost, J. Chen. Directed Acyclic Graph Neural Networks. ICLR 2021.
[2] D. Xu, Y. Zhu, et al. Scene Graph Generation by Iterative Message Passing. CVPR 2017.
[3] Kipf, Welling. Semi-Supervised Classification With Graph Convolutional Networks. ICLR 2017.

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值