一,官方文档
torch.squeeze — PyTorch 1.11.0 documentation
https://pytorch.org/docs/stable/generated/torch.squeeze.html?highlight=squeeze#torch.squeezetorch.unsqueeze — PyTorch 1.11.0 documentation
https://pytorch.org/docs/stable/generated/torch.unsqueeze.html#torch-unsqueeze二,代码理解
import torch
################################################
#创建一个2*3的tensor
a = torch.zeros([2,3])
print(a)#tensor([[0., 0., 0.],
# [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])
################################################
#在第0个位置增加一个维度
a = a.unsqueeze(0)
print(a)#tensor([[[0., 0., 0.],
# [0., 0., 0.]]])
print(a.shape)#torch.Size([1, 2, 3])
################################################
#在第0个位置减少一个维度 前提是0处的维度大小是1
a = a.squeeze(0)
print(a)#tensor([[0., 0., 0.],
# [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])
################################################
#0处的维度不是1,所以不生效
a = a.squeeze(0)
print(a)#tensor([[0., 0., 0.],
# [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])
可以看出unsqueeze(dim)函数就是让tensor在dim处增加一个维度;
而squeeze(dim)函数就是让tensor在dim处减少一个维度;但前提是dim处的维度是1,否则squeeze(dim)函数不会生效。