在 PyTorch 中实现的图形神经网络 (GNN),用于节点分类任务。如果您不是技术奇才,请不要担心;我会让事情变得简单有趣。
什么是图神经网络?
想象一下你有一个朋友网络。每个朋友 (节点) 都有特定的特征,而友谊 (边缘) 将它们连接起来。GNN 通过查看每个朋友的特征以及他们的联系方式来帮助我们了解这个网络。把它想象成根据谁认识谁和他们的个性来弄清楚谁受欢迎。
GNN 类似于卷积神经网络 (CNN),但它们不是处理图像,而是处理图形 — 互连点的网络。它们通过在节点之间传递消息来学习,根据每个节点的邻居更新他们对每个节点的理解。
项目概况
我将项目整齐地组织成不同的模块,如下所示:
简单_gnn/
├── GNN/
│ ├── __init__.py
│ ├── model.py
│ ├── data.py
│ ├── train.py
│ └── evaluate.py
├── main.py
└── 许可证└
── README.md└── requirements.txt
它是如何工作的?
- 消息传递:每个节点都从其邻居那里收集信息。
- 更新表示:神经网络根据收集到的信息更新每个节点的理解。
- 重复:此过程将持续几轮,以在整个网络中传播信息。
数据集:Cora Citation Network
Cora 数据集是科学出版物的集合。以下是一些相关信息:
- 节点:代表论文。
- 边缘:表示论文之间的引文。
- 节点功能:论文的属性(如关键字)。
- 类:每篇论文都属于几个类别之一(例如,神经网络、规则学习)。
使用这个数据集,我们可以根据论文的连接方式和特征将论文分类到各自的类别中。
模型组件
我们的 GNN 模型有三个主要部分:
- 图卷积层 (GCNConv):此层处理消息传递。
- ReLU 激活:为我们的模型添加了一些非线性。
- 线性层:将更新的节点表示映射到我们想要的输出。
运行 GNN
在 Windows 上:
首先,克隆存储库并设置虚拟环境:
git clone https://github.com/Spartan-119/simple_gnn.git
cd simple_gnn
python -m venv gnn_venv
gnn_venv\Scripts\activate
pip install -r requirements.txt
python main.py
在 MacOS/Linux 上:
同样,克隆存储库并设置您的环境:
git clone https://github.com/Spartan-119/simple_gnn.git
cd simple_gnn
python3 -m venv gnn_venv
源 gnn_venv/bin/activate
pip install -r requirements.txt
python3 main.py
这将下载 Cora 数据集、训练 GNN 并评估其性能。
代码分解
在继续之前,请确保您已下载了上一节中的所有依赖项。让我们快速看一下项目中的每个文件的作用:pip install -r requirements.txt
GNN/model.py
这就是魔法发生的地方。在这里,我们定义了我们的 GNN 模型。将其视为我们网络的大脑,我们在其中设置层以及它们如何交互。
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv as gcn
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GNN, self).__init__()
self.conv1 = gcn(in_channels, hidden_channels)
self.conv2 = gcn(hidden_channels, out_channels)
self.linear = torch.nn.Linear(out_channels, out_channels)
def forward(self, x, edge_index): x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = self.linear(x)
return F.log_softmax(x, dim = 1)
该文件使用 PyTorch 定义了我们的图形神经网络 (GNN)。它会导入必要的库(从 PyTorch Geometry 导入、 和 )。该类继承自 ,其方法设置两个图卷积层和一个线性层。该方法指定了输入数据如何通过这些层:首先通过 ReLU 激活,然后通过另一个 ReLU,最后通过 log-softmax 的线性层以产生输出概率。这种设置允许模型通过聚合来自节点邻居的信息来学习。gnn/model.py
torch
torch.nn.functional
GCNConv
GNN
torch.nn.Module
__init__
forward
conv1conv2
GNN/data.py
此文件处理数据加载。它就像引入数据集的守门人(如 Cora 引文网络),确保我们的模型拥有需要学习的数据。
def load_data(dataset_name = 'Cora', data_dir = '/tmp/Planetoid'):
dataset = Planetoid(root = data_dir, name = dataset_name)
data = dataset[0]
train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask
返回数据、train_mask、val_mask test_mask
该文件用于加载数据集。它从 PyTorch Geometric 导入 and。该函数从给定目录获取指定的数据集(默认为 'Cora')。它检索用于训练、验证和测试的数据和掩码。此功能简化了模型的数据加载和准备。gnn/data.py
torch
Planetoid
load_data
GNN/train.py
在这里,我们定义了训练过程。这就是我们的模型变得更智能的地方。我们向它提供数据,让它做出预测,将这些预测与实际结果进行比较,并根据它的错误进行调整。
def train_model(model, data, train_mask, val_mask, optimizer, criterion, num_epochs = 200):
model.train()
for epoch in range(num_epochs):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[train_mask], data.y[train_mask])
loss.backward()
optimizer.step()
如果纪元 % 10 == 0:
模型。带有 torch.no_grad() 的 eval(
):
out = model(data.x, data.edge_index)
val_acc = torch。sum(out[val_mask].argmax(dim = 1) == data.y[val_mask]) / val_mask.sum()
print(f'Epoch: {epoch}, 验证精度: {val_acc:.4f}')
返回模型
该文件处理 GNN 的训练。它导入并定义 ,后者在指定数量的 epoch 上训练模型。每个 epoch 都会将梯度归零,执行正向传递以计算输出和损失,反向传播以更新权重,并定期评估验证准确性。gnn/train.py
torch
train_model
GNN/evaluate.py
一旦我们的模型经过训练,我们就需要看看它的性能如何。此文件评估模型,通过将其预测与实际数据进行比较来告诉我们模型的准确性。
def evaluate_model(model, data, test_mask):
模型。带有 torch.no_grad() 的 eval(
):
out = model(data.x, data.edge_index)
test_acc = torch。sum(out[test_mask].argmax(dim=1) == data.y[test_mask]) / test_mask.sum()
print(f'测试精度: {test_acc:.4f}')
该文件评估经过训练的 GNN。它导入并定义 ,将模型设置为评估模式,在不更新权重的情况下计算预测,并通过将预测与实际标签进行比较来计算测试准确性。gnn/evaluate.py
torch
evaluate_model
main.py
这是将所有内容联系在一起的主要剧本。当您运行此文件时,它会启动整个过程 — 加载数据、训练模型并对其进行评估。这就像我们小节目的导演。
def main():
# 加载数据
数据, train_mask, val_mask, test_mask = load_data()
# 获取类
的数量 num_classes = data.y.max().item() + 1 # 假设类从 0
开始索引 # 定义模型、优化器和损失函数
model = GNN(data.num_node_features, 16, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
# 训练模型
trained_model = train_model(model, data, train_mask, val_mask, optimizer, criterion)
# 评估模型
evaluate_model(trained_model, data, test_mask)
if __name__ == “__main__”:
main()
“main.py”文件编排了加载数据、训练和评估图神经网络 (GNN) 的整个过程。它首先通过从 'gnn' 目录导入必要的模块来设置环境。该脚本首先从“data.py”调用 “load_data” 函数来加载 Cora 数据集及其训练、验证和测试掩码。
接下来,它使用适当的 input、hidden 和 output 维度初始化 'model.py' 中定义的 GNN 模型。还设置了一个优化器和一个损失函数来训练模型。然后调用 'train.py' 中的 'train_model' 函数,该函数迭代数据,更新模型的权重并定期打印验证准确性。训练后,调用 evaluate.py 中的 'evaluate_model' 函数来评估模型在测试数据上的性能,从而输出测试准确性。该文件以一种内聚的方式将数据加载、模型设置、训练和评估联系在一起,为运行 GNN 提供了一个简单的管道。
本文代码可看下图获取
定制
想要调整 GNN?您可以通过编辑几个文件来调整模型、数据集和超参数:
- model.py:更改体系结构。
- data.py:加载不同的数据集或预处理数据。
- main.py:修改超参数,如学习率或隐藏维度。
GNN 的应用
GNN 不仅适用于学术论文。它们具有广泛的应用范围:
- 社交网络:分析用户连接以推荐朋友或检测社区。
- 生物网络:了解蛋白质相互作用或基因调控网络。
- 推荐系统:根据用户-商品交互图推荐产品。
- 交通网络:通过分析道路网络来预测交通流量。
- 知识图谱:通过连接相关概念来增强搜索引擎。
灵感和未来用途
这个项目的灵感来自于更有效地处理图形结构数据的需求。与传统神经网络不同,GNN 可以:
- 捕获关系:不仅要了解数据点,还要了解它们是如何连接的。
- 概括:适用于各种类型的图形,无论是社交网络、分子还是运输系统。
GNN 与 CNN 有何不同
虽然 CNN 非常适合网格状数据(如图像),但它们不适用于连接不规则且变化的图表。GNN 在这里介入,能够处理图形的复杂性:
- 灵活性:可以使用任何图形结构,而不仅仅是网格。
- 消息传递:节点通过聚合来自邻居的信息来更新其状态,这与 CNN 不同,CNN 在固定位置上使用固定过滤器。
- 动态:可以适应不同的图形大小和形状,而 CNN 具有固定的输入大小。
未来方向
GNN 领域正在迅速发展。以下是一些令人兴奋的未来方向:
- 可扩展性:使 GNN 有效地处理更大的图形。
- 动态图形:处理随时间变化的图表。
- 可解释性:了解 GNN 如何做出决策,对于医学等领域至关重要。
结论
这就是一个简单而强大的 GNN,用于使用 PyTorch 进行节点分类。通过本指南,您可以开始试验 GNN 并探索图形结构数据的迷人世界。