目录
4.2 tensor.unsqueeze指定dim插入新维度
Pytorch张量维度变化是在构建模型过程中常用且重要的操作,本文从实际应用触发,详细介绍常用的维度变化方法,这些方法包含view、reshap、squeeze、unsqueeze、transpose等。
1 view函数
Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。
view函数的操作对象是Tensor类型,返回的对象类型也为Tensor类型
def view(self, *size: _int) -> Tensor: ...
更便于理解的表示形式:
view(参数a,参数b,…),其中,总的参数个数表示将张量重构后的维度。
1.1 指定变换后的维度
通过手工指定,将一个一维tensor变换为3*8维的tensor
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
a2 = a1.view(3, 8)
print(a1)
print(a2)
print(a1.shape)
print(a2.shape)
运行程序显示如下:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24])
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13, 14, 15, 16],
[17, 18, 19, 20, 21, 22, 23, 24]])
torch.Size([24])
torch.Size([3, 8])
1.2 自动推理变换后的维度
如果某个参数为-1,则表示该维度取决于其它维度,由Pytorch自己补充
import torch
a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
a4 = a3.view(4, -1)
a5 = a3.view(2, 3, -1)
a6 = a3.view(-1, 3, 2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a3.shape)
print(a4.shape)
print(a5.shape)
print(a6.shape)
运行程序显示如下:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24]])
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
tensor([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[

最低0.47元/天 解锁文章
3504

被折叠的 条评论
为什么被折叠?



