C/C++对Python模块扩展

本文介绍如何使用Cython显著提升Python代码执行效率,通过实例对比,展示Cython在矩阵乘法运算上的性能优势。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.基本知识
  • pytest1.h
#ifndef PYTEST1_H_
#define PYTEST1_H_

int fac (int n);
char *reverse(char *s);
int test(void);
#endif
  • py_test1.c
#include<stdio.h>
#include<stdlib.h>
#include<string.h>

int fac(int n){
    if(n<2)
        return 1;
    return n*fac(n-1);
}

char *reverse(char *s){
    char t,*p = s,*q = (s+strlen(s)-1);

    while(s && (p<q))
    {
         t = *p;
         *p++ = *q;
         *q-- = t;
    }

    return s;
}

int test(void)
{
     char s[1024];
 
     printf("5! = %d\n",fac(5));
     printf("10! = %d\n",fac(10));
     strcpy(s,"hello world");
     printf("reversing 'hello world',we get '%s'\n",reverse(s));
     return 0;
}
  • python
#include "Python.h"
#include <stdlib.h>
#include <string.h>
#include "py_test1.h"

static PyObject *py_test1_fac(PyObject *self,PyObject *args)
{
    int num;
    if (!PyArg_ParseTuple(args,"i",&num))
        return NULL;
    
    return (PyObject *)Py_BuildValue("i",fac(num));
}

static PyObject *py_test1_doppel(PyObject *self,PyObject *args)
{
    char *src;
    char *mstr;
    PyObject *retval;
    
    if (!PyArg_ParseTuple(args,"s",&src))
        return NULL;

    mstr = malloc(strlen(src) + 1);
    strcpy(mstr,src);
    reverse(mstr);
    retval = (PyObject *) Py_BuildValue("ss",src,mstr);
    free(mstr);
    return retval;
}

static PyObject *py_test1_test(PyObject *self,PyObject *args){
    test();
    return (PyObject *)Py_BuildValue("");
}

static PyMethodDef py_test1Methods[] = {
    {"fac",py_test1_fac,METH_VARARGS},
    {"doppel",py_test1_doppel,METH_VARARGS},
    {"test",py_test1_test,METH_VARARGS},
    {NULL,NULL},
};

void initpy_test11(void){##这里注意名称
    Py_InitModule("py_test11",py_test1Methods);
}

在终端输入 python setup.py build,将产生的py_test11.so与这些文件放在同一目录。
import py_test11
py_test11.fac()
或者gcc -shared -fPIC py_test1wrapper.c py_test1.c -I py_test1.h -I/usr/include/python2.7/ -L/usr/lib -lpython2.7 -o py_test11.so

2 矩阵乘法

假设我们现在正在编写一个很简单的矩阵乘法代码,其中矩阵是保存在 numpy.ndarray 中。Python 代码可以这么写:

# dot_python.py
import numpy as np

def naive_dot(a, b):
    if a.shape[1] != b.shape[0]:
        raise ValueError('shape not matched')
    n, p, m = a.shape[0], a.shape[1], b.shape[1]
    c = np.zeros((n, m), dtype=np.float32)
    for i in xrange(n):
        for j in xrange(m):
            s = 0
            for k in xrange(p):
                s += a[i, k] * b[k, j]
            c[i, j] = s
    return c

使用cython

# dot_cython.pyx
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.float32_t, ndim=2] _naive_dot(np.ndarray[np.float32_t, ndim=2] a, np.ndarray[np.float32_t, ndim=2] b):
    cdef np.ndarray[np.float32_t, ndim=2] c
    cdef int n, p, m
    cdef np.float32_t s
    if a.shape[1] != b.shape[0]:
        raise ValueError('shape not matched')
    n, p, m = a.shape[0], a.shape[1], b.shape[1]
    c = np.zeros((n, m), dtype=np.float32)
    for i in xrange(n):
        for j in xrange(m):
            s = 0
            for k in xrange(p):
                s += a[i, k] * b[k, j]
            c[i, j] = s
    return c

def naive_dot(a, b):
    return _naive_dot(a, b)

可以看到这个程序和 Python 写的几乎差不多。我们来看看不一样部分:

  • Cython 程序的扩展名是 .pyx
  • cimport 是 Cython 中用来引入 .pxd 文件的命令。有关 .pxd 文件,可以简单理解成 C/C++ 中用来写声明的头文件,更具体的我会在后面写到。这里引入的两个是 Cython 预置的。
  • @cython.boundscheck(False) 和 @cython.wraparound(False) 两个修饰符用来关闭 Cython 的边界检查
  • Cython 的函数使用 cdef 定义,并且他可以给所有参数以及返回值指定类型。比方说,我们可以这么编写整数 min 函数:

cdef int my_min(int x, int y):
return x if x <= y else y

这里 np.ndarray[np.float32_t, ndim=2] 就是一个类型名就像 int 一样,只是它比较长而且信息量比较大而已。它的意思是,这是个类型为 np.float32_t 的2维 np.ndarray。

  • 在函数体内部,我们一样可以使用 cdef typename varname 这样的语法来声明变量
  • 在 Python 程序中,是看不到 cdef 的函数的,所以我们这里 def naive_dot(a, b) 来调用 cdef 过的 _naive_dot 函数。

