squeeze()的函数定义:
torch.
squeeze
(input, dim=None, out=None) → Tensor
返回一个张量,其中所有大小为1的输入的维都已删除。
举个例子,如果输入张量的shape为(A×1×B×C×1×D) ,那么输出张量的shape是(A×B×C×D) .
如果指定了dim,则仅在给定维度上执行挤压操作。如果输入的形状为:(A×1×B),则squeeze(input,0)保持张量不变,但squeeze(input,1)会将张量压缩为形状(A×B)。
例子:
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
unsqueeze()的函数定义:
torch.
unsqueeze
(input, dim, out=None) → Tensor
返回在指定位置插入尺寸为1的新张量。
例子:
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1, 2, 3, 4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
[ 2],
[ 3],
[ 4]])
注意:unsqueeze()和squeeze()返回的张量和原张量共用存储空间,改变其中一个另外一个都会改变。