Loading a TorchScript Model in C++
The following paragraphs will outline the path PyTorch provides to go from an existing Python model to a serialized representation that can be loaded and executed purely from C++, with no dependency on Python.
A PyTorch model’s journey from Python to C++ is enabled by Torch Script, a representation of a PyTorch model that can be understood, compiled and serialized by the Torch Script compiler.
步骤:
- Step 1: Converting Your PyTorch Model to Torch Script
- Step 2: Serializing Your Script Module to a File
- Step 3: Loading Your Script Module in C++
- Step 4: Executing the Script Module in C++
- Step 5: Getting Help and Exploring the API
总结:
先生成torch script,这个script对象能被torch script compiler理解编译序列化。
通过:
traced_script_module = torch.jit.trace(model, example)
或
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
方法。
然后serlizing,变成pt文件that can be loaded and executed purely from C++
traced_script_module.save("traced_resnet_model.pt")
然后C++中load它,
module = torch::jit::load(argv[1]);
输入输出
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
注意:用C++使用pt模型的时候要依赖于libtorch
参考文献
https://pytorch.org/tutorials/advanced/cpp_export.html