PyG:PyTorch Geometric Library

PyTorch Geometric 深度学习图神经网络详解
PyTorch Geometric(PyG)是一个强大的库,用于处理图数据和图神经网络(GNN)。本文详细介绍了如何使用PyG构建数据集、Data对象、DataLoader,以及GNN模型,如TransformerEncoder,探讨了注意力机制、Seq2Seq模型和自注意力模型的工作原理。此外,还展示了在图生成任务中应用LayoutGAN++的模型结构。

PyG是一个基于PyTorch用与处理部规则数据(比如图)的库,是一个用于在图等数据上快速实现表征学习的框架,是当前最流行和广泛使用的GNN(Graph Neural Networks, GNN 图神经网络)库。

Graph Neural Networks,GNN,称为图神经网络,是深度学习中近年来比较受关注的领域,GNN通过对信息的传递、转换和聚合实现特征的提取,类似与传统的CNN,只是CNN只能处理规则的输入,如图像等输入的高、宽和通道数都是固定的,而GNN可以处理部规则的输入,如点云等。

安装

pip install torch-geometric
pip install torch-sparse
pip install torch-scatter
pip install pytorch-fid

torch_geometric.data.Data

节点和节点之间的边构成了图,在PyG中,构建图需要两个要素:节点和边。PyG提供了torch_geometric.data.Data(简称Data)用于构建图,包括5个属性,每一个属性都部是必须的,可以为空。

  • x:用于存储每个节点的特征,形状是[num_nodes, num_node_features].
  • edge_index:用于存储节点之间的边,形状是[2, num_edges]。
  • pos:存储节点的坐标,形状是[num_nodes, num_dimensions]。
  • y:存储样本标签。如果是每个节点都有标签,那么形状是[num_nodes, *];如果是整张图只有一个标签,那么形状是[1, *]。
  • edge_attr:存储边的特征。形状是[num_edges, num_edge_features]。

Data对象不仅仅限制于这些属性,还可以通过data.face来扩展Data,以张量保存三维网格中三角形的连接性。

和P有Torch稍有不同,Data里包含了样本的label,在PyTorch中,重写Dataset的__getitem__(),根据index返回对应的样本和label。在PyG中,在get()函数中根据index返回torch_geometric.data.Data类型的数据,在Data里包含了数据和label。

例如:未加权无向图(未加权指边上没有权值),包括3个节点和4条边:(0->1),(1->0),(1->2),(2->1),每个节点都有一维特征。
在这里插入图片描述

import torch 
from torch_geometric.data import Data

#由于是无向图,有四条边:(0->1),(1->0),(1->2),(2->1)
#方式一:常用方式,edge_index中边的存储方式有两个list,第一个list是边的起始点,第二个list是边的目标节点。
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)

#节点的特征
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

