将深度学习模型嵌入C++代码,通常有以下几个步骤:
-
选择合适的推理框架:如TensorFlow C++ API、PyTorch C++ API(LibTorch)、OpenVINO、ONNX Runtime、TFLite等。
-
将训练好的模型转换为对应格式:例如,PyTorch模型转换为TorchScript,TensorFlow模型转换为SavedModel或冻结的PB格式,或者转换为ONNX格式。
-
在C++项目中集成推理框架:通过CMake或其他构建工具链接推理库。
-
编写C++代码加载模型并进行推理。
下面以两种常见的框架为例简要说明:
1. 使用LibTorch(PyTorch C++)
步骤:
-
下载LibTorch(PyTorch的C++版本)。
-
将PyTorch模型转换为TorchScript。
-
在C++项目中引入LibTorch,加载模型并进行推理。
示例代码:
cpp
#include <torch/script.h> // 引入TorchScript头文件
#include <iostream>
#include <vector>
int main() {
// 加载模型
torch::jit::script::Module module;
try {
module = torch::jit::load("path/to/your/model.pt");
}
catch (const c10::Error& e) {
std::cerr << "错误:无法加载模型\n";
return -1;
}
// 创建输入向量
std::vector<float> input_data = {1.0, 2.0, 3.0, 4.0}; // 示例数据
// 将输入数据转换为tensor
torch::Tensor input_tensor = torch::from_blob(input_data.data(), {1, 4}, torch::kFloat);
// 执行推理
std::vector<torch::jit::IValue> inputs;
inputs.push_back(input_tensor);
at::Tensor output = module.forward(inputs).toTensor();
// 输出结果
std::cout << output << std::endl;
return 0;
}
编译时需要链接LibTorch库。CMakeLists.txt示例:
cmake
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example_app)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_executable(example_app example_app.cpp)
target_link_libraries(example_app "${TORCH_LIBRARIES}")
set_property(TARGET example_app PROPERTY CXX_STANDARD 14)
2. 使用ONNX Runtime
步骤:
-
将模型转换为ONNX格式。
-
在C++项目中集成ONNX Runtime。
-
加载ONNX模型并进行推理。
示例代码:
cpp
#include <onnxruntime_cxx_api.h>
#include <vector>
#include <iostream>
int main() {
// 初始化ONNX Runtime环境
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
// 加载模型
const char* model_path = "path/to/your/model.onnx";
Ort::Session session(env, model_path, session_options);
// 获取模型输入输出信息
Ort::AllocatorWithDefaultOptions allocator;
auto input_name = session.GetInputName(0, allocator);
auto output_name = session.GetOutputName(0, allocator);
std::vector<const char*> input_names = {input_name};
std::vector<const char*> output_names = {output_name};
// 准备输入数据
std::vector<float> input_tensor_values = {1.0, 2.0, 3.0, 4.0};
std::vector<int64_t> input_shape = {1, 4}; // 假设输入形状为[1,4]
// 创建输入tensor
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_values.size(), input_shape.data(), input_shape.size());
// 运行模型
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_names.data(), &input_tensor, 1, output_names.data(), 1);
// 获取输出
float* floatarr = output_tensors[0].GetTensorMutableData<float>();
// 输出结果
std::cout << "输出: " << floatarr[0] << std::endl;
// 释放资源
allocator.Free(input_name);
allocator.Free(output_name);
return 0;
}
编译时需要链接ONNX Runtime库。
注意事项
-
模型转换:确保模型转换正确,注意输入输出节点名称和形状。
-
依赖管理:确保C++项目能正确找到推理库的头文件和库文件。
-
性能优化:根据目标硬件平台,可能需要进行特定的优化(如使用GPU、量化等)。
-
内存管理:注意内存分配和释放,避免内存泄漏。
根据你的具体需求(如模型类型、部署环境、性能要求等)选择合适的框架和优化方法。
将深度学习模型嵌入C++代码主要有以下几种方式,我会从简单到复杂为你详细介绍:
1. 使用ONNX Runtime(推荐)
这是最通用和简单的方法,支持多种框架训练的模型。
步骤:
-
将模型转换为ONNX格式
python
# PyTorch示例
import torch
import torch.onnx
model = YourModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx",
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
-
C++代码集成
cpp
#include <onnxruntime_cxx_api.h>
#include <vector>
#include <iostream>
class ONNXModel {
private:
Ort::Env env;
Ort::Session session;
std::vector<const char*> input_names;
std::vector<const char*> output_names;
public:
ONNXModel(const std::string& model_path) :
env(ORT_LOGGING_LEVEL_WARNING, "Model"),
session(env, model_path.c_str(), Ort::SessionOptions{}) {
// 获取输入输出信息
Ort::AllocatorWithDefaultOptions allocator;
input_names.push_back(session.GetInputName(0, allocator));
output_names.push_back(session.GetOutputName(0, allocator));
}
std::vector<float> predict(const std::vector<float>& input) {
// 准备输入tensor
std::vector<int64_t> input_shape = {1, 3, 224, 224}; // 根据你的模型调整
size_t input_tensor_size = 1 * 3 * 224 * 224;
auto memory_info = Ort::MemoryInfo::CreateCpu(
OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info,
const_cast<float*>(input.data()),
input_tensor_size,
input_shape.data(),
input_shape.size()
);
// 运行推理
auto output_tensors = session.Run(
Ort::RunOptions{nullptr},
input_names.data(),
&input_tensor,
1,
output_names.data(),
1
);
// 获取输出
float* floatarr = output_tensors[0].GetTensorMutableData<float>();
auto output_shape = output_tensors[0].GetTensorTypeAndShapeInfo().GetShape();
size_t output_size = output_shape[0] * output_shape[1];
return std::vector<float>(floatarr, floatarr + output_size);
}
};
2. 使用LibTorch(PyTorch C++)
适合直接部署PyTorch模型。
步骤:
-
安装LibTorch
bash
# 从PyTorch官网下载预编译的LibTorch
-
C++代码示例
cpp
#include <torch/script.h>
#include <torch/torch.h>
class TorchModel {
private:
torch::jit::script::Module module;
public:
TorchModel(const std::string& model_path) {
try {
// 加载模型
module = torch::jit::load(model_path);
} catch (const c10::Error& e) {
std::cerr << "Error loading model\n";
}
}
std::vector<float> predict(const std::vector<float>& input) {
// 转换为tensor
auto options = torch::TensorOptions().dtype(torch::kFloat32);
torch::Tensor input_tensor = torch::from_blob(
(void*)input.data(),
{1, 3, 224, 224},
options
).clone();
// 推理
std::vector<torch::jit::IValue> inputs;
inputs.push_back(input_tensor);
auto output = module.forward(inputs).toTensor();
// 转换回vector
std::vector<float> result(output.data_ptr<float>(),
output.data_ptr<float>() + output.numel());
return result;
}
};
3. 使用TensorFlow C++ API
适合TensorFlow/Keras模型。
cpp
#include <tensorflow/cc/client/client_session.h>
#include <tensorflow/cc/ops/standard_ops.h>
#include <tensorflow/core/framework/tensor.h>
class TFModel {
private:
tensorflow::Session* session;
tensorflow::GraphDef graph_def;
public:
bool loadModel(const std::string& model_path) {
tensorflow::Status status = ReadBinaryProto(
tensorflow::Env::Default(), model_path, &graph_def);
if (!status.ok()) return false;
tensorflow::SessionOptions options;
status = tensorflow::NewSession(options, &session);
if (!status.ok()) return false;
status = session->Create(graph_def);
return status.ok();
}
std::vector<float> predict(const std::vector<float>& input) {
tensorflow::Tensor input_tensor(
tensorflow::DT_FLOAT,
tensorflow::TensorShape({1, 3, 224, 224})
);
// 填充数据...
std::vector<std::pair<std::string, tensorflow::Tensor>> inputs = {
{"input_layer", input_tensor}
};
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status status = session->Run(
inputs, {"output_layer"}, {}, &outputs);
if (status.ok()) {
// 处理输出...
}
return std::vector<float>();
}
};
4. 构建配置(CMake)
ONNX Runtime + CMake
cmake
cmake_minimum_required(VERSION 3.10) project(DLInference) set(CMAKE_CXX_STANDARD 14) # 查找ONNX Runtime find_package(ONNXRuntime REQUIRED) add_executable(inference_app main.cpp) target_link_libraries(inference_app ONNXRuntime::onnxruntime)
LibTorch + CMake
cmake
cmake_minimum_required(VERSION 3.10)
project(TorchInference)
set(CMAKE_CXX_STANDARD 14)
# 查找LibTorch
find_package(Torch REQUIRED)
add_executable(torch_app main.cpp)
target_link_libraries(torch_app "${TORCH_LIBRARIES}")
5. 完整使用示例
cpp
#include <iostream>
#include "ONNXModel.hpp" // 你的模型封装类
int main() {
// 初始化模型
ONNXModel model("path/to/your/model.onnx");
// 准备输入数据
std::vector<float> input_data(1 * 3 * 224 * 224, 0.5f); // 示例数据
// 进行推理
auto start = std::chrono::high_resolution_clock::now();
std::vector<float> result = model.predict(input_data);
auto end = std::chrono::high_resolution_clock::now();
// 输出结果
std::cout << "推理时间: "
<< std::chrono::duration<double, std::milli>(end - start).count()
<< " ms" << std::endl;
std::cout << "输出大小: " << result.size() << std::endl;
return 0;
}
选择建议:
-
ONNX Runtime - 最通用,支持多种框架,性能好
-
LibTorch - 适合PyTorch生态,开发活跃
-
TensorFlow C++ - 适合TensorFlow模型
-
OpenVINO - 英特尔硬件优化
-
TensorRT - NVIDIA GPU优化
推荐从ONNX Runtime开始,它提供了最好的兼容性和不错的性能。
4871

被折叠的 条评论
为什么被折叠?



