用法跟上面torch.sort()函数一样,不同的是torch.argsort()返回只是排序后的值所对应原输入input的下标,即torch.sort()返回的indices
dim = 1 表示对每行中的元素进行降序排序,descending=True表示降序排序,输出结果为返回排序后的值所对应原输入input的下标indices
x = torch.randn(3, 4)
indices = torch.argsort(x,dim=1,descending=True)
x,indices
输出结果如下:
(tensor([[-0.6069, -0.9252, -0.9177, 0.6997],
[ 0.3245, -0.0665, 0.4600, 0.0722],
[-1.0662, 2.2669, -0.1171, -0.9208]]),
tensor([[3, 0, 2, 1],
[2, 0, 3, 1],
[1, 2, 3, 0]]))
torch.argsort()函数与torch.sort()类似,但返回的是输入tensor在排序后的索引。当dim=1且descending=True时,它对每行元素进行降序排序,返回排序后元素在原输入中的下标。例如,对于一个3x4的tensorx,argsort(dim=1,descending=True)会返回每个元素在排序后的行中的位置。
2795

被折叠的 条评论
为什么被折叠?



