SAM导出onnx模型报错问题记录

问题记录:

1、Unsupported ONNX opset version: 17

2、Exporting the operator repeat_interleave to ONNX opset version 11 is not supported

系统及版本

win10系统,在Anaconda中为SAM项目创建了单独的虚拟环境 python=3.8,torch=1.8,cuda=10.2

背景知识

在布料瑕疵检测项目中,发现模型文件有.pt,.onnx,.trt,.engine等格式,算法的同事更新模型之后,获得.pt格式的模型,然后导出.onnx模型给到我,我在现场部署时,转换成.trt模型,用于实时推断。
ONNX(Open Neural Network Exchange)是一个开放的格式,旨在使机器学习模型能够在不同框架和平台之间互操作。
ONNX opset(运算符集版本)是指 ONNX 模型格式中所支持的运算符(operators)的版本。它定义了模型导出和导入时可用算子(运算符)的版本和行为。每个 ONNX 模型都有一个关联的 opset 版本,这个版本表示该模型中使用的运算符版本。
了解了ONNX opset之后,上面的2个问题是怎么回事,就有个大致的概念了。

Unsupported ONNX opset version: 17

错误信息表明,环境中的torch版本不能支持ONNX opset=17版本的导出,可以尝试降低ONNX opset版本值,或升级torch版本。
SAM模型对于ONNX opset版本的要求是>=11,默认17.
我的推断模型使用的是torch=1.8,cuda=10.2,不想要升级torch版本,尝试修改opset= 11,12,13,14,15,都未能成功。opset= 11,12,13时,提示问题2:Exporting the operator repeat_interleave to ONNX opset version 11 is not supported

Exporting the operator repeat_interleave to ONNX opset version 11 is not supported

错误信息表明,导出模型过程中,PyTorch的操作符 repeat_interleave在 ONNX opset 版本 12 中不被支持。
解决办法:替换 repeat_interleave 操作
在SAM代码中查找repeat_interleave函数
在这里插入图片描述
发现build、segment_anything文件夹下的mask_decoder.py文件中都有repeat_interleave函数。尝试发现,导出onnx脚本调用的是segment_anything文件夹下的mask_decoder.py。
将图上红框中126、132行代码,替换成127、133行内容,添加打印输出(便于确认变量维度是否正确),就可以成功导出onnx模型了。
repeat_interleave函数的介绍示例如下:

import torch

