unsqueeze() 方法在 PyTorch 中用于在指定的维度位置插入一个维度大小为 1 的新维度。
tips:
()内指定维度位置,‘0’表示第一个维度位置,以此类推‘1’ ‘2’ ‘3’.......
1.增加一个维度
import torch
# 创建一个形状为 [4] 的一维张量
x = torch.tensor([1, 2, 3, 4])
# 使用 unsqueeze 在第一个维度位置增加一个维度,结果形状变为 [1, 4]
x_unsqueezed = x.unsqueeze(0)
print(x.shape)
print(x)
print('*'*50)
print(x_unsqueezed.shape)
print(x_unsqueezed)
# 输出:
torch.Size([4])
tensor([1, 2, 3, 4])
**************************************************
torch.Size([1, 4])
tensor([[1, 2, 3, 4]])
2.在中间维度插入一个维度
# 创建一个形状为 [3, 4] 的二维张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 使用 unsqueeze 在第二个维度位置增加一个维度,结果形状变为 [3, 1, 4]
x_unsqueezed = x.unsqueeze(1)
print(x.shape)
print(x)
print('*'*50)
print(x_unsqueezed.shape)
print(x_unsqueezed)
# 输出
torch.Size([3, 4])
te