torch.max()——数组的最大值
torch.max()有两种形式
形式Ⅰ
torch.max(input) → Tensor
功能:输出数组的最大值
注意:
- 只有一个输入,只需要输入一个数组
- 该方式也可以通过
a.max()
实现,后者是求数组a
的最大值
形式Ⅱ
torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)
功能:按指定维度判断,返回数组的最大值以及最大值处的索引
输入:
input
:待判定的数组dim
:给定的维度keepdim
:如果指定为True
,则输出的张量数组维数和输入一致,并且除了dim
维度是1,其余的维度大小和输入数组维度大小一致。如果改为False
,则相当于将True
的结果压缩了(删去了大小是1的dim
维度)。两者的差别就在于是否保留dim
维度。
注意:
- 如果在指定的维度中,有多个重复的最大值,则返回第一个最大值的索引
- 该函数返回由最大值以及最大值处的索引组成元组
(max,max_indices)
- 该函数也可以通过
a.max()
实现,后者是求数组a
的最大值,只需要再指明dim
以及keepdim
即可 - 输出的索引是对应的
dim
维度上的索引,注意含义。
上述两种函数形式本质区别就是有没有指出dim
,如果未指出dim
,则返回整个数组的最大值,不返回索引。如果指出了dim
,则在指定的维度上搜索最大值,返回最大值以及索引。
代码案例
一般用法
import torch
a=torch.arange(10).reshape(2,5)
b=torch.max(a)
c=torch.max(a,0)
print(a)
print(b)
print(c)
输出
# 原数组
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
# 不输入dim时,返回数组的最大值
tensor(9)
# 输入dim时,返回指定维度的最大值及索引
torch.return_types.max(
# 第一个是值,按dim=0比较得到的
values=tensor([5, 6, 7, 8, 9]),
# 第二个是值对应的索引,对应维度dim=0上的索引
indices=tensor([1, 1, 1, 1, 1]))
keepdim
定为True
或者Flase
的区别
import torch
a=torch.arange(10).reshape(2,5)
b=torch.max(a,0,True)
c=torch.max(a,0)
print(a.shape)
# 这里定为0和1都一样,值与索引具有同样的形状
print(b[0].shape)
print(c[0].shape)
输出,只有维度不同
torch.Size([2, 5])
torch.Size([1, 5])
torch.Size([5])
不同的dim
对结果的影响,这里以input
的维数是3为例,维数更多的可以类推,首先定义好数组input
。
import torch
a=torch.arange(32).reshape(2,4,4)
print(a)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]])
dim
为0
b=torch.max(a,0)
print(b)
输出:
torch.return_types.max(
values=tensor([[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]),
indices=tensor([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]]))
dim
为1
c=torch.max(a,1)
print(c)
输出
torch.return_types.max(
values=tensor([[12, 13, 14, 15],
[28, 29, 30, 31]]),
indices=tensor([[3, 3, 3, 3],
[3, 3, 3, 3]]))
dim
为2
d=torch.max(a,2)
print(d)
输出
torch.return_types.max(
values=tensor([[ 3, 7, 11, 15],
[19, 23, 27, 31]]),
indices=tensor([[3, 3, 3, 3],
[3, 3, 3, 3]]))
扩展
在分类任务中经常用到该函数,softmax
函数输出得到的概率再经过torch.max
函数得到最终的预测结果(预测结果一般和索引值一一对应,因此可以用索引值来表示预测结果),可以进一步与标签做比较,得到准确率。
import numpy as np
import torch
# 假设batch_size=16,每批次有20个概率值(即20分类)
# 网络结构会得到16行20列的数组
a = torch.tensor(np.random.rand(16, 20))
# 索引号代表预测结果
pre = torch.max(a,dim=1)[1]
print(pre)
print(pre.shape)
输出,共有16个结果(对应batch_size=16)
tensor([11, 10, 1, 9, 19, 19, 15, 14, 5, 2, 17, 14, 13, 15, 15, 17])
torch.Size([16])
官方文档
torch.max():https://pytorch.org/docs/stable/generated/torch.max.html?highlight=torch%20max#torch.max