代码git地址:https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling.py
dgl在最近的4月份更新的0.4.3版本中增加了dgl.sampling.sample_neighbors的模块,用来采样邻居节点,之前的版本里面实现的graphsage没有做采样,这里记录一下实现的代码过程,对一些地方加了注释:
有几个比较重要的点:
1.RedditDataset数据集232965个节点,节点向量维度是602,edges=114848857,n_classes=41,train_nid 13w训练样本的id
2.dataloader每次yeild一个batch的seed是依次从节点id里取的(见DataLoader立马的batchsamper),然后自定义的collate_fn函数利用这个seed来采样k阶的邻居节点,这里是采样了2阶,采样是取的边数,1000个点采样10000条边,生成第一个block1000-9640的一个二部图,利用9640采样25倍24100的边,生成第二个block9640-10w+节点的二部图,注意的是为了方便计算这里的二部图里src节点里面包含了dst的节点。
3.训练的时候,SAGE网络框架里定义了两层的dglnn.SAGEConv,第一层是602 * 16,第二层转换是16 * 41,其中第一层SAGEConv处理的是block9640-10w+节点的二部图,把src节点的向量发送到dst节点并取均值,加上dst的节点的原始特征向量后,利用全连接转换602 * 16 , 第二层SAGEConv处理
block1000-9640的二部图,同样取均值转换再转换16 * 41。
4.在Inference时候,因为需要采样所以预测过程有点复杂,先算完一层结果再算另一层的结果;
5.采样的函数中是实现了一篇论文中的方法,看起来挺复杂@_@
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
import tqdm
import traceback
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts):
self.g = g
self.fanouts = fanouts
def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds)) #这个seed一开始是dataloader里的batchsampler,按照batch大小依次把graph的id一个个yeild出来
blocks = []
for fanout in self.fanouts: #[10,25]
# For each seed node, sample ``fanout`` neighbors. 这里的sampler是在v0.4.3版本新加入
frontier = dgl.sampling.sample_neighbors(g, seeds, fanout, replace=True) #利用1000个seeds得到的10000个边?节点
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds) # to_black操作是把将采样的子图转换为适合计算的二部图,这里特殊的地方在于block.srcdata中的id是包含了dstnodeid的
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
# 一个种子的长度是1000,就是一个batch的索引,1000个一个batch,采样10个邻居,得到10000边9640个点,再采样25个点,得到241000个边,105693个点,Blocks里面是两个子图
blocks.insert(0, block)
return blocks
class SAGE(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
self.dropout = nn