关于.sum(axis=)的解释
axis = 0的代码:
import torch
c = torch.Tensor([[[1,2],[3,4]],[[11,12],[13,14]]])
print(c.shape)
print(c.sum(axis=0))
输出:
torch.Size([2, 2, 2])
tensor([[12., 14.],
[16., 18.]])
设计一个2X2X2的Tensor,axis=0求和就是各个矩阵对应的位置求和(两个单位块做块与块之间的运算)
做完加法后本应是[[[12, 14], [16, 18]]],但是移除最外层[]后,原来的三层[]变成两层[],所以返回结果为[[12, 14], [16, 18]]
axis = 1的代码:
c = torch.Tensor([[[1,2],[3,4]],[[11,12],[13,14]]])
print(c.shape)
print(c.sum(axis=1))
输出:
tensor([[ 4., 6.],
[24., 26.]])
axis = 1就是单个矩阵列相加(对第二外层[]里的最大单位块做块与块之间的运算,同时移除第二外层[])
axis = 2的代码:
c = torch.Tensor([[[1,2],[3,4]],[[11,12],[13,14]]])
print(c.shape)
print(c.sum(axis=2))
结果:
tensor([[ 3., 7.],
[23., 27.]])
就是单个矩阵的行相加