torch.argmax 的使用

torch.argmax函数用于返回张量在指定维度上的最大值索引。输入是一个张量,可以指定维度,keepdim参数决定是否保留输出张量的维度。示例中展示了在一个(3,4)张量上找第0维最大值索引的过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在 PyTorch 中,torch.argmax 是一个函数,用于返回张量中某个维度上最大值的索引,即张量中在某个维度上具有最大值的元素所在的位置。该函数的语法如下:

torch.argmax(input, dim=None, keepdim=False)

其中:

  • input:要在其中查找最大元素的张量。
  • dim:如果指定了此参数,则在指定的维度上查找最大元素。否则,将在整个张量上查找最大元素。
  • keepdim:如果将此参数设置为 True,则将输出张量的形状保持与输入张量的形状相同。

例如,如果我们有一个形状为 (3, 4) 的张量 x,并且想在第 0 个维度上查找最大元素,则可以使用以下代码:

import torch

x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])
max_indices = torch.argmax(x, dim=0)
print(max_indices)  # tensor([2, 2, 2, 2])

在这里,第一个维度的大小为 3,第二个维度的大小为 4。我们在第 0 个维度上使用 torch.argmax 函数来查找 x 中最大值所在的索引。函数返回的张量 max_indices 包含四个元素,分别代表第 0 个维度上最大值所在的位置(在本例中,每个元素的值都是 2,表示第 2 行是第一列、第二列、第三列和第四列中的最大值所在的位置)。

如果指定了 keepdim=True,则输出张量的形状将与输入张量的形状相同,只是在指定的维度上将大小设置为 1:

max_indices = torch.argmax(x, dim=0, keepdim=True)
print(max_indices.size())  # torch.Size([1, 4])
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值