PyTorch 深度学习实战(10):图神经网络(GNN)与节点分类

在之前的文章中,我们探讨了时间序列预测、图像分类、文本生成等多种深度学习任务。本文将介绍图神经网络(Graph Neural Networks, GNN),并基于 PyTorch 实现一个简单的 GNN 模型,用于节点分类任务。我们将使用经典的 Cora 数据集,帮助读者掌握处理图数据的基本方法。


一、图神经网络基础

图神经网络(GNN)是专门用于处理图结构数据的深度学习模型。与传统神经网络不同,GNN 能够捕捉节点之间的关系信息,广泛应用于社交网络分析、推荐系统、分子结构预测等领域。

1. 图的表示
  • 节点(Node):图中的实体(如用户、论文)。

  • 边(Edge):节点之间的关系(如用户关注、论文引用)。

  • 特征(Feature):每个节点的属性(如用户的年龄、论文的关键词)。

2. 常见的 GNN 层
  • 图卷积网络(GCN):通过聚合邻居节点的特征更新当前节点表示。

  • 图注意力网络(GAT):引入注意力机制,区分不同邻居的重要性。

  • 图同构网络(GIN):通过多层感知机(MLP)增强表达能力。

3. 节点分类任务

节点分类的目标是为图中的每个节点分配一个类别标签。例如,在论文引用网络中,预测论文的研究领域。


二、节点分类实战

我们将使用 Cora 数据集,包含 2708 篇机器学习论文及其引用关系,每篇论文被分为 7 个类别。目标是根据论文内容和引用关系预测类别。

1. 实现步骤
  1. 安装 PyTorch Geometric 库。

  2. 加载并预处理 Cora 数据集。

  3. 构建 GCN 模型。

  4. 训练模型并评估性能。

2. 代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import NormalizeFeatures
import matplotlib.pyplot as plt
​
# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
​
# 1. 加载数据集
dataset = Planetoid(root='./data/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
​
# 2. 定义 GCN 模型
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.5)
​
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
​
# 3. 初始化模型和优化器
model = GCN(input_dim=dataset.num_features, hidden_dim=16, output_dim=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.NLLLoss()
​
# 4. 训练函数
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()
​
# 5. 测试函数
def evaluate(mask):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        correct = pred[mask] == data.y[mask]
        acc = correct.sum().item() / mask.sum().item()
    return acc
​
# 6. 训练模型
train_losses = []
val_accuracies = []
test_accuracies = []
​
for epoch in range(200):
    loss = train()
    train_losses.append(loss)
    val_acc = evaluate(data.val_mask)
    test_acc = evaluate(data.test_mask)
    val_accuracies.append(val_acc)
    test_accuracies.append(test_acc)
    if (epoch + 1) % 20 == 0:
        print(f'Epoch [{epoch+1}/200], Loss: {loss:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
​
# 7. 可视化结果
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.xlabel('训练轮数')
plt.ylabel('损失值')
plt.legend()
​
plt.subplot(1, 2, 2)
plt.plot(val_accuracies, label='Validation Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
plt.xlabel('训练轮数')
plt.ylabel('准确值')
plt.legend()
plt.show()

三、代码解析

  1. 数据加载

    • 使用 Planetoid 加载 Cora 数据集,包含节点特征 data.x、边索引 data.edge_index 和掩码 data.train_mask/data.val_mask/data.test_mask

    • NormalizeFeatures 对节点特征进行归一化。

  2. 模型结构

    • GCN 模型包含两个 GCNConv 层,中间使用 ReLU 激活和 Dropout 防止过拟合。

    • 输出层通过 log_softmax 计算对数概率。

  3. 训练过程

    • 使用负对数似然损失(NLLLoss)和 Adam 优化器。

    • 训练 200 个 epoch,每 20 个 epoch 打印损失和准确率。

  4. 结果可视化

    • 左图显示训练损失下降曲线。

    • 右图对比验证集和测试集的准确率。 


四、运行结果

运行代码后,你将看到:

  1. 训练损失从约 1.8 逐渐下降至 0.1 以下。

  2. 验证集和测试集的准确率最终达到约 80%。


五、改进建议

  1. 调整模型结构:增加隐藏层维度或添加更多 GCN 层。

  2. 使用其他 GNN 模型:如 GAT 或 GraphSAGE。

  3. 数据增强:通过随机删除边或节点提升鲁棒性。

  4. 超参数优化:调整学习率、Dropout 率或权重衰减系数。


六、总结

本文介绍了图神经网络的基本概念,并使用 PyTorch Geometric 实现了一个简单的 GCN 模型,完成了 Cora 数据集的节点分类任务。通过这个例子,我们学习了如何处理图数据、构建 GNN 模型以及进行训练和评估。

至此,《PyTorch 深度学习实战》系列文章已涵盖图像、文本、序列、图结构数据等多种场景,希望读者能通过这些实战案例掌握深度学习的核心技能。未来探索中,可继续关注模型优化、多模态学习等前沿方向!


代码实例说明

  1. 运行前需安装 PyTorch Geometric:

    pip install torch_geometric

  2. 完整代码依赖:torch, matplotlib, torch_geometric

  3. GPU 加速:将模型和数据移至 GPU(model.to('cuda'), data = data.to('cuda'))。

希望本系列文章为你提供了扎实的实践基础!如有疑问,欢迎在评论区交流。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值