最近看代码时遇到了两个函数,查阅pytorch官方文档后一时半会儿也没弄懂,现在写篇笔记来加深一下印象。
gather
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
沿着给定的维度dim
,将输入input
指定位置的值聚合起来,指定位置由index
决定。
index
和input
必须有相同数量的维度,且满足1 <= index[dim] <= input[dim]
、index[other_dims] == input[other_dims]
对于3维的张量,公式为:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
下面给出一个例子:
import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.],
[ 10., 28., 22., 27., 0.]],
[[ 26., 10., 20., 29., 18.],
[ 5., 24., 26., 21., 3.],
[ 10., 29., 10., 0., 22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,