一、view/reshape
In [1]: import torch
In [2]: a = torch.rand(4,1,28,28)
In [3]: a.shape
Out[3]: torch.Size([4, 1, 28, 28])
In [4]: a.view(4,28*28)
Out[4]:
tensor([[0.3060, 0.2680, 0.3763, ..., 0.6596, 0.5645, 0.8772],
[0.7751, 0.9969, 0.0864, ..., 0.7230, 0.4374, 0.1659],
[0.5146, 0.5350, 0.1214, ..., 0.2056, 0.2646, 0.8539],
[0.5737, 0.5637, 0.2420, ..., 0.7731, 0.6198, 0.6113]])
In [5]: a.view(4,28*28).shape
Out[5]: torch.Size([4, 784])
In [6]: a.view(4*28,28).shape
Out[6]: torch.Size([112, 28])
In [7]: a.view(4*1,28,28).shape
Out[7]: torch.Size([4, 28, 28])
In [8]: b = a.view(4,784)
二、squeeze/unsqueeze
unsqueeze增加维度
In [12]: a.unsqueeze(0).shape
Out[12]: torch.Size([1, 4, 1, 28, 28])
In [13]: a.shape
Out[13]: torch.Size([4, 1, 28, 28])
In [14]: a.unsqueeze(0).shape
Out[14]: torch.Size([1, 4, 1, 28, 28])
In [15]: a.unsqueeze(-1).shape
Out[15]: torch.Size([4, 1, 28, 28, 1])
In [16]: a.unsqueeze(4).shape
Out[16]: torch.Size([4, 1, 28, 28, 1])
In [17]: a.unsqueeze(-4).shape
Out[17]: torch.Size([4, 1, 1, 28, 28])
In [18]: a.unsqueeze(-5).shape
Out[18]: torch.Size([1, 4, 1, 28, 28])
In [19]: a.unsqueeze(5).shape
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-19-b54eab361a50> in <module>
----> 1 a.unsqueeze(5).shape
RuntimeError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
In [20]: a = torch.tensor([1.2,2.3])
In [21]: a.unsqueeze(-1).shape
Out[21]: torch.Size([2, 1])
In [22]: a.shape
Out[22]: torch.Size([2])
In [23]: a.unsqueeze(-1)
Out[23]:
tensor([[1.2000],
[2.3000]])
In [24]: a
Out[24]: tensor([1.2000, 2.3000])
In [25]: a.unsqueeze(0)
Out[25]: tensor([[1.2000, 2.3000]])
In [26]: a.unsqueeze(0).shape
Out[26]: torch.Size([1, 2])
squeeze压缩维度
In [27]: b = torch.rand(1,32,1,1)
In [28]: b.shape
Out[28]: torch.Size([1, 32, 1, 1])
In [29]: b.squeeze().shape
Out[29]: torch.Size([32])
In [30]: b.squeeze(0).shape
Out[30]: torch.Size([32, 1, 1])
In [31]: b.squeeze(-1).shape
Out[31]: torch.Size([1, 32, 1])
In [32]: b.squeeze(1).shape
Out[32]: torch.Size([1, 32, 1, 1])
In [33]: b.squeeze(-1).shape
Out[33]: torch.Size([1, 32, 1])
In [34]: b.squeeze(-4).shape
Out[34]: torch.Size([32, 1, 1])
三、transpose/t/permute
transpose一次只能完成数据内部的两两交换;
permute可以完成数据内部的任意交换
In [15]: a.shape
Out[15]: torch.Size([4, 3, 32, 32])
In [16]: a1 = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-16-262e84a7fdcb> in <module>
----> 1 a1 = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at c:\a\w\1\s\windows\pytorch\aten\src\th\generic/THTensor.cpp:213
In [17]: a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
In [18]: a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
In [19]: a1.shape,a2.shape
Out[19]: (torch.Size([4, 3, 32, 32]), torch.Size([4, 3, 32, 32]))
In [20]: torch.all(torch.eq(a,a1))
Out[20]: tensor(0, dtype=torch.uint8)
In [21]: torch.all(torch.eq(a,a2))
Out[21]: tensor(1, dtype=torch.uint8)
In [22]: a = torch.rand(4,3,28,28)
In [23]: a.transpose(1,3).shape
Out[23]: torch.Size([4, 28, 28, 3])
In [24]: b = torch.rand(4,3,28,32)
In [25]: b.transpose(1,3).shape
Out[25]: torch.Size([4, 32, 28, 3])
In [26]: b.transpose(1,3).transpose(1,2).shape
Out[26]: torch.Size([4, 28, 32, 3])
In [27]: b.permute(0,2,3,1).shape
Out[27]: torch.Size([4, 28, 32, 3])
四、expand/repeat
expand维度扩展在有需要的时候增加数据,不增加数据时只是改变数据的理解方式。
repeat增加了数据
In [35]: a = torch.rand(4,32,14,14)
In [36]: b = torch.rand(1,32,1,1)
In [37]: a.shape
Out[37]: torch.Size([4, 32, 14, 14])
In [38]: b.shape
Out[38]: torch.Size([1, 32, 1, 1])
In [39]: b.expand(4,32,14,14).shape#维度扩展,复制了数据
Out[39]: torch.Size([4, 32, 14, 14])
In [40]: b.expand(-1,32,-1,-1).shape#维度扩展,未增加数据
Out[40]: torch.Size([1, 32, 1, 1])
In [41]: b.expand(-1,32,-1,-4).shape#bug
Out[41]: torch.Size([1, 32, 1, -4])
In [42]: b.shape
Out[42]: torch.Size([1, 32, 1, 1])
In [43]: b.repeat(4,1,1,1).shape#dim0复制4次
Out[43]: torch.Size([4, 32, 1, 1])
In [44]: b.repeat(4,32,1,1).shape#dim0复制4次,dim1复制32次
Out[44]: torch.Size([4, 1024, 1, 1])
In [45]: b.repeat(4,1,32,32).shape#dim0复制4次,dim2复制32次,dim3复制32次
Out[45]: torch.Size([4, 32, 32, 32])