在看DGL中GraphSAGE实现源码的时候,发现里面出现多次srcdata,dstdata。一开始以为是图中边的所有源节点和目标节点,但是后来发现不是,因为在有向图上测试可以得到srcdata和dstdata完全是一样的,这就有点蛋疼了。
后来追踪DGLGraph.num_src_nodes()函数看到下面的注释后才明白,简单来说只有当图是二分图,且节点可以分成A、B两种类型,所有的边都是从A指向B的时候,A就是source,B就是destination。其他时候src和dst都是指全部的结点。官方代码注释如下:
def num_src_nodes(self, ntype=None):
"""Return the number of source nodes in the graph.
If the graph can further divide its node types into two subsets A and B where
all the edeges are from nodes of types in A to nodes of types in B, we call
this graph a *uni-bipartite* graph and the nodes in A being the *source*
nodes and the ones in B being the *destination* nodes. If the graph is not
uni-bipartite, the source and destination nodes are just the entire set of
nodes in the graph.
"""
在GraphSAGE的源码中,srcdata和dstdata的更新以及feat_src,feat_dst的设置都是为了特殊的二分图设置的,看起来是真蛋疼。。下面是GCN Aggregator部分对应的源码:
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
# 公式中给出集合并集,这里做 element-wise 的和,问题也不大。
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
rst = self.fc_neigh(h_neigh)
DGL中srcdata和dstdata含义解析
在DGL的GraphSAGE实现中,srcdata和dstdata的概念可能初看容易混淆。实际上,它们仅在图是二分图,节点分为A、B两类且边从A指向B时有意义,此时A为source,B为destination。在其他情况下,src和dst都代表所有节点。GraphSAGE的源码中,srcdata和dstdata的处理主要针对这种特殊二分图结构,体现在节点特征的更新和设置上。
1754





