dgl框架实现graphsage代码流程梳理

本文档详细介绍了如何使用DGL 0.4.3版本实现GraphSAGE模型,重点讨论了Reddit数据集、节点采样、dataloader的定制、SAGE网络的构建以及预测阶段的处理。在训练过程中,DGL的sample_neighbors模块用于采样邻居节点,形成二部图,并通过SAGEConv层进行特征传播和聚合。在推断阶段,由于需要多次采样,预测过程相对复杂。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

代码git地址:https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling.py
dgl在最近的4月份更新的0.4.3版本中增加了dgl.sampling.sample_neighbors的模块,用来采样邻居节点,之前的版本里面实现的graphsage没有做采样,这里记录一下实现的代码过程,对一些地方加了注释:
有几个比较重要的点:
1.RedditDataset数据集232965个节点,节点向量维度是602,edges=114848857,n_classes=41,train_nid 13w训练样本的id
2.dataloader每次yeild一个batch的seed是依次从节点id里取的(见DataLoader立马的batchsamper),然后自定义的collate_fn函数利用这个seed来采样k阶的邻居节点,这里是采样了2阶,采样是取的边数,1000个点采样10000条边,生成第一个block1000-9640的一个二部图,利用9640采样25倍24100的边,生成第二个block9640-10w+节点的二部图,注意的是为了方便计算这里的二部图里src节点里面包含了dst的节点
3.训练的时候,SAGE网络框架里定义了两层的dglnn.SAGEConv,第一层是602 * 16,第二层转换是16 * 41,其中第一层SAGEConv处理的是block9640-10w+节点的二部图,把src节点的向量发送到dst节点并取均值,加上dst的节点的原始特征向量后,利用全连接转换602 * 16 , 第二层SAGEConv处理
block1000-9640的二部图,同样取均值转换再转换16 * 41。
4.在Inference时候,因为需要采样所以预测过程有点复杂,先算完一层结果再算另一层的结果;
5.采样的函数中是实现了一篇论文中的方法,看起来挺复杂@_@
图中第二步计算的是block1

import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
import tqdm
import traceback

#### Neighbor sampler

class NeighborSampler(object):
    def __init__(self, g, fanouts):
        self.g = g
        self.fanouts = fanouts

    def sample_blocks(self, seeds):
        seeds = th.LongTensor(np.asarray(seeds)) #这个seed一开始是dataloader里的batchsampler,按照batch大小依次把graph的id一个个yeild出来
        blocks = []
        for fanout in self.fanouts: #[10,25]
            # For each seed node, sample ``fanout`` neighbors. 这里的sampler是在v0.4.3版本新加入
            frontier = dgl.sampling.sample_neighbors(g, seeds, fanout, replace=True) #利用1000个seeds得到的10000个边?节点
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds) # to_black操作是把将采样的子图转换为适合计算的二部图,这里特殊的地方在于block.srcdata中的id是包含了dstnodeid的
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]
            # 一个种子的长度是1000,就是一个batch的索引,1000个一个batch,采样10个邻居,得到10000边9640个点,再采样25个点,得到241000个边,105693个点,Blocks里面是两个子图
            blocks.insert(0, block)
        return blocks

class SAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值