【PyG】DATA(torch_geometric.data.Data )之学习

文章介绍了torch_geometric库中的Data类,它是构建图数据结构的基础,包含节点特征x、边索引edge_index、节点位置pos、边属性edge_attr和样本标签y等属性。Data类不仅限于这些属性,还可以扩展以适应不同图结构。在PyG中,Data对象包含了样本的label,与PyTorch中的处理方式略有不同。文章通过示例展示了如何使用Data类创建无向图,并提到了edge_index的不同存储方式。
部署运行你感兴趣的模型镜像

1 torch_geometric.data.Data属性

torch_geometric.data.data — pytorch_geometric documentation

class Data (self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, **kwargs):

x: 用于存储每个节点的特征,形状是[num_nodes, num_node_features]

edge_index: 用于存储节点之间的边,形状是 [2, num_edges]  使用稀疏的方式存储边关系(edge_index中边的存储方式,有两个list,第 1 个list是边的起始点,第 2 个list是边的目标节点));

pos: 存储节点的坐标,形状是[num_nodes, num_dimensions]

y: 存储样本标签。如果是每个节点都有标签,那么形状是[num_nodes, *];如果是整张图只有一个标签,那么形状是[1, *]

edge_attr: 存储边的特征。形状是[num_edges, num_edge_features]

 使用pyG的Data来建图:

from torch_geometric.data import Data
 
data = Data(x=x, edge_index=edge_index)

 在实际的应用场景中,图的形式多种多样,单纯的使用 x 和 edge index 是无法描述这众多的图结构的;

在 PyG 的Data 类当中,还提供了许多其他属性用于描述图的变量。

实际上,Data对象不仅仅限制于这些属性,我们可以通过data.face来扩展Data,以张量保存三维网格中三角形的连接性;

 在Data里包含了样本的 label,这意味和 PyTorch 稍有不同,在PyTorch中,我们重写Dataset的__getitem__(),根据 index 返回对应的样本和 label;

在 PyG 中,我们使用的不是这种写法,而是在get()函数中根据 index 返回torch_geometric.data.Data类型的数据,在Data里包含了数据和 label;

使用`torch_geometric.data.Data`

import torch
from torch_geometric.data import Data
# 由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
# 节点的特征                           
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
 
data = Data(x=x, edge_index=edge_index)

注意edge_index中边的存储方式,有两个list,第 1 个list是边的起始点,第 2 个list是边的目标节点

另一种存储edge_index的方式:

import torch
from torch_geometric.data import Data
 
edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
 
data = Data(x=x, edge_index=edge_index.t().contiguous())

这种情况edge_index需要先转置然后使用contiguous()方法。

有了Data,我们可以创建自己的Dataset,读取并返回Data。

参考:图神经网络 PyTorch Geometric 入门教程 - 掘金
 

有了data对象就可以快速开始了,PyG官方提供了许多图神经网络算法的接口

 

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值