基于图神经网络的图表征学习方法
本文主要参考DataWhale图神经网络组队学习
文中有两类任务:节点分类和图分类。
对于节点分类问题,节点在最后一层的表示
h
v
(
L
)
h_{v}^{(L)}
hv(L)就可以用于预测。
对于图分类问题,需要将graph中所有节点特征转变成graph特征,整个图的表示
h
G
h_{G}
hG如下:
h
G
=
READOUT
(
{
h
v
(
L
)
∣
v
∈
G
}
)
h_{G}=\operatorname{READOUT}\left(\left\{h_{v}^{(L)} \mid v \in G\right\}\right)
hG=READOUT({hv(L)∣v∈G})
READOUT表示一个置换不变性函数(permutation invariant function),也可以是一个图级pooling函数。
我们要学习图级别的特征用于图分类任务。
基于图同构网络(GIN)的图表征网络的实现
基于图同构网络的图表征学习主要包含以下两个过程:
- 首先计算得到节点表征;
- 其次对图上各个节点的表征做图池化(Graph Pooling),或称为图读出 (Graph Readout),得到图的表征(Graph Representation)。
GIN模型的公式如下: x i ′ = h Θ ( ( 1 + ϵ ) ⋅ x i + ∑ j ∈ N ( i ) x j ) \mathbf{x}_{i}^{\prime}=h_{\Theta}\left((1+\epsilon) \cdot \mathbf{x}_{i}+\sum_{j \in \mathcal{N}(i)} \mathbf{x}_{j}\right) xi′=hΘ((1+ϵ)⋅xi+∑j∈N(i)xj)
图同构网络(Graph Isomorphism Network,GIN),其分辨图的同构性的能力与表示图的能力与WL Test相当。
图同构测试 (WL Test)
两个图是同构的,意思是两个图拥有一样的拓扑结构,也就是说,我们可以通过重 新标记节点从一个图转换到另外一个图。Weisfeiler-Lehman 图的同构性测试算法,简称WL Test,是一种用于测试两个图是否同构的算法。
WL Test 的一维形式,类似于图神经网络中的邻接节点聚合。WL Test 1)迭代地聚 合节点及其邻接节点的标签,然后 2)将聚合的标签散列(hash)成新标签,该过程形式化为下方的公式所示:
L
u
h
←
hash
(
L
u
h
−
1
+
∑
v
∈
N
(
U
)
L
v
h
−
1
)
L_{u}^{h} \leftarrow \operatorname{hash}\left(L_{u}^{h-1}+\sum_{v \in \mathcal{N}(U)} L_{v}^{h-1}\right)
Luh←hash(Luh−1+∑v∈N(U)Lvh−1)
在上方的公示中,
L
u
h
L_{u}^{h}
Luh表示节点
u
u
u的第
h
h
h次迭代的标签,第0次迭代的标签为节点原始标签。
在迭代过程中,发现两个图之间的节点的标签不同时,就可以确定这两个图是非同构的。需要注意的是节点标签可能的取值只能是有限个数。
Weisfeiler-Leman Test 算法举例说明
Weisfeiler-Leman Test 算法通过重复执行以下给节点打标签的过程来实现图是否同构的判断:
- 聚合自身与邻接节点的标签得到一串字符串,自身标签与邻接节点的标签中间 用,分隔,邻接节点的标签按升序排序。排序的原因在于要保证单射性,即保证输出的结果不因邻接节点的顺序改变而改变。
- 标签散列,即标签压缩,将较长的字符串映射到一个简短的标签。
- 给节点重新打上标签。
每重复一次以上的过程,就完成一次节点自身标签与邻接节点标签的聚合。
当出现两个图相同节点标签的出现次数不一致时,即可判断两个图不相似。如果上述的步骤重复一定的次数后,没有发现有相同节点标签的出现次数不一致的情况,那么我们无法判断两个图是否同构。
代码实现
实现GIN层
import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder
### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr = "add")
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边嵌入
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
实现利用GIN层得到节点嵌入的模型
import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F
# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
"""GIN Node Embedding Module
采用多层GINConv实现图上结点的嵌入。
"""
super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
# add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
# List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr
# computing input node embedding
h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子嵌入
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
实现图池化层
import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding
class GINGraphPooling(nn.Module):
def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
"""GIN Graph Pooling Module
此模块首先采用GINNodeEmbedding模块对图上每一个节点做嵌入,然后对节点嵌入做池化得到图的嵌入,最后用一层线性变换得到图的最终的表示(graph representation)。
Args:
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表示的维度,dimension of graph representation).
num_layers (int, optional): number of GINConv layers. Defaults to 5.
emb_dim (int, optional): dimension of node embedding. Defaults to 300.
residual (bool, optional): adding residual connection or not. Defaults to False.
drop_ratio (float, optional): dropout rate. Defaults to 0.
JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
Out:
graph representation
"""
super(GINGraphPooling, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
# Pooling function to generate whole-graph embeddings
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=nn.Sequential(
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
elif graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
output = self.graph_pred_linear(h_graph)
if self.training:
return output
else:
# At inference time, relu is applied to output to ensure positivity
return torch.clamp(output, min=0, max=50)
作业
请画出下方图片中的6号、3号和5号节点的从1层到3层到WL子树。
6号的WL子树:
3号的WL子树:
5号的WL子树: