对于numpy包中的axis参数的理解

本文详细解释了在numpy中,如何使用axis参数进行sum、mean、min、max和sort等操作,并通过实例展示了axis参数在不同维度上的效果,强调了sort函数在保持原始数组维度上的特性。

对于numpy包中的axis参数的理解

在numpy中,对于多维数组进行sum,mean,min,max,sort的操作时,均会涉及到axis这一参数。那么axis具体是什么呢?

我们先引入一个切片(slices)的概念:

  • 如果kkk维数组A∈Rn1×n2×...×nkA \in \mathbb{R}^{n_1\times n_2 \times ...\times n_k}ARn1×n2×...×nk ,则AAA的Slices为固定其中1个索引位置之后形成的k−1k-1k1维数组

ok,有了slices的概念,我们如何对高维数组进行sum,mean,min,max,sort之类的运算呢?

以sum函数为例。对于一个kkk维数组A∈Rn1×n2×...×nkA \in \mathbb{R}^{n_1\times n_2 \times ...\times n_k}ARn1×n2×...×nk,在axis=i上进行运算就是在第i个维度上对相应的slices进行运算。

以sum函数在axis=i上运算为例,这个过程就相当于计算:
np.sum(A,axis=i)=∑j=1nix[:,:,...,:,j,...,:].np.sum(A,axis=i) = \sum_{j=1}^{n_i}x[:,:,...,:,j,...,:].np.sum(A,axis=i)=j=1nix[:,:,...,:,j,...,:].

以sort函数在axis=i上运算为例,这个过程就相当于计算:
np.sort(A,axis=i)=sorted{x[:,:,...,:,j,...,:],j=1,2,⋯ ,ni}.np.sort(A,axis=i) = sorted\{x[:,:,...,:,j,...,:],j=1,2,\cdots,n_i\}.np.sort(A,axis=i)=sorted{x[:,:,...,:,j,...,:],j=1,2,,ni}.

以min函数在axis=i上运算为例,这个过程就相当于计算:
np.min(A,axis=i)=min{x[:,:,...,:,j,...,:],j=1,2,⋯ ,ni}.np.min(A,axis=i) = min\{x[:,:,...,:,j,...,:],j=1,2,\cdots,n_i\}.np.min(A,axis=i)=min{x[:,:,...,:,j,...,:],j=1,2,,ni}.

以mean函数在axis=i上运算为例,这个过程就相当于计算:
np.max(A,axis=i)=mean{x[:,:,...,:,j,...,:],j=1,2,⋯ ,ni}.np.max(A,axis=i) = mean\{x[:,:,...,:,j,...,:],j=1,2,\cdots,n_i\}.np.max(A,axis=i)=mean{x[:,:,...,:,j,...,:],j=1,2,,ni}.

具体操作

我们定义一个2×3×42\times 3 \times 42×3×4的变量data,打印一下它在各个axis上的slices:

import numpy as np

np.random.seed(1)
data = np.random.randint(0,24,(2,3,4))
# data变量在axis=0上的全部slices:
print('data在axis=0上的全部slices:')
for slices in range(data.shape[0]):  
    print(data[slices,:,:])
    print('---------------')
print('data在axis=1上的全部slices:')
for slices in range(data.shape[1]):  
    print(data[:,slices,:])
    print('---------------')
print('data在axis=2上的全部slices:')
for slices in range(data.shape[2]):  
    print(data[:,:,slices])
    print('---------------')

输出:

data在axis=0上的全部slices:
[[ 5 11 12  8]
 [ 9 11  5 15]
 [ 0 16  1 12]]
---------------
[[ 7 13  6 18]
 [20  5 18 20]
 [11 10 14 18]]
---------------
data在axis=1上的全部slices:
[[ 5 11 12  8]
 [ 7 13  6 18]]
---------------
[[ 9 11  5 15]
 [20  5 18 20]]
---------------
[[ 0 16  1 12]
 [11 10 14 18]]
---------------
data在axis=2上的全部slices:
[[ 5  9  0]
 [ 7 20 11]]
---------------
[[11 11 16]
 [13  5 10]]
