关于torch.unsqueeze与torch.squeeze

torch内置的unsqueeze方法是增加1的维度,squeeze和unsqueeze相反,是删除数据中1的维度,当指定数据时,删除指定位置的1的维度,当不指定参数时,删除数据中所有1的维度,具体操作如下:​​​​​​​

a =torch.ones(3,3)
b = torch.unsqueeze(a,1)#第二个参数指定在哪一个维度增加1
#b.shape为([3,1,3])
c = b.squeeze()
c.shape
#结果为([3,3])

除了squeeze和unsqueeze之外,还有另一种方法可以增加或删除维度,即利用[None]的形式,具体操作如下,当然也可以指定在某一维度中添加。使用这种方法如果想删除维度可以采用切片的方式将其取出,也可以使用squeeze方法。

a = torch.ones(3,3)
b = a[None]
#b的shape变为[1,3,3]
c = a[:,None,None]
#c的shape变为[3,3,1,1]

### PyTorch `unsqueeze` 方法的使用说明 #### 定义功能 `torch.unsqueeze(input, dim)` 是 PyTorch 提供的一个函数,用于向输入张量(tensor)指定位置增加一个新的维度。新增加的维度大小为 1[^1]。 #### 参数解释 - **input**: 输入的张量。 - **dim**: 新增维度的位置索引。可以是负数,表示从倒数方向计数。 #### 返回值 返回一个具有新维度的张量副本,原张量不会被修改[^4]。 --- #### 示例代码 ##### 示例 1:在第 0 维度添加新的维度 ```python import torch # 创建一个形状为 (3, 5) 的随机张量 x = torch.rand(3, 5) print("原始张量形状:", x.shape) # 输出: torch.Size([3, 5]) # 在第 0 维度插入新维度 y = torch.unsqueeze(x, 0) print("插入新维度后的张量形状:", y.shape) # 输出: torch.Size([1, 3, 5]) ``` 上述代码展示了如何通过 `unsqueeze` 将原本二维的张量扩展到三维。 --- ##### 示例 2:在不同维度上应用 `unsqueeze` ```python # 原始张量形状为 (3, 4) z = torch.randn(3, 4) # 插入新维度至第 1 维度 w = z.unsqueeze(1) print(w.shape) # 输出: torch.Size([3, 1, 4]) # 插入新维度至最后一维(等效于 -1) v = z.unsqueeze(-1) print(v.shape) # 输出: torch.Size([3, 4, 1]) ``` 此示例演示了可以在任意维度上插入新轴的操作[^3]。 --- #### 结合其他操作的应用场景 当需要调整张量形状以便其他张量广播或配合某些层结构时,`unsqueeze` 非常有用。例如,在自然语言处理领域中,可能需要用它来匹配模型输入的要求: ```python mask = torch.tensor([[True, False], [False, True]]) # 形状为 (2, 2) N = 3 # 添加新维度并扩展形状 expanded_mask = mask.unsqueeze(1).expand(-1, N, -1, -1) print(expanded_mask.shape) # 输出: torch.Size([2, 3, 2, 2]) ``` 这里先用 `unsqueeze` 扩展了一个维度,再利用 `expand` 进一步改变形状[^2]。 --- #### 注意事项 1. 如果尝试在一个已经存在的非单一维度处插入新轴,会引发错误。 2. 调整维度顺序通常可以通过 `permute()` 或者类似的重排列方法实现,而不仅仅是依赖 `unsqueeze` 和 `squeeze`。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值