作者:CHEONG
公众号:AI机器学习与知识图谱
研究方向:自然语言处理与知识图谱
本文先简单概述GNN节点分类任务,然后详细介绍如何使用Deep Graph Library + Pytorch实现一个简单的两层GNN模型在Cora引文数据上实现节点分类任务。若需获取模型的完整代码,可关注公众号后回复:DGL第一讲完整代码
一、GNN节点分类概述
节点分类是图/图谱数据上常被采用的一个学习任务,既是用模型预测图中每个节点的类别。在GNN模型被提出之前,常用的模型如DeepWalk,Node2Vec等,都是借助序列属性和节点自身特性进行预测,但显然图数据不像NLP中的文本数据那样具有序列依赖性。相比之下,GNN系列模型是利用节点的邻接子图,使用子图汇聚的方式先获得节点表征,再对节点类别进行预测。例如,在2017年Kipf et al.等提出的GCN模型将图的节点分类问题看作一个半监督学习任务。即只利用图中一小部分节点,模型就可以准确预测其他节点的类别。
接下来的实验将通过构建GCN模型,在Cora数据集上进行半监督节点分类任务的训练和预测。Cora数据集是一个引文网络,其中节点是代指某篇论文,节点之间的边代表论文之间的相互引用关系。
| NumNodes | NumEdges | NumFeats | NumClasses |
|---|---|---|---|
| 2708 | 10556 | 1433 | 7 |
| Num Training Samples | Num Validation Samples | Num Test Samples |
|---|---|---|
| 140 | 500 | 1000 |
如上表格所示,Cora引文网络共包含2708个节点,10556个边,其中每个节点由1433维特征组成,每个特征代表词库中的一个Word,如果此篇论文中包含这个Word则这一维特征为1,否则这一维特征为0。在训练数据划分上,其中训练集140个样本节点,验证集500个,测试集1000个。目的是训练模型少标签半监督任务的预测能力。Cora引文网络中节点共分为七类,因此节点分类任务是个七分类问题。
二、DGL实现GNN节点分类
接下来使用DGL框架实现GNN模型进行节点分类任务,对代码进行逐行解释。
1 import dgl
2 import torch
3 import torch.nn as nn
4 import torch.nn.functional as F
首先,上述四行代码,先加载需要使用的dgl库和pytorch库;
1 import dgl.data
2 dataset = dgl.data.CoraGraphDataset()
3 print('Number of categories:', dataset.num_classes)
4 g = dataset[0]
上面第二行代码,加载dgl库提供的Cora数据对象,第四行代码,dgl库中Dataset数据集可能是包含多个图的,所以加载的dataset对象是一个list,list中的每个元素对应该数据的一个graph,但Cora数据集是由单个图组成,因此直接使用dataset[0]取出graph。
print('Node features: ', g.ndata)
print('Edge features: ', g.edata)
看上面两行代码,需要说明DGL库中一个Graph对象是使用字典形式存储了其Node Features和Edge Features,其中第一行g.ndata使用字典结构存储了节点特征信息,第二行g.edata使用字典结构存储了边特征信息。对于Cora数据集的graph来说,Node Features共包含以下五个方面:
\1. train_mask: 指示节点是否在训练集中的布尔张量
\2. val_mask: 指示节点是否在验证集中的布尔张量
\3. test_mask: 指示节点是否在测试机中的布尔张量
\4. label: 每个节点的真实类别
\5. feat: 节点自身的属性
1 from dgl.nn import GraphConv
2
3 class GCN(nn.Module):
4 def __init__(self, in_feats, h_feats, num_classes):
5 super(GCN, self).__init__()
6 self.conv1 = GraphConv(in_feats, h_feats)
7 self.conv2 = GraphConv(h_feats, num_classes)
8
9 def forward(self, g, in_feat):
10 # 这里g代表的Cora数据Graph信息,一般就是经过归一化的邻接矩阵
11 # in_feat表示的是node representation,即节点初始化特征信息
12 h = self.conv1(g, in_feat)
13 h = F.relu(h)
14 h = self.conv2

本文详细介绍了使用DGL和Pytorch实现GCN模型在Cora引文数据集上进行节点分类的过程,包括数据概述、模型构建和训练步骤。
最低0.47元/天 解锁文章
1348





