【Code】GraphSAGE 源码解析

本文深入解析GraphSAGE的源码,重点介绍了SAGEConv层的不同聚合方式(MEAN、GCN、POOL、LSTM)以及Neighbor Sampler的实现。通过阅读代码,详细阐述了每种聚合器的工作原理,并提供了相关算法的伪代码和实际代码片段。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.GraphSAGE

本文代码源于 DGL 的 Example 的,感兴趣可以去 github 上面查看。

阅读代码的本意是加深对论文的理解,其次是看下大佬们实现算法的一些方式方法。当然,在阅读 GraphSAGE 代码时我也发现了之前忽视的 GraphSAGE 的细节问题和一些理解错误。比如说:之前忽视了 GraphSAGE 的四种聚合方式的具体实现,对 Alogrithm 2 的算法理解也有问题,再回头看那篇 GraphSAGE 的推文时,实在惨不忍睹= =。

进入正题,简单回顾一下 GraphSAGE。

核心算法:

2.SAGEConv

dgl 已经实现了 SAGEConv 层,所以我们可以直接导入。

有了 SAGEConv 层后,GraphSAGE 实现起来就比较简单。

和基于 GraphConv 实现 GCN 的唯一区别在于把 GraphConv 改成了 SAGEConv:

class GraphSAGE(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.g = g
        # input layer
        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type,
                                    feat_drop=dropout, activation=activation))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type,
                                        feat_drop=dropout, activation=activation))
        # output layer
        self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type,
                                    feat_drop=dropout, activation=None)) # activation None
        
    def forward(self, features):
        h = features
        for layer in self.layers:
            h = layer(self.g, h)
        return h

来看一下 SAGEConv 是如何实现的

SAGEConv 接收七个参数:

  • in_feat:输入的特征大小,可以是一个整型数,也可以是两个整型数。如果用在单向二部图上,则可以用整型数来分别表示源节点和目的节点的特征大小,如果只给一个的话,则默认源节点和目的节点的特征大小一致。需要注意的是,如果参数 aggregator_type 为 gcn 的话,则源节点和目的节点的特征大小必须一致;
  • out_feats:输出特征大小;
  • aggregator_type:聚合类型,目前支持 mean、gcn、pool、lstm,比论文多一个 gcn 聚合,gcn 聚合可以理解为周围所有的邻居结合和当前节点的均值;
  • feat_drop=0.:特征 drop 的概率,默认为 0;
  • bias=True:输出层的 bias,默认为 True;
  • norm=None:归一化,可以选择一个归一化的方式,默认为 None
  • activation=None:激活函数,可以选择一个激活函数去更新节点特征,默认为 None。
class SAGEConv(nn.Module):
    

    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        # expand_as_pair 函数可以返回一个二维元组。
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type != 'gcn':
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        """初始化参数
        这里的 gain 可以从 calculate_gain 中获取针对非线形激活函数的建议的值
        用于初始化参数
        """
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
        _, (rst, _) = self.lstm(m, h)
        return {
   'neigh': rst.squeeze(0)}

    def forward(self, graph, feat):
        """ SAGE 层的前向传播
        接收 DGLGraph 和 Tensor 格式的节点特征
        """
        # local_var 会返回一个作用在内部函数中使用的 Graph 对象
        # 外部数据的变化不会影响到这个 Graph
        # 可以理解为保护数据不被意外修改
        graph = graph.local_var()

        if isinstance(feat, tuple):
            feat_src = self.feat_drop(feat[0])
            feat_dst = self.feat_drop(feat[1])
        else:
            feat_src = feat_dst = self.feat_drop(feat)

        h_self = feat_dst

        # 根据不同的聚合类型选择不同的聚合方式
        # 值得注意的是,论文在 3.3 节只给出了三种聚合方式
        # 而这里却多出来一个 gcn 聚合
        # 具体原因将在后文给出
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            # check_eq_shape 用于检查源节点和目的节点的特征大小是否一致
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst     # same as above if homogeneous
            graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
            # divide in_degrees
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值