运筹系列58:python使用numba进行加速

本文介绍Numba库的使用技巧,包括如何通过不同目标选择优化并行计算性能、利用装饰器提升循环效率、创建通用函数以减少开销、实现并发操作及高级特性如卷积计算等。

vectorize中的参数target一共有三种取值:cpu(默认)、parallel和cuda。关于选择哪个取值,官方文档上有很好的说明:The “cpu” target works well for small data sizes (approx. less than 1KB) and low compute intensity algorithms. It has the least amount of overhead. The “parallel” target works well for medium data sizes (approx. less than 1MB). Threading adds a small delay. The “cuda” target works well for big data sizes (approx. greater than 1MB) and high compute intensity algorithms. Transfering memory to and from the GPU adds significant overhead.
GPU我们在下一篇单独介绍。

1. numba快速入门

1.1 注意事项

nopython指的是完全不用python编译器,推荐使用。可以用@njit代替。
fastmath可以去掉一些数值检查的步骤
使用cache=True将代码进行缓存
使用generated_jit进行数据类型重载

1.2 加速循环

因为numba内置的函数本身是个装饰器,所以只要在自己定义好的函数前面加个@nb.jit()就行,简单上手。下面以一个求和函数为例

# 用numba加速的求和函数
@nb.jit()
def nb_sum(a):
    Sum = 0
    for i in range(len(a)):
        Sum += a[i]
    return Sum

# 没用numba加速的求和函数
def py_sum(a):
    Sum = 0
    for i in range(len(a)):
        Sum += a[i]
    return Sum

来测试一下速度

import numpy as np
a = np.linspace(0,100,100) # 创建一个长度为100的数组

%timeit np.sum(a) # numpy自带的求和函数
%timeit sum(a) # python自带的求和函数
%timeit nb_sum(a) # numba加速的求和函数
%timeit py_sum(a) # 没加速的求和函数

结果如下

# np.sum(a)
7.1 µs ± 537 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# sum(a)
27.7 µs ± 2.64 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# nb_sum(a)
1.05 µs ± 27.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
# py_sum(a)
43.7 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

可以看出,numba甚至比号称最接近C语言速度运行的numpy还要快6倍以上。但大家都知道,numpy往往对大的数组更加友好,那我们来测试一个更长的数组

a = np.linspace(0,100,10**6) # 创建一个长度为100万的数组

测试结果如下

# np.sum(a)
2.51 ms ± 246 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# sum(a)
249 ms ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# nb_sum(a)
3.01 ms ± 59.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# py_sum(a)
592 ms ± 42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可见即便是用很长的loop来计算,numba的表现也丝毫不亚于numpy。在这里,我们可以看到numba相对于numpy一个非常明显的优势:numba可以把各种具有很大loop的函数加到很快的速度,但numpy的加速只适用于numpy自带的函数。

1.3 尽量写的像c

我们来看下下面的代码:

import numba as nb
# 普通的 MaxPool
def max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)

# 简单地加了个 jit 后的 MaxPool
@nb.jit(nopython=True)
def jit_max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值