在使用numpy多维数组时我们常会需要获取数组中的元素,这一般有两种方法:
import numpy as np
a = np.random.randint(10, 20, size=[10, 20])
print(a)
print(a[2, 2])
print(a[2][2])
'''
[[17 12 16 14 19 10 14 13 15 13 17 19 19 11 11 18 16 12 16 17]
[15 17 15 11 19 16 19 18 12 12 19 19 15 19 18 11 18 12 10 13]
[10 12 13 12 14 11 12 12 10 18 14 16 16 16 14 13 11 15 11 15]
[10 17 19 11 16 17 15 11 14 12 17 14 15 17 12 17 16 15 10 14]
[13 14 13 16 14 18 14 16 16 10 11 13 14 14 12 11 18 12 14 13]
[17 19 14 15 19 12 10 17 14 13 19 11 17 13 17 10 19 14 18 11]
[17 11 13 18 14 17 14 18 11 18 18 14 16 19 18 18 15 18 15 12]
[19 17 10 13 14 12 19 16 10 18 11 11 12 18 16 15 15 13 15 19]
[18 12 11 15 11 13 13 18 15 19 19 13 16 13 19 15 10 12 10 15]
[14 19 12 10 11 10 14 19 12 10 19 12 18 15 18 17 19 12 18 14]]
13
13
'''
但当我们需要用到切片时,第二种写法却是错误的:
import numpy as np
a = np.random.randint(10, 20, size=[10, 20])
print(a)
print(a[:2, :2])
print(a[:2][:2])
'''
[[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17]
[16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]
[12 10 10 12 13 18 10 13 14 13 13 19 10 16 13 19 13 19 13 17]
[17 19 11 10 11 17 16 10 18 15 10 18 12 15 17 11 16 13 12 11]
[17 19 11 13 12 15 14 16 12 12 14 11 15 13 19 19 17 14 16 19]
[10 15 18 19 10 15 12 13 11 18 19 11 14 15 14 17 15 10 13 16]
[10 16 17 18 19 14 15 10 14 11 11 16 18 15 16 12 10 11 16 18]
[13 12 13 13 16 16 17 16 15 15 15 16 14 16 15 16 19 14 19 14]
[16 13 17 16 10 15 19 15 19 13 19 12 16 11 14 17 18 19 15 15]
[14 10 19 14 10 11 16 14 10 16 18 12 10 14 12 10 12 14 10 15]]
[[16 16]
[16 13]]
[[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17]
[16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]]
'''
第一种写法获取到了我们实际想要的子矩阵,而第二种写法实际上需要分开来看待:先获取a的前两行得到一个子矩阵,再获取这个子矩阵的前两行。
最近写代码时总弄混这两个写法,因此记录一下,numpy切片的正确用法是用逗号隔开,而不是像多维数组索引那样隔开。
今天又发现了新的问题,numpy真是有趣。在做cs231n的作业时,我需要从一个N∗CN * CN∗C的分数中按照一个N∗1N * 1N∗1的label来取出N∗1N * 1N∗1的正确分数(每行按照label选一个分数),自然会想到花式索引和切片结合的方法,但遇到了一些问题,这里总结一下可能的写法:
a = np.random.randint(5, 10, size=(5, 10))
print(a)
y1 = np.random.randint(0, 10, size=(5, ))
print(y1)
y2 = np.random.randint(0, 10, size=(5, 1))
print(y2)
'''
[[7 6 5 7 9 5 9 6 8 8]
[8 8 8 7 9 9 8 5 9 8]
[8 9 8 7 5 7 7 6 8 8]
[6 9 5 9 5 7 7 8 7 7]
[9 6 6 9 9 6 7 9 7 6]]
[8 1 0 9 8]
[[5]
[3]
[5]
[9]
[5]]
'''
可以看到y1和y2的shape是不一样的,y1是一个数组,y2则是一个二维矩阵。
print(a[:, y1])
'''
[[8 6 7 8 8]
[9 8 8 8 9]
[8 9 8 8 8]
[7 9 6 7 7]
[7 6 9 6 7]]
可以看到,这种写法得到一个N * N的矩阵,每一行对应a的每一行按照y1的所有元素来取值,
即本来a的每一行取一个值就可以,但是却取了N个值,每一行相当于a[i, y1]
a[0, y1] = [8 6 7 8 8]
'''
print(a[range(5), y1])
'''
[8 8 8 7 7]
这种写法就是我们想要的结果
'''
print(a[:, y2])
'''
[[[5]
[7]
[5]
[8]
[5]]
[[9]
[7]
[9]
[8]
[9]]
[[7]
[7]
[7]
[8]
[7]]
[[7]
[9]
[7]
[7]
[7]]
[[6]
[9]
[6]
[6]
[6]]]
这种写法得到的结果更加离谱,是一个5 * 5 * 1的三维矩阵,
每个5 * 1的子矩阵相当于第一种写法的结果
'''
print(a[range(5), y2])
'''
[[5 9 7 7 6]
[7 7 7 9 9]
[5 9 7 7 6]
[8 8 8 7 6]
[5 9 7 7 6]]
一共5行,每一行都是a的某一列,按照y2来取值。
'''
结论就是要使用range和一维数组来进行切片和花式索引。
切片中省略号…的作用
有时候我们会看到这样的索引写法a[..., 1:],其中的...是一种特殊写法,常适用于高维数组的切片。比如,当a是5维数组时,a[:, :, :, :, 1:] = a[..., 1:],即...是所有完整切片的缩短,这种写法更简洁。
在numpy多维数组操作中,正确使用切片和索引至关重要。本文通过实例区分了两种常见的切片误区,并探讨了如何结合切片和花式索引从N*C的矩阵中按N*1的label选取特定行。同时,介绍了省略号...在高维数组切片中的作用,帮助更好地理解和应用numpy的索引功能。
355

被折叠的 条评论
为什么被折叠?



