crnn-pth转换成onnx(pth输入输出变长转换)

注意事项

1.首先torch1.1.0是没有dynamic_axes这个开关的,所以我升级到了1.14.0。
2.其次如果是变长的话,pth模型找不到input和output的node-name,可以自己随意定义。
3.weights和data必须同样的dtype,可以是cpu也可以是gpu,如果多gpu训练的model,去掉module即可。
4.dynamic_axes找到当前变长的维度,用list或者dict表示,list自动分配名字,很方便。

import torch
import torchvision
import onnx
from crnn.models import crnn as crnn
from collections import OrderedDict

from crnn import keys
from config import ocrModel, LSTMFLAG, GPU


alphabet = keys.alphabetChinese
# LSTMFLAG=True crnn 否则 dense ocr
model = crnn.CRNN(32, 1, len(alphabet) + 1, 256,
                    1, lstmFlag=True)

state_dict = torch.load(ocrModel,map_location='cpu')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')  # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
model.eval()

input_names = ["input"]
output_names = ["output"]
 #batch,channel,h,w

image = torch.zeros(1, 1, 32, 32)
inputs = image

#w->w*32/h

dynamic_axes = {"input":[3],"output":[0]}
torch.onnx.export(model, inputs, './oup_dingchang.onnx', 
                    export_params=True, 
                    verbose=False, 
                    input_names=input_names, 
                    output_names=output_names, 
                    dynamic_axes=dynamic_axes,
                    opset_version=9)
### 将CRNN模型转换ONNX格式的方法 要将PyTorch中的CRNN模型转换ONNX格式,可以通过`torch.onnx.export()`函数实现。以下是具体方法以及注意事项: #### 1. 导入必要的库并加载预训练模型 在开始之前,需要导入所需的库,并确保已加载经过训练的CRNN模型实例。 ```python import torch from torchvision import models # 假设 CRNN 是自定义类,num_classes 表示类别数量 class CRNN(torch.nn.Module): def __init__(self, num_classes): super(CRNN, self).__init__() # 定义网络结构... def forward(self, x): # 实现前向传播逻辑... # 加载预训练权重 model = CRNN(num_classes=37) # 示例:假设字符集大小为37 model.load_state_dict(torch.load("crnn.pth")) # 替换为你自己的路径 model.eval() ``` 上述代码片段展示了如何创建一个CRNN模型实例,并将其设置为评估模式[^4]。 --- #### 2. 准备虚拟输入张量 为了导出ONNX模型,需要准备一个与实际输入形状一致的虚拟张量作为测试数据。 ```python dummy_input = torch.randn(1, 3, 32, 100) # 输入尺寸应匹配CRNN的要求 (batch_size, channels, height, width) ``` 此处的输入维度 `(1, 3, 32, 100)` 需根据具体的CRNN模型配置调整。如果不确定,请查阅原始论文或源码文档[^1]。 --- #### 3. 使用 `torch.onnx.export` 方法导出模型 通过调用 `torch.onnx.export` 方法,可将 PyTorch 模型保存为 ONNX 文件。 ```python output_onnx_file = "/path/to/crnn.onnx" # 设置目标文件路径 torch.onnx.export( model, dummy_input, output_onnx_file, input_names=["input"], # 输入节点名称 output_names=["output"], # 输出节点名称 dynamic_axes={"input": {0: "batch_size", 2: "height", 3: "width"}, "output": {0: "batch_size"}}, # 动态轴支持 opset_version=11 # ONNX 版本号 ) print(f"Model has been converted to ONNX and saved at {output_onnx_file}") ``` 此部分代码实现了模型到ONNX格式的转换过程。其中: - `dynamic_axes` 参数允许指定动态批量处理能力。 - `opset_version` 应选择兼容的目标版本(推荐使用最新稳定版)[^2]。 --- #### 4. 验证生成的ONNX模型 完成导出后,建议验证生成的ONNX文件是否有效。 ```python import onnx onnx_model = onnx.load(output_onnx_file) onnx.checker.check_model(onnx_model) print("The model is checked successfully!") ``` 这一步骤有助于确认模型无误,便于后续部署至其他框架环境。 --- #### 5. 推理阶段注意事项 当利用ONNX Runtime执行推理时,需注意以下几点: - **输入类型**:确保输入张量的数据类型严格遵循 `.float32` 要求; - **Dropout 层行为**:由于 Dropout 在推断过程中会被禁用,因此无需显式移除这些层。 --- ### 总结 以上流程涵盖了从加载模型、准备输入张量到最终导出及验证的过程。每一步均附带详细说明以帮助理解操作细节及其背后原理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值