torch.squeeze()
函数详解
torch.squeeze()
是 PyTorch 中用于 去除张量中维度为 1 的维度(也称为“单维”) 的函数。它常用于去除多余的维度,使张量变得更紧凑,便于后续处理。
1. 函数原型
torch.squeeze(input, dim=None) → Tensor
参数 | 说明 |
---|
input | 输入张量 |
dim (可选) | 只在指定的维度上尝试去除,如果该维度不是 1,则不处理 |
返回值 | 去除指定维度(或所有为 1 的维度)后的张量(新张量) |
2. 基本功能
- 无
dim
参数时:移除所有 shape 为 1
的维度。 - 有
dim
参数时:仅当该维度为 1 时才会被移除。
3. 示例:基本用法
3.1 无 dim
(移除所有单维)
import torch
x = torch.zeros(1, 3, 1, 5)
print("原始形状:", x.shape)
y = torch.squeeze(x)
print("squeeze 后形状:", y.shape)
3.2 指定 dim
(只尝试去除该维度)
x = torch.zeros(1, 3, 1, 5)
y = torch.squeeze(x, dim=2)
print(y.shape)
z = torch.squeeze(x, dim=1)
print(z.shape)
4. 常见用途
场景 | 示例 |
---|
移除 batch 中 shape 为 [batch_size, 1] 的维度 | y = y.squeeze(1) |
与 unsqueeze() 搭配使用 | x = x.unsqueeze(0).squeeze() |
简化模型输出形状 | 把 [batch_size, 1] 输出变为 [batch_size] |
5. 对比:unsqueeze()
操作 | 功能 |
---|
torch.unsqueeze(x, dim) | 在指定位置添加一个维度(1 ) |
torch.squeeze(x, dim) | 去除指定位置的维度(如果为 1 ) |
示例:
x = torch.tensor([1.0, 2.0, 3.0])
x = x.unsqueeze(0)
x = x.squeeze()
6. 注意事项
torch.squeeze()
不会修改原张量,返回的是新张量。- 如果不确定是否会存在维度为 1 的轴,推荐使用
dim
参数更安全。
7. 总结
特性 | 说明 |
---|
功能 | 删除张量中所有或指定维度为 1 的维度 |
常用场景 | 输出降维、模型 reshape、简化计算 |
返回值 | 新张量,原张量不变 |
推荐搭配 | unsqueeze() ,用于调整维度 |
torch.squeeze()
是 PyTorch 中非常常见的张量变形操作,尤其在处理模型输入输出时非常有用。