numpy的reshape和transpose机制解释

本文详细解析了NumPy数组ndarray的base和strides属性,探讨了reshape和transpose高效性的原因,指出它们不涉及内存重排,仅改变形状或strides。通过实例演示了base、strides在reshape和transpose中的作用,揭示了shape和strides在操作中的关键作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

reshape和transpose都是非常高效的算子,究其原因,是因为二者均没有在内存中重新排列数据,只是对数据的shape或strides等信息进行了改变。下面分别简介。

ndarray的base和strides属性

为了更好地理解reshape和transpose算子,需要对ndarray的shape, base, strides三个属性有所了解,其中shape很容易理解,就不多说了,下面简单介绍一下base和strides。

base

base参考:https://numpy.org/doc/stable/reference/generated/numpy.ndarray.base.html

如果一个ndarray是通过其他ndarray经过某种操作创建出来的,那么其base就会指向最初的源头。
比如下面例子中,ab, c的源头,所以b.base 和 c.base 都等于a,而a本身没有base,所以是None。

import numpy as np

a = np.array([0, 1, 2, 3, 4, 5])
print(a)  # ==> [0 1 2 3 4 5]
print(a.base)  # ==> None

b = a.reshape([2, 3])
print(b)  # ==>
# [[0 1 2]
#  [3 4 5]]
print(b.base)  # ==> [0 1 2 3 4 5]

c = b.transpose([1, 0])
print(c)  # ==>
# [[0 3]
#  [1 4]
#  [2 5]]
print(c.base)  # ==> [0 1 2 3 4 5]

strides

strides参考:https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html

ndarray的每一个维度(axis)都有一个strides,表示从数组在某个维度进行遍历的内存偏移量

比如在下面的例子中,数组a三个维度的strides分别是(48, 16, 4),意思是:

  • a[0, 0, 0]a[0, 0, 1] = 1的内存偏移量是4字节,1个int型数字是4字节
  • a[0, 0, 0]a[0, 1, 0] = 4的内存偏移量是16字节,因为需要偏移4个int型数字
  • a[0, 0, 0]a[1, 0, 0] = 12的内存偏移量是48字节,因为需要偏移12个int型数字
import numpy as np

a = np.arange(24).reshape([2, 3, 4])

print(a.strides)  # ==> (48, 16, 4)
print(a)  # ==>
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
#
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

reshape

reshape仅仅只是改变了数组的shape属性,比如把shape从 ( 4 , ) (4,) (4,)改成 ( 2 , 2 ) (2,2) (2,2)。通过下面的测试代码,可以明白reshape的下列性质:

  • 如果我们从最后一个维度开始,依次向前循环打印数组的话,会发现无论怎么样reshape,数组打印的顺序不会发生任何变化。也就是说无论reshape多少次,数组打印顺序不变。
  • 类似于python的浅拷贝,reshape之后,尽管变量发生了变化,但是变量内的数据体却未被碰过。下面列子中,改变reshape后的b的第一个值,发现所有相关的变量的第一个值都发生了变化,所以就可以知道,经reshape后,变量用于保存数据的那块内存没有被碰过。
import numpy as np

a = np.arange(4)  # a = torch.arange(4)
print(a)  # ==> [0 1 2 3]
print(a.shape)  # ==> (4,)

b = a.reshape([2, 2])  # b = a.reshape([2, 2])
print(b)  # ==> [[0 1], [2 3]]
print(b.shape)  # ==> (2, 2)

c = b.reshape([-1])  # c = torch.reshape(b, [-1])
print(c)  # ==> [0 1 2 3]

b[0, 0] = 100
print(a)  # ==> [100   1   2   3]
print(b)  # ==> [[100 1], [2 3]]
print(c)  # ==> [100   1   2   3]

transpose

transpose改变了数组的维度(axis)排列顺序。比如对于二维数组,如果我们把两个维度的顺序互换,那就是我们很熟悉的矩阵转置。而transpose可以在更多维度的情况下生效。transpose的入参是输出数组的维度排列顺序,序号从0开始计数。

