NumPy数组的大小是固定的,因此不能就地移除元素。例如,使用del不起作用:>>> import numpy as np
>>> arr = np.arange(5)
>>> del arr[-1]
ValueError: cannot delete array elements
注意,索引-1表示最后一个元素。这是因为Python(和NumPy)中的负索引是从末尾开始计算的,所以-1是最后一个,-2是最后一个之前的那个,-len实际上是第一个元素。那只是为了你的信息,以防你不知道。
Python列表的大小是可变的,因此添加或删除元素很容易。
所以如果你想删除一个元素,你需要创建一个新的数组或视图。
创建新视图
可以使用切片表示法创建包含除最后一个元素外的所有元素的新视图:>>> arr = np.arange(5)
>>> arr
array([0, 1, 2, 3, 4])
>>> arr[:-1] # all but the last element
array([0, 1, 2, 3])
>>> arr[:-2] # all but the last two elements
array([0, 1, 2])
>>> arr[1:] # all but the first element
array([1, 2, 3, 4])
>>> arr[1:-1] # all but the first and last element
array([1, 2, 3])
但是,视图与原始数组共享数据,因此,如果其中一个被修改,则另一个被修改:>>> sub = arr[:-1]
>>> sub
array([0, 1, 2, 3])
>>> sub[0] = 100
>>> sub
array([100, 1, 2, 3])
>>> arr
array([100, 1, 2, 3, 4])
创建新数组
一。复制视图
如果不喜欢这种内存共享,则必须创建一个新数组,在这种情况下,创建一个视图然后复制(例如使用数组的^{}方法)可能是最简单的:>>> arr = np.arange(5)
>>> arr
array([0, 1, 2, 3, 4])
>>> sub_arr = arr[:-1].copy()
>>> sub_arr
array([0, 1, 2, 3])
>>> sub_arr[0] = 100
>>> sub_arr
array([100, 1, 2, 3])
>>> arr
array([0, 1, 2, 3, 4])
2。使用整数数组索引[docs]
但是,也可以使用整数数组索引删除最后一个元素并获取新数组。此整数数组索引将始终(不是100%确定)创建副本而不是视图:>>> arr = np.arange(5)
>>> arr
array([0, 1, 2, 3, 4])
>>> indices_to_keep = [0, 1, 2, 3]
>>> sub_arr = arr[indices_to_keep]
>>> sub_arr
array([0, 1, 2, 3])
>>> sub_arr[0] = 100
>>> sub_arr
array([100, 1, 2, 3])
>>> arr
array([0, 1, 2, 3, 4])
此整数数组索引对于从数组中删除任意元素非常有用(如果需要查看,这可能很棘手或不可能):>>> arr = np.arange(5, 10)
>>> arr
array([5, 6, 7, 8, 9])
>>> arr[[0, 1, 3, 4]] # keep first, second, fourth and fifth element
array([5, 6, 8, 9])
如果需要使用整数数组索引删除最后一个元素的通用函数:def remove_last_element(arr):
return arr[np.arange(arr.size - 1)]
三。使用布尔数组索引[docs]
还可以使用布尔索引,例如:>>> arr = np.arange(5, 10)
>>> arr
array([5, 6, 7, 8, 9])
>>> keep = [True, True, True, True, False]
>>> arr[keep]
array([5, 6, 7, 8])
这也创建了一个副本!一般的方法可能是这样的:def remove_last_element(arr):
if not arr.size:
raise IndexError('cannot remove last element of empty array')
keep = np.ones(arr.shape, dtype=bool)
keep[-1] = False
return arr[keep]
如果您想了解更多关于NumPys索引的信息,那么documentation on "Indexing"非常好,并且涵盖了很多情况。
四。使用^{}
通常情况下,我不会推荐NumPy函数“看起来”像是在原地修改数组(比如np.append和np.insert),但确实返回副本,因为这些通常是不必要的缓慢和误导性的。你应该尽量避免,这就是为什么这是我回答的最后一点。不过,在这种情况下,它实际上是一个完美的契合,所以我必须提到:>>> arr = np.arange(10, 20)
>>> arr
array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
>>> np.delete(arr, -1)
array([10, 11, 12, 13, 14, 15, 16, 17, 18])
5.)使用^{}
NumPy有另一个方法,听起来像是在执行就地操作,但它确实返回了一个新数组:>>> arr = np.arange(5)
>>> arr
array([0, 1, 2, 3, 4])
>>> np.resize(arr, arr.size - 1)
array([0, 1, 2, 3])
为了删除最后一个元素,我只提供了一个比以前小1的新形状,这有效地删除了最后一个元素。
就地修改数组
是的,我之前写过,不能就地修改数组。但我说,因为在大多数情况下,这是不可能的,或者只有通过禁用一些(完全有用的)安全检查。我不确定内部结构,但根据旧的大小和新的大小,这可能包括(仅内部)复制操作,因此它可能比创建视图慢。
如果阵列不与任何其他阵列共享内存,则可以就地调整阵列大小:>>> arr = np.arange(5, 10)
>>> arr.resize(4)
>>> arr
array([5, 6, 7, 8])
但是,如果另一个数组也实际引用了它,则会抛出ValueErrors:>>> arr = np.arange(5)
>>> view = arr[1:]
>>> arr.resize(4)
ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function
您可以通过设置refcheck=False来禁用该安全检查,但这不应轻易执行,因为如果其他引用试图访问已删除的元素,您将容易受到分段错误和内存损坏的影响!此refcheck参数应视为专家专用选项!
摘要
创建视图非常快,而且不需要太多额外的内存,因此只要有可能,就应该尽可能多地使用视图。但是根据用例的不同使用基本切片移除任意元素。虽然很容易删除前n个元素和/或最后n个元素,或者删除每个x元素(用于切片的step参数),但这就是您可以使用它所做的一切。
但在您删除一维数组的最后一个元素的情况下,我建议:arr[:-1] # if you want a view
arr[:-1].copy() # if you want a new array
因为这些最清楚地表达了意图,每个有Python/NumPy经验的人都会认识到这一点。
时间安排
基于这个answer的时间框架:# Setup
import numpy as np
def view(arr):
return arr[:-1]
def array_copy_view(arr):
return arr[:-1].copy()
def array_int_index(arr):
return arr[np.arange(arr.size - 1)]
def array_bool_index(arr):
if not arr.size:
raise IndexError('cannot remove last element of empty array')
keep = np.ones(arr.shape, dtype=bool)
keep[-1] = False
return arr[keep]
def array_delete(arr):
return np.delete(arr, -1)
def array_resize(arr):
return np.resize(arr, arr.size - 1)
# Timing setup
timings = {view: [],
array_copy_view: [], array_int_index: [], array_bool_index: [],
array_delete: [], array_resize: []}
sizes = [2**i for i in range(1, 20, 2)]
# Timing
for size in sizes:
print(size)
func_input = np.random.random(size=size)
for func in timings:
print(func.__name__.ljust(20), ' ', end='')
res = %timeit -o func(func_input) # if you use IPython, otherwise use the "timeit" module
timings[func].append(res)
# Plotting
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(1)
ax = plt.subplot(111)
for func in timings:
ax.plot(sizes,
[time.best for time in timings[func]],
label=func.__name__)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('size')
ax.set_ylabel('time [seconds]')
ax.grid(which='both')
ax.legend()
plt.tight_layout()
我得到以下时间作为日志图覆盖所有细节,较低的时间仍然意味着更快,但两个刻度之间的范围代表一个数量级,而不是一个固定的数量。如果您对特定值感兴趣,我将它们复制到这个gist:
根据这些时间安排,这两种方法也是最快的。(Python 3.6和NumPy 1.14.0)