ONNX加载和保存模型

ONNX

ONNX(Open Neural Network Exchange)是一个开放的格式,用于表示机器学习模型。它使得不同框架之间的模型可以互操作,方便模型的迁移和部署。以下是一些关于 ONNX 的基本介绍和使用方法。

在这里插入图片描述

  1. 模型转换:ONNX 允许你将模型从一个深度学习框架(如 PyTorch、TensorFlow)转换为 ONNX 格式。
  2. 互操作性:ONNX 模型可以在支持 ONNX 的不同平台和工具之间共享。
  3. 优化:ONNX 提供了工具来优化模型,以提高推理性能。

将模型转换为 ONNX 格式

以下是将 PyTorch 模型转换为 ONNX 模型的步骤:

  1. 安装 ONNX

安装了 ONNX 和相关的转换工具:

pip install onnx
pip install onnxruntime  # 用于运行 ONNX 模型
pip install torch  # PyTorch
  1. 转换 PyTorch 模型

一个已训练的 PyTorch 模型,可以使用以下代码将其转换为 ONNX 格式:

import torch
import torch.onnx
import torchvision.models as models

# 加载预训练的 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval()  # 设置模型为推理模式

# 创建示例输入张量
dummy_input = torch.randn(1, 3, 224, 224)

# 将模型导出为 ONNX 格式
torch.onnx.export(model, dummy_input, "resnet18.onnx", verbose=True)

在这个示例中,将一个预训练的 ResNet-18 模型转换为 ONNX 格式并保存为 resnet18.onnx 文件。

加载和运行 ONNX 模型

使用 ONNX Runtime 来加载和运行转换后的 ONNX 模型:

import onnx
import onnxruntime as ort
import numpy as np

# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)  # 检查模型是否有效

# 创建 ONNX Runtime 会话
ort_session = ort.InferenceSession("resnet18.onnx")

# 创建输入数据
dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 运行模型
outputs = ort_session.run(None, {"input": dummy_input})
print(outputs[0])

检查和优化 ONNX 模型

ONNX 提供了一些工具来检查和优化模型:

1. 检查模型

使用 onnx.checker 来验证模型的有效性:

import onnx

onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)

2. 优化模型

使用 onnx.optimizer 来优化模型:

import onnx
import onnx.optimizer

onnx_model = onnx.load("resnet18.onnx")

# 定义优化通道
passes = ["fuse_consecutive_transposes", "eliminate_deadend"]

# 优化模型
optimized_model = onnx.optimizer.optimize(onnx_model, passes)

# 保存优化后的模型
onnx.save(optimized_model, "resnet18_optimized.onnx")

其他常用工具和库

  • Netron:用于可视化 ONNX 模型的工具。可以下载并使用 Netron 打开 .onnx 文件进行模型可视化。
  • ONNX Model Zoo:ONNX 模型库,包含许多预训练的 ONNX 模型,可以直接下载和使用。

小结

ONNX 作为一个开放的模型格式,可以极大地提高模型在不同框架和平台之间的可移植性。通过学习如何将模型转换为 ONNX 格式,并使用 ONNX Runtime 进行推理和优化,你可以更高效地部署和管理你的机器学习模型。


只有一个元素的时候才能够使用item()转为scalar,无论是一个0维度张量,还是1维张量,还是2维度

x_t = torch.tensor([1.0])
x2_t =torch.tensor(1.0)
x4_t = torch.tensor([[[1.0]]])
  
x_n = x_t.item()       # 1.0
x2_n = x2_t.item()     # 1.0
x3_n = x3_t.item()     # 1.0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值