模型部署RV11126的流程大致为:训练得到.pth模型、pth2onnx、onnx2rknn,最后在边缘计算设备上完成部署,本文从初学者的角度为模型部署学习者提供一些模型格式转化的意见与经验。
一、所需条件
- ubuntu上rknn转换环境,参考:深度学习模型部署RV1126准备工作(一)——Ubuntu搭建Rknn环境_rv1126模型部署-优快云博客
- windows上rknn转换环境,参考:深度学习模型部署RV1126准备工作(一)附加篇——Windows配置rknn环境_rv1126 windows 模型转换-优快云博客
- 模型训练的.pth文件,参考:深度学习模型部署RV1126准备工作(二)——初学者如何快速上手训练所需模型_rv1126 训练-优快云博客
二、代码详解
1、调用相关文件
import torch.onnx
from torchvision import models
from onnxsim import simplify
import onnx
2、pytorch的推理模型加载
torch_model = torch.load("/home/xl/thinclient_drives/xd/Awesome-Backbones-main/Awesome-Backbones-main/logs/ResNet/2024-10-30-15-23-14/Val_Epoch056-Acc69.602.pth",
map_location='cpu') # pytorch模型加载
model = models.resnet50()
# model.load_state_dict(torch_model)
batch_size = 1 # 批处理大小
input_shape = (3, 224, 224) # 输入数据,改成自己的输入shape
3、开启模型的评估模式
model.eval()
4、设置onnx导出地址,并导出onnx模型
x = torch.randn(batch_size, *input_shape) # 生成张量
export_path = "../output/resnet.onnx"
export_onnx_file = export_path # 目的ONNX文件名
torch.onnx.export(model,
x,
export_onnx_file,
opset_version=12,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"]) # 输出名
5、简化onnx模型
# 加载导出的 ONNX 模型
onnx_model = onnx.load(export_path)
# 简化模型
simplified_model, check = simplify(onnx_model)
# 保存简化后的模型
onnx.save_model(simplified_model, "../output/resnet_simplified.onnx")
三、完整示例代码
import torch.onnx
from torchvision import models
from onnxsim import simplify
import onnx
#from models.resnet import resnet50 as model
torch_model = torch.load("/home/xl/thinclient_drives/xd/Awesome-Backbones-main/Awesome-Backbones-main/logs/ResNet/2024-10-30-15-23-14/Val_Epoch056-Acc69.602.pth",
map_location='cpu') # pytorch模型加载
model = models.resnet50()
# model.load_state_dict(torch_model)
batch_size = 1 # 批处理大小
input_shape = (3, 224, 224) # 输入数据,改成自己的输入shape
# #set the model to inference mode
model.eval()
x = torch.randn(batch_size, *input_shape) # 生成张量
export_path = "../output/resnet.onnx"
export_onnx_file = export_path # 目的ONNX文件名
torch.onnx.export(model,
x,
export_onnx_file,
opset_version=12,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"]) # 输出名
# dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
# "output":{0:"batch_size"}})
# 加载导出的 ONNX 模型
onnx_model = onnx.load(export_path)
# 简化模型
simplified_model, check = simplify(onnx_model)
# 保存简化后的模型
onnx.save_model(simplified_model, "../output/resnet_simplified.onnx")