Pytorch中常见函数1

函数:unfold(), view(), reshape(), permute(), transpose(), flatten(), cat(), chunk(), split(),
stack(), take(), tile(), unbind(), squeeze(), unsqueeze(), where(), full(), cumprod(), gather()


一、Tensor操作的函数

1.unfold()

unfold函数相当于一维滑动窗口的操作
unfold(dimension, size, step) -> Tensor
dimension->表示从哪个维度展开滑动
size->滑动窗口的大小
step->每次滑动的步长
对于一维的情况
代码如下(示例):

x = torch.arange(1., 5)
print(x)
x = x.unfold(0, 2, 1)   
print(x)
print(x.shape)

最终输出的x是3个大小为2的窗口
输出结果如下:

tensor([1., 2., 3., 4.])
tensor([[1., 2.],
        [2., 3.],
        [3., 4.]])
torch.Size([3, 2])

对于二维的情况
连续两次使用unfold()函数,就变成了一个2维滑动窗口的效果

x = torch.randn(4, 3)
print(x)
x = x.unfold(0, 2, 1).unfold(1, 2, 1)  # 获得了3×2个大小为2×2的窗口
print(x)
print(x.shape)

输出结果:

tensor([[-1.0882, -0.9543,  0.7692],
        [ 1.0062, -0.5342,  2.0376],
        [-1.2618, -0.2669,  0.0763],
        [ 0.8875, -1.6452, -0.7613]])
tensor([[[[-1.0882, -0.9543],
          [ 1.0062, -0.5342]],

         [[-0.9543,  0.7692],
          [-0.5342,  2.0376]]],


        [[[ 1.0062, -0.5342],
          [-1.2618, -0.2669]],

         [[-0.5342,  2.0376],
          [-0.2669,  0.0763]]],


        [[[-1.2618, -0.2669],
          [ 0.8875, -1.6452]],

         [[-0.2669,  0.0763],
          [-1.6452, -0.7613]]]])
torch.Size([3, 2, 2, 2])

2.view()

view()将tensor重塑为想要的shape。其操作是将张量展平成一维之后,再排列成想要的形状。
代码如下(示例):

x = torch.randn(4, 3)
print(x)
x = x.view((3, 4))
print(x)
x = x.view(-1)	# 将x张量展平
print(x)

# 变形后不知道其中一个维度的大小,可用-1表示
x = x.view(2, -1, 2)
print(x.shape)

tips:view()函数,从3×4的形状变到4×3,这个操作与转置操作不同。
输出结果:

tensor([[ 0.2126, -0.8827,  1.2574],
        [ 0.7362,  0.2389, -0.9048],
        [ 0.6516, -1.2950,  0.7889],
        [ 1.5935,  1.4653,  0.5844]])
tensor([[ 0.212
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值