图深度学习框架PyG(Pytorch-Geometric)代码实战

本文介绍了如何使用PyG进行图神经网络的实践,包括GCN和GraphSAGE的代码框架,以及空手道俱乐部数据集和Cora数据集上的应用。同时,文章讲解了如何构建PyG数据格式,特别是对于多图情景的处理,并提供了雅虎电商数据集的加载示例。最后,展示了如何在PyG中实现GraphSAGE模型并进行训练。

Github地址:PyG-pytoch_geometric
官方文档地址:PyG开发文档

手动防爬虫,作者优快云:总是重复名字我很烦啊,联系邮箱daledeng123@163.com

PyG安装

1、PyG is available for Python 3.7 to Python 3.11.(安装3.7以上的python版本,推荐2017年10月3.7版本,兼容性最好)
2、You can now install PyG via Anaconda for all major OS/PyTorch/CUDA combinations 🤗 If you have not yet installed PyTorch, install it via conda as described in the official PyTorch documentation. (安装pytorch,相关教程可以直接在torch官网看,也可以去其他博主看,有cuda版本和cpu版本,我的版本如下)
在这里插入图片描述
3、From PyG 2.3 onwards, you can install and use PyG without any external library required except for PyTorch. (除torch外,不需要额外安装其他依赖,直接pip安装)

!pip install torch_geometric -i https://pypi.tuna.tsinghua.edu.cn/simple

在本文代码运行过程中,如果出现其他库没有的,请自行安装。


注意:
2.3版本以上的PyG由于取消了依赖,因此有些经典代码的运行会出问题。如果安装低于2.3版本的PyG则需要手动安装依赖,具体操作可以参考知乎文章:PyG手动安装依赖步骤

图神经网络的通用代码框架撰写指南

GCN代码框架

PyG依然使用的Torch框架,因此需要对torch有一定的使用经验。在理论部分首先介绍的是GCN(图卷积模型),这个模型本质上就是嵌入向量和邻接混淆矩阵的不断左乘,因此在框架中的写法非常简单:

在框架搭建部分:
self.conv1 = GCNConv(输入节点嵌入维度, hidden_channel1)
self.conv2 = GCNConv(hidden_channel1, hidden_channel2)
...
self.convn = GCNConv(hidden_channen, 最终分类数量)
在前向传播部分:
x = self.conv1(x, 节点与节点之间的连接关系)

这里节点与节点的连接关系输入格式如下:

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
         33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
        [ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21, 31,  0,  2,
         18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])

GraphSAGE代码框架

class Net(torch.nn.Module): #针对图进行分类任务
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = SAGEConv(embed_dim, 128)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = SAGEConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = SAGEConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
        ...
        self.lin = torch.nn.Linear(128, output_dim)
        
    def forward(self, data):
    	x, edge_index, batch = data.x, data.edge_index, data.batch
    	...

空手道俱乐部(karate club dataset)GCN代码实战

准备工作

%matplotlib inline
import matplotlib.pyplot as plt
import networkx as nx
import torch

from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx

# 两个画图函数
def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels = False, node_color=color, cmap='Set2')
    plt.show()
    
def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0], h[:,1], s=140, c=color, cmap='Set2')
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch:  {
     
     epoch}, Loss:  {
     
     loss.item():.4f}', fontsize=16)
    plt.show()

读取数据

dataset = KarateClub()

print(dataset)
print(len(dataset))
print(dataset.num_features)
print(dataset.num_classes)
>>KarateClub()
1
34
4
data = dataset[0]
data

>>Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

这个表示34个样本量,每个样本34个特征,一共有156对关系,标签是34,预测的train_mask是34。

edge_index = data.edge_index
edge_index
>>tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

总是重复名字我很烦啊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值