---------------
[[12  5  1]
 [ 6 18 14]]
---------------
[[ 8 15 12]
 [18 20 18]]
---------------

np.sum函数

假设现在我们对data变量在axis=0上做sum运算:知道了data在axis=0上的slices,一个for循环就可以解决。

sum_axis0 = 0
for slices in range(data.shape[0]):  
    sum_axis0+=data[slices,:,:]
print(sum_axis0)
print(sum_axis0==data.sum(axis=0))

输出:

[[12 24 18 26]
 [29 16 23 35]
 [11 26 15 30]]
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]

从上面可以看出,sum作用的基本单元其实就是data在axis=0上的各个slices。同理我们可以得到axis=1,axis=2上面的sum。这里就不写了。

np.mean函数

有了sum的结果,计算mean也是easy了:

sum_axis0 = 0
for slices in range(data.shape[0]):  
    sum_axis0+=data[slices,:,:]
mean_data = sum_axis0/2
print(mean_data)
print(mean_data==data.mean(axis=0))

输出:

[[ 6.  12.   9.  13. ]
 [14.5  8.  11.5 17.5]
 [ 5.5 13.   7.5 15. ]]
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]

np.sort函数

现在来试一下sort函数。以axis=0为例,data在axis=0上的slices为:

[[ 5 11 12  8]
 [ 9 11  5 15]
 [ 0 16  1 12]]
---------------
[[ 7 13  6 18]
 [20  5 18 20]
 [11 10 14 18]]

那么sort函数的作用过程就是(以升序排列为例):依次比较上面两个slices的各个元素。具体过程为:

  • 对slices(0)与slices(1)的(0,0)元素进行排序:由于5<7,故排序后的array的(0,0,0)元素为5,(1,0,0)元素为7
  • 对slices(0)与slices(1)的(0,1)元素进行排序:由于11<13,故排序后的array的(0,0,1)元素为11,(1,0,1)元素为13
  • 对slices(0)与slices(1)的(0,2)元素进行排序:由于6<12,故排序后的array的(0,0,2)元素为6,(1,0,2)元素为12
  • 对slices(0)与slices(1)的(2,3)元素进行排序:由于8<18,故排序后的array的(0,2,3)元素为12,(1,2,3)元素为18

具体代码为:

slices0 = data[0,:,:]
slices1 = data[1,:,:]
sorted_data  =  np.zeros((2,3,4),dtype=int)
_,row,col = data.shape
for i in range(row):
    for j in range(col):
        sorted_data[0,i,j],sorted_data[1,i,j] = sorted([slices0[i,j],slices1[i,j]])

print(sorted_data)
default_sort_data = np.sort(data,axis=0)
print(sorted_data==default_sort_data)

输出

[[[ 5 11  6  8]
  [ 9  5  5 15]
  [ 0 10  1 12]]

 [[ 7 13 12 18]
  [20 11 18 20]
  [11 16 14 18]]]
[[[ True  True  True  True]
  [ True  True  True  True]
  [ True  True  True  True]]

 [[ True  True  True  True]
  [ True  True  True  True]
  [ True  True  True  True]]]

np.min与np.max函数

好哒,那么min和max函数也是手到擒来了

slices0 = data[0,:,:]
slices1 = data[1,:,:]
min_data  =  np.zeros((3,4),dtype=int)
_,row,col = data.shape
for i in range(row):
    for j in range(col):
        min_data[i,j]= min([slices0[i,j],slices1[i,j]])

print(min_data)
default_min_data = np.min(data,axis=0)
print(min_data==default_min_data)

输出:

[[ 5 11  6  8]
 [ 9  5  5 15]
 [ 0 10  1 12]]
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]

总结

从以上各个例子中我们还可以发现,sum、mean、min和max函数作用之后得到的结果与axis=i上的slices的维度相同,但是sort函数作用之后的结果是与原始数组的维度相同的。具体原因通过代码也是可以看出来的,简单提一下就是了,sort函数结果与原始数组相同是因为它是先对axis=i下的全部slices进行一个排序操作,然后再把这些排序后的结果合在一起,所以sort函数的结果维度与原始数组维度保持一致。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值