# 假设 image_pe 的形状是 (2, 3, 4, 4):表示有 2 个样本,每个样本是一个形状为 (3, 4, 4) 的张量(例如 RGB 图像)。
image_pe = torch.randn(2, 3, 4, 4)
# 假设 tokens.shape[0] 为 3
tokens = torch.randn(3, 4)
# 将image_pe 沿 dim=0 维度(样本个数)重复tokens.shape[0]次,对数据进行扩展,以便与其它张量(如 tokens)对齐
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
print(pos_src.shape)  # 输出:torch.Size([6, 3, 4, 4])
# repeat将image_pe 在dim=0 维度上重复 tokens.shape[0] 次
# view将张量展平为需要的形状,-1 代表让 PyTorch 自动计算该维度的大小
pos_src = image_pe.repeat(1, tokens.shape[0], 1, 1).view(-1, image_pe.shape[1], image_pe.shape[2],                                                     image_pe.shape[3])
print(pos_src.shape)  # 输出:torch.Size([6, 3, 4, 4])
我现在要把SAM模型pt文件转onnx格式,代码如下:import torch import numpy as np from segment_anything import sam_model_registry from segment_anything.utils.onnx import SamOnnxModel import sys import os # ===== 配置部分 ===== checkpoint_path = r"E:\ketizu\XAnylabeling\sam2.1_hiera_base_plus.pt" # 模型权重路径 onnx_output_path = r"E:\ketizu\XAnylabeling\sam_model.onnx" # ONNX输出路径 model_types = ["vit_b", "vit_l", "vit_h", "vit_t"] # 尝试的模型类型列表 # =================== def inspect_model(sam): """检查模型结构信息""" print("\n" + "=" * 50) print("模型结构检查:") print(f"模型类型: {type(sam).__name__}") print(f"图像编码器输入尺寸: {sam.image_encoder.img_size}") # 检查模型组件 print("\n模型组件:") print(f"图像编码器: {type(sam.image_encoder).__name__}") print(f"提示编码器: {type(sam.prompt_encoder).__name__}") print(f"掩码解码器: {type(sam.mask_decoder).__name__}") # 检查掩码解码器的输入尺寸 if hasattr(sam.mask_decoder, 'transformer'): transformer = sam.mask_decoder.transformer print(f"Transformer尺寸: {transformer.num_multimask_outputs}") print("=" * 50 + "\n") def export_sam_to_onnx(): # 加载模型权重 state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) # 调试:打印权重文件内容 print("权重文件中的键:", list(state_dict.keys())) if "model" in state_dict: print("'model'键中的前5个键:", list(state_dict["model"].keys())[:5]) state_dict = state_dict["model"] # 尝试不同模型类型 for model_type in model_types: try: print(f"\n{'=' * 40}") print(f"尝试模型类型: {model_type}") print('=' * 40) # 构建 SAM 模型 sam = sam_model_registry[model_type]() # 加载权重并处理不匹配 result = sam.load_state_dict(state_dict, strict=False) print(f"匹配的键数量: {len(result.missing_keys)}") print(f"不匹配的键数量: {len(result.unexpected_keys)}") print(f"缺失的键: {result.missing_keys[:3]}...") print(f"多余的键: {result.unexpected_keys[:3]}...") # 检查模型结构 inspect_model(sam) # 设置模型为评估模式 sam.eval() # 获取实际图像尺寸 image_size = sam.image_encoder.img_size print(f"使用图像尺寸: {image_size}x{image_size}") # 包装为 ONNX 兼容的模型 onnx_model = SamOnnxModel(sam, return_single_mask=True) # 动态轴配置 dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, } # 创建匹配的虚拟输入 dummy_input = { "image": torch.randn(1, 3, image_size, image_size).float(), "point_coords": torch.randint(low=0, high=image_size, size=(1, 5, 2)).float(), "point_labels": torch.randint(low=0, high=4, size=(1, 5)).float(), "mask_input": torch.randn(1, 1, image_size // 4, image_size // 4).float(), "has_mask_input": torch.tensor([1]).float(), # 设置为1表示有掩码输入 "orig_im_size": torch.tensor([image_size, image_size]).float(), } # 测试前向传播 with torch.no_grad(): print("\n测试前向传播...") outputs = onnx_model(*tuple(dummy_input.values())) print("前向传播成功! 输出形状:") for i, out in enumerate(outputs): print(f"输出 {i + 1}: {out.shape}") # 导出ONNX模型 print("\n开始导出ONNX模型...") torch.onnx.export( onnx_model, tuple(dummy_input.values()), onnx_output_path, export_params=True, verbose=True, # 启用详细输出 opset_version=17, do_constant_folding=True, input_names=list(dummy_input.keys()), output_names=["masks", "iou_predictions", "low_res_masks"], dynamic_axes=dynamic_axes, ) print(f"\n✅ ONNX模型已成功导出: {onnx_output_path}") print(f"使用的模型类型: {model_type}") return # 成功则退出函数 except KeyError: print(f"⚠️ 模型类型 {model_type} 不在注册表中,尝试下一个...") except RuntimeError as e: print(f"⚠️ 运行时错误: {str(e)}") print(f"尝试下一个模型类型...") except Exception as e: print(f"⚠️ 发生错误: {str(e)}") import traceback traceback.print_exc() print("\n❌ 所有模型类型尝试失败!") print("可能原因:") print("1. 模型权重与标准SAM架构不兼容") print("2. 模型是自定义变体,需要特殊处理") print("3. 权重文件可能已损坏") print("建议检查模型来源文档或联系模型提供者") if __name__ == "__main__": # 添加项目路径到系统路径 project_root = r"E:\ketizu\XAnylabeling" sys.path.append(project_root) # 确保输出目录存在 os.makedirs(os.path.dirname(onnx_output_path), exist_ok=True) export_sam_to_onnx() 报错E:\python3.11\python.exe C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py 权重文件中的键: ['model'] 'model'键中的前5个键: ['maskmem_tpos_enc', 'no_mem_embed', 'no_mem_pos_enc', 'no_obj_ptr', 'no_obj_embed_spatial'] ======================================== 尝试模型类型: vit_b ======================================== Traceback (most recent call last): 匹配的键数量: 314 不匹配的键数量: 615 缺失的键: ['image_encoder.pos_embed', 'image_encoder.patch_embed.proj.weight', 'image_encoder.patch_embed.proj.bias']... 多余的键: ['maskmem_tpos_enc', 'no_mem_embed', 'no_mem_pos_enc']... ================================================== 模型结构检查: 模型类型: Sam 图像编码器输入尺寸: 1024 模型组件: 图像编码器: ImageEncoderViT 提示编码器: PromptEncoder 掩码解码器: MaskDecoder ⚠️ 发生错误: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs' File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 64, in export_sam_to_onnx inspect_model(sam) File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 32, in inspect_model print(f"Transformer尺寸: {transformer.num_multimask_outputs}") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "E:\python3.11\Lib\site-packages\torch\nn\modules\module.py", line 1940, in __getattr__ raise AttributeError( AttributeError: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs' ======================================== 尝试模型类型: vit_l ======================================== Traceback (most recent call last): File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 64, in export_sam_to_onnx inspect_model(sam) File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 32, in inspect_model print(f"Transformer尺寸: {transformer.num_multimask_outputs}") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "E:\python3.11\Lib\site-packages\torch\nn\modules\module.py", line 1940, in __getattr__ raise AttributeError( AttributeError: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs' 匹配的键数量: 482 不匹配的键数量: 615 缺失的键: ['image_encoder.pos_embed', 'image_encoder.patch_embed.proj.weight', 'image_encoder.patch_embed.proj.bias']... 多余的键: ['maskmem_tpos_enc', 'no_mem_embed', 'no_mem_pos_enc']... ================================================== 模型结构检查: 模型类型: Sam 图像编码器输入尺寸: 1024 模型组件: 图像编码器: ImageEncoderViT 提示编码器: PromptEncoder 掩码解码器: MaskDecoder ⚠️ 发生错误: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs' ======================================== 尝试模型类型: vit_h ======================================== Traceback (most recent call last): File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 64, in export_sam_to_onnx inspect_model(sam) File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 32, in inspect_model print(f"Transformer尺寸: {transformer.num_multimask_outputs}") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "E:\python3.11\Lib\site-packages\torch\nn\modules\module.py", line 1940, in __getattr__ raise AttributeError( AttributeError: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs' 匹配的键数量: 594 不匹配的键数量: 615 缺失的键: ['image_encoder.pos_embed', 'image_encoder.patch_embed.proj.weight', 'image_encoder.patch_embed.proj.bias']... 多余的键: ['maskmem_tpos_enc', 'no_mem_embed', 'no_mem_pos_enc']... ================================================== 模型结构检查: 模型类型: Sam 图像编码器输入尺寸: 1024 模型组件: 图像编码器: ImageEncoderViT 提示编码器: PromptEncoder 掩码解码器: MaskDecoder ⚠️ 发生错误: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs' ======================================== 尝试模型类型: vit_t ======================================== ⚠️ 模型类型 vit_t 不在注册表中,尝试下一个... ❌ 所有模型类型尝试失败! 可能原因: 1. 模型权重与标准SAM架构不兼容 2. 模型是自定义变体,需要特殊处理 3. 权重文件可能已损坏 建议检查模型来源文档或联系模型提供者 Process finished with exit code 0
08-01
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值