Pybind11 中 NumPy 的常用操作指南
—— 在 C++ 与 Python 间实现高性能数据交互
引言
在科学计算和机器学习领域,NumPy 是 Python 生态的核心库,而 Pybind11 则是连接 C++ 高性能代码与 Python 的桥梁。二者结合,既能利用 C++ 的性能优势,又能享受 Python 的易用性。本文将详解如何在 Pybind11 中高效操作 NumPy 数组,涵盖数据传递、视图创建和避免复制等关键技巧。
1. 环境配置
确保环境包含:
- Pybind11 (v2.10+)
- NumPy (v1.21+)
- C++ 编译器支持 C++11 或更高
在 CMakeLists.txt
中添加:
find_package(pybind11 REQUIRED)
find_package(Python REQUIRED COMPONENTS NumPy)
target_link_libraries(your_target PRIVATE pybind11::module Python::NumPy)
2. 基础操作:接收与返回 NumPy 数组
2.1 从 Python 接收数组
使用 py::array_t<T>
直接接收 NumPy 数组:
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;
// 示例:数组求和
py::array_t<double> add_arrays(py::array_t<double> a, py::array_t<double> b) {
auto buf_a = a.request(), buf_b = b.request();
// 检查维度与形状
if (buf_a.ndim != 1 || buf_b.ndim != 1)
throw std::runtime_error("Only 1D arrays allowed!");
if (buf_a.size != buf_b.size)
throw std::runtime_error("Array sizes must match!");
// 申请结果内存
auto result = py::array_t<double>(buf_a.size);
auto buf_res = result.request();
// 获取数据指针
double *ptr_a = static_cast<double*>(buf_a.ptr);
double *ptr_b = static_cast<double*>(buf_b.ptr);
double *ptr_res = static_cast<double*>(buf_res.ptr);
// 计算
for (size_t i = 0; i < buf_a.size; i++) {
ptr_res[i] = ptr_a[i] + ptr_b[i];
}
return result;
}
2.2 返回数组给 Python
使用 py::array_t
封装 C++ 数据:
// 创建并返回一个全零数组
py::array_t<float> create_zeros(int size) {
// 自动管理内存
auto arr = py::array_t<float>(size);
auto buf = arr.request();
float* ptr = static_cast<float*>(buf.ptr);
std::fill(ptr, ptr + size, 0.0f);
return arr;
}
3. 高级技巧:避免数据复制
3.1 使用 array_t::unchecked
直接访问数据
跳过边界检查提升性能:
double sum_array(py::array_t<double> input) {
auto arr = input.unchecked<1>(); // 1D 数组视图
double sum = 0.0;
for (size_t i = 0; i < arr.shape(0); i++) {
sum += arr[i]; // 直接访问元素
}
return sum;
}
3.2 创建非复制视图(View)
将 C++ 数据暴露为 NumPy 数组而不复制:
// 将 C++ 向量转为 NumPy 视图
py::array_t<int> vector_to_numpy(std::vector<int>& vec) {
return py::array_t<int>(
{vec.size()}, // 形状
{sizeof(int)}, // 步长
vec.data() // 原始指针
);
}
4. 处理多维数组
操作形状为 (H, W)
的图像数组:
// 反转 RGB 图像通道 (H, W, 3)
py::array_t<uint8_t> reverse_channels(py::array_t<uint8_t> img) {
auto buf = img.request();
if (buf.ndim != 3 || buf.shape[2] != 3)
throw std::runtime_error("Expected H x W x 3 array!");
auto result = py::array_t<uint8_t>(buf.shape);
auto res_buf = result.request();
uint8_t* in = static_cast<uint8_t*>(buf.ptr);
uint8_t* out = static_cast<uint8_t*>(res_buf.ptr);
size_t H = buf.shape[0], W = buf.shape[1];
for (size_t i = 0; i < H; i++) {
for (size_t j = 0; j < W; j++) {
// 反转通道顺序: RGB -> BGR
out[i*W*3 + j*3] = in[i*W*3 + j*3 + 2]; // B
out[i*W*3 + j*3 + 1] = in[i*W*3 + j*3 + 1]; // G
out[i*W*3 + j*3 + 2] = in[i*W*3 + j*3]; // R
}
}
return result;
}
5. 关键注意事项
-
内存连续性
- 使用
arr.request().flags & py::array::c_contiguous
检查连续性。 - 非连续数组需显式处理步长(
buf.strides
)。
- 使用
-
数据类型匹配
- 通过
dtype()
检查类型:if (a.dtype() != py::dtype::of<float>()) throw py::type_error("Expected float32 array!");
- 通过
-
GIL 管理
- 长时间计算前释放 GIL:
py::call_guard<py::gil_scoped_release>();
- 长时间计算前释放 GIL:
6. 完整示例:矩阵乘法
// 绑定到 Python 模块
PYBIND11_MODULE(numpy_ops, m) {
m.def("add_arrays", &add_arrays, "Add two NumPy arrays");
m.def("sum_array", &sum_array, "Sum elements without copy");
m.def("reverse_channels", &reverse_channels, "Reverse RGB channels");
}
// 在 Python 中使用
import numpy_ops
import numpy as np
a = np.array([1.0, 2.0, 3.0])
b = np.array([4.0, 5.0, 6.0])
c = numpy_ops.add_arrays(a, b) # [5.0, 7.0, 9.0]
结语
通过 Pybind11 操作 NumPy 数组,开发者能够在 C++ 中实现高性能计算,同时与 Python 生态无缝集成。核心要点包括:
- 使用
py::array_t<T>
安全传递数据 - 通过
unchecked()
和视图避免复制 - 正确处理多维数组与内存布局
掌握这些技巧后,可显著提升混合编程的效率,尤其适用于图像处理、数值模拟等计算密集型任务。
扩展阅读