【记录】pth-->onnx-->om华为Atlas200DK深度学习Pytorch模型部署初体验

〇、官方参考文档

正文开始之前,附几个官方链接,以官方教程为主,下文为辅进行模型部署。

昇腾社区-官网丨昇腾万里 让智能无所不及   查阅文档内容

资源-Atlas 200I DK A2-昇腾社区      翻页-->开发课程-->开发者课程-->图片分类应用开发入门教程

昇腾论坛

一、引言

由于本科毕设要求,需要将pytorch图像识别(10分类)模型进行嵌入式端部署。本文假设已经获得训练好的.pth文件(权重+模型),介绍从pth文件到onnx文件再到om模型的转换以及python/C++的推理过程。

 二、软硬件准备

  1. Atlas 200I DK A2 开发者套件
  2. windows 10或ubuntu系统(用以模型转换)
  3. .pth模型
  4. Python3.11+PyTorch环境

三、pth-->onnx模型转换

1、python代码转换

import torchvision
import onnx
import torch
from vgg16 import Vgg16Classifier
from resnet18 import ResNet18Classifier 
        #如果保存的pth文件是权重文件不包含模型框架,
        #需要在代码中加入模型类的定义或者通过from的方式import
num_classes = 10 #模型分类输出数
DIR_STATE_DIC_PTH = './model/resnet18_offical.pth' #pth模型地址
ONNX_MODEL_PATH = './onnx/' #onnx模型保存地址
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 将state_dic_pth权重转换成完整的pth文件
def get_full_pth(type, DIR_STATE_DIC_PTH):
    if type=='vgg16':
        model = Vgg16Classifier(1, num_classes)

    elif type=='resnet18':
        model = ResNet18Classifier(1, num_classes) #(通道数,输出类别数)
        # 此处根据自己模型定义设置,相当于实例化一个类对象
    model.load_state_dict(torch.load(DIR_STATE_DIC_PTH))  # weight=torch.load(DIR_STATE_DIC_PTH)
    #将权重加载到实例化的类对象内,获得完整的pth文件
    print(f'{type}.pth完整模型已获得!')
    return model
def pth2onnx(type):
    model = get_full_pth(type, DIR_STATE_DIC_PTH)
    model.to(device)
    model.eval()
    dummy_img = torch.Tensor(torch.randn((1, 1, 224, 224)).cuda()) # (batchsize, channels, width, height)
    torch.onnx.export(
        model = model, #pth模型
        args=dummy_img,
        f=ONNX_MODEL_PATH+type+'.onnx', #onnx保存模型地址
        export_params=True,
        verbose=False,
        input_names=['input'],
        output_names=['output'],
        opset_version=11 #pth转onnx好像都是设置=11
    )
    print(f'已生成{type}.onnx文件!')
# 按间距中的绿色按钮以运行脚本。
if __name__ == '__main__':
    pth2onnx('resnet18')
    

值得注意的是,pth转onnx要保证pth文件既包含训练获得的权重(字典格式),也要包括模型结构本身。

1、方式一

一般训练过程中会在训练过程中间断地保存模型,一般选择保存字典格式

torch.save(ResNet18Model.state_dict(), './Model_train3/ResNet18Model_{}.pth'.format(i + 1))

那么对应的加载方式为:

ResNet18Model = ResNet18Classifier(in_channels=1, num_classes=num_classes)
ResNet18Model.load_state_dict(torch.load('./Model_train2/ResNet18Model_395.pth')) # 字典的方式需要先实例化再load


 2、方式二

也可以选择保存全部,占用内存会比字典格式大一丢丢:

torch.save(ResNet18Model, './Model/ResNet18Model_{}.pth'.format(i+1))

那么对应的加载方式为:

ResNet18Model = torch.load('.
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值