关于numpy.take()用法
此文章是对我当时使用错误后的改正,用以记录下来,并作分享。
首先用numpy定义一个二维数组
import numpy as np
a = np.random.uniform(-10, 10, size=(3, 5))
print(a)
output:
[[ 9.56258678 0.78449895 1.48800984 3.86522118 3.0777375 ]
[ 8.6176819 6.58367454 6.16471974 -3.71131304 -8.22857358]
[ 9.73814042 8.80724384 6.70023151 -2.12578197 3.01890996]]
然后本意是想借助index用numpy.take方法得出反序结果
indexes = np.argsort(-a) # argsort返回数组从小到大的数值的索引
q = len(indexes[0])
print(indexes)
print(np.take(a, indexes))
output:
[[0 3 4 2 1]
[0 1 2 3 4]
[0 1 2 4 3]]
array([[ 9.56258678, 3.86522118, 3.0777375 , 1.48800984, 0.78449895],
[ 8.6176819 , 6.58367454, 6.16471974, -3.71131304, -8.22857358],
[ 9.73814042, 8.80724384, 6.70023151, 3.01890996, -2.12578197]])
结果不料它返回的数据中全是a[0]的数据,a[1]a[2]被雪藏
查阅了官方文档得到下面的话:
If indices is not one dimensional, the output also has these dimensions.
翻译为:如果索引不是一维的,则输出也具有这些维度。
并有一例:

可以得知其后的index应为看做一维数组的索引,代码修正后便无误了
indexes = np.argsort(-a)
print(indexes)
print('*' * 30)
q = len(indexes[0])
j = 1
for i in indexes[1:]:
i += q * j
j += 1
print(indexes)
print(np.take(a, indexes))
output:
[[0 3 4 2 1]
[0 1 2 3 4]
[0 1 2 4 3]]
******************************
[[ 0 3 4 2 1]
[ 5 6 7 8 9]
[10 11 12 14 13]]
array([[ 9.56258678, 3.86522118, 3.0777375 , 1.48800984, 0.78449895],
[ 8.6176819 , 6.58367454, 6.16471974, -3.71131304, -8.22857358],
[ 9.73814042, 8.80724384, 6.70023151, 3.01890996, -2.12578197]])

本文记录了一次使用numpy.take()方法时遇到的问题及解决过程。作者原本希望通过索引获得数组元素的反序排列,但遇到了输出数据仅包含首个子数组元素的情况。通过查阅官方文档发现,索引必须是一维的。最终通过调整索引生成方式成功获取了预期结果。
2859

被折叠的 条评论
为什么被折叠?



