Numba是Python的即时编译器,在使用 NumPy 的数组和函数以及循环的代码上效果最佳。
使用Numba的最常见方法是通过其装饰器(numba.jit 等等),把装饰器可应用于函数上可以指示 Numba 对其进行编译。
调用带有Numba装饰的函数时,该函数将被“即时”编译为机器代码以执行!
例如写了一段动态时间规整的 python 代码:
def dtw_distance1(ts_a, ts_b, k, mww=2):
"""Computes dtw distance between two time series
Args:
ts_a: time series a
ts_b: time series b
d: distance function
mww: max warping window, int, optional (default = infinity)
Returns:
dtw distance
"""
# Create cost matrix via broadcasting with large int
cost = np.ones((k, k))
# Initialize the first row and column
cost[0, 0] = abs(ts_a[0] - ts_b[0])
for i in range(1, k):
cost[i, 0] = cost[i - 1, 0] + abs(ts_a[i]-ts_b[0])
for j in range(1, k):
cost[0, j] = cost[0, j - 1] + abs(ts_a[0]-ts_b[j])
# Populate rest of cost matrix within window
for i in range(1, k):
for j in range(max(1, i - mww), min(k, i + mww)):
choices = cost[i - 1, j - 1], cost[i, j - 1], cost[i - 1, j]
cost[i, j] = min(choices) + abs(ts_a[i]- ts_b[j])
# Return DTW distance given window
return cost[-1, -1]
全是 for 循环,写在 python 代码里非常尴尬,但是算法就这么实现的,有什么办法🤷♀️
下面只需一行代码,就可以大幅加速
import numba
@numba.jit(nopython=True)
def dtw_distance2(ts_a, ts_b, k, mww=2):
"""Computes dtw distance between two time series
Args:
ts_a: time series a
ts_b: time series b
d: distance function
mww: max warping window, int, optional (default = infinity)
Returns:
dtw distance
"""
...
只需一个装饰器,python 秒变 c 语言!
nopython=True 表示这段代码必须编译成非 python 语言,否则就给我报错,终止程序!
下面来测试以下效率的提升:
我不测了,大家自己测吧😆