pytorch基操04-比较运算符

本文详细介绍了PyTorch中的比较运算符,如`equal`、`gt`、`ge`、`lt`、`le`、`ne`,并展示了它们在张量操作中的应用。此外,还讲解了`sort`、`topk`和`kthvalue`等排序相关函数,以及`isinf`、`isfinite`和`isnan`用于检查数值状态的方法。通过实例解析,帮助读者深入理解这些操作在深度学习和张量处理中的作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 torch中的比较运算符

为了演示不同比较运算符的作用,先初始化两个tensor a和b。

# 创建a,b tensor并用均值为20,方差为10的高斯分布采样赋值
a=torch.empty(size=(2,2)).normal_(20,10).floor_() # floor 向下取整
b=torch.empty(size=(2,2)).normal_(20,10).floor_()
a[0,0],b[0,0]=10,10 # 强制修改[0,0]位置元素相同
a:
 tensor([[10., 24.],
        [ 9., 22.]])
b:
 tensor([[10., 25.],
        [22., 26.]])

1.1 torch.equal

equal方法只有当a和b形状元素值都相等才返回True。

print('equal:',torch.equal(a,b))

这里除了[0,0]位置其他元素都不相同,因此返回False.

equal: False

1.2 torch.equal

# a==b
print('eq:',torch.eq(a,b))
eq: tensor([[ True, False],
        [False, False]])

1.3 torch.gt

# a > b , gt=Greater Than
print('a>b:',torch.gt(a,b))

每个a的元素都小于等于在b对应位置的元素,因此都返回False。

a>b: tensor([[False, False],
        [False, False]])

1.4 torch.ge

# a>=b , ge=Greater than or Equal to
print('a>=b:',torch.ge(a,b))
a>=b: tensor([[ True, False],
        [False, False]])

1.5 torch.lt

# a<b , lt=Less Than
print('a<b:',torch.lt(a,b))
a<b: tensor([[False,  True],
        [ True,  True]])

1.6 torch.le

# a<=b , le=Less than or Equal to
print('a<=b:',torch.le(a,b))
a<=b: tensor([[True, True],
        [True, True]])

1.7 torch.ne

# a!=b , ne=Not Equal
print('ne:',torch.ne(a,b))
ne: tensor([[False,  True],
        [ True,  True]])

1.8 torch.sort

创建用于排序的张量。

print('\n'+'-'*8+'排序'+'-'*8+'\n')
a=torch.empty(size=(2,5)).normal_(20,10).floor_()
print('a:',a)
print('b:',b)
a: tensor([[28., 17., 18., 16., 19.],
        [44., 30., 10.,  9., 29.]])

排序操作,返回排序后的张量,以及排序后的原元素的索引。此处沿着维度1进行排序。

# torch.sort
print('-'*8+'sort'+'-'*8)
a_sort,a_idx=torch.sort(a,dim=1)
print('a_sort:\n',a_sort)
print('idx_sort:\n',a_idx)
--------sort--------
a_sort:
 tensor([[ 3.,  5.,  8., 12., 22.],
        [-3.,  7., 19., 33., 33.]])
idx_sort:
 tensor([[3, 2, 0, 4, 1],
        [4, 1, 2, 0, 3]])

1.9 torch.topk

topk可以获取对应维度上,topk大或者topk小的元素。还是使用上上面的a作为例子。

# torch.topk 前k大个元素
print('-'*8+'topk'+'-'*8)
topk_res=torch.topk(a,dim=1,k=2,largest=True)
print(topk_res)

可以看到,维度1上最大的两个元素为(22,12)和(33,33)。

a: tensor([[ 8., 22.,  5.,  3., 12.],
        [33.,  7., 19., 33., -3.]])
--------topk--------
torch.return_types.topk(
values=tensor([[22., 12.],
        [33., 33.]]),
indices=tensor([[1, 4],
        [0, 3]]))

1.10 torch.kthvalue

kthvalue可以获取在指定维度上第k小的元素,只返回一个元素,且只能是第k小的(大的也不行)。

# torch.kthvalue
# get the k-th smallest values 第k小的元素
print('-'*8+'kthvalue'+'-'*8)
print(torch.kthvalue(a,k=3,dim=1))
a: tensor([[ 8., 22.,  5.,  3., 12.],
        [33.,  7., 19., 33., -3.]])

torch.return_types.kthvalue(
values=tensor([ 8., 19.]),
indices=tensor([0, 2]))

1.11 torch.isinf

isinf用于判断元素是否是无界的,这里故意除以0来制造无界的元素。

# torch.isinf 是否无界
print(torch.isinf(a/0))
tensor([[True, True, True, True, True],
        [True, True, True, True, True]])

1.12 torch.isfinite

isfinite用于判断是否有界。

# 是否有界
print(torch.isfinite(a/0))
tensor([[False, False, False, False, False],
        [False, False, False, False, False]])

1.13 torch.nan

# 是否是nan
a[0,0]=np.NAN
print(torch.isnan(a))
tensor([[ True, False, False, False, False],
        [False, False, False, False, False]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值