PyG是一个基于PyTorch用与处理部规则数据(比如图)的库,是一个用于在图等数据上快速实现表征学习的框架,是当前最流行和广泛使用的GNN(Graph Neural Networks, GNN 图神经网络)库。
Graph Neural Networks,GNN,称为图神经网络,是深度学习中近年来比较受关注的领域,GNN通过对信息的传递、转换和聚合实现特征的提取,类似与传统的CNN,只是CNN只能处理规则的输入,如图像等输入的高、宽和通道数都是固定的,而GNN可以处理部规则的输入,如点云等。
安装
pip install torch-geometric
pip install torch-sparse
pip install torch-scatter
pip install pytorch-fid
torch_geometric.data.Data
节点和节点之间的边构成了图,在PyG中,构建图需要两个要素:节点和边。PyG提供了torch_geometric.data.Data(简称Data)用于构建图,包括5个属性,每一个属性都部是必须的,可以为空。
- x:用于存储每个节点的特征,形状是[num_nodes, num_node_features].
- edge_index:用于存储节点之间的边,形状是[2, num_edges]。
- pos:存储节点的坐标,形状是[num_nodes, num_dimensions]。
- y:存储样本标签。如果是每个节点都有标签,那么形状是[num_nodes, *];如果是整张图只有一个标签,那么形状是[1, *]。
- edge_attr:存储边的特征。形状是[num_edges, num_edge_features]。
Data对象不仅仅限制于这些属性,还可以通过data.face来扩展Data,以张量保存三维网格中三角形的连接性。
和P有Torch稍有不同,Data里包含了样本的label,在PyTorch中,重写Dataset的__getitem__(),根据index返回对应的样本和label。在PyG中,在get()函数中根据index返回torch_geometric.data.Data类型的数据,在Data里包含了数据和label。
例如:未加权无向图(未加权指边上没有权值),包括3个节点和4条边:(0->1),(1->0),(1->2),(2->1),每个节点都有一维特征。

