概要
ONNX模型中的结构是一个有向图,包含了很多节点。每个节点执行一个特定的操作,最终就得到了推理结果。ONNX模型格式标准并没有要求所有节点按照拓扑顺序来存储,进行模型解析的时候也基本不要求解析出来的节点一定要符合拓扑顺序排列。有些模型很简单,从输入到输出,可能只有一条通路;有些模型很复杂,不仅输入和输出节点间存在多条通路,还有可能存在多个输入和输出节点。ONNX Runtime 是如何确定模型中各个节点执行的先后顺序的呢?怎么确保某个节点被执行之前,其所有先导节点都已经被执行?这就是今天需要解决的疑惑。ONNX Runtime 执行模型的方式主要有两种:串行和并行,好像有点废话了。通过初始化的时候传递个InferenceSession
的构造函数的结构体SessionOptions
中的ExecutionMode
成员来控制。今天主要研究串行执行时节点执行顺序。
涉及文件
onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc
onnxruntime\onnxruntime\core\session\inference_session.cc
onnxruntime\onnxruntime\core\framework\sequential_executor.cc
onnxruntime\onnxruntime\core\framework\session_state_initializer.cc
onnxruntime\onnxruntime\core\graph\graph_viewer.cc
onnxruntime\onnxruntime\core\framework\session_state.cc
onnxruntime\onnxruntime\core\graph\graph.cc
正文
举个栗子,有一个简单的模型,如图1所示:
在这个简单的模型里面,一共有六个节点,从输入到输出有两条通路。由于ONNX模型格式标准并没有要求所有节点按照拓扑顺序来存储,因此模型再次加载到内存以后,节点的顺序的排列完全是随机的,有可能是1、3、2、4、6、5,也可能是其他的顺序。因此,必须要先确定节点的拓扑结构并按照结构存储起来,这样才能在跑的时候知道那个是输入,哪些节点必须先跑完。
代码调用
在上一篇文章ONNX Runtime 源码阅读:模型推理过程概览中我们说过,模型节点执行顺序的确定是在InferenceSession
实例化完毕后,在初始化阶段完成的。
// onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc
py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
.def(
"load_model", [](InferenceSession* sess, std::vector<std::string>& provider_types) {
OrtPybindThrowIfError(sess->Load());
InitializeSession(sess, provider_types);
},
R"pbdoc(Load a model saved in ONNX format.)pbdoc")
从上面代码中可以看到,初始化也分为两个阶段:1)模型加载 2)InferenceSession
实例初始化。
模型加载?模型不是在生成InferenceSession
实例的时候已经加载到内存了么?其实在InferenceSession
实例化阶段加载的模型知识编译proto文件得到的类ModelProto
的一个实例,直接使用还是不太方便,因此还需要对它进行进一步解析和封装,OrtPybindThrowIfError(sess->Load());
这句话主要做的就是这件事。
我们接着来看InitializeSession(sess, provider_types);
:
// onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc
void InitializeSession(InferenceSession* sess, const std::vector<std::string>& provider_types) {
if (provider_types.empty()) {
// use default registration priority.
RegisterExecutionProviders(sess, GetAllProviders());
} else {
RegisterExecutionProviders(sess, provider_types);
}
OrtPybindThrowIfError(sess->Initialize());
}
可以看到,InitializeSession(sess, provider_types)
在注册Provider后,最终调用到了onnxruntime\onnxruntime\core\session\inference_session.cc
中类InferenceSession
的Initiablize()
方法。
Initiablize()
方法体非常长,但是有两行非常刺眼,session_initializer.CreatePlan; InitializeSubgraphSessions(graph, *session_state_)
,字面意思就是创建执行计划,开个上帝视角执行顺序这的是在这里创建的。由于方法体很长,这就贴一部分重要的好了:
// onnxruntime\onnxruntime\core\session\inference_session.cc # InferenceSession::Initialize()
onnxruntime::Graph& graph = model_->MainGraph();
// Collect the kernel registries from execution provider instances;
// There are 2 kinds of kernel registries with priority from high to low as below,
// 1. Custom execution provider type specific kernel registries.
// 2. common execution provider type specific kernel registries.
// The 1st and 2nd ones are shared across sessions.
// The 1st ones should have already been registered via session-level API into KernelRegistryManager.
//
// Register 2nd registries into KernelRegistryManager.
ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_));
SessionStateInitializer session_initializer(session_options_.enable_mem_pattern, model_location_, graph,
*session_state_, execution_providers_, kernel_registry_manager_);
// create SessionState for subgraphs as it's needed by the transformers
ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(graph, *session_state_));
// apply any transformations to the main graph and any subgraphs
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, *graph_transformation_mgr_,
execution_providers_, kernel_registry_manager_,
insert_cast_transformer_,
*session_state_