pytorch --- tensor.squeeze(dim)和unsqueeze(dim)

本文详细介绍了PyTorch中张量的squeeze和unsqueeze函数的使用方法,通过实例展示了如何压缩和增加张量的维度,这对于数据预处理和模型构建至关重要。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tensor.squeeze(dim)

作用: 如果dim指定的维度的值为1,则将该维度删除,若指定的维度值不为1,则返回原来的tensor

例子:

x = torch.rand(2,1,3)
print(x)
print(x.squeeze(1))
print(x.squeeze(2))

输出:

tensor([[[0.7031, 0.7173, 0.0606]],

        [[0.6884, 0.4072, 0.0516]]])
        
tensor([[0.7031, 0.7173, 0.0606],
        [0.6884, 0.4072, 0.0516]])

tensor([[[0.7031, 0.7173, 0.0606]],

        [[0.6884, 0.4072, 0.0516]]])

如上结果所示:x.shape=[2, 1, 3] , 第一维度的值为1, 因此x.squeeze(dim=1)的输出会将第一维度去掉,其输出shape=[2,3], 第二维度值不为1, 因此x.squeeze(dim=2)输出tensor的shape不变

tensor.unsqueeze(dim)

这个函数主要是对数据维度进行扩充。给指定位置加上维数为1的维度,比如原本有个三行的数据(3,),在0的位置加了一维就变成一行三列(1,3)。还有一种形式就是b=torch.squeeze(tensor,dim) 就是在tensor中指定位置 dim 加上一个维数为1的维度

例子:

x = torch.rand(2,3)
print(x)
print("x.shape:", x.shape)
y = torch.unsqueeze(x, 1)
print(y)
print("y.shape:", y.shape)
z = x.unsqueeze(2)
print(z)
print("z.shape:", z.shape)

输出:

tensor([[0.1255, 0.7249, 0.5253],
        [0.9247, 0.4592, 0.3944]])
x.shape: torch.Size([2, 3])


tensor([[[0.1255, 0.7249, 0.5253]],

        [[0.9247, 0.4592, 0.3944]]])
y.shape: torch.Size([2, 1, 3])


tensor([[[0.1255],
         [0.7249],
         [0.5253]],

        [[0.9247],
         [0.4592],
         [0.3944]]])
z.shape: torch.Size([2, 3, 1])
[Finished in 2.6s]
### PyTorch 中 `tensor.squeeze(dim)` 函数详解 `tensor.squeeze(dim)` 是一种用于调整张量形状的操作,其主要功能是从指定维度移除大小为 1 的维度。如果未提供具体维度,则会自动移除所有大小为 1 的维度[^1]。 #### 基本语法 函数签名如下: ```python torch.squeeze(input, dim=None, out=None)Tensor ``` - 参数说明: - `input`: 输入张量。 - `dim`: 可选参数,表示要操作的具体维度。如果不指定该参数,则会对整个张量进行处理并移除所有大小为 1 的维度。 - `out`: 输出张量的可选项,默认为空。 当指定了某个特定维度时,只有这个维度上的大小为 1 才会被移除;其他任何尺寸不等于 1 的维度都不会受到影响。 #### 示例代码展示 以下是几个具体的使用案例: ```python import torch # 创建一个具有多余单一维度的张量 example_tensor = torch.tensor([[[1], [2], [3]]]) print(f"Original shape: {example_tensor.shape}") # Output Shape: torch.Size([1, 3, 1]) # 移除所有的单维条目 squeezed_all = example_tensor.squeeze() print(f"Squeeze all dims result shape: {squeezed_all.shape}") # Output Shape: torch.Size([3]) # 特定维度上执行squeeze操作 squeezed_dim_0 = example_tensor.squeeze(0) print(f"Squeeze on dim=0 result shape: {squeezed_dim_0.shape}") # Output Shape: torch.Size([3, 1]) try: squeezed_dim_1 = example_tensor.squeeze(1) except RuntimeError as e: print(e) # 尝试挤压非单位长度的维度将会抛出错误 ``` 上述代码展示了如何通过调用 `.squeeze()` 方法来改变原始张量结构,并验证了不同情况下可能产生的效果以及异常情况下的行为表现。 #### 应注意的地方 尽管可以自由地利用此方法简化模型输入/输出的数据形式,但在实际应用过程中需要注意以下几点事项: - 如果尝试压缩那些并非真正意义上的“冗余”轴线(即它们的实际数值并不恒等于一),那么程序运行期间就会触发相应的报错提示; - 当前版本支持动态图模式下多次反向传播计算场景中的梯度累积机制设置问题可以通过适当配置 retain_graph 来实现正常工作流程管理[^3]。 #### 结论 综上所述,在深度学习框架 PyTorch 当中,`tensor.squeeze(dim)` 提供了一种便捷的方式来消除不必要的低阶特征映射空间表达方式从而优化整体存储效率的同时也便于后续进一步加工处理阶段衔接顺畅过渡。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值