torch.argsort()&&torch.sort()用法

本文介绍了PyTorch中的torch.argsort()和torch.sort()两个函数,详细阐述了它们的参数、作用及使用示例。torch.argsort()返回输入张量按照指定维度排序后的下标,而torch.sort()则返回排序后的张量及其对应的下标。通过示例展示了这两个函数在降序排列时的输出结果,并验证了torch.argsort()与torch.sort()返回的下标矩阵的一致性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

函数原型1:

torch.argsort(input, dim=- 1, descending=False) → LongTensor

参数:

input:(Tensor)输入张量
dim:(int类型)要排序的维度
descending:(布尔类型),升序还是降序。默认升序。

作用:返回按照指定维度排序后的值对应排序前的下标。

该函数其实是torch.sort()返回的第二个元素,第一个元素是排序后的Tensor。

函数原型2

torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)

参数:

input ( Tensor ) :输入张量。
dim ( int , optional ) : 要排序的维度
descending ( bool , optional ) : 控制顺序(升序或降序)
stable ( bool , optional ) : 使排序更加稳定,这保证了等价元素的顺序得以保留。

作用:

将输入张量的元素按照给定的维度按值升序排序。返回一个元组。

实例

x = torch.randint(10, size=(4, 3))
print(f'x:{x}')
x1 = torch.argsort(x, dim=-1, descending=True)#降序
print(x1)
values, indices = torch.sort(x, dim=-1, descending=True)
print(values)
print(indices==x1)

输出结果:

x:tensor([[0, 6, 7],
        [8, 5, 5],
        [3, 4, 9],
        [4, 6, 5]])
tensor([[2, 1, 0],
        [0, 1, 2],
        [2, 1, 0],
        [1, 2, 0]])
tensor([[7, 6, 0],
        [8, 5, 5],
        [9, 4, 3],
        [6, 5, 4]])
tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值