pytorch torch.argmax()函数

torch.argmax() 是 PyTorch 中的一个函数,用于返回张量中最大值的索引。这个函数在处理分类问题或者需要从一组数值中找到最大值的索引时非常有用。

使用方法

torch.argmax(input, dim=None, keepdim=False)

  • input:输入张量,可以是任意形状的。
  • dim:指定在哪个维度上进行操作。如果不指定(即 dim=None),则默认在整个张量上寻找最大值,并返回其扁平化后的索引。
  • keepdim:一个布尔值,指定是否保持输出的维度与输入相同。如果为 True,则输出的张量将具有与输入张量相同的维度(但在所选维度上的大小为 1)。如果为 False(默认值),则输出张量将比输入张量少一维。

返回值

返回一个包含最大值索引的长整型(LongTensor)张量。如果 dim 被指定,则返回的索引是沿着该维度的最大值索引。如果 dim=None,则返回的是整个张量中的最大值索引(扁平化后的索引)。

示例

import torch  
  
# 一维张量  
x = torch.tensor([1.0, 2.0, 3.0])  
# 返回整个张量中的最大值索引  
max_idx = torch.argmax(x)  
print(f"Max index in x: {max_idx}")  # 输出: Max index in x: 2  
  
# 二维张量  
x_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0]])  
# 在每列中寻找最大值的索引  
max_idx_col = torch.argmax(x_2d, dim=0)  
# 在每行中寻找最大值的索引  
max_idx_row = torch.argmax(x_2d, dim=1)  
print(f"Max indices in columns: {max_idx_col}")  # 输出: Max indices in columns: tensor([1, 1])  
print(f"Max indices in rows: {max_idx_row}")     # 输出: Max indices in rows: tensor([1, 0]),注意这里可能看起来有点反直觉,因为索引是从0开始的  
  
# 使用 keepdim 参数  
max_idx_row_keepdim = torch.argmax(x_2d, dim=1, keepdim=True)  
print(f"Max indices in rows with keepdim: {max_idx_row_keepdim}")  # 输出: Max indices in rows with keepdim: tensor([[1], [0]])


注意,在二维张量的例子中,max_idx_col 返回的是在每列中最大值的索引(注意索引是从 0 开始的),而 max_idx_row 返回的是在每行中最大值的索引。这里可能会有点反直觉,特别是当最大值位于张量的右上角时(如上例所示),因为通常我们可能期望索引是从左上角开始计数的,但实际上索引是从 0 开始的,并且是按行优先的顺序排列的。

另外,keepdim 参数允许我们控制输出张量的维度是否与输入张量相同。在这个例子中,当 keepdim=True 时,max_idx_row_keepdim 的形状与 x_2d 在 dim=1 上的维度相同,但在该维度上的大小为 1。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浩瀚之水_csdn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值