scatter_方法源码解析
源码位置
torch_geometric\utils\scatter.py
源码及注释
代码相当于在原有torch_scatter包的基础上进行二次开发,1.4.3之后版本移除了,相当于只要理解torch.scatter(…, reduce=‘add’) 就好了。
下面是开发者说的原话:
“The scatter_ call got removed in one of the more recent versions, and you can now simply use torch.scatter(…, reduce=‘add’) for the same effect.”
import torch_scatter
def scatter_(name, src, index, dim=0, dim_size=None):
r"""Aggregates all values from the :attr:`src` tensor at the indices
specified in the :attr:`index` tensor along the first dimension.
If multiple indices reference the same location, their contributions
are aggregated according to :attr:`name` (either :obj:`"add"`,
:obj:`"mean"` or :obj:`"max"`).
Args:
name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
:obj:`"min"`, :obj:`"max"`).
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index. (default: :obj:`0`)
dim_size (int, optional): Automatically create output tensor with size
:attr:`dim_size` in the first dimension. If set to :attr:`None`, a
minimal sized output tensor is returned. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
# 例行断言
assert name in ['add', 'mean', 'min', 'max']
# 返回对象属性值
op = getattr(torch_scatter, 'scatter_{}'.format(name))
# 获取对象属性后返回值可直接使用,看具体案例
out = op(src, index, dim, None, dim_size)
out = out[0] if isinstance(out, tuple) else out
# 限制取最值后数据约束
if name == 'max':
out[out < -10000] = 0
elif name == 'min':
out[out > 10000] = 0
return out
torch_scatter方法
结合官网和以下网址大家可以看得懂(看不懂没关系,多看几遍,我也是看了好几遍才懂的,不得不佩服其设计),这里不做过多解释。
官网 Docs > torch.Tensor > torch.Tensor.scatter_
scatter_方法解析
常规运行
在理解上述原有开发包的基础上,我们再来看一下此方法的作用。
在上一篇博客里面用一句话介绍过,按传入方式聚合节点信息。
其是用于创建消息传递层的基类。
为了便于理解,这里给出一个具体例子:
假设我们有一个图,图中有四个节点,节点特征维度为3。
其关系如下图所示(为方便理解我画成了单向图):
以下代码可以在scatter_方法调试中使用:
#src 节点特征
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
#index 索引
b = torch.tensor([2, 3, 0, 1])
#dim 维度信息
c = 1
#dim_size 返回的维度
e = 4
# op = getattr(torch_scatter, 'scatter_{}'.format(name))
# out = op(src, index, dim, None, dim_size)
qq = getattr(torch_scatter, 'scatter_{}'.format("add"))
qq(a, b, c, None, e)
tensor([[2., 2., 2.],
[4., 4., 4.],
[1., 1., 1.],
[3., 3., 3.]])
可以看到他进行了一次消息传递,以节点1为例,他接收了来自节点3的消息,所以特征变为[2., 2., 2.]。
节点特征维度
此样例说明节点信息聚合时,各个维度相互独立。
a = torch.tensor([[1.0, 0.0, 1.0], [2.0, 0.0, 2.0], [3.0, 0.0, 3.0], [4.0, 0.0, 4.0]])
b = torch.tensor([2, 3, 0, 1])
c = 1
e = 4
qq = getattr(torch_scatter, 'scatter_{}'.format("add"))
qq(a, b, c, None, e)
tensor([[2.0, 0.0, 2.0],
[4.0, 0.0, 4.0],
[1.0, 0.0, 1.0],
[3.0, 0.0, 3.0]])
dim_size(运算后返回的维度)参数
只修改dim_size ,可以看到返回维度增加。
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 3, 0, 1])
c = 1
e = 8
qq = getattr(torch_scatter, 'scatter_{}'.format("add"))
qq(a, b, c, None, e)
tensor([[2., 2., 2.],
[4., 4., 4.],
[1., 1., 1.],
[3., 3., 3.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
index(索引)参数
这里相当于修改了节点的指向,修改后图:
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 0, 0, 1])
c = 1
e = 4
qq = getattr(torch_scatter, 'scatter_{}'.format("add"))
qq(a, b, c, None, e)
tensor([[5., 5., 5.],
[4., 4., 4.],
[1., 1., 1.],
[0., 0., 0.]])
因为这里使用的“add”方法,所以同时传递到一个节点的信息被累加。
name(聚合方式)
很显然,当我把聚合方式改为均值,除了节点1有两个值传入被平均,其他保持不变;同理最值一样道理。
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 0, 0, 1])
c = 1
e = 4
qq = getattr(torch_scatter, 'scatter_{}'.format("mean"))
qq(a, b, c, None, e)
tensor([[2.5000, 2.5000, 2.5000],
[4.0000, 4.0000, 4.0000],
[1.0000, 1.0000, 1.0000],
[0.0000, 0.0000, 0.0000]])
PS:torch_scatter(函数库)转torch.scatter_add(内置原生函数)
要转换的代码:
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
scatter_add 是 PyTorch 函数库 torch_scatter 中的一个函数,而 torch.scatter_add 是 PyTorch 内置的原生函数。两者的调用方式略有不同。以下是将 scatter_add 转换为 torch.scatter_add 的重写方法。
假设代码中的参数如下:
edge_weight:包含要累加的数据的张量(源张量)。
row:指定目标索引位置的张量(即 index 参数)。
dim=0:表示累加操作发生在第 0 维。
dim_size=num_nodes:目标张量在第 0 维的大小。
转换代码:
要将 scatter_add 转换为 torch.scatter_add,我们首先需要创建一个目标张量 output,然后在该张量上执行 torch.scatter_add 操作。
import torch
# 假设已知 edge_weight、row 和 num_nodes
edge_weight = torch.tensor([0.5, 1.0, 1.5, 2.0]) # 举例的源张量
row = torch.tensor([0, 1, 0, 2]) # 索引张量
num_nodes = 3 # 输出张量的第 0 维大小
# 创建目标张量 output,初始化为零
output = torch.zeros(num_nodes, dtype=edge_weight.dtype)
# 使用 torch.scatter_add
deg = torch.scatter_add(output,dim=0,index=row,src=edge_weight))
output 是大小为 (num_nodes,) 的张量,初始化为零,以便累加 edge_weight 中的值。
scatter_add(0, row, edge_weight) 表示在 output 张量的第 0 维按 row 中的索引位置累加
edge_weight 中的值。
感谢及参考博文
部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
官方说明 Docs > torch.Tensor > torch.Tensor.scatter
https://pytorch.org/docs/master/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
官网说明 Docs » Scatter
https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
参考博文1 PyTorch笔记之scatter()函数的使用
https://www.zhangshengrong.com/p/rG1V7g5ka3/
参考博文2 pytorch中的 scatter_()函数使用和详解
https://blog.youkuaiyun.com/t20134297/article/details/105755817
参考GPT4o