CyclesDataset(helpers.dataset_classes文件中的cycles_dataset.py)

任务类型:分类任务
用途:`CyclesDataset` 生成循环图,用于图分类任务。每个图的标签是基于其拓扑结构的类别,模型通过训练来预测每个环形图的类别,因此这是一个典型的图分类任务。

from helpers.dataset_classes.cycles_dataset import CyclesDataset

import torch
from torch_geometric.data import Data
from typing import List
from torch import Tensor


def make_undirected(edge_index: Tensor) -> Tensor:
    edge_index_other_direction = torch.stack((edge_index[1], edge_index[0]), dim=0)
    edge_index = torch.cat((edge_index_other_direction, edge_index), dim=1)
    return edge_index


def create_cycle(max_cycle: int) -> List[Data]:
    data_list = []
    for cycle_size in range(6, max_cycle + 1):
        if cycle_size < (max_cycle + 1 - 6) / 3 + 6:
            train_mask = torch.ones(size=(1,), dtype=torch.bool)
            val_mask = torch.zeros(size=(1,), dtype=torch.bool)
            test_mask = torch.zeros(size=(1,), dtype=torch.bool)
        elif cycle_size < 2 * (max_cycle + 1 - 6) / 3 + 6:
            train_mask = torch.zeros(size=(1,), dtype=torch.bool)
            val_mask = torch.ones(size=(1,), dtype=torch.bool)
            test_mask = torch.zeros(size=(1,), dtype=torch.bool)
        else:
            train_mask = torch.zeros(size=(1,), dtype=torch.bool)
            val_mask = torch.zeros(size=(1,), dtype=torch.bool)
            test_mask = torch.ones(size=(1,), dtype=torch.bool)

        x = torch.ones(size=(cycle_size, 1))
        edge_index1 = torch.tensor([list(range(cycle_size)),
                                    list(range(1, cycle_size)) + [0]])
        edge_index1 = make_undirected(edge_index=edge_index1)
        edge_index2 = torch.tensor([[0, 1, 2] + list(range(3, cycle_size)),
                                    [1, 2, 0] + list(range(4, cycle_size)) + [3]])
        edge_index2 = make_undirected(edge_index=edge_index2)

        data_list.append(Data(x=x, edge_index=edge_index1, y=torch.tensor([0], dtype=torch.long),
                              train_mask=train_mask, val_mask=val_mask, test_mask=test_mask))
        data_list.append(Data(x=x, edge_index=edge_index2, y=torch.tensor([1], dtype=torch.long),
                              train_mask=train_mask, val_mask=val_mask, test_mask=test_mask))

    return data_list


class CyclesDataset(object):

    def __init__(self):
        super().__init__()
        self.data = create_cycle(max_cycle=13)


if __name__ == '__main__':
    data = CyclesDataset()

1. 导入模块<

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值