# 方式二:需要先转置然后使用contiguous()方法。
edge_index = torch.tensor([[0, 1],[1, 0], [1, 2], [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())

PyTorch中的contiguous
contiguous是形容词,表示连续的,PyTorch提供了is_contiguous、contiguous(形容词动用)两个方法,分别用于判断Tensor是否是contiguous的,以及保证Tensor是contiguous的。
is_contiguous直观的解释是Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致。
Tensor多维数组底层实现是使用一块连续内存的1维数组(行优先顺序存储),Tensor在元信息里保存了多为数组的形状,在访问元素时,通过多维度索引转化成1维数组相对于数组起始位置的偏移量即可找到对应的数据。某些Tensor操作(如transpose、permute、narrow、expand)与原Tensor是共享内存中的数据,不会改变底层数组的存储,但原来在语义上相邻、内存里也相邻的元素在执行这样的操作后,在语义上相邻,但在内存不相邻,即不连续了。
如果像要变得连续,使用contiguous方法,如果Tensor不是连续的,则会重新开辟一块内存空间保证数据是在内存中是连续的,如果Tensor是连续的,则contiguous无操作。
行优先
C/C++中使用的是行优先(raw major),Matlab、Fortran使用的是列优先(column major),PyTorch中Tensor底层实现是C,也是使用行优先顺序。
t = torch.arange(12).reshape(3, 4)
在这里插入图片描述
数组t在内存中实际以一维数组形式存储,通过flatten方法查看t的一维展开形式,实际存储形式与一维展开一致。
t.flatten()
在这里插入图片描述
列优先的存储逻辑结构
在这里插入图片描述
使用列优先存储时,一维数组中元素顺序:
在这里插入图片描述
图1、图2、图3、图4中颜色相同的数据表示在同一行,不论是行优先顺序、或是列优先顺序,如果要访问矩阵中的下一个元素都是通过偏移来实现,这个偏移量称为步长(stride)。在行优先的存储方式下,访问行中相邻元素物理结构需要偏移1个位置,在列优先存储方式下偏移3个位置。

例如:有向图有4个节点,每个节点有两个特征,有自己的类别标签。
在这里插入图片描述

import torch
from torch_geometric.data import Data

x = torch.tensor([[2, 1], [5, 6], [3, 7], [12, 0]], dtype=troch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
#与节点对应顺序无关,顺序怎么写都性
edge_index = torch.tensor([[0, 1, 2, 0, 3], [1, 0, 1, 3, 2]], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)

Dataset与DataLoader

有了Data就可以创建自己的Dataset,读取并返回Data了。

自定义Dataset

尽管PyG包含了许多有用的数据集,也可以通过继承torch_geometric.data.Dataset使用自己的数据集。提供2种不同的Dataset
:

  • InMemoryDataset:使用这个Dataset会一次性把数据全部加载到内存中。
  • Dataset:使用这个Dataset每次加载一个数据到内存中,比较常用。

需要在自定义的Dataset的初始化方法中传入数据存放的路径root,然后PyG会在这个路径下再划分2个文件夹:

  • raw_dir:存放原始数据的路径,一般是csv、mat等格式。
  • processed_dir:存放处理后的数据,一般是pt格式,由重写process()方法实现。

除了root,类初始化的init函数还接收三个函数参数transform, pre_transform 和pre_filter,这些参数的默认值都是None。transform函数用于动态的转换数据对象。pre_transform函数在数据保存到硬盘之前进行一次转换。pre_filter用于过滤某些数据对象。

保存在内存中的数据集

为了创建InMemoryDataset,需要实现下面四个方法:

  • raw_file_names():该函数返回文件名需要在raw_dir文件夹下找到才可以跳过下载过程。
  • processed_file_names():该函数返回的文件名需要在processed_dir中找到才可以跳过处理过程。
  • download():下载文件到raw_dir。
  • process():处理原始数据并保存在processed_dir。

在process():函数中,需要读入并创建一个Data对象列表之后将所有Data类型的对象保存在processed_dir文件夹中。由于无法将全部数据保存到内存中,需要在数据固化之前通过collate()函数保存Data对象的索引,此外,该函数还会返回一个slices字典用于从本地重建单个样例对象。于是在数据集对象new的时候,需要从本地读取self.data和self.slices对象。

创建更大规模的数据集

有一些数据的规模太大,无法一次性加载到内存中,需要自己实现torch_geometric.data.Dataset,只需要额外实现两个方法:

  • len():返回数据集的长度
  • get():自定义加载Graph的方法

在PyTorch中,是没有raw和processed这两个文件夹的,这两个文件夹在PyG中的实际意义和处理逻辑。

torch_geometric.data.Dataset继承自torch.utils.data.Dataset,在初始化方法__init__()中,会调用_download()方法和_process()方法。

_download()方法如下,首先检查self.raw_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.download()方法下载文件。

_process()方法如下,首先在self.processed_dir中有pre_transform,那么判断这个pre_transform和传进来的pre_transform是否一致,如果不一致,那么警告提示用户先删除self.processed_dir文件夹。pre_filter同理。

然后检查self.processed_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.process()生成文件。

一般来说不用实现downloand()方法。

如果你直接把处理好的 pt 文件放在了self.processed_dir中,那么也不用实现process()方法。

在 Pytorch 的dataset中,需要实现__getitem__()方法,根据index返回样本和标签。在这里torch_geometric.data.Dataset中,重写了__getitem__()方法,其中调用了get()方法获取数据。

需要实现的是get()方法,根据index返回torch_geometric.data.Data类型的数据。

process()方法存在的意义是原始的格式可能是 csv 或者 mat,在process()函数里可以转化为 pt 格式的文件,这样在get()方法中就可以直接使用torch.load()函数读取 pt 格式的文件,返回的是torch_geometric.data.Data类型的数据,而不用在get()方法做数据转换操作 (把其他格式的数据转换为 torch_geometric.data.Data类型的数据)。当然也可以提前把数据转换为 torch_geometric.data.Data类型,使用 pt 格式保存在self.processed_dir中。

#torch_geometric/data/dataset.py
from typing import List, Optional, Callable, Union, Any, Tuple

import sys
import re
import copy
import warnings
import numpy as np
import os.path as osp
from collections.abc import Sequence

import torch.utils.data
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data.makedirs import makedirs

IndexType = Union[slice, Tensor, np.ndarray, Sequence]


class Dataset(torch.utils.data.Dataset):
    r"""Dataset base class for creating graph datasets.
    See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
    create_dataset.html>`__ for the accompanying tutorial.

    Args:
        root (string, optional): Root directory where the dataset should be
            saved. (optional: :obj:`None`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """
    @property
    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        r"""The name of the files in the :obj:`self.raw_dir` folder that must
        be present in order to skip downloading."""
        raise NotImplementedError

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        r"""The name of the files in the :obj:`self.processed_dir` folder that
        must be present in order to skip processing."""
        raise NotImplementedError

    def download(self):
        r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
        raise NotImplementedError

    def process(self):
        r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
        raise NotImplementedError

    def len(self) -> int:
        r"""Returns the number of graphs stored in the dataset."""
        raise NotImplementedError

    def get(self, idx: int) -> Data:
        r"""Gets the data object at index :obj:`idx`."""
        raise NotImplementedError

    def __init__(self, root: Optional[str] = None,
                 transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None,
                 pre_filter: Optional[Callable] = None):
        super().__init__()

        if isinstance(root, str):
            root = osp.expanduser(osp.normpath(root))

        self.root = root
        self.transform = transform
        self.pre_transform = pre_transform
        self.pre_filter = pre_filter
        self._indices: Optional[Sequence] = None

        if 'download' in self.__class__.__dict__:
            self._download()

        if 'process' in self.__class__.__dict__:
            self._process()

    def indices(self) -> Sequence:
        return range(self.len()) if self._indices is None else self._indices

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, 'processed')

    @property
    def num_node_features(self) -> int:
        r"""Returns the number of features per node in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        if hasattr(data, 'num_node_features'):
            return data.num_node_features
        raise AttributeError(f"'{
     
     data.__class__.__name__}' object has no "
                             f"attribute 'num_node_features'")

    @property
    def num_features(self) -> int:
        r"""Returns the number of features per node in the dataset.
        Alias for :py:attr:`~num_node_features`."""
        return self.num_node_features

    @property
    def num_edge_features(self) -> int:
        r"""Returns the number of features per edge in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        if hasattr(data, 'num_edge_features'):
            return data.num_edge_features
        raise AttributeError(f"'{
     
     data.__class__.__name__}' object has no "
                             f"attribute 'num_edge_features'")

    @property
    def raw_paths(self) -> List[str]:
        r"""The absolute filepaths that must be present in order to skip
        downloading."""
        files = to_list(self.raw_file_names)
        return [osp.join(self.raw_dir, f) for f in files]

    @property
    def processed_paths(self) -> List[str]:
        r"""The absolute filepaths that must be present in order to skip
        processing."""
        files = to_list(self.processed_file_names)
        return [osp.join(self.processed_dir, f) for f in files]

    def _download(self):
        if files_exist(self.raw_paths):  # pragma: no cover
            return

        makedirs(self.raw_dir)
        self.download()

    def _process(se
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值