另外,Cython 程序需要先编译之后才能被 Python 调用,流程是:

  1. Cython 编译器把 Cython 代码编译成调用了 Python 源码的 C/C++ 代码
  2. 把生成的代码编译成动态链接库
  3. Python 解释器载入动态链接库
# setup.py
from distutils.core import setup, Extension
from Cython.Build import cythonize
import numpy
setup(ext_modules = cythonize(Extension(
    'dot_cython',
    sources=['dot_cython.pyx'],
    language='c',
    include_dirs=[numpy.get_include()],
    library_dirs=[],
    libraries=[],
    extra_compile_args=[],
    extra_link_args=[]
)))

这段代码对于我们这个简单的例子来说有些太复杂了,不过实际上,再复杂也就这么复杂了,为了省得后面再贴一遍,所以索性就在这里把最复杂的列出来好了。这里顺带解释一下好了:

  1. ‘dot_cython’ 是我们要生成的动态链接库的名字
  2. sources 里面可以包含 .pyx 文件,以及后面如果我们要调用 C/C++ 程序的话,还可以往里面加 .c / .cpp 文件
  3. language 其实默认就是 c,如果要用 C++,就改成 c++ 就好了
  4. include_dirs 这个就是传给 gcc 的 -I 参数
  5. library_dirs 这个就是传给 gcc 的 -L 参数
  6. libraries 这个就是传给 gcc 的 -l 参数
  7. extra_compile_args 就是传给 gcc 的额外的编译参数,比方说你可以传一个 -std=c++11
  8. extra_link_args 就是传给 gcc 的额外的链接参数(也就是生成动态链接库的时候用的)
  9. 如果你从来没见过上面几个 gcc 参数,说明你暂时还没这些需求,等你遇到了你就懂了
    然后我们只需要执行下面命令就可以把 Cython 程序编译成动态链接库了。

python setup.py build_ext --inplace

成功运行完上面这句话,可以看到在当前目录多出来了 dot_cython.c 和 dot_cython.so。前者是生成的 C 程序,后者是编译好了的动态链接库。

下面让我们来试试看效果:

ipython
Python 2.7.12 (default, Oct 11 2016, 05:20:59)
Type “copyright”, “credits” or “license” for more information.
IPython 4.0.1 – An enhanced Interactive Python.
? -> Introduction and overview of IPython’s features.
%quickref -> Quick reference.
help -> Python’s own help system.
object? -> Details about ‘object’, use ‘object??’ for extra details.
In [1]: import numpy as np
In [2]: import dot_python
In [3]: import dot_cython
In [4]: a = np.random.randn(100, 200).astype(np.float32)
In [5]: b = np.random.randn(200, 50).astype(np.float32)
In [6]: %timeit -n 100 -r 3 dot_python.naive_dot(a, b)
100 loops, best of 3: 560 ms per loop
In [7]: %timeit -n 100 -r 3 dot_cython.naive_dot(a, b)
100 loops, best of 3: 982 µs per loop
In [8]: %timeit -n 100 -r 3 np.dot(a, b)
100 loops, best of 3: 49.2 µs per loop

所以说,提升了大概 570 倍的效率!而我们的代码基本上就没有改动过!当然啦,你要跟高度优化过的 numpy 实现比,当然还是慢了很多啦。不过掐指一算,这 0.982ms 其实跟直接写 C++ 是差不多的,能实现这个这样的效果已经很令人满意了。不信我们可以试试看手写一次 C++ 版本:

// dot.cpp
#include <ctime>
#include <cstdlib>
#include <chrono>
#include <iostream>

class Matrix {
    float *data;
public:
    size_t n, m;
    Matrix(size_t r, size_t c): data(new float[r*c]), n(r), m(c) {}
    ~Matrix() { delete[] data; }
    float& operator() (size_t x, size_t y) { return data[x*m+y]; }
    float operator() (size_t x, size_t y) const { return data[x*m+y]; }
};

float dot(const Matrix &a, const Matrix& b) {
    Matrix c(a.n, b.m);
    for (size_t i = 0; i < a.n; ++i)
        for (size_t j = 0; j < b.m; ++j) {
            float s = 0;
            for (size_t k = 0; k < a.m; ++k)
                s += a(i, k) * b(k, j);
            c(i, j) = s;
        }
    return c(0, 0); // to comfort -O2 optimization
}

void fill_rand(Matrix &a) {
    for (size_t i = 0; i < a.n; ++i)
        for (size_t j = 0; j < a.m; ++j)
            a(i, j) = rand() / static_cast<float>(RAND_MAX) * 2 - 1;
}

int main() {
    srand((unsigned)time(NULL));
    const int n = 100, p = 200, m = 50, T = 100;
    Matrix a(n, p), b(p, m);
    fill_rand(a);
    fill_rand(b);
    auto st = std::chrono::system_clock::now();
    float s = 0;
    for (int i = 0; i < T; ++i) {
        s += dot(a, b);
    }
    auto ed = std::chrono::system_clock::now();
    std::chrono::duration<double> diff = ed-st;
    std::cerr << s << std::endl;
    std::cout << T << " loops. average " << diff.count() * 1e6 / T << "us" << std::endl;
}

$ g++ -O2 -std=c++11 -o dot dot.cpp
$ ./dot 2>/dev/null
100 loops. average 1112.11us

https://zhuanlan.zhihu.com/p/38212302

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值