1. 大于、大于等于、小于、小于等于、不相等
x = torch.Tensor([[2,3,5],[4,7,9]])
y = torch.Tensor([[2,4,5],[4,8,9]])
z = torch.Tensor([[2,3,5],[4,7,9]])
print(torch.eq(x,y))
print(torch.equal(x,z))
print(torch.equal(x,y))
print(torch.ge(x,y))
print(torch.gt(x,y))
print(torch.le(x,y))
print(torch.lt(x,y))
print(torch.ne(x,y))
--------------------------------------------------------------------------------
result:
tensor([[ True, False, True],
[ True, False, True]])
True
False
tensor([[ True, False, True],
[ True, False, True]])
tensor([[False, False, False],
[False, False, False]])
tensor([[True, True, True],
[True, True, True]])
tensor([[False, True, False],
[False, True, False]])
tensor([[False, True, False],
[False, True, False]])
2. 最大值,最小值
x = torch.Tensor([[2,3,5],[4,7,9]])
print(torch.max(x))
print(torch.max(x,dim=1))
print(torch.min(x))
print(torch.min(x,dim=0))
--------------------------------------------------------------------------------
result:
tensor(9.)
torch.return_types.max(
values=tensor([5., 9.]),
indices=tensor([2, 2]))
tensor(2.)
torch.return_types.min(
values=tensor([2., 3., 5.]),
indices=tensor([0, 0, 0]))
3. 排序
x = torch.Tensor([[10,3,5],[4,20,9]])
print(torch.sort(x))
print(torch.sort(x,descending=True))
print(torch.sort(x,dim=0,descending=True))
print(torch.sort(x,dim=0,descending=False))
--------------------------------------------------------------------------------
result:
torch.return_types.sort(
values=tensor([[ 3., 5., 10.],
[ 4., 9., 20.]]),
indices=tensor([[1, 2, 0],
[0, 2, 1]]))
--------------
torch.return_types.sort(
values=tensor([[10., 5., 3.],
[20., 9., 4.]]),
indices=tensor([[0, 2, 1],
[1, 2, 0]]))
-------------
torch.return_types.sort(
values=tensor([[10., 20., 9.],
[ 4., 3., 5.]]),
indices=tensor([[0, 1, 1],
[1, 0, 0]]))
--------------
torch.return_types.sort(
values=tensor([[ 4., 3., 5.],
[10., 20., 9.]]),
indices=tensor([[1, 0, 0],
[0, 1, 1]]))
4. topk
x = torch.Tensor([[10,3,5],[4,20,9]])
print(torch.topk(x,k=2))
print(torch.topk(x,k=2,largest=False))
print(torch.topk(x,k=2,dim=0))
print(torch.topk(x,k=2,dim=1))
--------------------------------------------------------------------------------
result:
torch.return_types.topk(
values=tensor([[10., 5.],
[20., 9.]]),
indices=tensor([[0, 2],
[1, 2]]))
torch.return_types.topk(
values=tensor([[3., 5.],
[4., 9.]]),
indices=tensor([[1, 2],
[0, 2]]))
torch.return_types.topk(
values=tensor([[10., 20., 9.],
[ 4., 3., 5.]]),
indices=tensor([[0, 1, 1],
[1, 0, 0]]))
torch.return_types.topk(
values=tensor([[10., 5.],
[20., 9.]]),
indices=tensor([[0, 2],
[1, 2]]))