torch.scatter中的scatter_sum
能做什么事
TL;DL: 可以计算一个一个 tensor 中不同元素的数量,存成一个一维 tensor,新 tensor 中的每个元素就是 sum 出来的不同元素的数量。
最早是在 MEAN (CONDITIONAL ANTIBODY DESIGN AS 3D EQUIVARIANT GRAPH TRANSLATIO) 这篇工作的code中看到的:
对于一个batch中的sample没有划分成 [bsz, seqlen, dim] 形式,而是沿着seq维度拼起来变成 [n_all_token, dim] 形式,这里 n_all_token 就是每个sample的长度之和。
需要另外一个一维tensor记录着batch中不同sample起始位置的id,MEAN这里是用cumsum先变出一个长度为 n_all_token 的一维向量,每个位置是该token所属的sample id,这个一维向量称为 batch_id
4421 即 n_all_token,0-15 因为 batchsize=16
这里来用到了 scatter_sum:
lengths = scatter_sum(torch.ones_like(batch_id), batch_id)
这样 lengths 即:
lengths
中的每个元素就是batch中每个sample的长度