下面例子中我们改变了transpose后的b的第一个元素的值,发现a也随之改变,说明transpose也没有去碰数组的内存。那么问题来了,既然数组没有在内存中重新排列,那么打印顺序是受什么影响而发生了改变呢?是strides。

import numpy as np

a = np.arange(24).reshape([2, 3, 4])
print(a.base)  # ==>
# [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]

print(a.shape)  # ==> (2, 3, 4)
print(a.strides)  # ==> (48, 16, 4)
print(a)  # ==>
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
#
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

b = a.transpose([1, 2, 0])
print(b.shape)  # ==>(3, 4, 2)
print(b.strides)  # ==> (16, 4, 48)
print(b)  # ==>
# [[[ 0 12]
#   [ 1 13]
#   [ 2 14]
#   [ 3 15]]
#
#  [[ 4 16]
#   [ 5 17]
#   [ 6 18]
#   [ 7 19]]
#
#  [[ 8 20]
#   [ 9 21]
#   [10 22]
#   [11 23]]]


b[0, 0, 0] = 100
print(a)  # ==>
# [[[100   1   2   3]
#   [  4   5   6   7]
#   ...]]]

print(b)  # ==>
# [[[100  12]
#   [  1  13]
#   ...]]]

下面图示一下strides的含义。

在这里插入图片描述

首先明确一个很重要的概念,strides都是相对于base数组而言进行遍历的,所以无论是a还是b,遍历时需要参考的源头都是a.base / b.base,也就是最上面的一维数组。

数组a的strides情况我们前面已经讲过了,接下来主要看看b

  • 由于b.strides最后一个维度的值是48,所以b[0, 0, 1]b[0, 0, 0]b.base中偏移48字节后的数字,也就是12
  • b.strides中间维度的值是4,所以b[0, 1, 0]b[0, 0, 0]b.base中偏移4字节后的数字,也就是1
  • b.strides第一个维度的值是16,所以b[1, 0, 0]b[0, 0, 0]b.base中偏移16字节后的数字,也就是4

所以ranspose操作只是改变了strides的顺序,没有重新排列内存中的数据。

总结

前面我们在解释reshape和transpose的机制时,分别从ndarray的shape和strides属性进行了侧重解释。实际上reshape既改变shape也改变strides,而transpose也可能会改变shape。

但这两个算子均不会在内存中重新排列数据。

### Numpy `transpose` 函数详解 #### 1. 基本概念 Numpy 的 `transpose` 函数用于改变数组的轴顺序,从而实现矩阵转置的效果。对于二维数组而言,这相当于行列互换;而对于多维数组,则可以重新排列各个维度的位置。 #### 2. API 接口说明 函数签名如下所示: ```python numpy.transpose(a, axes=None) ``` 参数解释: - `a`: 输入数组。 - `axes`: 可选参数,默认为 None。如果提供此参数,则按照给定的新轴序来重排数据;如果不给出该参数,则默认反转所有轴的方向[^3]。 返回值是一个视图对象,表示原数组经过变换后的版本。 #### 3. 示例代码展示 ##### 例子一:简单二阶张量(即矩阵)转置 ```python import numpy as np A = np.array([[1, 2], [3, 4]]) print("Original matrix:\n", A) B = np.transpose(A) print("\nTransposed matrix using transpose():\n", B) ``` 输出结果将是原始矩阵与其转置形式之间的对比显示。 ##### 例子二:三维数组沿特定轴方向进行转置操作 考虑一个形状为 `(2, 3, 4)` 的三维数组 X ,现在希望将其第二第三两个轴位置调换得到新的数组 Y 。可以通过下面的方式完成这一目标: ```python X = np.random.rand(2, 3, 4) Y = X.transpose((0, 2, 1)) print('Shape of original array:', X.shape) print('Shape after transposing axis 1 and 2:', Y.shape) ``` 这段程序会打印出转换前后数组的具体尺寸信息[^4]。 #### 4. 结合 reshape 进行复杂的数据重组 有时候不仅需要调整轴的顺序还需要改变某些维度大小,在这种情况下就可以联合使用 `reshape()` `transpose()` 来达到目的。例如先通过 `reshape()` 将高维数据降成低维再做必要的转置处理,反之亦然。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值