PyG 中Message Passing机制详解

本文详细介绍了消息传递范式在图神经网络中的应用,通过PyG库中的MessagePassing基类,阐述了如何构建和理解图神经网络。内容涵盖消息传递的基本流程,包括节点信息的变换、聚合和更新,以及如何通过继承MessagePassing基类定义自己的图神经网络层。文中以GCNConv为例,展示了如何实现图卷积层,并进一步分析了MessagePassing基类的内部工作原理,包括message()、aggregate()、message_and_aggregate()和update()方法的使用。最后,文章强调了通过实践来深入理解图神经网络构建的重要性。

消息传递图神经网络

一、引言

在开篇中我们介绍了,为节点生成节点表征(Node Representation)是图计算任务成功的关键,我们要利用神经网络来学习节点表征。消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。消息传递范式因为简单、强大的特性,于是被人们广泛地使用。遵循消息传递范式的图神经网络被称为消息传递图神经网络。本节中,

  • 首先我们将学习图神经网络生成节点表征的范式–消息传递(Message Passing)范式
  • 接着我们将初步分析PyG中的MessagePassing基类,通过继承此基类我们可以方便地构造一个图神经网络。
  • 然后我们以继承MessagePassing基类的GCNConv类为例,学习如何通过继承MessagePassing基类来构造图神经网络。
  • 再接着我们将对MessagePassing基类进行剖析。
  • 最后我们将学习在继承MessagePassing基类的子类中覆写message(),aggreate(),message_and_aggreate()update(),这些方法的规范。

二、 消息传递范式介绍

下方图片展示了基于消息传递范式的聚合邻接节点信息来更新中心节点信息的过程

  1. 图中黄色方框部分展示的是一次邻接节点信息传递到中心节点的过程:B节点的邻接节点(A,C)的信息经过变换聚合到B节点,接着B节点信息与邻接节点聚合信息一起经过变换得到B节点的新的节点信息。同时,分别如红色和绿色方框部分所示,遵循同样的过程,C、D节点的信息也被更新。实际上,同样的过程在所有节点上都进行了一遍,所有节点的信息都更新了一遍。
  2. 这样的“邻接节点信息传递到中心节点的过程”会进行多次。如图中蓝色方框部分所示,A节点的邻接节点(B,C,D)的已经发生过一次更新的节点信息,经过变换、聚合、再变换产生了A节点第二次更新的节点信息。多次更新后的节点信息就作为节点表征。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h1dqI2J9-1648002261843)(images/image-20210516110407207.png)]

图片来源于:Graph Neural Network • Introduction to Graph Neural Networks

消息传递图神经网络遵循上述的“聚合邻接节点信息来更新中心节点信息的过程”,来生成节点表征。用 x i ( k − 1 ) ∈ R F \mathbf{x}^{(k-1)}_i\in\mathbb{R}^F xi(k1)RF表示 ( k − 1 ) (k-1) (k1)层中节点 i i i的节点表征, e j , i ∈ R D \mathbf{e}_{j,i} \in \mathbb{R}^D ej,iRD 表示从节点 j j j到节点 i i i的边的属性,消息传递图神经网络可以描述为
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)),
其中 □ \square 表示可微分的、具有排列不变性(函数输出结果与输入参数的排列无关)的函数。具有排列不变性的函数有,sum()函数、mean()函数和max()函数。 γ \gamma γ ϕ \phi ϕ表示可微分的函数,如MLPs(多层感知器)。此处内容来源于CREATING MESSAGE PASSING NETWORKS

注(1):神经网络的生成节点表征的操作称为节点嵌入(Node Embedding),节点表征也可以称为节点嵌入。为了统一此次组队学习中的表述,我们规定节点嵌入只代指神经网络生成节点表征的操作

注(2):未经过训练的图神经网络生成的节点表征还不是好的节点表征,好的节点表征可用于衡量节点之间的相似性。通过监督学习对图神经网络做很好的训练,图神经网络才可以生成好的节点表征。我们将在第5节介绍此部分内容。

注(3),节点表征与节点属性的区分:遵循被广泛使用的约定,此次组队学习我们也约定,节点属性data.x是节点的第0层节点表征,第 h h h层的节点表征经过一次的节点间信息传递产生第 h + 1 h+1 h+1层的节点表征。不过,节点属性不单指data.x,广义上它就指节点的属性,如节点的度等。

三、MessagePassing基类初步分析

Pytorch Geometric(PyG)提供了MessagePassing基类,它封装了“消息传递”的运行流程。通过继承MessagePassing基类,可以方便地构造消息传递图神经网络。构造一个最简单的消息传递图神经网络类,我们只需定义**message()方法( ϕ \phi ϕupdate()方法( γ \gamma

评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值