【GNN框架系列】DGL第一讲:使用Deep Graph Library实现GNN进行节点分类

本文详细介绍了使用DGL和Pytorch实现GCN模型在Cora引文数据集上进行节点分类的过程,包括数据概述、模型构建和训练步骤。

作者:CHEONG

公众号:AI机器学习与知识图谱

研究方向:自然语言处理与知识图谱


本文先简单概述GNN节点分类任务,然后详细介绍如何使用Deep Graph Library + Pytorch实现一个简单的两层GNN模型在Cora引文数据上实现节点分类任务。若需获取模型的完整代码,可关注公众号后回复:DGL第一讲完整代码



一、GNN节点分类概述


节点分类是图/图谱数据上常被采用的一个学习任务,既是用模型预测图中每个节点的类别。在GNN模型被提出之前,常用的模型如DeepWalk,Node2Vec等,都是借助序列属性和节点自身特性进行预测,但显然图数据不像NLP中的文本数据那样具有序列依赖性。相比之下,GNN系列模型是利用节点的邻接子图,使用子图汇聚的方式先获得节点表征,再对节点类别进行预测。例如,在2017年Kipf et al.等提出的GCN模型将图的节点分类问题看作一个半监督学习任务。即只利用图中一小部分节点,模型就可以准确预测其他节点的类别。

接下来的实验将通过构建GCN模型,在Cora数据集上进行半监督节点分类任务的训练和预测。Cora数据集是一个引文网络,其中节点是代指某篇论文,节点之间的边代表论文之间的相互引用关系。

NumNodes NumEdges NumFeats NumClasses
2708 10556 1433 7
Num Training Samples Num Validation Samples Num Test Samples
140 500 1000

如上表格所示,Cora引文网络共包含2708个节点,10556个边,其中每个节点由1433维特征组成,每个特征代表词库中的一个Word,如果此篇论文中包含这个Word则这一维特征为1,否则这一维特征为0。在训练数据划分上,其中训练集140个样本节点,验证集500个,测试集1000个。目的是训练模型少标签半监督任务的预测能力。Cora引文网络中节点共分为七类,因此节点分类任务是个七分类问题。



二、DGL实现GNN节点分类


接下来使用DGL框架实现GNN模型进行节点分类任务,对代码进行逐行解释。

1 import dgl
2 import torch
3 import torch.nn as nn
4 import torch.nn.functional as F

首先,上述四行代码,先加载需要使用的dgl库和pytorch库;

1 import dgl.data
2 dataset = dgl.data.CoraGraphDataset()
3 print('Number of categories:', dataset.num_classes)
4 g = dataset[0]

上面第二行代码,加载dgl库提供的Cora数据对象,第四行代码,dgl库中Dataset数据集可能是包含多个图的,所以加载的dataset对象是一个list,list中的每个元素对应该数据的一个graph,但Cora数据集是由单个图组成,因此直接使用dataset[0]取出graph。

print('Node features: ', g.ndata)
print('Edge features: ', g.edata)

看上面两行代码,需要说明DGL库中一个Graph对象是使用字典形式存储了其Node Features和Edge Features,其中第一行g.ndata使用字典结构存储了节点特征信息,第二行g.edata使用字典结构存储了边特征信息。对于Cora数据集的graph来说,Node Features共包含以下五个方面:

\1. train_mask: 指示节点是否在训练集中的布尔张量

\2. val_mask: 指示节点是否在验证集中的布尔张量

\3. test_mask: 指示节点是否在测试机中的布尔张量

\4. label: 每个节点的真实类别

\5. feat: 节点自身的属性

1  from dgl.nn import GraphConv
2  
3  class GCN(nn.Module):
4      def __init__(self, in_feats, h_feats, num_classes):
5          super(GCN, self).__init__()
6          self.conv1 = GraphConv(in_feats, h_feats)
7          self.conv2 = GraphConv(h_feats, num_classes)
8  
9      def forward(self, g, in_feat):
10         # 这里g代表的Cora数据Graph信息,一般就是经过归一化的邻接矩阵
11         # in_feat表示的是node representation,即节点初始化特征信息
12         h = self.conv1(g, in_feat)
13         h = F.relu(h)
14         h = self.conv2
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值