一、背景
pytorch模型使用3种分辨率的输入图片,输入图片的分辨率分别为(224,224),(112,112),(56,56).
二、多输入的pytorch模型转成onnx模型
#coding:utf-8
import torch
import torch.onnx as onnx
from c3ae_net_mbv2 import MobileNetV2
torch_model_path = 'pytorch_model.pth'
onnx_model_path = 'onnx_model.onnx'
# 创建一个 PyTorch 模型实例
model = MobileNetV2()
checkpoint = torch.load(torch_model_path)
model.load_state_dict(checkpoint["state_dict"])
# 定义输入和输出的名称和形状
input_names = ['input1', 'input2','input3']
output_names = ['output']
# input_shapes = [(1, 3, 224, 224), (1, 3, 112, 112), (1, 3, 56, 56)]
# 将 PyTorch 模型转换为 ONNX 模型
model_input = [torch.randn

本文介绍了如何将一个PyTorch模型转换为支持多个输入分辨率(224x224,112x112,56x56)的ONNX模型,并演示了如何使用ONNXRuntime进行推理,包括图片预处理和输出结果的获取。
最低0.47元/天 解锁文章
5059

被折叠的 条评论
为什么被折叠?



