组队学习-图预测 Taks 07

本文介绍了PyTorch Geometric库中 Dataset 基类的使用,包括自定义数据类型,以及如何处理图样本和二部图。通过PairData和BipartiteData类展示了图的匹配和二部图的邻接矩阵处理。此外,还提供了图预测任务的代码实践,包括参数设置和数据加载。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、知识点回顾

   (1)Dataset基类简介

    在PyG中,输入数据可以自定义数据类型,需要继承基类torch_geometric.data.Dataset,定义len()函数和get()函数。如果数据不需要下载,可以不定义函数download()和process().

    (2)图样本封装成批和与DataLoader类

     图分类和图预测的图像类型多种多样,想要将这些数据进行输入,就需要进行图样本的封装。在PyTorch Geometric中,程序将小图放在大图的对角线上构成大图。PyTorch Geometric使用torch_geometric.data.__inc__()和torch_geometric.data.__cat_dim__()进行相应的处理。

     图的匹配分为源图和目标图,分别使用x_s.size(0)和x_t.size(0)进行节点的增值。 

    (3)二部图。二部图的邻接矩阵定义两种类型的节点之间的连接关系。

 

class BipartiteData(Data):
    def __init__(self, edge_index, x_s, x_t):
        super(BipartiteData, self).__init__()
        self.edge_index = edge_index
        self.x_s = x_s
        self.x_t = x_t


def __inc__(self, key, value):
    if key == 'edge_index':
        return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
    else:
        return super().__inc__(key, value)


edge_index = torch.tensor([
    [0, 0, 1, 1],
    [0, 1, 1, 2],
])
x_s = torch.randn(2, 16)  # 2 nodes.
x_t = torch.randn(3, 16)  # 3 nodes.

data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)

print(batch.edge_index)

 

Batch(batch=[6], edge_index=[2, 8], ptr=[3], x_s=[4, 16], x_t=[6, 16])
tensor([[0, 0, 1, 1, 3, 3, 4, 4],
        [0, 1, 1, 2, 3, 4, 4, 5]])

 2、代码实践

    (1)图的匹配

import torch
from torch_geometric.data import Dataset, DataLoader

class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t

class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t

    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value)

edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
print(batch.edge_index_s)
print(batch.edge_index_t)
Batch(edge_index_s=[2, 8], edge_index_t=[2, 6], x_s=[10, 16], x_t=[8, 16])
tensor([[0, 0, 0, 0, 5, 5, 5, 5],
        [1, 2, 3, 4, 6, 7, 8, 9]])
tensor([[0, 0, 0, 4, 4, 4],
        [1, 2, 3, 5, 6, 7]])

     (2)二部图

3、作业:

    在此图预测任务实践中: 我们将前面所学的基于GIN的图表示学习神经网络和超大规模数据集类的创建方法付诸于实际应用;我们构建了一种很方便的设置不同参数进行试验的方法,不同试验的过程与结果信息通过简单的操作即可进行比较分析。

python main.py  --task_name GINGraphPooling\    # 为当前试验取名
                --device 0\                     
                --num_layers 5\                 # 使用GINConv层数
                --graph_pooling sum\            # 图读出方法
                --emb_dim 256\                  # 节点嵌入维度
                --drop_ratio 0.\
                --save_test\                    # 是否对测试集做预测并保留预测结果
                --batch_size 512\
                --epochs 100\
                --weight_decay 0.00001\
                --early_stop 10\                # 当有`early_stop`个epoches验证集结果没有提升,则停止训练
                --num_workers 4\
                --dataset_root dataset          # 存放数据集的根目录

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值