(Python)numpy的argmax用法

这篇博客详细解释了numpy库中argmax函数的用法,特别是在一维、二维和三维数组中的应用。argmax函数返回数组中最大值的索引,当设置axis参数时,可以指定在哪个轴上查找最大值。例如,axis=0时,函数返回每一列的最大值索引;axis=1时,返回每一行的最大值索引。通过实例展示了如何使用argmax来获取最大值的索引,并且深入到三维数组的处理,帮助读者更好地理解和运用该函数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

numpy.argmax 是 NumPy 中用于返回数组中最大值索引的函数,常用于查找最大值的所在位置(如分类任务中预测概率最大的类别)。以下是其用法详解:


语法

numpy.argmax(a, axis=None, out=None, keepdims=False)
  • a: 输入数组。
  • axis (可选): 沿指定轴查找,默认 (None) 表示将数组展平后查找。
  • keepdims (可选): 是否保持维度,默认为 False

基本用法

1. 一维数组
import numpy as np

arr = np.array([3, 1, 4, 1, 5, 9])
index = np.argmax(arr)
print(index)  # 输出: 5 (最大值为9,索引为5)
2. 二维数组(指定轴)​
arr = np.array([[1, 5, 3], [4, 2, 6]])

# 沿行查找(axis=1):每行的最大值索引
row_max_indices = np.argmax(arr, axis=1)
print(row_max_indices)  # 输出: [1, 2]

# 沿列查找(axis=0):每列的最大值索引
col_max_indices = np.argmax(arr, axis=0)
print(col_max_indices)  # 输出: [1, 0, 1]
3. 高维数组
arr = np.random.rand(2, 3, 4)  # 三维数组

# 沿第三个维度(axis=2)查找
result = np.argmax(arr, axis=2)
print(result.shape)  # 输出: (2, 3)

参数详解

​**axis 参数**
  • ​**axis=None**​(默认): 将数组展平为一维后查找。

    arr = np.array([[1, 2], [3, 4]])
    print(np.argmax(arr))  # 输出: 3 (展平后索引为3,对应值4)
  • ​**axis=0**: 沿列查找(纵向)。

  • ​**axis=1**: 沿行查找(横向)。

​**keepdims=True**

保持输出维度与原数组一致:

arr = np.array([[1, 5], [3, 2]])
result = np.argmax(arr, axis=1, keepdims=True)
print(result)  # 输出: [[1], [0]] (保持二维形状)

实际应用场景

1. 分类任务中获取预测类别
# 假设模型输出为3个样本的3个类别概率
probabilities = np.array([[0.1, 0.8, 0.1],
                          [0.3, 0.2, 0.5],
                          [0.7, 0.1, 0.2]])

predicted_classes = np.argmax(probabilities, axis=1)
print(predicted_classes)  # 输出: [1, 2, 0]
2. 找到最大值的坐标(多维数组)​
arr = np.array([[3, 7, 2], [4, 1, 6]])

# 找到全局最大值的坐标
max_index_flat = np.argmax(arr)
row = max_index_flat // arr.shape[1]  # 行索引
col = max_index_flat % arr.shape[1]   # 列索引
print((row, col))  # 输出: (0, 1) (对应值7)

注意事项

  1. 多个相同最大值: 返回第一个出现的索引。

    arr = np.array([5, 3, 5, 2])
    print(np.argmax(arr))  # 输出: 0
  2. 空数组: 对空数组使用会抛出 ValueError

  3. 数据类型: 返回值是整数类型(np.int64)。


相关函数

  • ​**np.max**: 返回最大值。
  • ​**np.argmin**: 返回最小值索引。
  • ​**np.unravel_index**: 将展平后的索引转换为多维坐标。
    max_index_flat = np.argmax(arr)
    coords = np.unravel_index(max_index_flat, arr.shape)
    print(coords)  # 输出元组,如 (0, 1)

掌握 argmax 可以高效处理数组中的极值问题,尤其在机器学习和数据分析中应用广泛!

参考:

  1. (Python)numpy的argmax用法 - 冂冋冏囧 - 博客园
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浩瀚之水_csdn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值