1. argmax在多维张量中返回最大值元素的索引,从而确定最大值在张量中的位置
2.模型输出一个概率分布,用argmax来获取概率最大的类别
import torch
x = torch.tensor([[1, 2], [3, 4]])
max_value = torch.max(x)
max_index = torch.argmax(x)
print("Max value:", max_value)
print("Max index:", max_index)
输出结果如下
Max value: tensor(4)
Max index: tensor(3)
---> 这里的3表示在二维向量中最大值在索引[1, 1]处,按顺序索引为3(从0开始计数)
---------------------------------------------------------------------------------------------------------------------------------也可以固定维度(按列获取最大值索引和按行获取最大值索引)
max_index_dim0 = torch.argmax(x, dim=0)
max_index_dim1 = torch.argmax(x, dim=1)
print("Max index along dimension 0:", max_index_dim0)
print("Max index along dimension 1:", max_index_dim1)
dim = 0表示沿着第一维度(行)来查找索引,dim = 1表示沿着第二维度(列)来查找索引
结果如下:
Max index along dimension 0: tensor([1, 1])
Max index along dimension 1: tensor([1, 1])
--->[1, 1] 表示在每一行中最大值的索引,[1, 1] 表示在每一列中最大值的索引。
这种方法在返回概率最大的索引时经常使用,一般是按照列来确定类型索引(dim = 1)
1万+

被折叠的 条评论
为什么被折叠?



