- 官方文档 可供参考
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
- Gathers values along an
axis
specified bydim
. - Parameters:
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
- For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
- Example:
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
对该 example 的解释如下:
tensor([[ ? ← t [0, (0)], ? ← t [0, (1)] ], [ ? ← t [1, (0)], ? ← t [1, (1)]]])
其中,已填写的索引是本身该位置的索引,()占位的是需要 index 参数指定的,也就是:
tensor([[ 1 ← t [0, 0
], 1 ← t [0, 0
] ], [ 4 ← t [1, 1
], 3← t [1, 0
]]])