torch.unsqueeze()与torch.squeeze()用法

本文介绍了PyTorch中的unsqueeze和squeeze函数,用于调整Tensor的维度。unsqueeze能在指定位置插入一个维度,而squeeze则删除尺寸为1的维度。示例中展示了如何在不同位置增加和移除维度,并解释了只有尺寸为1的维度才能被squeeze函数删除。
部署运行你感兴趣的模型镜像

函数描述:

unsqueeze(input, dim) → Tensor

作用:在指定位置插入一个维度,对数据维度进行扩充

input:输入的Tensor
dim:要插入的维度

a = torch.arange(6).reshape(2, 3)
print(a)
b = a.unsqueeze(1)#在第2维度加一维度
print(b)
print(b.shape)
>>>
tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[[0, 1, 2]],

        [[3, 4, 5]]])
torch.Size([2, 1, 3])

函数描述:

squeeze(input, dim) → Tensor

作用:对数据维度进行压缩

a = torch.arange(12).reshape(1, 2, 6)
print(a)
a1 = a.squeeze(0)#将第一个维度去掉 
print(a1)
print(a1.shape)
>>>
tensor([[[ 0,  1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10, 11]]])
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])
torch.Size([2, 6])
-----------------------
a2 = a.squeeze(-1)#最后一个维度并没有被去掉,因为不为1
print(a2)
print(a2.shape)
>>>
tensor([[[ 0,  1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10, 11]]])
torch.Size([1, 2, 6])

CJ:只有维度为1的才能被去掉

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

### PyTorch `unsqueeze` 方法的使用说明 #### 定义功能 `torch.unsqueeze(input, dim)` 是 PyTorch 提供的一个函数,用于向输入张量(tensor)指定位置增加一个新的维度。新增加的维度大小为 1[^1]。 #### 参数解释 - **input**: 输入的张量。 - **dim**: 新增维度的位置索引。可以是负数,表示从倒数方向计数。 #### 返回值 返回一个具有新维度的张量副本,原张量不会被修改[^4]。 --- #### 示例代码 ##### 示例 1:在第 0 维度添加新的维度 ```python import torch # 创建一个形状为 (3, 5) 的随机张量 x = torch.rand(3, 5) print("原始张量形状:", x.shape) # 输出: torch.Size([3, 5]) # 在第 0 维度插入新维度 y = torch.unsqueeze(x, 0) print("插入新维度后的张量形状:", y.shape) # 输出: torch.Size([1, 3, 5]) ``` 上述代码展示了如何通过 `unsqueeze` 将原本二维的张量扩展到三维。 --- ##### 示例 2:在不同维度上应用 `unsqueeze` ```python # 原始张量形状为 (3, 4) z = torch.randn(3, 4) # 插入新维度至第 1 维度 w = z.unsqueeze(1) print(w.shape) # 输出: torch.Size([3, 1, 4]) # 插入新维度至最后一维(等效于 -1) v = z.unsqueeze(-1) print(v.shape) # 输出: torch.Size([3, 4, 1]) ``` 此示例演示了可以在任意维度上插入新轴的操作[^3]。 --- #### 结合其他操作的应用场景 当需要调整张量形状以便其他张量广播或配合某些层结构时,`unsqueeze` 非常有用。例如,在自然语言处理领域中,可能需要用它来匹配模型输入的要求: ```python mask = torch.tensor([[True, False], [False, True]]) # 形状为 (2, 2) N = 3 # 添加新维度并扩展形状 expanded_mask = mask.unsqueeze(1).expand(-1, N, -1, -1) print(expanded_mask.shape) # 输出: torch.Size([2, 3, 2, 2]) ``` 这里先用 `unsqueeze` 扩展了一个维度,再利用 `expand` 进一步改变形状[^2]。 --- #### 注意事项 1. 如果尝试在一个已经存在的非单一维度处插入新轴,会引发错误。 2. 调整维度顺序通常可以通过 `permute()` 或者类似的重排列方法实现,而不仅仅是依赖 `unsqueeze` 和 `squeeze`。 ---
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值