numpy数组索引遇到的问题

在numpy多维数组操作中,正确使用切片和索引至关重要。本文通过实例区分了两种常见的切片误区,并探讨了如何结合切片和花式索引从N*C的矩阵中按N*1的label选取特定行。同时,介绍了省略号...在高维数组切片中的作用,帮助更好地理解和应用numpy的索引功能。

在使用numpy多维数组时我们常会需要获取数组中的元素,这一般有两种方法:

import numpy as np

a = np.random.randint(10, 20, size=[10, 20])
print(a)
print(a[2, 2])
print(a[2][2])
'''
[[17 12 16 14 19 10 14 13 15 13 17 19 19 11 11 18 16 12 16 17]
 [15 17 15 11 19 16 19 18 12 12 19 19 15 19 18 11 18 12 10 13]
 [10 12 13 12 14 11 12 12 10 18 14 16 16 16 14 13 11 15 11 15]
 [10 17 19 11 16 17 15 11 14 12 17 14 15 17 12 17 16 15 10 14]
 [13 14 13 16 14 18 14 16 16 10 11 13 14 14 12 11 18 12 14 13]
 [17 19 14 15 19 12 10 17 14 13 19 11 17 13 17 10 19 14 18 11]
 [17 11 13 18 14 17 14 18 11 18 18 14 16 19 18 18 15 18 15 12]
 [19 17 10 13 14 12 19 16 10 18 11 11 12 18 16 15 15 13 15 19]
 [18 12 11 15 11 13 13 18 15 19 19 13 16 13 19 15 10 12 10 15]
 [14 19 12 10 11 10 14 19 12 10 19 12 18 15 18 17 19 12 18 14]]
13
13
 '''

但当我们需要用到切片时,第二种写法却是错误的:

import numpy as np

a = np.random.randint(10, 20, size=[10, 20])
print(a)
print(a[:2, :2])
print(a[:2][:2])
'''
[[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17]
 [16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]
 [12 10 10 12 13 18 10 13 14 13 13 19 10 16 13 19 13 19 13 17]
 [17 19 11 10 11 17 16 10 18 15 10 18 12 15 17 11 16 13 12 11]
 [17 19 11 13 12 15 14 16 12 12 14 11 15 13 19 19 17 14 16 19]
 [10 15 18 19 10 15 12 13 11 18 19 11 14 15 14 17 15 10 13 16]
 [10 16 17 18 19 14 15 10 14 11 11 16 18 15 16 12 10 11 16 18]
 [13 12 13 13 16 16 17 16 15 15 15 16 14 16 15 16 19 14 19 14]
 [16 13 17 16 10 15 19 15 19 13 19 12 16 11 14 17 18 19 15 15]
 [14 10 19 14 10 11 16 14 10 16 18 12 10 14 12 10 12 14 10 15]]
[[16 16]
 [16 13]]
[[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17]
 [16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]]
'''

第一种写法获取到了我们实际想要的子矩阵,而第二种写法实际上需要分开来看待:先获取a的前两行得到一个子矩阵,再获取这个子矩阵的前两行。
最近写代码时总弄混这两个写法,因此记录一下,numpy切片的正确用法是用逗号隔开,而不是像多维数组索引那样隔开


今天又发现了新的问题,numpy真是有趣。在做cs231n的作业时,我需要从一个N∗CN * CNC的分数中按照一个N∗1N * 1N1的label来取出N∗1N * 1N1的正确分数(每行按照label选一个分数),自然会想到花式索引和切片结合的方法,但遇到了一些问题,这里总结一下可能的写法:

a = np.random.randint(5, 10, size=(5, 10))
print(a)

y1 = np.random.randint(0, 10, size=(5, ))
print(y1)

y2 = np.random.randint(0, 10, size=(5, 1))
print(y2)
'''
[[7 6 5 7 9 5 9 6 8 8]
 [8 8 8 7 9 9 8 5 9 8]
 [8 9 8 7 5 7 7 6 8 8]
 [6 9 5 9 5 7 7 8 7 7]
 [9 6 6 9 9 6 7 9 7 6]]
[8 1 0 9 8]
[[5]
 [3]
 [5]
 [9]
 [5]]
'''

可以看到y1和y2的shape是不一样的,y1是一个数组,y2则是一个二维矩阵。

print(a[:, y1])
'''
[[8 6 7 8 8]
 [9 8 8 8 9]
 [8 9 8 8 8]
 [7 9 6 7 7]
 [7 6 9 6 7]]
 可以看到,这种写法得到一个N * N的矩阵,每一行对应a的每一行按照y1的所有元素来取值,
 即本来a的每一行取一个值就可以,但是却取了N个值,每一行相当于a[i, y1]
 a[0, y1] = [8 6 7 8 8]
'''
print(a[range(5), y1])
'''
[8 8 8 7 7]
这种写法就是我们想要的结果
'''
print(a[:, y2])
'''
[[[5]
  [7]
  [5]
  [8]
  [5]]

 [[9]
  [7]
  [9]
  [8]
  [9]]

 [[7]
  [7]
  [7]
  [8]
  [7]]

 [[7]
  [9]
  [7]
  [7]
  [7]]

 [[6]
  [9]
  [6]
  [6]
  [6]]]
 这种写法得到的结果更加离谱,是一个5 * 5 * 1的三维矩阵,
 每个5 * 1的子矩阵相当于第一种写法的结果
'''
print(a[range(5), y2])
'''
[[5 9 7 7 6]
 [7 7 7 9 9]
 [5 9 7 7 6]
 [8 8 8 7 6]
 [5 9 7 7 6]]
 一共5行,每一行都是a的某一列,按照y2来取值。
'''

结论就是要使用range和一维数组来进行切片和花式索引。

切片中省略号…的作用
有时候我们会看到这样的索引写法a[..., 1:],其中的...是一种特殊写法,常适用于高维数组的切片。比如,当a是5维数组时,a[:, :, :, :, 1:] = a[..., 1:],即...是所有完整切片的缩短,这种写法更简洁。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值