在对张量操作的过程中,dim的含义尤为重要,参考了两篇博客的讲解,现梳理如下
视作对某个维度操作
假设定义了某张量a
a = torch.tensor([[1,4],[3,2]])
很容易看出它的形状是2×2的,即有两个维度:第零个维度包含[1,4], [3,2]两个tensor,第一个维度包含了4个(两对)标量。
若想沿某个维度找到最大值,则执行
torch.max(a, dim)
若令dim=0,则是从第零个维度中找到最大值。
tensor([[1, 4],
[3, 2]])
张量之间比较大小,是要比较 每一列元素 的大小,例如有向量x=[x1,x2]x=[x_1,x_2]x=[x1,x2]和x=[y1,y2]x=[y_1,y_2]x=[y1,y2],比较xxx与yyy之间的大小即max(x,y)=[max(x1,y1),max(x2,y2)]\max(x,y)=[\max(x_1,y_1),\max(x_2,y_2)]max(x,y)=[max(x1,y1),max(x2,y2)]。若x1x_1x1仍为向量,则 继续递归地调用自身,直到遇到标量为止。
此时,[max(1,3),max(4,2)][\max(1,3),\max(4,2)][max(1,3),max(4,2)],得到[3,4]
若令dim=1,即第一个维度。此时需要从两对标量中找最大值。即[max(1,4),max(3,2)][\max(1,4),\max(3,2)][max(1,4),max(3,2)],得到[4,3]
对于三维数组,若
b = torch.tensor([[[9, 2], [5, 4]], [[7, 6], [3, 8]]])
当dim=0时,比较[[9, 2], [5, 4]]和[[7, 6], [3, 8]]两个张量之间的大小,即max([[9,2],[5,4]],[[7,6],[3,8]])=[max([9,2],[7,6]),max([5,4],[3,8])]\max([[9, 2], [5, 4]],[[7, 6], [3, 8]])=[\max([9, 2],[7, 6]),\max([5, 4],[3, 8])]max([[9,2],[5,4]],[[7,6],[3,8]])=[max([9,2],[7,6]),max([5,4],[3,8])],可得[[9,6],[5,8]]
当dim=1时,比较[9, 2]、[5, 4]、[7, 6]、[3, 8]四个张量(两对张量)之间的大小,即[max([9,2],[5,4]),max([7,6],[3,8])][\max([9,2],[5,4]),\max([7,6],[3,8])][max([9,2],[5,4]),max([7,6],[3,8])],可得[[9,4],[7,8]]
当dim=2时,比较[9, 2]、[5, 4]、[7, 6]、[3, 8]八个标量(四对张量)之间的大小,即[[max(9,2),max(5,4)],[max(7,6),max(3,8)]][[\max(9,2),\max(5,4)],[\max(7,6),\max(3,8)]][[max(9,2),max(5,4)],[max(7,6),max(3,8)]],可得[[9,5],[7,8]]
https://www.cnblogs.com/flix/p/11262606.html
视为某个维度塌缩
只有dim指定的维度是可变的,其他都是固定不变的。
dim即操作的方向,dim = 0,在行之间操作,列不变。理解成:同一列中每一行之间的比较或者操作,是每一行的比较,因为行是可变的。
https://blog.youkuaiyun.com/qq_41375609/article/details/106078474