torch.scatter中的scatter_sum能做什么事

torch.scatter中的scatter_sum能做什么事

TL;DL: 可以计算一个一个 tensor 中不同元素的数量,存成一个一维 tensor,新 tensor 中的每个元素就是 sum 出来的不同元素的数量。

最早是在 MEAN (CONDITIONAL ANTIBODY DESIGN AS 3D EQUIVARIANT GRAPH TRANSLATIO) 这篇工作的code中看到的:
对于一个batch中的sample没有划分成 [bsz, seqlen, dim] 形式,而是沿着seq维度拼起来变成 [n_all_token, dim] 形式,这里 n_all_token 就是每个sample的长度之和。
需要另外一个一维tensor记录着batch中不同sample起始位置的id,MEAN这里是用cumsum先变出一个长度为 n_all_token 的一维向量,每个位置是该token所属的sample id,这个一维向量称为 batch_id
4421 即 n_all_token,0-15 因为 batchsize=16
在这里插入图片描述
这里来用到了 scatter_sum:
lengths = scatter_sum(torch.ones_like(batch_id), batch_id)
这样 lengths 即:
在这里插入图片描述
lengths 中的每个元素就是batch中每个sample的长度

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值