【无标题】

函数

torch.roll(input, shifts, dims=None)

作用

  • 沿给定维度滚动张量输入。
  • 超出最后一个位置的元素在第一个位置重新引入。
  • 如果 dims 为 None,则张量将在滚动前被展平,然后恢复到原始形状。

参数

  • input (Tensor) : 输入张量
  • shifts (int or tuple of python: ints) : 张量中元素(按照维度)移动的位置数。
    • 如果 shifts 是元组,则 dims 必须与 input 的 shape 大小一致,每个维度都会移动对应的值
  • dims (int or tuple of python: ints) :沿着 dims 指定的维度进行滚动

举例 1

import torch

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
print(x)

x1 = torch.roll(x, 1)
print(x1)

在这里插入图片描述
注 : 因为函数中没有指定参数 dim,所以操作是会先将 x 展平为 [1, 2, 3, 4, 5, 6, 7, 8] , 移动 1 位,变成 [8, 1, 2, 3, 4, 5, 6, 7] ,再 reshape 成原来的尺寸,得到 tensor([[8, 1], [2, 3], [4, 5], [6, 7]])


举例 2

import torch

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
print(x)
# tensor([[1, 2],
#         [3, 4],
#         [5, 6],
#         [7, 8]])

x2 = torch.roll(x, 1, 0)
print(x2)
# tensor([[7, 8],
#         [1, 2],
#         [3, 4],
#         [5, 6]])

x3 = torch.roll(x, -1, 0)
print(x3)
# tensor([[3, 4],
#         [5, 6],
#         [7, 8],
#         [1, 2]])

x4 = torch.roll(x, shifts=(2, 1), dims=(0, 1))
print(x4)
# tensor([[6, 5],
#         [8, 7],
#         [2, 1],
#         [4, 3]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Enzo 想砸电脑

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值