torch.scatter与torch_scatter库使用整理

本文介绍了PyTorch中的torch.scatter和torch._scatter函数,用于根据索引修改张量信息,以及torch_scatter库的scatter方法,支持按索引分组进行计算。详细解释了scatter函数的工作原理,并通过实例展示了如何进行分组求和等操作。torch_scatter库提供了更高效的分组处理,如segment_coo和segment_csr方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

最近在做图结构相关的算法,scatter能把邻接矩阵里的信息修改,或者把邻居分组算个sum或者reduce,挺方便的,简单整理一下。

torch.scatter 与 tensor._scatter

Pytorch自带的函数,用来将作为src的tensor根据index的描述填充到input中,形式如下:

ouput = torch.scatter(input, dim, index, src)
# 或者是
input.scatter_(dim, index, src)

两个方法的功能是相同的,而带下划线的_scatter方法是将原tensor input直接修改了,不带的则会返回一个新的tensor outputinput不变。

其中dim决定index对应值是沿着哪个维度进行修改。而src为数据来源,当其为tensor张量时,shape要和index相同,这样index中每个元素都能对应src中对应位置的信息。

理解scatter方法主要是要理解index实现的srcinput之间的位置对应关系,举个例子:

dim = 0
index = torch.tensor(
	[[0, 2, 2], 
	[2, 1, 0]]
)

dim为0时,遵循的映射原则为:input[index[i][j]][j] = src[i][j].
也就是说,将位置 (i, j) 中dim对应的位置改为 index[i][j] 的值。如位置(1,0),index[1][0]为2,则映射后的位置为(2,0),意味着input中(2,0)的位置被更改为src中(1,0)位置的值。

我个人形象理解是这些值会沿着dim方向滑动,上面例子中src[1][0]位置的值滑到2,成为input中的新值,这样理解起来更形象一点。

基本理解了上面这个例子,多维情况和不同dim的情况都可以类推了。

需要注意:src和input的dtype需要相同,不然会报Expected self.dtype to be equal to src.dtype,不一样就先转换再使用。

t = torch.arange(6).view(2, 3)
t = t.to(torch.float32)
print(t)

output = torch.scatter(torch.zeros((3
dgl.scatter_add和torch_scatter的作用都是在张量的某个维度上对指定的索引进行聚合操作,但是它们的实现方式略有不同。 dgl.scatter_add是Deep Graph Library(DGL)中提供的函数,用于在图上对节点特征进行聚合。它的具体使用方式如下: ``` import dgl import torch # 构造一个包含3个节点的图 g = dgl.graph(([0, 1, 2], [1, 2, 0])) # 构造一个3x4的特征矩阵 feat = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) # 定义聚合的方式为相加 def reducer(nodes): return {'sum': torch.sum(nodes.mailbox['m'], dim=1)} # 对节点的特征进行聚合,聚合方式为相加 g.update_all(dgl.function.copy_src('feat', 'm'), dgl.function.sum('m', 'feat')) feat = g.ndata['feat'] # 使用scatter_add对特征进行聚合 idx = torch.tensor([0, 1, 2, 0]) val = torch.tensor([1, 2, 3, 4]) feat = torch.zeros(3, 5) feat.scatter_add_(1, idx.unsqueeze(1).repeat(1, 4), val.unsqueeze(1).repeat(1, 4)) ``` 其中,scatter_add_函数的第一个参数表示要聚合的维度,这是第1维,即按行进行聚合。第二个参数是一个形状为(N, M)的张量,表示要聚合的索引,这是idx.unsqueeze(1).repeat(1, 4),即将idx扩展为形状为(N, M),其中M为feat的第1维大小。第三个参数是一个形状为(N, M)的张量,表示要聚合的值,这是val.unsqueeze(1).repeat(1, 4),即将val扩展为形状为(N, M)。 torch_scatter中的scatter函数使用方式也类似,具体使用方式如下: ``` import torch from torch_scatter import scatter idx = torch.tensor([0, 1, 2, 0]) val = torch.tensor([1, 2, 3, 4]) feat = torch.zeros(3, 5) feat = scatter(val.unsqueeze(1).repeat(1, 4), idx.unsqueeze(1).repeat(1, 4), dim=1, out=feat, reduce='add') ``` 其中,scatter函数的第一个参数表示要聚合的值,第二个参数表示要聚合的索引,第三个参数表示要聚合的维度,这是第1维,即按行进行聚合。第四个参数表示输出的张量,这是feat,最后一个参数表示聚合方式,这是相加。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值