1.GraphSAGE
本文代码源于 DGL 的 Example 的,感兴趣可以去 github 上面查看。
阅读代码的本意是加深对论文的理解,其次是看下大佬们实现算法的一些方式方法。当然,在阅读 GraphSAGE 代码时我也发现了之前忽视的 GraphSAGE 的细节问题和一些理解错误。比如说:之前忽视了 GraphSAGE 的四种聚合方式的具体实现,对 Alogrithm 2 的算法理解也有问题,再回头看那篇 GraphSAGE 的推文时,实在惨不忍睹= =。
进入正题,简单回顾一下 GraphSAGE。
核心算法:
2.SAGEConv
dgl 已经实现了 SAGEConv 层,所以我们可以直接导入。
有了 SAGEConv 层后,GraphSAGE 实现起来就比较简单。
和基于 GraphConv 实现 GCN 的唯一区别在于把 GraphConv 改成了 SAGEConv:
class GraphSAGE(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
aggregator_type):
super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList()
self.g = g
# input layer
self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type,
feat_drop=dropout, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type,
feat_drop=dropout, activation=activation))
# output layer
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type,
feat_drop=dropout, activation=None)) # activation None
def forward(self, features):
h = features
for layer in self.layers:
h = layer(self.g, h)
return h
来看一下 SAGEConv 是如何实现的
SAGEConv 接收七个参数:
- in_feat:输入的特征大小,可以是一个整型数,也可以是两个整型数。如果用在单向二部图上,则可以用整型数来分别表示源节点和目的节点的特征大小,如果只给一个的话,则默认源节点和目的节点的特征大小一致。需要注意的是,如果参数 aggregator_type 为 gcn 的话,则源节点和目的节点的特征大小必须一致;
- out_feats:输出特征大小;
- aggregator_type:聚合类型,目前支持 mean、gcn、pool、lstm,比论文多一个 gcn 聚合,gcn 聚合可以理解为周围所有的邻居结合和当前节点的均值;
- feat_drop=0.:特征 drop 的概率,默认为 0;
- bias=True:输出层的 bias,默认为 True;
- norm=None:归一化,可以选择一个归一化的方式,默认为 None
- activation=None:激活函数,可以选择一个激活函数去更新节点特征,默认为 None。
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
# expand_as_pair 函数可以返回一个二维元组。
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
"""初始化参数
这里的 gain 可以从 calculate_gain 中获取针对非线形激活函数的建议的值
用于初始化参数
"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def _lstm_reducer(self, nodes):
"""LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)))
_, (rst, _) = self.lstm(m, h)
return {
'neigh': rst.squeeze(0)}
def forward(self, graph, feat):
""" SAGE 层的前向传播
接收 DGLGraph 和 Tensor 格式的节点特征
"""
# local_var 会返回一个作用在内部函数中使用的 Graph 对象
# 外部数据的变化不会影响到这个 Graph
# 可以理解为保护数据不被意外修改
graph = graph.local_var()
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
h_self = feat_dst
# 根据不同的聚合类型选择不同的聚合方式
# 值得注意的是,论文在 3.3 节只给出了三种聚合方式
# 而这里却多出来一个 gcn 聚合
# 具体原因将在后文给出
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
# check_eq_shape 用于检查源节点和目的节点的特征大小是否一致
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'