torch.max()函数_跃跃的笔记本的博客-优快云博客_torch.max
简单记法,dim是哪一维,那一维就变1
例如shap是(3,2,5),dim=0时,结果就应该是(1,2,5)。dim=1时,结果应该是(3,1,5)
个人关于dim维度的理解:dim维可以看作最小比较单位,dim维度之后的维度统统视作一个整体参与比较,比较的方法就是对整体内每一个值比较取最大值。
torch.max()在三维张量中的使用_soulworker的博客-优快云博客
import torch
x = torch.rand(3, 2, 5) # 生成随机数
print(x)
>>> tensor([
[[0.2514, 0.7950, 0.9641, 0.0135, 0.2785],
[0.2575, 0.4410, 0.6829, 0.6668, 0.5850]],
[[0.4725, 0.2015, 0.3406, 0.6989, 0.3551],
[0.9674, 0.5781, 0.6250, 0.3404, 0.4238]],
[[0.2377, 0.3673, 0.3647, 0.1027, 0.9024],
[0.0047, 0.0106, 0.4600, 0.6851, 0.7389]]])
x_value_index = torch.max(x, dim=0, keepdim=True) # 最大值和对应索引
print(x_value_index)
>>> torch.return_types.max(
values=tensor([[
[0.4725, 0.7950, 0.9641, 0.6989, 0.9024],
[0.9674, 0.5781, 0.6829, 0.6851, 0.7389]
]]),
indices=tensor([[
[1, 0, 0, 1, 2],
[1, 1, 0, 2, 2]
]])
)
x_value = torch.max(x, 2, keepdim=True)[0] # 单独取出最大值
print(x_value)
>>> tensor([[[0.9641],[0.6829]],[[0.6989],[0.9674]],[[0.9024],[0.7389]]])
x_index = torch.max(x, 2, keepdim=True)[1] #单独取出最大值索引
print(x_index)
>>> tensor([[[2],[2]], [[3],[0]],[[4],[4]]])