查了好多博客都似懂非懂,后来写了几个小例子,瞬间一目了然。
目录
一、torch.max()
import torch
a=torch.randn(3)
print("a:\n",a)
print('max(a):',torch.max(a))
b=torch.randn(3,4)
print("b:\n",b)
print('max(b,0):',torch.max(b,0))
print('max(b,1):',torch.max(b,1))
输出:
a:
tensor([ 0.9558, 1.1242, 1.9503])
max(a): tensor(1.9503)
b:
tensor([[ 0.2765, 0.0726, -0.7753, 1.5334],
[ 0.0201, -0.0005, 0.2616, -1.1912],
[-0.6225, 0.6477, 0.8259, 0.3526]])
max(b,0): (tensor([ 0.2765, 0.6477, 0.8259, 1.5334]), tensor([ 0, 2, 2, 0]))
max(b,1): (tensor([ 1.5334, 0.2616, 0.8259]), tensor([ 3, 2, 2]))
max(a),用于一维数据,求出最大值。
max(a,0),计算出数据中一列的最大值,并输出最大值所在的行号。
max(a,1),计算出数据中一行的最大值,并输出最大值所在的列号。
print('max(b,1):',torch.max(b,1)[1])