官网解释
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
Parameters:
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
例子
>>> t = torch.Tensor([[1,2,3],[4,5,6]])
>>> print(t)
tensor([[1., 2., 3.],
[4., 5., 6.]])
>>> index_a = torch.LongTensor([[0,0],[0,1]])
>>> print(index_a)
tensor([[0, 0],
[0, 1]])
>>> print(torch.gather(t,dim=1,index=index_a))
tensor([[1., 1.],
[4., 5.]])
>>> index_b = torch.LongTensor([[0,1,1],[1,0,0]])
>>> print(index_b)
tensor([[0, 1, 1],
[1, 0, 0]])
>>> print(torch.gather(t,dim=0,index=index_b))
tensor([[1., 5., 6.],
[4., 2., 3.]])
输出和index的维度是一致的