源代码:GitHub - tkipf/pygcn: Graph Convolutional Networks in PyTorch
依据记忆,自己敲了源代码,可能与源代码有细微差别,但也可以正常运行。
注释中给出了一些debug时的流程。
一、utils代码分析
主要由两部分构成:数据预处理load_data()+精度计算accuracy()
load_data()
1)features
提取featues -> 转化为csr_matrix存储 -> 归一化 -> 转变为张量torch.from_numpy()
2)labels
提取labels -> onehot编码-> 找到节点的标签np.where(labels) -> 转变为张量
注:代码中第20行将前两步合成一步,43行将后两步合成了一步。
3)建立图
论文的idx是杂乱无章的,表示起来不方便,所以构建字典给论文编号。
建立图的过程较复杂,要生成idx, edges和adj.
| 图的组成 | 步骤 |
| idx | 提取idx -> 生成字典,给节点编号 |
| edges | 读取边 -> 用节点编号表示边 |
| adj | 生成adj -> 对称化 -> 归一化 -> 转变为稀疏张量 |
accuracy()
计算精度, 在train()中被调用。
import numpy as np
import scipy.sparse as sp
import torch
def encode_onehot(labels):
classes=set(labels) #求出labels的类
classes_dict={c:np.identity(len(classes))[i,:] for i,c in enumerate(classes)} #生成类别字典
labels_onehot=np.array(list(map(classes_dict.get, labels)), dtype=np.int32) #类型还没加上
return labels_onehot
def load_data(path='../data/cora/', dataset='cora'):
print("Loading {} data...".format(dataset))
idx_features_labels = np.genfromtxt('{}{}.content'.format(path, dataset), dtype=np.dtype(str))
#idx_features_labels=np.array(np.genfromtxt('{}{}.content'.format(path, dataset)), dtype=np.dtype(str))
#error:得到的features是浮点型,得到的labels是nan
'''
也可以写作sp.coo_matrix,只不过变成了col,row作为指示
'''
features=sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32) #提取第一列至倒数第二列
labels=encode_onehot(idx_features_labels[:,-1])
'''build graph'''
'''①给结点编号'''
idx=np.array(idx_features_labels[:,0], dtype=np.int32)
#idx=idx_features_labels[:,0]有bug,原因:提取出来的是str类型的数据,idx:['31336' '1061127']
# 得到的idx_dict:{'31336': 0,'1061127': 1,...}在第30行map()会出错!
idx_dict={j:i for i,j in enumerate(idx)}
'''②处理边,把边的关系用编号表示'''
edges_unordered=np.genfromtxt('{}{}.cites'.format(path, dataset), dtype=np.int32)
edges=np.array(list(map(idx_dict.get, edges_unordered.flatten())),
dtype=np.int32).reshape(edges_unordered.shape)
'''③生成邻接矩阵'''
adj=sp.csr_matrix((np.ones(edges.shape[0]),(edges[:,0], edges[:,1])),
shape=(labels.shape[0], labels.shape[0]), dtype=np.int32)
# row,

本文详细解析了PyTorch中实现图卷积网络GCN的代码,包括数据预处理、模型构建、训练过程。在数据预处理中,涉及特征提取、onehot编码和邻接矩阵的构建。模型部分介绍了GCN类的定义,包括图卷积层的实现。在训练阶段,讲解了损失计算、反向传播和优化器的使用。此外,还探讨了nn.Module与nn.functional的区别以及优化器的作用。
最低0.47元/天 解锁文章
24万+





