一. 前言
上一篇文章对TVM Relay和Pass进行了介绍,但还没有介绍整体的编译流程。这一篇文章将继续介绍一下TVM的编译流程,即TVM是如何将深度学习框架的模型转换成Relay IR之后进一步编译和优化为硬件可以执行的IR,再将这个底层IR和运行时库以及模型参数打包为一个tvm.Module
返回。关于为什么要将底层IR和运行时库以及模型参数打包,根据官方文档可以知道这样是为了可以更方便的保存底层IR和运行时库,做到一次编译,可持久化推理。
二. TVM编译流程详解
TVM的编译流程在Python端的调用方式非常简单:
with tvm.transform.PassContext(opt_level=10):
lib = relay.build(func, "llvm", params=params)
这里的with tvm.transform.PassContext(opt_level=10)
是指定Pass的优化等级,在【从零开始学深度学习编译器】五,TVM Relay以及Pass简介 已经介绍了。这里就跟进一下lib = relay.build(func, "llvm", params=params)
这行代码来看一下TVM的编译流程。
首先这里的func
和params
分别代表模型的图结构以及权重参数。relay.build
这个函数定义在tvm/python/tvm/relay/build_module.py
这个函数中,入口代码如下:
@register_func("tvm.relay.build")
def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"):
"""A wrapper around build which discards the Python GraphFactoryRuntime.
This wrapper is suitable to be used from other programming languages as
the runtime::Module can be freely passed between language boundaries.
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
return build(mod, target, params=params, mod_name=mod_name).module
对于上面调用的例子,target
为llvm
代表这个模型会被TVM编译成CPU的可执行程序。Target.check_and_update_host_consist
这个函数应该是用来检查目标设备类型targer
以及target
对应的host
端是否指定正确的,如果指定正确则将这两个参数合并到一个Target
类中并返回。Target
这个类的实现在tvm/python/tvm/target/target.py
这里,是用来管理TVM支持的设备后端的。
接着就来到了build这个函数,代码实现如下:
def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"):
# fmt: off
# pylint: disable=line-too-long
"""一个将Relay Function编译成可执行程序的函数
参数
----------
ir_mod : :py:class:`~tvm.IRModule`
要编译的IR Module. 不推荐使用relay.Function
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional
对于异构编译,它是一个指示context到target映射的字典。 对于同构编译,它是一个编译target。
target_host : str or :any:`tvm.target.Target`, optional
主机编译target,如果target是device。 当 TVM 编译 CUDA 等device特定程序时,我们还需要主机(CPU)端代码与驱动程序交互,正确设置维度和参数。target_host 用于指定主机端代码生成target。 默认情况下,如果启用 llvm,则使用 llvm,否则使用 stackvm 解释器。
params : dict of str to NDArray
在推理阶段不会更改的Graph的权重参数,用于常量折叠。
mod_name: Optional[str]
The module name we will build
Returns
-------
graph_json : str
The json string that can be accepted by graph executor.
mod : tvm.Module
The module containing necessary libraries.
params : dict
The parameters of the final graph.
"""
# pylint: enable=line-too-long
# fmt: on
if not isinstance(ir_mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")
if isinstance(ir_mod, _function.Function):
if params:
ir_mod = bind_params_by_name(ir_mod, params)
ir_mod = IRModule.from_expr(ir_mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.function.Function)",
DeprecationWarning,
)
target = _update_target(target)
if isinstance(target_host, (str, Target)):
target_host = Target(target_host)
elif target_host:
raise ValueError("target host must be the type of str, " + "tvm.target.Target, or None")
target, target_host = Target.check_and_update_host_consist(
target, target_host, target_is_dict_key=False
)
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.utils.EmptyContext()
with tophub_context:
bld_mod = BuildModule()
graph_json, runtime_mod, params = bld_mod.build(mod=ir_mod, target=target, params=params)
executor_factory = _graph_executor_factory.GraphExe