DGL笔记1——用DGL表示图
DGL笔记2——用DGL识别节点
DGL笔记3——自己写一个GNN模型
之前我们学习了 DGL 怎么表示一个图,然后怎么写一个简单的 GCN 模型进行节点识别。但是有时候我们的模型不仅仅是简单地堆叠现有的 GNN 模块。 比如我们现在想发明一种考虑节点重要性或边权重来聚合邻域信息的新方法,该怎么办?
所以我们现在将要学习:
-
DGL 的消息传递 API。
-
自己实现 GraphSAGE 卷积模块。
记得在看这篇之前先看上一篇 GNN 分类哈~
首先导入相关包:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
消息传递与GNNs
DGL 遵循 由 Gilmer 等人提出的 Message Passing Neural Network 中启发产生的消息传递范式(message passing paradigm )。本质上,很多GNN模型都符合以下框架:
m u → v = M ( l ) ( h v ( l − 1 ) , h u ( l − 1 ) , e u → v ( l − 1 ) ) \large m_{u\rightarrow v}=M^{(l)}\left( h_v^{(l-1)},h_u^{(l-1)},e_{u\rightarrow v}^{(l-1)} \right) mu→v=M(l)(hv(l−1),hu(l−1),eu→v(l−1))
m u = Σ u ∈ N ( v ) m u → v ( l ) \large m_{u}=\Sigma_{u\in\mathcal N(v)}m_{u\rightarrow v}^{(l)} mu=Σu∈N(v)mu→v(l)
h v ( l ) = U (