#安装依赖
#!pip install paddlepaddle==1.8.5
!pip install pgl
2. GraphSage采样函数实现
GraphSage的作者提出了采样算法来使得模型能够以Mini-batch的方式进行训练,算法伪代码见论文附录A。
- 假设我们要利用中心节点的k阶邻居信息,则在聚合的时候,需要从第k阶邻居传递信息到k-1阶邻居,并依次传递到中心节点。
- 采样的过程刚好与此相反,在构造第t轮训练的Mini-batch时,我们从中心节点出发,在前序节点集合中采样 Nt个邻居节点加入采样集合。
- 接着将邻居节点作为新的中心节点继续进行第t-1轮训练的节点采样,以此类推。
- 最后将采样到的节点和边一起构造得到子图。
下面请将GraphSage的采样函数补充完整。
%%writefile userdef_sample.py
import numpy as np
def traverse(item):
"""traverse
"""
if isinstance(item, list) or isinstance(item, np.ndarray):
for i in iter(item):
for j in traverse(i):
yield j
else:
yield item
def flat_node_and_edge(nodes):
"""flat_node_and_edge
"""
nodes = list(set(traverse(nodes)))
return nodes
def my_graphsage_sample(graph, batch_train_samples, samples):
"""
输入:graph - 图结构 Graph
batch_train_samples - 中心节点 list (batch_size,)
samples - 采样时的最大邻节点数列表 list
输出:被采样节点下标的集合
对当前节点进行k阶采样后得到的子图
"""
st

本文详细介绍了GraphSage的采样算法和聚合函数实现,包括从中心节点出发的反向采样过程,以及Mean Aggregator和MaxPool Aggregator的示例代码。通过与PGL库的对比,帮助读者理解不同聚合策略的差异。
最低0.47元/天 解锁文章
923

被折叠的 条评论
为什么被折叠?



