GraphSAGE 代码示例,演示如何在大规模图上通过邻居采样(Neighbor Sampling)来降低计算复杂度。
代码使用 PyTorch + torch_geometric(PyG),因为它已经把邻居采样、分布式训练等常用技巧做得比较好。但为了让你更直观地看到实现细节,我也在关键处写了大量注释,并给出了可选的纯 PyTorch 实现思路。
⚠️ 运行前请先安装依赖
pip install torch torchvision torchaudio torch-geometric==2.5.3 \
torch-scatter==2.1.2 torch-sparse==0.6.18 torch-cluster==1.6.3 \
torch-spline-conv==1.2.1 tqdm
⚠️ 如果你使用的是 RTX 3090 / A100 等显卡,建议把
torch_scatter、torch_sparse用对应的 CUDA wheel 安装,以获得更快的稀疏矩阵运算。
1. 数据准备
我们以常用的 Cora 文档分类数据集为例(图大小约 2,708 篇文献,5,429 条边),你可以把下面的数据加载逻辑直接改成自己的大规模图文件。
import torch
from torch_ge
GraphSAGE 代码示例详解
订阅专栏 解锁全文
104

被折叠的 条评论
为什么被折叠?



