转换
torch.onnx.export(model, args, f, export_params=True, verbose=False, input_names=None,
output_names=None,do_constant_folding=True,dynamic_axes=None,opset_version=9)
常用参数:
1.model:torch.nn.model
要导出的模型
2.args:tuple or tensor
模型的输入参数。注意tuple的最后参数为dict要小心,详见pytorch文档。
输入参数只需满足shape正确,为什么要输入参数呢?因为后面torch.jit.trace要用到,先按下不表。
3.f:file object or string
转换输出的模型的位置,如'yolov4.onnx'
4.export_params:bool,default=True
true表示导出trained model,否则untrained model。默认即可
5.verbose:bool,default=False
true表示打印调试信息
6.input_names:list of string,default=None
指定输入节点名称
7.output_names:list of string,default=None
指定输出节点名称
8.do_constant_folding:bool,default=True
是否使用常量折叠,默认即可
9.dynamic_axes:dict<string, dict<int, string>> or dict<string, list(int)>,default=None
有时模型的输入输出是可变的,如RNN,或者输入输出图片的batch是可变的,
这时我们通过dynamic_axes来指定输入tensor的哪些参