numpy.argmax
是 NumPy 中用于返回数组中最大值索引的函数,常用于查找最大值的所在位置(如分类任务中预测概率最大的类别)。以下是其用法详解:
语法
numpy.argmax(a, axis=None, out=None, keepdims=False)
- a: 输入数组。
- axis (可选): 沿指定轴查找,默认 (
None
) 表示将数组展平后查找。 - keepdims (可选): 是否保持维度,默认为
False
。
基本用法
1. 一维数组
import numpy as np
arr = np.array([3, 1, 4, 1, 5, 9])
index = np.argmax(arr)
print(index) # 输出: 5 (最大值为9,索引为5)
2. 二维数组(指定轴)
arr = np.array([[1, 5, 3], [4, 2, 6]])
# 沿行查找(axis=1):每行的最大值索引
row_max_indices = np.argmax(arr, axis=1)
print(row_max_indices) # 输出: [1, 2]
# 沿列查找(axis=0):每列的最大值索引
col_max_indices = np.argmax(arr, axis=0)
print(col_max_indices) # 输出: [1, 0, 1]
3. 高维数组
arr = np.random.rand(2, 3, 4) # 三维数组
# 沿第三个维度(axis=2)查找
result = np.argmax(arr, axis=2)
print(result.shape) # 输出: (2, 3)
参数详解
**axis
参数**
-
**
axis=None
**(默认): 将数组展平为一维后查找。arr = np.array([[1, 2], [3, 4]]) print(np.argmax(arr)) # 输出: 3 (展平后索引为3,对应值4)
-
**
axis=0
**: 沿列查找(纵向)。 -
**
axis=1
**: 沿行查找(横向)。
**keepdims=True
**
保持输出维度与原数组一致:
arr = np.array([[1, 5], [3, 2]])
result = np.argmax(arr, axis=1, keepdims=True)
print(result) # 输出: [[1], [0]] (保持二维形状)
实际应用场景
1. 分类任务中获取预测类别
# 假设模型输出为3个样本的3个类别概率
probabilities = np.array([[0.1, 0.8, 0.1],
[0.3, 0.2, 0.5],
[0.7, 0.1, 0.2]])
predicted_classes = np.argmax(probabilities, axis=1)
print(predicted_classes) # 输出: [1, 2, 0]
2. 找到最大值的坐标(多维数组)
arr = np.array([[3, 7, 2], [4, 1, 6]])
# 找到全局最大值的坐标
max_index_flat = np.argmax(arr)
row = max_index_flat // arr.shape[1] # 行索引
col = max_index_flat % arr.shape[1] # 列索引
print((row, col)) # 输出: (0, 1) (对应值7)
注意事项
-
多个相同最大值: 返回第一个出现的索引。
arr = np.array([5, 3, 5, 2]) print(np.argmax(arr)) # 输出: 0
-
空数组: 对空数组使用会抛出
ValueError
。 -
数据类型: 返回值是整数类型(
np.int64
)。
相关函数
- **
np.max
**: 返回最大值。 - **
np.argmin
**: 返回最小值索引。 - **
np.unravel_index
**: 将展平后的索引转换为多维坐标。max_index_flat = np.argmax(arr) coords = np.unravel_index(max_index_flat, arr.shape) print(coords) # 输出元组,如 (0, 1)
掌握 argmax
可以高效处理数组中的极值问题,尤其在机器学习和数据分析中应用广泛!
参考: