torch.argmax()
是 PyTorch 中的一个函数,用于返回张量中最大值的索引。这个函数在处理分类问题或者需要从一组数值中找到最大值的索引时非常有用。
使用方法
torch.argmax(input, dim=None, keepdim=False)
input
:输入张量,可以是任意形状的。dim
:指定在哪个维度上进行操作。如果不指定(即dim=None
),则默认在整个张量上寻找最大值,并返回其扁平化后的索引。keepdim
:一个布尔值,指定是否保持输出的维度与输入相同。如果为True
,则输出的张量将具有与输入张量相同的维度(但在所选维度上的大小为 1)。如果为False
(默认值),则输出张量将比输入张量少一维。
返回值
返回一个包含最大值索引的长整型(LongTensor)张量。如果 dim
被指定,则返回的索引是沿着该维度的最大值索引。如果 dim=None
,则返回的是整个张量中的最大值索引(扁平化后的索引)。
示例
import torch
# 一维张量
x = torch.tensor([1.0, 2.0, 3.0])
# 返回整个张量中的最大值索引
max_idx = torch.argmax(x)
print(f"Max index in x: {max_idx}") # 输出: Max index in x: 2
# 二维张量
x_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
# 在每列中寻找最大值的索引
max_idx_col = torch.argmax(x_2d, dim=0)
# 在每行中寻找最大值的索引
max_idx_row = torch.argmax(x_2d, dim=1)
print(f"Max indices in columns: {max_idx_col}") # 输出: Max indices in columns: tensor([1, 1])
print(f"Max indices in rows: {max_idx_row}") # 输出: Max indices in rows: tensor([1, 0]),注意这里可能看起来有点反直觉,因为索引是从0开始的
# 使用 keepdim 参数
max_idx_row_keepdim = torch.argmax(x_2d, dim=1, keepdim=True)
print(f"Max indices in rows with keepdim: {max_idx_row_keepdim}") # 输出: Max indices in rows with keepdim: tensor([[1], [0]])
注意,在二维张量的例子中,max_idx_col
返回的是在每列中最大值的索引(注意索引是从 0 开始的),而 max_idx_row
返回的是在每行中最大值的索引。这里可能会有点反直觉,特别是当最大值位于张量的右上角时(如上例所示),因为通常我们可能期望索引是从左上角开始计数的,但实际上索引是从 0 开始的,并且是按行优先的顺序排列的。
另外,keepdim
参数允许我们控制输出张量的维度是否与输入张量相同。在这个例子中,当 keepdim=True
时,max_idx_row_keepdim
的形状与 x_2d
在 dim=1
上的维度相同,但在该维度上的大小为 1。