import torch
from torch_geometric.data import Data
#由于是无向图,有四条边:(0->1),(1->0),(1->2),(2->1)
#方式一:常用方式,edge_index中边的存储方式有两个list,第一个list是边的起始点,第二个list是边的目标节点。
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
#节点的特征
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
# 方式二:需要先转置然后使用contiguous()方法。
edge_index = torch.tensor([[0, 1],[1, 0], [1, 2], [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())
PyTorch中的contiguous
contiguous是形容词,表示连续的,PyTorch提供了is_contiguous、contiguous(形容词动用)两个方法,分别用于判断Tensor是否是contiguous的,以及保证Tensor是contiguous的。
is_contiguous直观的解释是Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致。
Tensor多维数组底层实现是使用一块连续内存的1维数组(行优先顺序存储),Tensor在元信息里保存了多为数组的形状,在访问元素时,通过多维度索引转化成1维数组相对于数组起始位置的偏移量即可找到对应的数据。某些Tensor操作(如transpose、permute、narrow、expand)与原Tensor是共享内存中的数据,不会改变底层数组的存储,但原来在语义上相邻、内存里也相邻的元素在执行这样的操作后,在语义上相邻,但在内存不相邻,即不连续了。
如果像要变得连续,使用contiguous方法,如果Tensor不是连续的,则会重新开辟一块内存空间保证数据是在内存中是连续的,如果Tensor是连续的,则contiguous无操作。
行优先
C/C++中使用的是行优先(raw major),Matlab、Fortran使用的是列优先(column major),PyTorch中Tensor底层实现是C,也是使用行优先顺序。
t = torch.arange(12).reshape(3, 4)
数组t在内存中实际以一维数组形式存储,通过flatten方法查看t的一维展开形式,实际存储形式与一维展开一致。
t.flatten()
列优先的存储逻辑结构
使用列优先存储时,一维数组中元素顺序:
图1、图2、图3、图4中颜色相同的数据表示在同一行,不论是行优先顺序、或是列优先顺序,如果要访问矩阵中的下一个元素都是通过偏移来实现,这个偏移量称为步长(stride)。在行优先的存储方式下,访问行中相邻元素物理结构需要偏移1个位置,在列优先存储方式下偏移3个位置。
例如:有向图有4个节点,每个节点有两个特征,有自己的类别标签。

import torch
from torch_geometric.data import Data
x = torch.tensor([[2, 1], [5, 6], [3, 7], [12, 0]], dtype=troch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
#与节点对应顺序无关,顺序怎么写都性
edge_index = torch.tensor([[0, 1, 2, 0, 3], [1, 0, 1, 3, 2]], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
Dataset与DataLoader
有了Data就可以创建自己的Dataset,读取并返回Data了。
自定义Dataset
尽管PyG包含了许多有用的数据集,也可以通过继承torch_geometric.data.Dataset使用自己的数据集。提供2种不同的Dataset
:
- InMemoryDataset:使用这个Dataset会一次性把数据全部加载到内存中。
- Dataset:使用这个Dataset每次加载一个数据到内存中,比较常用。
需要在自定义的Dataset的初始化方法中传入数据存放的路径root,然后PyG会在这个路径下再划分2个文件夹:
- raw_dir:存放原始数据的路径,一般是csv、mat等格式。
- processed_dir:存放处理后的数据,一般是pt格式,由重写process()方法实现。
除了root,类初始化的init函数还接收三个函数参数transform, pre_transform 和pre_filter,这些参数的默认值都是None。transform函数用于动态的转换数据对象。pre_transform函数在数据保存到硬盘之前进行一次转换。pre_filter用于过滤某些数据对象。
保存在内存中的数据集
为了创建InMemoryDataset,需要实现下面四个方法:
- raw_file_names():该函数返回文件名需要在raw_dir文件夹下找到才可以跳过下载过程。
- processed_file_names():该函数返回的文件名需要在processed_dir中找到才可以跳过处理过程。
- download():下载文件到raw_dir。
- process():处理原始数据并保存在processed_dir。
在process():函数中,需要读入并创建一个Data对象列表之后将所有Data类型的对象保存在processed_dir文件夹中。由于无法将全部数据保存到内存中,需要在数据固化之前通过collate()函数保存Data对象的索引,此外,该函数还会返回一个slices字典用于从本地重建单个样例对象。于是在数据集对象new的时候,需要从本地读取self.data和self.slices对象。
创建更大规模的数据集
有一些数据的规模太大,无法一次性加载到内存中,需要自己实现torch_geometric.data.Dataset,只需要额外实现两个方法:
- len():返回数据集的长度
- get():自定义加载Graph的方法
在PyTorch中,是没有raw和processed这两个文件夹的,这两个文件夹在PyG中的实际意义和处理逻辑。
torch_geometric.data.Dataset继承自torch.utils.data.Dataset,在初始化方法__init__()中,会调用_download()方法和_process()方法。
_download()方法如下,首先检查self.raw_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.download()方法下载文件。
_process()方法如下,首先在self.processed_dir中有pre_transform,那么判断这个pre_transform和传进来的pre_transform是否一致,如果不一致,那么警告提示用户先删除self.processed_dir文件夹。pre_filter同理。
然后检查self.processed_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.process()生成文件。
一般来说不用实现downloand()方法。
如果你直接把处理好的 pt 文件放在了self.processed_dir中,那么也不用实现process()方法。
在 Pytorch 的dataset中,需要实现__getitem__()方法,根据index返回样本和标签。在这里torch_geometric.data.Dataset中,重写了__getitem__()方法,其中调用了get()方法获取数据。
需要实现的是get()方法,根据index返回torch_geometric.data.Data类型的数据。
process()方法存在的意义是原始的格式可能是 csv 或者 mat,在process()函数里可以转化为 pt 格式的文件,这样在get()方法中就可以直接使用torch.load()函数读取 pt 格式的文件,返回的是torch_geometric.data.Data类型的数据,而不用在get()方法做数据转换操作 (把其他格式的数据转换为 torch_geometric.data.Data类型的数据)。当然也可以提前把数据转换为 torch_geometric.data.Data类型,使用 pt 格式保存在self.processed_dir中。
#torch_geometric/data/dataset.py
from typing import List, Optional, Callable, Union, Any, Tuple
import sys
import re
import copy
import warnings
import numpy as np
import os.path as osp
from collections.abc import Sequence
import torch.utils.data
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.data.makedirs import makedirs
IndexType = Union[slice, Tensor, np.ndarray, Sequence]
class Dataset(torch.utils.data.Dataset):
r"""Dataset base class for creating graph datasets.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
create_dataset.html>`__ for the accompanying tutorial.
Args:
root (string, optional): Root directory where the dataset should be
saved. (optional: :obj:`None`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
@property
def raw_file_names(self) -> Union[str, List[str], Tuple]:
r"""The name of the files in the :obj:`self.raw_dir` folder that must
be present in order to skip downloading."""
raise NotImplementedError
@property
def processed_file_names(self) -> Union[str, List[str], Tuple]:
r"""The name of the files in the :obj:`self.processed_dir` folder that
must be present in order to skip processing."""
raise NotImplementedError
def download(self):
r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
raise NotImplementedError
def process(self):
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError
def len(self) -> int:
r"""Returns the number of graphs stored in the dataset."""
raise NotImplementedError
def get(self, idx: int) -> Data:
r"""Gets the data object at index :obj:`idx`."""
raise NotImplementedError
def __init__(self, root: Optional[str] = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None):
super().__init__()
if isinstance(root, str):
root = osp.expanduser(osp.normpath(root))
self.root = root
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self._indices: Optional[Sequence] = None
if 'download' in self.__class__.__dict__:
self._download()
if 'process' in self.__class__.__dict__:
self._process()
def indices(self) -> Sequence:
return range(self.len()) if self._indices is None else self._indices
@property
def raw_dir(self) -> str:
return osp.join(self.root, 'raw')
@property
def processed_dir(self) -> str:
return osp.join(self.root, 'processed')
@property
def num_node_features(self) -> int:
r"""Returns the number of features per node in the dataset."""
data = self[0]
data = data[0] if isinstance(data, tuple) else data
if hasattr(data, 'num_node_features'):
return data.num_node_features
raise AttributeError(f"'{
data.__class__.__name__}' object has no "
f"attribute 'num_node_features'")
@property
def num_features(self) -> int:
r"""Returns the number of features per node in the dataset.
Alias for :py:attr:`~num_node_features`."""
return self.num_node_features
@property
def num_edge_features(self) -> int:
r"""Returns the number of features per edge in the dataset."""
data = self[0]
data = data[0] if isinstance(data, tuple) else data
if hasattr(data, 'num_edge_features'):
return data.num_edge_features
raise AttributeError(f"'{
data.__class__.__name__}' object has no "
f"attribute 'num_edge_features'")
@property
def raw_paths(self) -> List[str]:
r"""The absolute filepaths that must be present in order to skip
downloading."""
files = to_list(self.raw_file_names)
return [osp.join(self.raw_dir, f) for f in files]
@property
def processed_paths(self) -> List[str]:
r"""The absolute filepaths that must be present in order to skip
processing."""
files = to_list(self.processed_file_names)
return [osp.join(self.processed_dir, f) for f in files]
def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return
makedirs(self.raw_dir)
self.download()
def _process(se

PyTorch Geometric(PyG)是一个强大的库,用于处理图数据和图神经网络(GNN)。本文详细介绍了如何使用PyG构建数据集、Data对象、DataLoader,以及GNN模型,如TransformerEncoder,探讨了注意力机制、Seq2Seq模型和自注意力模型的工作原理。此外,还展示了在图生成任务中应用LayoutGAN++的模型结构。




最低0.47元/天 解锁文章
2266

被折叠的 条评论
为什么被折叠?



