Pybind11 numpy实战

Pybind11 中 NumPy 的常用操作指南

—— 在 C++ 与 Python 间实现高性能数据交互


引言

在科学计算和机器学习领域,NumPy 是 Python 生态的核心库,而 Pybind11 则是连接 C++ 高性能代码与 Python 的桥梁。二者结合,既能利用 C++ 的性能优势,又能享受 Python 的易用性。本文将详解如何在 Pybind11 中高效操作 NumPy 数组,涵盖数据传递、视图创建和避免复制等关键技巧。


1. 环境配置

确保环境包含:

  • Pybind11 (v2.10+)
  • NumPy (v1.21+)
  • C++ 编译器支持 C++11 或更高

CMakeLists.txt 中添加:

find_package(pybind11 REQUIRED)  
find_package(Python REQUIRED COMPONENTS NumPy)  
target_link_libraries(your_target PRIVATE pybind11::module Python::NumPy)

2. 基础操作:接收与返回 NumPy 数组
2.1 从 Python 接收数组

使用 py::array_t<T> 直接接收 NumPy 数组:

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>

namespace py = pybind11;

// 示例:数组求和
py::array_t<double> add_arrays(py::array_t<double> a, py::array_t<double> b) {
    auto buf_a = a.request(), buf_b = b.request();
    
    // 检查维度与形状
    if (buf_a.ndim != 1 || buf_b.ndim != 1)
        throw std::runtime_error("Only 1D arrays allowed!");
    if (buf_a.size != buf_b.size)
        throw std::runtime_error("Array sizes must match!");

    // 申请结果内存
    auto result = py::array_t<double>(buf_a.size);
    auto buf_res = result.request();
    
    // 获取数据指针
    double *ptr_a = static_cast<double*>(buf_a.ptr);
    double *ptr_b = static_cast<double*>(buf_b.ptr);
    double *ptr_res = static_cast<double*>(buf_res.ptr);

    // 计算
    for (size_t i = 0; i < buf_a.size; i++) {
        ptr_res[i] = ptr_a[i] + ptr_b[i];
    }
    return result;
}
2.2 返回数组给 Python

使用 py::array_t 封装 C++ 数据:

// 创建并返回一个全零数组
py::array_t<float> create_zeros(int size) {
    // 自动管理内存
    auto arr = py::array_t<float>(size);
    auto buf = arr.request();
    float* ptr = static_cast<float*>(buf.ptr);
    std::fill(ptr, ptr + size, 0.0f);
    return arr;
}

3. 高级技巧:避免数据复制
3.1 使用 array_t::unchecked 直接访问数据

跳过边界检查提升性能:

double sum_array(py::array_t<double> input) {
    auto arr = input.unchecked<1>(); // 1D 数组视图
    double sum = 0.0;
    for (size_t i = 0; i < arr.shape(0); i++) {
        sum += arr[i]; // 直接访问元素
    }
    return sum;
}
3.2 创建非复制视图(View)

将 C++ 数据暴露为 NumPy 数组而不复制:

// 将 C++ 向量转为 NumPy 视图
py::array_t<int> vector_to_numpy(std::vector<int>& vec) {
    return py::array_t<int>(
        {vec.size()},      // 形状
        {sizeof(int)},     // 步长
        vec.data()         // 原始指针
    );
}

4. 处理多维数组

操作形状为 (H, W) 的图像数组:

// 反转 RGB 图像通道 (H, W, 3)
py::array_t<uint8_t> reverse_channels(py::array_t<uint8_t> img) {
    auto buf = img.request();
    if (buf.ndim != 3 || buf.shape[2] != 3)
        throw std::runtime_error("Expected H x W x 3 array!");
    
    auto result = py::array_t<uint8_t>(buf.shape);
    auto res_buf = result.request();
    
    uint8_t* in = static_cast<uint8_t*>(buf.ptr);
    uint8_t* out = static_cast<uint8_t*>(res_buf.ptr);
    
    size_t H = buf.shape[0], W = buf.shape[1];
    for (size_t i = 0; i < H; i++) {
        for (size_t j = 0; j < W; j++) {
            // 反转通道顺序: RGB -> BGR
            out[i*W*3 + j*3] = in[i*W*3 + j*3 + 2]; // B
            out[i*W*3 + j*3 + 1] = in[i*W*3 + j*3 + 1]; // G
            out[i*W*3 + j*3 + 2] = in[i*W*3 + j*3];     // R
        }
    }
    return result;
}

5. 关键注意事项
  1. 内存连续性

    • 使用 arr.request().flags & py::array::c_contiguous 检查连续性。
    • 非连续数组需显式处理步长(buf.strides)。
  2. 数据类型匹配

    • 通过 dtype() 检查类型:
      if (a.dtype() != py::dtype::of<float>())
          throw py::type_error("Expected float32 array!");
      
  3. GIL 管理

    • 长时间计算前释放 GIL:
      py::call_guard<py::gil_scoped_release>();
      

6. 完整示例:矩阵乘法
// 绑定到 Python 模块
PYBIND11_MODULE(numpy_ops, m) {
    m.def("add_arrays", &add_arrays, "Add two NumPy arrays");
    m.def("sum_array", &sum_array, "Sum elements without copy");
    m.def("reverse_channels", &reverse_channels, "Reverse RGB channels");
}

// 在 Python 中使用
import numpy_ops
import numpy as np

a = np.array([1.0, 2.0, 3.0])
b = np.array([4.0, 5.0, 6.0])
c = numpy_ops.add_arrays(a, b)  # [5.0, 7.0, 9.0]

结语

通过 Pybind11 操作 NumPy 数组,开发者能够在 C++ 中实现高性能计算,同时与 Python 生态无缝集成。核心要点包括:

  • 使用 py::array_t<T> 安全传递数据
  • 通过 unchecked() 和视图避免复制
  • 正确处理多维数组与内存布局

掌握这些技巧后,可显著提升混合编程的效率,尤其适用于图像处理、数值模拟等计算密集型任务。

扩展阅读

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值