【PyTorch】torch.squeeze() 函数:去除张量中维度为 1 的维度

torch.squeeze() 函数详解

torch.squeeze() 是 PyTorch 中用于 去除张量中维度为 1 的维度(也称为“单维”) 的函数。它常用于去除多余的维度,使张量变得更紧凑,便于后续处理。


1. 函数原型

torch.squeeze(input, dim=None) → Tensor
参数说明
input输入张量
dim (可选)只在指定的维度上尝试去除,如果该维度不是 1,则不处理
返回值去除指定维度(或所有为 1 的维度)后的张量(新张量)

2. 基本功能

  • dim 参数时:移除所有 shape 为 1 的维度。
  • dim 参数时:仅当该维度为 1 时才会被移除。

3. 示例:基本用法

3.1 无 dim(移除所有单维)

import torch

x = torch.zeros(1, 3, 1, 5)
print("原始形状:", x.shape)  # torch.Size([1, 3, 1, 5])

y = torch.squeeze(x)
print("squeeze 后形状:", y.shape)  # torch.Size([3, 5])

3.2 指定 dim(只尝试去除该维度)

x = torch.zeros(1, 3, 1, 5)

# 只尝试去除 dim=2(成功)
y = torch.squeeze(x, dim=2)
print(y.shape)  # torch.Size([1, 3, 5])

# 尝试去除 dim=1(失败,因为 dim=1 的值是 3 ≠ 1)
z = torch.squeeze(x, dim=1)
print(z.shape)  # torch.Size([1, 3, 1, 5])

4. 常见用途

场景示例
移除 batch 中 shape 为 [batch_size, 1] 的维度y = y.squeeze(1)
unsqueeze() 搭配使用x = x.unsqueeze(0).squeeze()
简化模型输出形状[batch_size, 1] 输出变为 [batch_size]

5. 对比:unsqueeze()

操作功能
torch.unsqueeze(x, dim)在指定位置添加一个维度(1
torch.squeeze(x, dim)去除指定位置的维度(如果为 1

示例:

x = torch.tensor([1.0, 2.0, 3.0])     # shape: (3,)
x = x.unsqueeze(0)                   # shape: (1, 3)
x = x.squeeze()                      # shape: (3,)

6. 注意事项

  • torch.squeeze() 不会修改原张量,返回的是新张量。
  • 如果不确定是否会存在维度为 1 的轴,推荐使用 dim 参数更安全

7. 总结

特性说明
功能删除张量中所有或指定维度为 1 的维度
常用场景输出降维、模型 reshape、简化计算
返回值新张量,原张量不变
推荐搭配unsqueeze(),用于调整维度

torch.squeeze() 是 PyTorch 中非常常见的张量变形操作,尤其在处理模型输入输出时非常有用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值