max(tensor, a)[b] 或 tensor.max(a)[b]
tensor:要进行操作的数据(张量)
a:要进行操作的维度
b:要输出的内容类型;b=0时输出最大值数据,b=1时输出最大值索引
-----------------------------------------------------------------------------------------------
例1:
import torch
tensor = torch.randn(3, 5)
print('tensor: ', tensor)
print('tensor.max(0)[0]: ', tensor.max(0)[0]) # 每一列的最大值
print('tensor.max(0)[1]: ', tensor.max(0)[1]) # 每一列最大值对应的索引
print('tensor.max(1)[0]: ', tensor.max(1)[0]) # 每一行的最大值
print('tensor.max(1)[1]: ', tensor.max(1)[1]) # 每一行最大值对应的索引
运行结果:
tensor: tensor([[ 1.0385, -1.4738, 0.0067, 1.7649, -0.9139],
[-0.9097, -0.2890, 2.0857, 0.7173, -1.1650],
[-0.6127, -2.2111, -0.5820, 0.2978, 0.9971]])
tensor.max(0)[0]: tensor([ 1.0385, -0.2890, 2.085