学习目标
- 学习message passing范式
- 初步分析PyG中的MessagePassing基类
- 学习对MessagePassing基类的继承
消息传递范式
- 神经网络生成节点表征的操作称作node embedding
- node embedding应该可以衡量节点之间的相似性
- data.x是第0层的节点表征,第 h h h层节点表征经过一次节点间信息传递产生第 h + 1 h+1 h+1层的节点表征
MessagePassing基类分析
PyG提供了MessagePassing基类,它封装了消息传递的运行流程。构造一个最简单的消息传递图神经网络,我们只需要定义message() ( ϕ \phi ϕ), update() ( γ \gamma γ)和消息聚合方案。
- MessagePassing(aggr=“add”, flow=“source_to_target”, node_dim=-2)
aggr
:定义要使用的聚合方案(“add”、"mean "或 “max”)
flow
:定义消息传递的流向("source_to_target "或 “target_to_source”)
node_dim
:定义沿着哪个维度传播,默认值为-2
,也就是节点表征张量(Tensor)的哪一个维度是节点维度。节点表征张量x
形状为[num_nodes, num_features]
,其第0维度(也是第-2维度)是节点维度,其第1维度(也是第-1维度)是节点表征维度,所以我们可以设置node_dim=-2
。 - MessagePassing.propagate(edge_index, size=None, **kwargs)
开始传递消息的起始调用,在此方法中message
、update
等方法被调用。
propagate()
不仅限于基于形状为[N, N]
的对称邻接矩阵进行“消息传递过程”。基于非对称的邻接矩阵进行消息传递(当图为二部图时),需要传递参数size=(N, M)
。如果设置size=None
,则认为邻接矩阵是对称的。 - MessagePassing.message(…)
我们用 i i i表示消息传递的中心节点,用 j j