我现在要把SAM模型pt文件转onnx格式,代码如下:import torch
import numpy as np
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
import sys
import os
# ===== 配置部分 =====
checkpoint_path = r"E:\ketizu\XAnylabeling\sam2.1_hiera_base_plus.pt" # 模型权重路径
onnx_output_path = r"E:\ketizu\XAnylabeling\sam_model.onnx" # ONNX输出路径
model_types = ["vit_b", "vit_l", "vit_h", "vit_t"] # 尝试的模型类型列表
# ===================
def inspect_model(sam):
"""检查模型结构信息"""
print("\n" + "=" * 50)
print("模型结构检查:")
print(f"模型类型: {type(sam).__name__}")
print(f"图像编码器输入尺寸: {sam.image_encoder.img_size}")
# 检查模型组件
print("\n模型组件:")
print(f"图像编码器: {type(sam.image_encoder).__name__}")
print(f"提示编码器: {type(sam.prompt_encoder).__name__}")
print(f"掩码解码器: {type(sam.mask_decoder).__name__}")
# 检查掩码解码器的输入尺寸
if hasattr(sam.mask_decoder, 'transformer'):
transformer = sam.mask_decoder.transformer
print(f"Transformer尺寸: {transformer.num_multimask_outputs}")
print("=" * 50 + "\n")
def export_sam_to_onnx():
# 加载模型权重
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# 调试:打印权重文件内容
print("权重文件中的键:", list(state_dict.keys()))
if "model" in state_dict:
print("'model'键中的前5个键:", list(state_dict["model"].keys())[:5])
state_dict = state_dict["model"]
# 尝试不同模型类型
for model_type in model_types:
try:
print(f"\n{'=' * 40}")
print(f"尝试模型类型: {model_type}")
print('=' * 40)
# 构建 SAM 模型
sam = sam_model_registry[model_type]()
# 加载权重并处理不匹配
result = sam.load_state_dict(state_dict, strict=False)
print(f"匹配的键数量: {len(result.missing_keys)}")
print(f"不匹配的键数量: {len(result.unexpected_keys)}")
print(f"缺失的键: {result.missing_keys[:3]}...")
print(f"多余的键: {result.unexpected_keys[:3]}...")
# 检查模型结构
inspect_model(sam)
# 设置模型为评估模式
sam.eval()
# 获取实际图像尺寸
image_size = sam.image_encoder.img_size
print(f"使用图像尺寸: {image_size}x{image_size}")
# 包装为 ONNX 兼容的模型
onnx_model = SamOnnxModel(sam, return_single_mask=True)
# 动态轴配置
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}
# 创建匹配的虚拟输入
dummy_input = {
"image": torch.randn(1, 3, image_size, image_size).float(),
"point_coords": torch.randint(low=0, high=image_size, size=(1, 5, 2)).float(),
"point_labels": torch.randint(low=0, high=4, size=(1, 5)).float(),
"mask_input": torch.randn(1, 1, image_size // 4, image_size // 4).float(),
"has_mask_input": torch.tensor([1]).float(), # 设置为1表示有掩码输入
"orig_im_size": torch.tensor([image_size, image_size]).float(),
}
# 测试前向传播
with torch.no_grad():
print("\n测试前向传播...")
outputs = onnx_model(*tuple(dummy_input.values()))
print("前向传播成功! 输出形状:")
for i, out in enumerate(outputs):
print(f"输出 {i + 1}: {out.shape}")
# 导出ONNX模型
print("\n开始导出ONNX模型...")
torch.onnx.export(
onnx_model,
tuple(dummy_input.values()),
onnx_output_path,
export_params=True,
verbose=True, # 启用详细输出
opset_version=17,
do_constant_folding=True,
input_names=list(dummy_input.keys()),
output_names=["masks", "iou_predictions", "low_res_masks"],
dynamic_axes=dynamic_axes,
)
print(f"\n✅ ONNX模型已成功导出至: {onnx_output_path}")
print(f"使用的模型类型: {model_type}")
return # 成功则退出函数
except KeyError:
print(f"⚠️ 模型类型 {model_type} 不在注册表中,尝试下一个...")
except RuntimeError as e:
print(f"⚠️ 运行时错误: {str(e)}")
print(f"尝试下一个模型类型...")
except Exception as e:
print(f"⚠️ 发生错误: {str(e)}")
import traceback
traceback.print_exc()
print("\n❌ 所有模型类型尝试失败!")
print("可能原因:")
print("1. 模型权重与标准SAM架构不兼容")
print("2. 模型是自定义变体,需要特殊处理")
print("3. 权重文件可能已损坏")
print("建议检查模型来源文档或联系模型提供者")
if __name__ == "__main__":
# 添加项目路径到系统路径
project_root = r"E:\ketizu\XAnylabeling"
sys.path.append(project_root)
# 确保输出目录存在
os.makedirs(os.path.dirname(onnx_output_path), exist_ok=True)
export_sam_to_onnx()
报错E:\python3.11\python.exe C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py
权重文件中的键: ['model']
'model'键中的前5个键: ['maskmem_tpos_enc', 'no_mem_embed', 'no_mem_pos_enc', 'no_obj_ptr', 'no_obj_embed_spatial']
========================================
尝试模型类型: vit_b
========================================
Traceback (most recent call last):
匹配的键数量: 314
不匹配的键数量: 615
缺失的键: ['image_encoder.pos_embed', 'image_encoder.patch_embed.proj.weight', 'image_encoder.patch_embed.proj.bias']...
多余的键: ['maskmem_tpos_enc', 'no_mem_embed', 'no_mem_pos_enc']...
==================================================
模型结构检查:
模型类型: Sam
图像编码器输入尺寸: 1024
模型组件:
图像编码器: ImageEncoderViT
提示编码器: PromptEncoder
掩码解码器: MaskDecoder
⚠️ 发生错误: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs'
File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 64, in export_sam_to_onnx
inspect_model(sam)
File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 32, in inspect_model
print(f"Transformer尺寸: {transformer.num_multimask_outputs}")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\python3.11\Lib\site-packages\torch\nn\modules\module.py", line 1940, in __getattr__
raise AttributeError(
AttributeError: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs'
========================================
尝试模型类型: vit_l
========================================
Traceback (most recent call last):
File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 64, in export_sam_to_onnx
inspect_model(sam)
File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 32, in inspect_model
print(f"Transformer尺寸: {transformer.num_multimask_outputs}")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\python3.11\Lib\site-packages\torch\nn\modules\module.py", line 1940, in __getattr__
raise AttributeError(
AttributeError: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs'
匹配的键数量: 482
不匹配的键数量: 615
缺失的键: ['image_encoder.pos_embed', 'image_encoder.patch_embed.proj.weight', 'image_encoder.patch_embed.proj.bias']...
多余的键: ['maskmem_tpos_enc', 'no_mem_embed', 'no_mem_pos_enc']...
==================================================
模型结构检查:
模型类型: Sam
图像编码器输入尺寸: 1024
模型组件:
图像编码器: ImageEncoderViT
提示编码器: PromptEncoder
掩码解码器: MaskDecoder
⚠️ 发生错误: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs'
========================================
尝试模型类型: vit_h
========================================
Traceback (most recent call last):
File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 64, in export_sam_to_onnx
inspect_model(sam)
File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 32, in inspect_model
print(f"Transformer尺寸: {transformer.num_multimask_outputs}")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\python3.11\Lib\site-packages\torch\nn\modules\module.py", line 1940, in __getattr__
raise AttributeError(
AttributeError: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs'
匹配的键数量: 594
不匹配的键数量: 615
缺失的键: ['image_encoder.pos_embed', 'image_encoder.patch_embed.proj.weight', 'image_encoder.patch_embed.proj.bias']...
多余的键: ['maskmem_tpos_enc', 'no_mem_embed', 'no_mem_pos_enc']...
==================================================
模型结构检查:
模型类型: Sam
图像编码器输入尺寸: 1024
模型组件:
图像编码器: ImageEncoderViT
提示编码器: PromptEncoder
掩码解码器: MaskDecoder
⚠️ 发生错误: 'TwoWayTransformer' object has no attribute 'num_multimask_outputs'
========================================
尝试模型类型: vit_t
========================================
⚠️ 模型类型 vit_t 不在注册表中,尝试下一个...
❌ 所有模型类型尝试失败!
可能原因:
1. 模型权重与标准SAM架构不兼容
2. 模型是自定义变体,需要特殊处理
3. 权重文件可能已损坏
建议检查模型来源文档或联系模型提供者
Process finished with exit code 0