pytorch的squeeze()和unsqueeze()函数理解

一,官方文档 

torch.squeeze — PyTorch 1.11.0 documentationicon-default.png?t=M4ADhttps://pytorch.org/docs/stable/generated/torch.squeeze.html?highlight=squeeze#torch.squeezetorch.unsqueeze — PyTorch 1.11.0 documentationicon-default.png?t=M4ADhttps://pytorch.org/docs/stable/generated/torch.unsqueeze.html#torch-unsqueeze二,代码理解

import torch

################################################
#创建一个2*3的tensor
a = torch.zeros([2,3])
print(a)#tensor([[0., 0., 0.],
        #       [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])

################################################
#在第0个位置增加一个维度
a = a.unsqueeze(0)
print(a)#tensor([[[0., 0., 0.],
        #       [0., 0., 0.]]])
print(a.shape)#torch.Size([1, 2, 3])

################################################
#在第0个位置减少一个维度 前提是0处的维度大小是1
a = a.squeeze(0)
print(a)#tensor([[0., 0., 0.],
        #       [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])

################################################
#0处的维度不是1,所以不生效
a = a.squeeze(0)
print(a)#tensor([[0., 0., 0.],
        #       [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])

可以看出unsqueeze(dim)函数就是让tensor在dim处增加一个维度;

而squeeze(dim)函数就是让tensor在dim处减少一个维度;但前提是dim处的维度是1,否则squeeze(dim)函数不会生效。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值