Task07 图预测任务
一、超大规模数据集类的创建
- 数据集规模超级大,我们很难有足够大的内存完全存下所有数据。
- 需要一个按需加载样本到内存的数据集类。
1.1 Dataset基类
-
通过继承
torch_geometric.data.Dataset
基类来自定义一个按需加载样本到内存的数据集类。 -
还需要实现:len() 、get()方法。
-
无需下载数据集原文件的情况,我们不重写(override)
download
方法即可跳过下载。对于无需对数据集做预处理的情况,我们不重写process
方法即可跳过预处理。
二、图样本封装成批(BATCHING)与DataLoader
类
2.1 合并小图组成大图
- 大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲\mathbf{A} = \b…
此方法有以下关键的优势**:
-
依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。
-
没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。
2.2 小图的属性增值与拼接
- PyTorch Geometric允许我们通过覆盖
torch_geometric.data.__inc__()
和torch_geometric.data.__cat_dim__()
函数来实现我们希望的行为。
def __inc__(self, key, value): if 'index' in key or 'face' in key: return self.num_nodes else: return 0 def __cat_dim__(self, key, value): if 'index' in key or 'face' in key: return 1 else: return 0
2.2.1 小图的属性增值与拼接
class PairData(Data):
def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
super(PairData, self).__init__()
self.edge_index_s = edge_index_s
self.x_s = x_s
self.edge_index_t = edge_index_t
self.x_t = x_t
def __inc__(self, key, value):
if key == 'edge_index_s':
return self.x_s.size(0)
if key == 'edge_index_t':
return self.x_t.size(0)
else:
return super().__inc__(key, value)
2.2.2 二部图
class BipartiteData(Data):
def __init__(self, edge_index, x_s, x_t):
super(BipartiteData, self).__init__()
self.edge_index = edge_index
self.x_s = x_s
self.x_t = x_t
def __inc__(self, key, value):
if key == 'edge_index':
return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
else:
return super().__inc__(key, value)
edge_index = torch.tensor([
[0, 0, 1, 1],
[0, 1, 1, 2],
])
x_s = torch.randn(2, 16) # 2 nodes.
x_t = torch.randn(3, 16) # 3 nodes.
data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
# Batch(edge_index=[2, 8], x_s=[4, 16], x_t=[6, 16])
print(batch.edge_index)
# tensor([[0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 3, 4, 4, 5]])
2.2.3 在新的维度上做拼接
class MyData(Data):
def __cat_dim__(self, key, item):
if key == 'foo':
return None
else:
return super().__cat_dim__(key, item)
edge_index = torch.tensor([
[0, 1, 1, 2],
[1, 0, 2, 1],
])
foo = torch.randn(16)
data = MyData(edge_index=edge_index, foo=foo)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
# Batch(edge_index=[2, 8], foo=[2, 16])
PS: argparse库的使用
-
argparse 是 Python 内置的一个用于命令项选项与参数解析的模块。
主要有三个步骤:
- 创建 ArgumentParser() 对象
- 调用 add_argument() 方法添加参数
- 使用 parse_args() 解析添加的参数
-
add_argument() 方法定义如何解析命令行参数:
ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])
每个参数解释如下:
- name or flags - 选项字符串的名字或者列表,例如 foo 或者 -f, –foo。
- action - 命令行遇到参数时的动作,默认值是 store。
- store_const,表示赋值为const;
- append,将遇到的值存储成列表,也就是如果参数重复则会保存多个值;
- append_const,将参数规范中定义的一个值保存到一个列表;
- count,存储遇到的次数;此外,也可以继承 argparse.Action 自定义参数解析;
- nargs - 应该读取的命令行参数个数,可以是具体的数字,或者是?号,当不指定值时对于 Positional argument 使用 default,对于 Optional argument 使用 const;或者是 * 号,表示 0 或多个参数;或者是 + 号表示 1 或多个参数。
- const - 一个在 action 和 nargs 选项所需的常量值。
- default - 不指定参数时的默认值。
- type - 命令行参数应该被转换成的类型。
- choices - 参数可允许的值的一个容器。
- required - 可选参数是否可以省略 (仅针对可选参数)。
- help - 参数的帮助信息,当指定为 argparse.SUPPRESS 时表示不显示该参数的帮助信息.
- metavar - 在 usage 说明中的参数名称,对于必选参数默认就是参数名称(上面的 name or flags),对于可选参数默认是全大写的参数名称.
- dest - parse_args() 方法返回的对象所添加的属性的名称。默认情况下,对于可选参数选取最长的名称,中划线转换为下划线.
参考资料
DataWhale开源学习资料:https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN