torch_geometric 使用 batch

基于GCNConv, SAGEConv

问题描述

基于torch.geometric的data具有node和edge的特征,其中node的维度表示为[num_nodes, node features],edge的维度表示为[2, edge_features]

torch.geometric的batch不支持CNN中广泛使用的batch维度。CNN中使用batch的数据具有[batch_size, channel, h, w]的维度。如果使用[batch_size, num_nodes, node_features][batch_size, 2, edge_features],代码会有如下报错:

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size XX but got size XXX for tensor number 1 in the list.

解决方案:

解决方案

完整问题

更多细节

在对角线上重复edge_index特征

batch_edge_index = Batch.from_data_list([Data(edge_index=edge_index)]*batch_size)

使用node_dim

conv = GCNConv(in_channels, out_channels, node_dim=1)
conv(x,edge_index)
# 在这种情况下,x(node)的维度是[batch_size, num_nodes, num_features]
# edge_index的维度是(2, edge_features)
# edge_index holds indices < num_nodes

edge_index always need to be two-dimensional, so the mini-batching of [batch_size, num_nodes, num_features] node features is designed to operate on the same underlying graph of shape [2, num_edges]. If your graph is not static across examples, you will have to use PyG’s approach of diagonal adjacency matrix stacking.

⚠️值得注意的是 edge_index中起始点和终点的index都必须小于nodes的个数(因为默认从0开始)

基于 GATConv

问题描述

与GCNConv, SAGEConv不同的是,当使用

x(node)的维度是[batch_size, num_nodes, num_features]
edge_index的维度是(2, edge_features)

会遇到如下报错:Static graphs not supported in GATConv

解决方案

解决方案,可以使用Batch.from_data_list(data_list)将数据转换成Batch格式,或者手动将node concate成[batchsize*num_nodes, num_features],edge_index concate成[2, batch_size*edge_features]
具体维度解析

edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

edge_index_3 = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x_3 = torch.randn(4, 16)

data1= Data(x=x_s,edge_index=edge_index_s)
data2= Data(x=x_t,edge_index=edge_index_t)
data3= Data(x=x_3,edge_index=edge_index_3)
#上面是构建3张Data图对象
# * `Batch(Data)` in case `Data` objects are batched together
#* `Batch(HeteroData)` in case `HeteroData` objects are batched together

data_list = [data1, data2,data3]


loader = Batch.from_data_list(data_list)#调用该函数data_list里的data1、data2、data3 三张图形成一张大图,也就是batch
————————————————
版权声明:本文为优快云博主「知行合一,至于至善」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_41800917/article/details/120444534
### 关于 PyTorch Geometric 和 Kaggle 的教程与资源 PyTorch Geometric 是一个用于处理图结构数据的扩展库,它基于 PyTorch 构建并支持高效的图形神经网络 (GNN) 训练[^1]。以下是关于 PyTorch Geometric 以及其在 Kaggle 竞赛中的应用的相关教程、示例代码和资源: #### 教程与文档 官方文档提供了详细的入门指南和技术细节,适合初学者快速上手: - 官方文档地址:https://pytorch-geometric.readthedocs.io/en/latest/ #### 示例代码 以下是一些常见的 GitHub 存储库和 Jupyter Notebook 示例,展示了如何使用 PyTorch Geometric 进行模型训练和数据分析: ```python import torch from torch_geometric.datasets import TUDataset from torch_geometric.data import DataLoader from torch.nn import Linear import torch.nn.functional as F from torch_geometric.nn import GCNConv, global_mean_pool # 加载数据集 dataset = TUDataset(root='data/TU', name='MUTAG') class GNN(torch.nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(GNN, self).__init__() self.conv1 = GCNConv(input_dim, hidden_dim) self.conv2 = GCNConv(hidden_dim, hidden_dim) self.lin = Linear(hidden_dim, output_dim) def forward(self, x, edge_index, batch): x = self.conv1(x, edge_index) x = x.relu() x = self.conv2(x, edge_index) x = global_mean_pool(x, batch) x = F.dropout(x, p=0.5, training=self.training) x = self.lin(x) return F.log_softmax(x, dim=-1) model = GNN(dataset.num_node_features, 64, dataset.num_classes) print(model) ``` 上述代码片段展示了一个简单的图卷积网络 (GCN),适用于节点分类任务。 #### Kaggle 竞赛实例 一些 Kaggle 竞赛涉及图学习领域,例如 Open Graph Benchmark 或者 Zindi 平台上的比赛。这些竞赛通常提供预处理好的图数据集,并鼓励参赛者利用工具如 PyTorch Geometric 提交解决方案。具体可以参考以下链接获取更多实战经验: - https://www.kaggle.com/c/open-graph-benchmark-lsc #### 数据集推荐 除了标准的数据集外,还可以尝试以下公开可用的图数据集来练习 PyTorch Geometric 技能: - Cora 引文网络数据集 - Reddit 图形评论数据集 - OGB(Open Graph Benchmark) 以上提到的内容均来源于已知的最佳实践和社区贡献。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值