对于numpy包中的axis参数的理解
在numpy中,对于多维数组进行sum,mean,min,max,sort的操作时,均会涉及到axis这一参数。那么axis具体是什么呢?
我们先引入一个切片(slices)的概念:
- 如果kkk维数组A∈Rn1×n2×...×nkA \in \mathbb{R}^{n_1\times n_2 \times ...\times n_k}A∈Rn1×n2×...×nk ,则AAA的Slices为固定其中1个索引位置之后形成的k−1k-1k−1维数组
ok,有了slices的概念,我们如何对高维数组进行sum,mean,min,max,sort之类的运算呢?
以sum函数为例。对于一个kkk维数组A∈Rn1×n2×...×nkA \in \mathbb{R}^{n_1\times n_2 \times ...\times n_k}A∈Rn1×n2×...×nk,在axis=i上进行运算就是在第i个维度上对相应的slices进行运算。
以sum函数在axis=i上运算为例,这个过程就相当于计算:
np.sum(A,axis=i)=∑j=1nix[:,:,...,:,j,...,:].np.sum(A,axis=i) = \sum_{j=1}^{n_i}x[:,:,...,:,j,...,:].np.sum(A,axis=i)=j=1∑nix[:,:,...,:,j,...,:].
以sort函数在axis=i上运算为例,这个过程就相当于计算:
np.sort(A,axis=i)=sorted{x[:,:,...,:,j,...,:],j=1,2,⋯ ,ni}.np.sort(A,axis=i) = sorted\{x[:,:,...,:,j,...,:],j=1,2,\cdots,n_i\}.np.sort(A,axis=i)=sorted{x[:,:,...,:,j,...,:],j=1,2,⋯,ni}.
以min函数在axis=i上运算为例,这个过程就相当于计算:
np.min(A,axis=i)=min{x[:,:,...,:,j,...,:],j=1,2,⋯ ,ni}.np.min(A,axis=i) = min\{x[:,:,...,:,j,...,:],j=1,2,\cdots,n_i\}.np.min(A,axis=i)=min{x[:,:,...,:,j,...,:],j=1,2,⋯,ni}.
以mean函数在axis=i上运算为例,这个过程就相当于计算:
np.max(A,axis=i)=mean{x[:,:,...,:,j,...,:],j=1,2,⋯ ,ni}.np.max(A,axis=i) = mean\{x[:,:,...,:,j,...,:],j=1,2,\cdots,n_i\}.np.max(A,axis=i)=mean{x[:,:,...,:,j,...,:],j=1,2,⋯,ni}.
具体操作
我们定义一个2×3×42\times 3 \times 42×3×4的变量data,打印一下它在各个axis上的slices:
import numpy as np
np.random.seed(1)
data = np.random.randint(0,24,(2,3,4))
# data变量在axis=0上的全部slices:
print('data在axis=0上的全部slices:')
for slices in range(data.shape[0]):
print(data[slices,:,:])
print('---------------')
print('data在axis=1上的全部slices:')
for slices in range(data.shape[1]):
print(data[:,slices,:])
print('---------------')
print('data在axis=2上的全部slices:')
for slices in range(data.shape[2]):
print(data[:,:,slices])
print('---------------')
输出:
data在axis=0上的全部slices:
[[ 5 11 12 8]
[ 9 11 5 15]
[ 0 16 1 12]]
---------------
[[ 7 13 6 18]
[20 5 18 20]
[11 10 14 18]]
---------------
data在axis=1上的全部slices:
[[ 5 11 12 8]
[ 7 13 6 18]]
---------------
[[ 9 11 5 15]
[20 5 18 20]]
---------------
[[ 0 16 1 12]
[11 10 14 18]]
---------------
data在axis=2上的全部slices:
[[ 5 9 0]
[ 7 20 11]]
---------------
[[11 11 16]
[13 5 10]]
---------------
[[12 5 1]
[ 6 18 14]]
---------------
[[ 8 15 12]
[18 20 18]]
---------------
np.sum函数
假设现在我们对data变量在axis=0上做sum运算:知道了data在axis=0上的slices,一个for循环就可以解决。
sum_axis0 = 0
for slices in range(data.shape[0]):
sum_axis0+=data[slices,:,:]
print(sum_axis0)
print(sum_axis0==data.sum(axis=0))
输出:
[[12 24 18 26]
[29 16 23 35]
[11 26 15 30]]
[[ True True True True]
[ True True True True]
[ True True True True]]
从上面可以看出,sum作用的基本单元其实就是data在axis=0上的各个slices。同理我们可以得到axis=1,axis=2上面的sum。这里就不写了。
np.mean函数
有了sum的结果,计算mean也是easy了:
sum_axis0 = 0
for slices in range(data.shape[0]):
sum_axis0+=data[slices,:,:]
mean_data = sum_axis0/2
print(mean_data)
print(mean_data==data.mean(axis=0))
输出:
[[ 6. 12. 9. 13. ]
[14.5 8. 11.5 17.5]
[ 5.5 13. 7.5 15. ]]
[[ True True True True]
[ True True True True]
[ True True True True]]
np.sort函数
现在来试一下sort函数。以axis=0为例,data在axis=0上的slices为:
[[ 5 11 12 8]
[ 9 11 5 15]
[ 0 16 1 12]]
---------------
[[ 7 13 6 18]
[20 5 18 20]
[11 10 14 18]]
那么sort函数的作用过程就是(以升序排列为例):依次比较上面两个slices的各个元素。具体过程为:
- 对slices(0)与slices(1)的(0,0)元素进行排序:由于5<7,故排序后的array的(0,0,0)元素为5,(1,0,0)元素为7
- 对slices(0)与slices(1)的(0,1)元素进行排序:由于11<13,故排序后的array的(0,0,1)元素为11,(1,0,1)元素为13
- 对slices(0)与slices(1)的(0,2)元素进行排序:由于6<12,故排序后的array的(0,0,2)元素为6,(1,0,2)元素为12
- …
- 对slices(0)与slices(1)的(2,3)元素进行排序:由于8<18,故排序后的array的(0,2,3)元素为12,(1,2,3)元素为18
具体代码为:
slices0 = data[0,:,:]
slices1 = data[1,:,:]
sorted_data = np.zeros((2,3,4),dtype=int)
_,row,col = data.shape
for i in range(row):
for j in range(col):
sorted_data[0,i,j],sorted_data[1,i,j] = sorted([slices0[i,j],slices1[i,j]])
print(sorted_data)
default_sort_data = np.sort(data,axis=0)
print(sorted_data==default_sort_data)
输出
[[[ 5 11 6 8]
[ 9 5 5 15]
[ 0 10 1 12]]
[[ 7 13 12 18]
[20 11 18 20]
[11 16 14 18]]]
[[[ True True True True]
[ True True True True]
[ True True True True]]
[[ True True True True]
[ True True True True]
[ True True True True]]]
np.min与np.max函数
好哒,那么min和max函数也是手到擒来了
slices0 = data[0,:,:]
slices1 = data[1,:,:]
min_data = np.zeros((3,4),dtype=int)
_,row,col = data.shape
for i in range(row):
for j in range(col):
min_data[i,j]= min([slices0[i,j],slices1[i,j]])
print(min_data)
default_min_data = np.min(data,axis=0)
print(min_data==default_min_data)
输出:
[[ 5 11 6 8]
[ 9 5 5 15]
[ 0 10 1 12]]
[[ True True True True]
[ True True True True]
[ True True True True]]
总结
从以上各个例子中我们还可以发现,sum、mean、min和max函数作用之后得到的结果与axis=i上的slices的维度相同,但是sort函数作用之后的结果是与原始数组的维度相同的。具体原因通过代码也是可以看出来的,简单提一下就是了,sort函数结果与原始数组相同是因为它是先对axis=i下的全部slices进行一个排序操作,然后再把这些排序后的结果合在一起,所以sort函数的结果维度与原始数组维度保持一致。
本文详细解释了在numpy中,如何使用axis参数进行sum、mean、min、max和sort等操作,并通过实例展示了axis参数在不同维度上的效果,强调了sort函数在保持原始数组维度上的特性。
920

被折叠的 条评论
为什么被折叠?



