Task07 图预测任务

本文介绍了在处理大规模图数据集时,如何创建一个按需加载的Dataset类,以及如何利用PyTorchGeometric进行图的批处理。通过合并小图形成大图,实现了在不增加额外计算或内存开销的情况下进行有效的图神经网络运算。同时,展示了如何自定义Data类以实现属性的增值与拼接,并给出了二部图和一般图的实例。最后,提到了argparse库在命令行参数解析中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Task07 图预测任务

一、超大规模数据集类的创建

  • 数据集规模超级大,我们很难有足够大的内存完全存下所有数据
  • 需要一个按需加载样本到内存的数据集类
1.1 Dataset基类
  • 通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类。

  • 还需要实现:len() 、get()方法。

  • 无需下载数据集原文件的情况,我们不重写(override)download方法即可跳过下载。对于无需对数据集做预处理的情况,我们不重写process方法即可跳过预处理。

二、图样本封装成批(BATCHING)与DataLoader

2.1 合并小图组成大图
  • 大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:

KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲\mathbf{A} = \b…

此方法有以下关键的优势**:

  • 依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。

  • 没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。

2.2 小图的属性增值与拼接
def __inc__(self, key, value):
    if 'index' in key or 'face' in key:
        return self.num_nodes
    else:
        return 0

def __cat_dim__(self, key, value):
    if 'index' in key or 'face' in key:
        return 1
    else:
        return 0
2.2.1 小图的属性增值与拼接
class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t

    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value)
2.2.2 二部图
class BipartiteData(Data):
    def __init__(self, edge_index, x_s, x_t):
        super(BipartiteData, self).__init__()
        self.edge_index = edge_index
        self.x_s = x_s
        self.x_t = x_t
        
def __inc__(self, key, value):
    if key == 'edge_index':
        return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
    else:
        return super().__inc__(key, value)
    
edge_index = torch.tensor([
    [0, 0, 1, 1],
    [0, 1, 1, 2],
])
x_s = torch.randn(2, 16)  # 2 nodes.
x_t = torch.randn(3, 16)  # 3 nodes.

data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
# Batch(edge_index=[2, 8], x_s=[4, 16], x_t=[6, 16])

print(batch.edge_index)
# tensor([[0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 3, 4, 4, 5]])
2.2.3 在新的维度上做拼接
 class MyData(Data):
     def __cat_dim__(self, key, item):
         if key == 'foo':
             return None
         else:
             return super().__cat_dim__(key, item)

edge_index = torch.tensor([
   [0, 1, 1, 2],
   [1, 0, 2, 1],
])
foo = torch.randn(16)

data = MyData(edge_index=edge_index, foo=foo)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
# Batch(edge_index=[2, 8], foo=[2, 16])

PS: argparse库的使用

  • argparse 是 Python 内置的一个用于命令项选项与参数解析的模块。

    主要有三个步骤:

    • 创建 ArgumentParser() 对象
    • 调用 add_argument() 方法添加参数
    • 使用 parse_args() 解析添加的参数
  • add_argument() 方法定义如何解析命令行参数:

ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])

每个参数解释如下:

  • name or flags - 选项字符串的名字或者列表,例如 foo 或者 -f, –foo。
  • action - 命令行遇到参数时的动作,默认值是 store。
    • store_const,表示赋值为const;
    • append,将遇到的值存储成列表,也就是如果参数重复则会保存多个值;
    • append_const,将参数规范中定义的一个值保存到一个列表;
    • count,存储遇到的次数;此外,也可以继承 argparse.Action 自定义参数解析;
  • nargs - 应该读取的命令行参数个数,可以是具体的数字,或者是?号,当不指定值时对于 Positional argument 使用 default,对于 Optional argument 使用 const;或者是 * 号,表示 0 或多个参数;或者是 + 号表示 1 或多个参数。
  • const - 一个在 action 和 nargs 选项所需的常量值。
  • default - 不指定参数时的默认值。
  • type - 命令行参数应该被转换成的类型。
  • choices - 参数可允许的值的一个容器。
  • required - 可选参数是否可以省略 (仅针对可选参数)。
  • help - 参数的帮助信息,当指定为 argparse.SUPPRESS 时表示不显示该参数的帮助信息.
  • metavar - 在 usage 说明中的参数名称,对于必选参数默认就是参数名称(上面的 name or flags),对于可选参数默认是全大写的参数名称.
  • dest - parse_args() 方法返回的对象所添加的属性的名称。默认情况下,对于可选参数选取最长的名称,中划线转换为下划线.

参考资料

DataWhale开源学习资料:https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值