一、背景
pytorch模型使用3种分辨率的输入图片,输入图片的分辨率分别为(224,224),(112,112),(56,56)
.
二、多输入的pytorch模型转成onnx模型
#coding:utf-8
import torch
import torch.onnx as onnx
from c3ae_net_mbv2 import MobileNetV2
torch_model_path = 'pytorch_model.pth'
onnx_model_path = 'onnx_model.onnx'
# 创建一个 PyTorch 模型实例
model = MobileNetV2()
checkpoint = torch.load(torch_model_path)
model.load_state_dict(checkpoint["state_dict"])
# 定义输入和输出的名称和形状
input_names = ['input1', 'input2','input3']
output_names = ['output']
# input_shapes = [(1, 3, 224, 224), (1, 3, 112, 112), (1, 3, 56, 56)]
# 将 PyTorch 模型转换为 ONNX 模型
model_input = [torch.randn(1, 3, 224, 224), torch.randn(1, 3, 112, 112), torch.randn(1, 3, 56, 56)]
onnx.export(model, model_input, onnx_model_path, input_names=input_names, output_names=output_names)
二、多输入的onnx模型推理
#coding:utf-8
import os
import cv2
import onnxruntime as ort
import numpy as np
onnx_path = 'my_model.onnx'
img_path='test_img.jpg'
img_sizes = [56, 112, 224]
sess = ort.InferenceSession(onnx_path)
## 获取onnx模型输入/输出节点的名称
input_name1 = sess.get_inputs()[0].name
input_name2 = sess.get_inputs()[1].name
input_name3 = sess.get_inputs()[2].name
input_shape = sess.get_inputs()[0].shape
output_name = sess.get_outputs()[0].name #
## 读取图片
image = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_COLOR) #### 忽略透明度
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
## 输入图片预处理
### 输入图片转成onnx模型的输入形式
inputs = []
for img_size in img_sizes:
input1 = cv2.resize(image, img_size)
input1 = np.transpose(input1 / 255.0, (2, 0, 1))
input1 = input1[np.newaxis, :]
input1 = np.float32(input1)
inputs.append(input1)
## 调用模型,得到模型输出
results_ort = sess.run([output_name], {input_name1: inputs[0],input_name2: inputs[1],input_name3: inputs[2]})
detects = np.array(results_ort[0][0])
print('detects',detects)
参考https://juejin.cn/s/pytorch%20%E8%BD%AConnx%20%E5%A4%9A%E4%B8%AA%E8%BE%93%E5%85%A5