Torch2ONNX的一个实例

该代码段展示了如何将一个PyTorch模型(使用SimpleConvAE示例)转换为ONNX格式,并进行模型验证、ONNX运行时推理以及模型可视化的过程。首先,创建环境并安装所需库,然后定义函数将.pth权重文件转换为ONNX模型,接着检查ONNX模型的正确性,最后执行ONNX模型推理并使用Netron进行模型可视化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

环境安装

conda create -n onnx python==3.8
pip install numpy          # numpy <= 1.23.5
pip install pandas         
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install netron

将已生成的pth文件转换为onnx

pth_path为pth文件路径
dummy_input为实例化输入

def torch2onnx(onnx_name, pth_path, model, batch_size):
    onnx_file_name = onnx_name

    model.load_state_dict(torch.load(pth_path))
    model.eval()

    dummy_input = torch.randn(batch_size, 7, 100, requires_grad = True)
    #output = model(dummy_input)
    #print(output.shape)

    torch.onnx.export(model,
                    dummy_input,
                    onnx_file_name,
                    export_params = True,
                    opset_version = 10,
                    do_constant_folding= True,
                    input_names = ['input'],
                    output_names = ['output'],
                    dynamic_axes = {'input':{0:'batch_size'},
                                    'output':{0:'batch_size'}})#需要batch=1可以取消设置dynamic_axes

检查onnx文件是否可用

def check(model_path):
    try:
        onnx.checker.check_model(model_path)
    except onnx.checker.ValidationError as e:
        print("The model is invalid: %s"%e)
    else:
        print("The model is valid!")

Onnxruntime

def runtime(model_path, input):
    #CPU推理模式
    ort_session = onnxruntime.InferenceSession(model_path,providers=['CPUExecutionProvider'])
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)}
    ort_outputs = ort_session.run(None, ort_inputs)
    return ort_outputs

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

Onnx可视化

def vis(model_path):
    netron.start(model_path, browse = True)

总进程

import torch
import torch.onnx
from idps_train import SimpleConvAE
import torch
import onnx
import onnxruntime
import numpy as np
import netron

def torch2onnx(onnx_name, pth_path, model, batch_size):
    onnx_file_name = onnx_name

    model.load_state_dict(torch.load(pth_path))
    model.eval()

    dummy_input = torch.randn(batch_size, 7, 100, requires_grad = True)
    #output = model(dummy_input)
    #print(output.shape)

    torch.onnx.export(model,
                    dummy_input,
                    onnx_file_name,
                    export_params = True,
                    opset_version = 10,
                    do_constant_folding= True,
                    input_names = ['input'],
                    output_names = ['output'],
                    dynamic_axes = {'input':{0:'batch_size'},
                                    'output':{0:'batch_size'}})

def check(model_path):
    try:
        onnx.checker.check_model(model_path)
    except onnx.checker.ValidationError as e:
        print("The model is invalid: %s"%e)
    else:
        print("The model is valid!")

def runtime(model_path, input):
    #CPU推理模式
    ort_session = onnxruntime.InferenceSession(model_path,providers=['CPUExecutionProvider'])
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)}
    ort_outputs = ort_session.run(None, ort_inputs)
    return ort_outputs

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def vis(model_path):
    netron.start(model_path, browse = True)

if __name__ == '__main__':
    pth_path = "***.pth"
    model_path = "***.onnx"		#onnx模型路径
    onnx_name = "***.onnx" 		#转化目标onnx文件名
    model = SimpleConvAE()

    batch_size = 1
    input = torch.randn(batch_size, 7, 100, requires_grad = True)

    torch2onnx(onnx_name, pth_path, model, batch_size)
    check(model_path)
    output = runtime(model_path, input) #output->list (1, 1, 7, 100)
    #print(np.array(output[0]).shape)
    vis(model_path)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值