记录一下torch.gather函数
用法:torch.gather(input: Tensor, dim: int, index: LongTensor, *, sparse_grad=False, out=None) -> Tensor
功能:指定张量index,根据其元素的值来获取输入矩阵input上的值。
注意
index需要与input有相同的维度,并且对d!=dim时要求index.size(d)<=input.size(d)
意思就是说如果input的size为(2, 3, 4),如果dim指定为1,那么需要index.size(0)<=2以及index.size(2)<=4.- 函数输出的Tensor与index的shape相同
举个例子:
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1],
[ 4, 3]])
怎么得到这个结果的呢,可以这样记忆:index现在是[[0, 0], [1, 0]],它的每个元素在index中都有其索引,比如元素1索引是[1, 0](index[1, 0]=1),由于现在指定的dim=1,那么就用1代替[1, 0]中dim=1处的0,变成[1, 1],即获取到input[1, 1] ,如下图所示。

对于多维矩阵也是一样的流程,用index的每个元素的值代替该元素在index上的索引在dim维度上的值,便能得到在input上的索引。
也就是官方举的例子:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
再有一个例子
>>> input_ = [[2, 3, 4, 5, 0, 0],
[1, 4, 3, 0, 0, 0],
[4, 2, 2, 5, 7, 0],
[1, 0, 0, 0, 0, 0]]
>>> input_ = torch.tensor(input_)
>>> index = torch.LongTensor([[3],[2],[4],[0]])
>>> torch.gather(input_, 1, index)
tensor([[5],
[3],
[7],
[1]])
本文详细介绍了PyTorch中的gather函数使用方法及注意事项。通过具体示例解释了如何根据索引从多维张量中选取元素,适用于深度学习模型中的特征选择等场景。
2966

被折叠的 条评论
为什么被折叠?



