参考:https://blog.youkuaiyun.com/oHongHong/article/details/72772459
参考文章举的例子都非常好,我想把文章的内容更加细节化,方便读者对多维情况进行理解。
1. 函数原型
argmax(a, axis=None, out=None)
包含三个参数:
a—-输入array
axis—-轴,默认为none,即把n维数组的具体数值平铺到1维数组中
out—-结果写到这个array里面
函数作用:沿轴找数组的最大值对应的第一个下标
对于轴我想解释一下,n维数组就有n个轴(除了none),比较常见的:
一维数组:一个轴,x轴方向
二维数组:2个轴,x轴(行方向),y轴(列方向)
三维数组:3个轴,每个参数为一个轴,三个方向,你可以不用想象出来,意会即可。
而axis的取值,对于n维数组来说,则是从0~n-1。即二维数组axis=0或axis=1。三维数组:axis=0 or axis=1 or axis=2。以此类推。
以2x3x4的三维矩阵为例,当axis=0,你可以认为第一维被固定住了,即你需要在a[0][j][k],a[1][j][k] (j=0,1,2,k=0,1,2,3)中找最大值的索引。当axis=1,则是第二维是常数被固定住了,即你需要在a[i][0][k],a[i][1][k] ,a[i][2][k] (i=0,1,k=0,1,2,3)中找最大值的索引。当axis=2,则是第三维是常数被固定住了,即你需要在a[i][j][0],a[i][j][1],a[i][j][2],a[i][j][3] (i=0,1,j=0,1,2)中找最大值的索引。
如果上面的你理解了,你应该能知道最后输出的结果是几乘几的数组。首先输出肯定降一维,三维变二维,除了固定的那一维,剩下的组成新的降维后的矩阵。
2x3x4的三维矩阵,axis=0,最后输出的是3x4的数组。
2x3x4的三维矩阵,axis=1,最后输出的是2x4的数组。
2x3x4的三维矩阵,axis=2,最后输出的是2x3的数组。
如果还是不明白没关系,下面我会用2x3x4的数组进行举例具体解析的。
2. 一维数组举例:实质上就是找数组中最大值
import numpy as np
a = np.array([3, 1, 2, 4, 6, 1])
print(np.argmax(a))
3. 二维数组举例,轴可取none,0,1
1)轴取none:把n维数组压缩为一维数组,再找最大值的下标
import numpy as np
a = np.array([[1, 5, 5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]])
print(np.argmax(a)) #输出为4
本例中,数组压缩完变为:[1, 5, 5, 2,9, 6, 2, 8,3, 7, 9, 1],最大值为9,出现的第一个下标为4。
2)轴取0:
import numpy as np
a = np.array([[1, 5, 5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]])
print(np.argmax(a, axis=0)) #输出为(1,2,2,1)
相当于把第一维固定住了,即在a[0][j],a[1][j],a[2][j] (j=0,1,2,3)中找最大值索引。对于3x4数组,固定了第一维,我们可以提前知道输出的数组必定为一维数组,第一维被固定了,第二维为4,所以必定为4个元素的一维数组。
a[0][j]的含义为:第一行的所有元素
a[1][j]的含义为:第二行的所有元素
a[2][j]的含义为:第三行的所有元素
三组进行比较:即
a[0][j]: 1, 5, 5, 2
a[1][j]: 9, 6, 2, 8
a[2][j]: 3, 7, 9, 1
比较的结果: a[1]大 a[2]大 a[2]大 a[1]大
所以最后结果为(1,2,2,1)。
3)轴取1:
import numpy as np
a = np.array([[1, 5, 5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]])
print(np.argmax(a, axis=1)) #输出为(1,0,2)
相当于把第二维固定住了,即在a[i][0],a[i][1],a[i][2] ,a[i][3] (i=0,1,2)中找最大值索引。对于3x4数组,固定了第二维,我们可以提前知道输出的数组必定为降维成一维数组,固定了第二维,第一维是3,所以必定为3个元素的一维数组。
a[i][0]的含义为:第一列的所有元素
a[i][1]的含义为:第二列的所有元素
a[i][2]的含义为:第三列的所有元素
a[i][3]的含义为:第四列的所有元素
三组进行比较:即
a[i][0]: 1, 9, 3
a[i][1]: 5, 6, 7
a[i][2]: 5, 2, 9
a[i][3]: 2, 8, 1
比较的结果: a[1]大 a[0]大 a[2]大
所以最后结果为(1,0,2)。
4. 三维数组举例,轴可取none,0,1,2
1)轴取none
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[-1, 5, -5, 2],
[3, 7, 9, 1]
]
])
print(np.argmax(a)) #输出为4
本例中,数组压缩完变为:[1, 5, 5, 2,9, -6, 2, 8,-3, 7, -9, 1,-1, 5, -5, 2,-1, 5, -5, 2,3, 7, 9, 1],最大值为9,出现的第一个下标为4。
2)轴取0
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[-1, 5, -5, 2],
[3, 7, 9, 1]
]
])
print(np.argmax(a, axis=0)) #输出为[[0 0 0 0]
# [0 1 0 0]
# [1 0 1 0]]
相当于把第一维固定住了,即在a[0][j][k],a[1][j][k]
(j=0,1,2,k=0,1,2,3)中找最大值索引。对于2x3x4数组,固定了第一维,我们可以提前知道输出的数组必定为二维数组(必降一维),第一维被固定了,剩下两维是3x4,所以必定为3x4的二维数组。
a[0][j][k]的含义为:第一维,第一个的所有元素
a[1][j][k]的含义为:第一维,第二个的所有元素
两组进行比较:即
a[0][j][k]: [1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
a[1][j][k]: [-1, 5, -5, 2],
[-1, 5, -5, 2],
[3, 7, 9, 1]
即两个数组比较大小,对应位置两两比较,结果如下:
a[0]大 a[0]大 a[0]大 a[0]大
a[0]大 a[1]大 a[0]大 a[0]大
a[1]大 a[0]大 a[1]大 a[0]大
所以最后结果为:
[0 0 0 0]
[0 1 0 0]
[1 0 1 0]
3)轴取1
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[-1, 5, -5, 2],
[3, 7, 9, 1]
]
])
print(np.argmax(a, axis=1)) #输出为[[1 2 0 1]
# [1 2 2 1]]
相当于把第二维固定住了,即在a[i][0][k],a[i][1][k],a[i][2][k]
(i=0,1,k=0,1,2,3)中找最大值索引。对于2x3x4数组,固定了第二维,我们可以提前知道输出的数组必定为二维数组(必降一维),第二维被固定了,剩下两维是2x4,所以必定为2x4的二维数组。
a[i][0][k]的含义为:第一维第一个取第二维的第一行,第一维第二个取第二维的第一行,拼接而成
a[i][1][k]的含义为:第一维第一个取第二维的第二行,第一维第二个取第二维的第二行,拼接而成
a[i][2][k]的含义为:第一维第一个取第二维的第三行,第一维第二个取第二维的第三行,拼接而成
三组进行比较:即
a[i][0][k]: [1, 5, 5, 2],
[-1, 5, -5, 2]
a[i][1][k]: [9, -6, 2, 8],
[9, 6, 2, 8]
a[i][2][k]: [-3, 7, -9, 1],
[3, 7, 9, 1]
即三个数组比较大小,对应位置进行比较,结果如下:
a[1]大 a[2]大 a[0]大 a[1]大
a[1]大 a[2]大 a[2]大 a[1]大
所以最后结果为:
[1 2 0 1]
[1 2 2 1]
4)轴取2
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[-1, 5, -5, 2],
[3, 7, 9, 1]
]
])
print(np.argmax(a, axis=2)) #输出为[[1 0 1]
# [1 0 2]]
相当于把第三维固定住了,即在a[i][j][0],a[i][j][1],a[i][j][2],a[i][j][3]
(i=0,1,j=0,1,2)中找最大值索引。对于2x3x4数组,固定了第三维,我们可以提前知道输出的数组必定为二维数组(必降一维),第三维被固定了,剩下两维是2x3,所以必定为2x3的二维数组。
a[i][j][0]的含义为:第一维第一个取第二维的第一列,第一维第二个取第二维的第一列,拼接而成
a[i][j][1]的含义为:第一维第一个取第二维的第二列,第一维第二个取第二维的第二列,拼接而成
a[i][j][2]的含义为:第一维第一个取第二维的第三列,第一维第二个取第二维的第三列,拼接而成
a[i][j][3]的含义为:第一维第一个取第二维的第三列,第一维第二个取第二维的第三列,拼接而成
三组进行比较:即
a[i][j][0]: [1, 9, -3],
[-1, 9, 3]
a[i][j][1]: [5, -6, 7],
[5, 6, 7]
a[i][j][2]: [5, 2, -9],
[-5, 2, 9]
a[i][j][3]: [2, 8, 1],
[2, 8, 1]
即三个数组比较大小,对应位置进行比较,结果如下:
a[1]大 a[0]大 a[1]大
a[1]大 a[0]大 a[2]大
所以最后结果为:
[1 0 1]
[1 0 2]