在PyTorch中,transpose()是一种操作,它交换张量中两个指定维度的位置。实现这一点的关键在于不实际移动数据,而是通过改变张量的元数据(包括步长(stride)和尺寸(size))来达到效果。
举例来说,假设我们有一个形状为(3, 4)的二维张量,其内存布局为行优先(row-major)即C风格的。当我们对这个张量执行transpose(0, 1)操作时,我们期望该张量行变成列,列变成行,即得到一个形状为(4,3)的新视图。
这是通过以下步骤完成的:
-
改变尺寸:改变
size元数据,使得原本第一个维度(行)的大小与第二个维度(列)的大小交换。 -
改变步长:步长(stride)是一个数组,指示了在每个维度上移动一个元素需要跳过的内存位置数。执行
transpose()时,交换了两个维度的步长。在行优先存储的张量中,行的步长通常比列的步长大。 -
不移动数据:实际上数据并没有在内存中移动,只是改变了在这块内存空间上的解释方式。
以下是一个简单的示例:
import torch
# 创建一个 3x4 的张量
x = torch.arange(12

文章详细介绍了PyTorch中的transpose()函数,它通过改变张量的尺寸和步长实现维度交换,无需实际移动数据,提高性能。示例展示了如何在PyTorch中使用transpose()创建新的张量视图,且底层C++实现依赖于ATen库优化内存操作。
最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



