[Onnx简化库深度剖析] OnnxSimplifier和OnnxOptimizer解读-(1)

[Onnx简化库深度剖析] OnnxSimplifier和OnnxOptimizer解读-(1)

简介

OnnxSimplifier是一个用于简化onnx模型的工具,主要工具就是:拥有折叠常量(FoldConstant)的功能、自动调度OnnxOptimizer,最为重要而且核心的是FixedPointFn这个简化调度算法。

OnnxOptimizer是一个onnx官方的一个onnx模型优化库,内部包含很多模型简化/优化的功能。用户也可直接通过python/c++/c api执行调用,但是需要比较了解内部的opt优化手段,才能够得到理想的结果。
  • 依赖情况
40% 24% 24% 12% OnnxSimplifier OnnxOptimizer Onnx OnnxRuntime pybind11

目的

从上述的描述来看,似乎OnnxSimplifier也没有干什么事情,因为OnnxOptimizer才是干简化模型的主要工具。但是OnnxSimplifier主要有以下的几点主要优点和必要性让其比较突出:

  • OnnxSimplifier接口参数较为简单,不需要了过多了解OnnxOptimizer的内部参数和优化手段
  • FixedPointFn简化调度算法让模型能够尽可能优化到最简的模型结果上,这主要因为这个迭代算法在交替使用FoldConstant和OnnxOptimizer进行优化。

OnnxSimplifier基本原理

FoldConstant功能

  • 目的:去除掉模型中那些跟输入数据流无关的叶子节点,也就是constant_node。通过单独运行constant_node,可以得到常量的output tensor,这些output tensor将被加入到模型中作为常量数据而存在,而该constant_node也将会从模型中移除。
  • constant_node条件:
    • node的domain应该属于以下的一种:[✔]
      • ai.onnx
      • ai.onnx.ml
    • node的op_type不属于以下的任何一种:[✖]
      • RandomUniform
      • RandomNormal
      • RandomUniformLike
      • RandomNormalLike
      • Multinomial
    • node不应该是以下的节点: [✖]
      • QuantizeLinear
      • DequantizeLinear
    • node不存在子图 [✖]
    • node不会产生超过threshold大小的tensor [✖]
    • node的所有输入应该都在model.graph.initializer中 [✔]

FixedPointFn迭代优化函数

  • 基本原理:就是通过两个优化函数,反复迭代优化中得到了最终无法继续优化的最终模型。

  • FixedPointFn的原始代码如下:

    template <typename T>
     std::function<T(const T&)> FixedPointFn(const std::function<T(const T&)>& f1,
                                             const std::function<T(const T&)>& f2,
                                             size_t max_iters, bool* converged) {
     return [f1, f2, max_iters, converged](const T& x) {
         size_t _max_iters = max_iters;
         T tmp1 = f1(x);
         T tmp2 = f2(tmp1);
         T& y1 = tmp1;
         T& y2 = tmp2;
         while (_max_iters-- > 0) {
         // 超出迭代次数则跳出
         if (google::protobuf::util::MessageDifferencer::Equals(y1, y2)) {
             // f1(x) == f2(f1(x))时,则无法继续优化,直接返回f2(f1(x))
             if (converged) {
             *converged = true;
             }
             return y2;
         }
         y1 = f1(y2);
         if (google::protobuf::util::MessageDifferencer::Equals(y1, y2)) {
             if (converged) {
             *converged = true;
             }
             return y1;
         }
         y2 = f2(y1);
         }
    
         if (converged) {
         *converged = false;
         }
         return y2;
     };
     }
    
  • FixedPointFn的流程图如下所示:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值