〇、官方参考文档
正文开始之前,附几个官方链接,以官方教程为主,下文为辅进行模型部署。
昇腾社区-官网丨昇腾万里 让智能无所不及 查阅文档内容
资源-Atlas 200I DK A2-昇腾社区 翻页-->开发课程-->开发者课程-->图片分类应用开发入门教程
一、引言
由于本科毕设要求,需要将pytorch图像识别(10分类)模型进行嵌入式端部署。本文假设已经获得训练好的.pth文件(权重+模型),介绍从pth文件到onnx文件再到om模型的转换以及python/C++的推理过程。
二、软硬件准备
- Atlas 200I DK A2 开发者套件
- windows 10或ubuntu系统(用以模型转换)
- .pth模型
- Python3.11+PyTorch环境
三、pth-->onnx模型转换
1、python代码转换
import torchvision
import onnx
import torch
from vgg16 import Vgg16Classifier
from resnet18 import ResNet18Classifier
#如果保存的pth文件是权重文件不包含模型框架,
#需要在代码中加入模型类的定义或者通过from的方式import
num_classes = 10 #模型分类输出数
DIR_STATE_DIC_PTH = './model/resnet18_offical.pth' #pth模型地址
ONNX_MODEL_PATH = './onnx/' #onnx模型保存地址
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 将state_dic_pth权重转换成完整的pth文件
def get_full_pth(type, DIR_STATE_DIC_PTH):
if type=='vgg16':
model = Vgg16Classifier(1, num_classes)
elif type=='resnet18':
model = ResNet18Classifier(1, num_classes) #(通道数,输出类别数)
# 此处根据自己模型定义设置,相当于实例化一个类对象
model.load_state_dict(torch.load(DIR_STATE_DIC_PTH)) # weight=torch.load(DIR_STATE_DIC_PTH)
#将权重加载到实例化的类对象内,获得完整的pth文件
print(f'{type}.pth完整模型已获得!')
return model
def pth2onnx(type):
model = get_full_pth(type, DIR_STATE_DIC_PTH)
model.to(device)
model.eval()
dummy_img = torch.Tensor(torch.randn((1, 1, 224, 224)).cuda()) # (batchsize, channels, width, height)
torch.onnx.export(
model = model, #pth模型
args=dummy_img,
f=ONNX_MODEL_PATH+type+'.onnx', #onnx保存模型地址
export_params=True,
verbose=False,
input_names=['input'],
output_names=['output'],
opset_version=11 #pth转onnx好像都是设置=11
)
print(f'已生成{type}.onnx文件!')
# 按间距中的绿色按钮以运行脚本。
if __name__ == '__main__':
pth2onnx('resnet18')
值得注意的是,pth转onnx要保证pth文件既包含训练获得的权重(字典格式),也要包括模型结构本身。
1、方式一
一般训练过程中会在训练过程中间断地保存模型,一般选择保存字典格式:
torch.save(ResNet18Model.state_dict(), './Model_train3/ResNet18Model_{}.pth'.format(i + 1))
那么对应的加载方式为:
ResNet18Model = ResNet18Classifier(in_channels=1, num_classes=num_classes)
ResNet18Model.load_state_dict(torch.load('./Model_train2/ResNet18Model_395.pth')) # 字典的方式需要先实例化再load
2、方式二
也可以选择保存全部,占用内存会比字典格式大一丢丢:
torch.save(ResNet18Model, './Model/ResNet18Model_{}.pth'.format(i+1))
那么对应的加载方式为:
ResNet18Model = torch.load('.