超大规模数据集类的创建
在一些应用场景中,数据集规模超级大,我们很难有足够大的内存完全存下所有数据。需要一个按需加载样本到内存的数据集类。
Dataset基类简介
在PyG中,我们通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类。
继承torch_geometric.data.InMemoryDataset基类要实现的方法,继承此基类同样要实现,此外还需要实现以下方法:
len():返回数据集中的样本的数量。get():实现加载单个图的操作。注意:在内部,__getitem__()返回通过调用get()来获取Data对象,并根据transform参数对它们进行选择性转换。
下面让我们通过一个简化的例子看继承torch_geometric.data.Dataset基类的规范:
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
path = download_url(url, self.raw_dir)
...
def process(self):
i = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
/data/anaconda3/lib/python3.7/site-packages/numba/decorators.py:146: RuntimeWarning: Caching is not available when the 'parallel' target is in use. Caching is now being disabled to allow execution to continue.
warnings.warn(msg, RuntimeWarning)
其中,每个Data对象在process()方法中单独被保存,并在get()中通过指定索引进行加载。
跳过download/process
对于无需下载数据集原文件的情况,我们不重写(override)download方法即可跳过下载。对于无需对数据集做预处理的情况,我们不重写process方法即可跳过预处理。
无需定义Dataset类
通过下面的方式,我们可以不用定义一个Dataset类,而直接生成一个Dataloader对象,直接用于训练:
from torch_geometric.data import Data, DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
我们也可以通过下面的方式将一个列表的Data对象组成一个batch:
from torch_geometric.data import Data, Batch
data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, batch_size=32)
图样本封装成批(BATCHING)与DataLoader类
合并小图组成大图
图可以有任意数量的节点和边,它不是规整的数据结构,因此对图数据封装成批的操作与对图像与序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲\mathbf{A} = \b…
此方法有以下关键的优势:
-
依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。
-
没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。
通过torch_geometric.data.DataLoader类,多个小图被封装成一个大图。torch_geometric.data.DataLoader是PyTorch的DataLoader的子类,它覆盖了collate()函数,改函数定义了一列表的样本是如何封装成批的。因此,所有可以传递给PyTorch DataLoader的参数也可以传递给PyTorch Geometric的 DataLoader,例如,num_workers。
小图的属性增值与拼接
将小

本文介绍了如何使用PyTorch Geometric创建处理大规模数据集的定制Dataset类,包括按需加载、数据预处理、图的批量处理和特殊场景下的图匹配与二部图封装。通过实例PCQM4M-LSC数据集,展示了如何实现高效的数据加载和处理策略。
最低0.47元/天 解锁文章
1651





