PyTorch中的 squeeze
和 unsqueeze
操作
1. squeeze()
操作
torch.squeeze()
用于去除张量中大小为1的维度。例如,形状为 (1, 3, 1, 4)
的张量,经过 squeeze()
后将变为形状为 (3, 4)
的张量。
示例:
import torch
# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)
print("原始张量:", x.shape)
# 使用 squeeze 去除大小为 1 的维度
x_squeezed = x.squeeze()
print("squeeze 后的张量:", x_squeezed.shape)
输出:
原始张量: torch.Size([1, 3, 1, 4])
squeeze 后的张量: torch.Size([3, 4])
squeeze()
去除了大小为1的维度,如1
和1
。
2. unsqueeze()
操作
torch.unsqueeze()
用于在指定的位置添加一个大小为1的维度。例如,形状为 (3, 4)
的张量,经过 unsqueeze(0)
后将变为形状为 (1, 3, 4)
的张量。
示例:
import torch
# 创建一个形状为 (3, 4) 的张量
x = torch.randn(3, 4)
print("原始张量:", x.shape)
# 使用 unsqueeze 在第 0 维添加一个大小为 1 的维度
x_unsqueezed = x.unsqueeze(0)
print("unsqueeze 后的张量:", x_unsqueezed.shape)
输出:
原始张量: torch.Size([3, 4])
unsqueeze 后的张量: torch.Size([1, 3, 4])
unsqueeze(0)
在第一个维度位置添加了一个大小为1的维度,变成了(1, 3, 4)
。
总结
squeeze()
:移除所有大小为1的维度,或根据需要移除特定维度。unsqueeze()
:在指定的维度位置插入一个大小为1的维度。