Pytorch中view()、squeeze()、unsqueeze()

原文出处:http://milletpu.com/2018/04/07/pytorch-view/

本篇博客主要向大家介绍Pytorch中view()、squeeze()、unsqueeze()函数,这些函数虽然简单,但是在神经网络编程总却经常用到,希望大家看了这篇博文能够把这些函数的作用弄清楚。

import torch

a=torch.Tensor(2,3)

a


3.8686e+25 9.1836e-39 1.2771e-40

9.0079e+15 1.6751e-37 2.9775e-41

[torch.FloatTensor of size 2x3]

view()函数作用是将一个多行的Tensor,拼接成一行。如果一个数为-1,则表示这个维度的大小会根据另外几个维度的大小自动计算,可以看到2*3=1*6

a.view(1,-1)
a



3.8686e+25 9.1836e-39 1.2771e-40 9.0079e+15 1.6751e-37 2.9775e-41

[torch.FloatTensor of size 1x6]

下面是torch中squeeze()和unsqueeze()两个函数。

先介绍squeeze函数

b=torch.Tensor(1,3)

b


3.3447e+30 6.1237e-43 5.9179e+32

[torch.FloatTensor of size 1x3]


b.squeeze(0)
b

3.3447e+30

6.1237e-43

5.9179e+32

[torch.FloatTensor of size 3]


b.squeeze(1)
b

3.3447e+30 6.1237e-43 5.9179e+32

[torch.FloatTensor of size 1x3]

 

squeeze中的参数0、1分别代表第一、第二维度,squeeze(0)表示如果第一维度值为1,则去掉,否则不变。故b的维度(1,3),可去掉1成(3),但不可去掉3。如果还有更高的维度,则squeeze()中的参数可以逐渐增大。

接下来是unsqueeze函数

c=torch.Tensor(3)
c


7.5589e+28

5.2839e-11

1.8888e+31

[torch.FloatTensor of size 3]

c.unsqueeze(0)
c



7.5589e+28 5.2839e-11 1.8888e+31

[torch.FloatTensor of size 1x3]

c.unsqueeze(1)
c


7.5589e+28

5.2839e-11

1.8888e+31

[torch.FloatTensor of size 3x1]

unsqueeze()与squeeze()作用相反。参数代表的意思相同。unsqueeze中参数的输入范围取决于要处理的向量的维数,比如输入的是一个二维向量,那么参数范围是(-3,2),输入的是一个三维向量,那么参数范围是(-4,3),负数表示倒序。加入的维度总是在指定位置的前面。

  
### PyTorch 中 Tensor 的 `squeeze`、`reshape` 和 `unsqueeze` 函数详解 #### 1. `torch.squeeze()` 函数 `squeeze()` 函数用于移除张量中尺寸为 1 的维度。如果指定了某个特定的维度,则仅当该维度大小为 1 时才会被移除;如果没有指定任何维度参数,则会自动移除所有尺寸为 1 的维度。 - **语法**: ```python torch.squeeze(input, dim=None) ``` - **功能说明**: - 如果不提供 `dim` 参数,那么所有尺寸为 1 的维度都会被删除。 - 如果提供了 `dim` 参数,只有这个维度上的大小为 1 才会被删除,其他情况下不会有任何改变。 - **示例代码**: ```python import torch a = torch.tensor([[[[1], [2], [3]]]]) print("原始张量:", a) print("原始形状:", a.shape) # 输出: (1, 3, 1) b = torch.squeeze(a) print("经过 squeeze 后的张量:", b) print("经过 squeeze 后的形状:", b.shape) # 输出: (3,) c = torch.squeeze(a, dim=0) print("沿第 0 维度挤压后的张量:", c) print("沿第 0 维度挤压后的形状:", c.shape) # 输出: (3, 1) ``` --- #### 2. `torch.unsqueeze()` 函数 `unsqueeze()` 函数的作用是在指定位置增加一个新的维度,新维度的大小默认为 1- **语法**: ```python torch.unsqueeze(input, dim) ``` - **功能说明**: - 新增的维度由 `dim` 参数决定。 - 增加的新维度可以位于任意位置(前、中间或后)。 - **示例代码**: ```python d = torch.tensor([1, 2, 3]) print("原始张量:", d) print("原始形状:", d.shape) # 输出: (3,) e = torch.unsqueeze(d, dim=0) print("新增第 0 维度后的张量:", e) print("新增第 0 维度后的形状:", e.shape) # 输出: (1, 3) f = torch.unsqueeze(d, dim=1) print("新增第 1 维度后的张量:", f) print("新增第 1 维度后的形状:", f.shape) # 输出: (3, 1) ``` --- #### 3. `torch.reshape()` 函数 `reshape()` 函数允许重新定义张量的形状,只要满足总元素数量不变即可。 - **语法**: ```python input.reshape(shape) ``` - **功能说明**: - 可以通过 `-1` 自动推导某一个维度的大小。 - 不同于 `.view()` 方法,`.reshape()` 支持跨内存布局的操作,因此即使原张量不是连续存储也可以成功调用。 - **示例代码**: ```python g = torch.arange(6) print("原始张量:", g) print("原始形状:", g.shape) # 输出: (6,) h = g.reshape(2, 3) print("重塑后的张量:", h) print("重塑后的形状:", h.shape) # 输出: (2, 3) i = g.reshape(-1, 2) print("使用 -1 推导后的张量:", i) print("使用 -1 推导后的形状:", i.shape) # 输出: (3, 2) ``` --- ### 总结对比 | 功能 | 描述 | |--------------|----------------------------------------------------------------------------------------| | `squeeze()` | 移除尺寸为 1 的维度[^1] | | `unsqueeze()`| 在指定位置插入新的维度,其大小固定为 1[^5] | | `reshape()` | 更改张量的整体结构,前提是保持总的元素数目一致[^3] | 上述三个函数均适用于调整张量的维度,在实际应用中可以根据需求灵活组合它们来完成复杂的矩阵变换操作。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值