export_2_onnx2.py
import numpy as np
import torch
import torch.nn as nn
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import threading
import numpy as np
import torch
import os
import sys
import glob
import json
import os
import time
import numpy as np
import cv2
import torch
import gc
current_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(current_dir)
print('current_dir', current_dir)
paths = [current_dir, current_dir + '/../']
paths.append(current_dir + '/../sam2')
for path in paths:
sys.path.insert(0, path)
os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
class SAM2ImageApi:
def __init__(self, device: str = "cuda:0"):
self.device = device
self.model_configs = {}
self.predictors = {}
self.color = [(255, 0, 0)]
if sys.platform.startswith('win'):
self.model_path = r"D:\data\models\sam2.1_hiera_large.pt"
else:
self.model_path = "/home/common/models/seg/sam2.1_hiera_large.pt"
self.initialize_model()
def determine_model_cfg(self, model_path: str):
if "large" in model_path:
return "configs/samurai/sam2.1_hiera_l.yaml"
elif "base_plus" in model_path:
return "configs/samurai/sam2.1_hiera_b+.yaml"
elif "small" in model_path:
return "configs/samurai/sam2.1_hiera_s.yaml"
else:
raise ValueError("Unknown model size in path!")
def initialize_model(self):
"""加载并缓存模型"""
if self.model_path in self.predictors:
return self.predictors[self.model_path]
model_cfg = self.determine_model_cfg(self.model_path)
self.img_predictor = SAM2ImagePredictor(
build_sam2(model_cfg, self.model_path, device=self.device)
)
def img_segmentation(self, params):
"""执行图像分割(单线程安全版)"""
self.img_predictor.set_image(params.frame)
if params.mode == "points":
point_coords = torch.Tensor(params.points_coords).unsqueeze(0) if params.points_coords is not None else None
point_labels = torch.Tensor(params.points_labels).unsqueeze(0) if params.points_labels is not None else None
box = None
elif params.mode == "box":
point_coords = None
point_labels = None
box = torch.Tensor(params.box) if params.box is not None else None
masks, scores, _ = self.img_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
box=box,
multimask_output=True
)
# 锁释放后再处理 CPU 的部分
best_mask_idx = np.argmax(scores)
mask = masks[best_mask_idx]
score = scores[best_mask_idx]
mask_uint8 = (mask * 255).astype(np.uint8)
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
max_contour = max(contours, key=cv2.contourArea)
# 面积过滤,比如最小100
if cv2.contourArea(max_contour) > 100:
epsilon = 0.002 * cv2.arcLength(max_contour, True)
approx = cv2.approxPolyDP(max_contour, epsilon, True)
# 转换为点列表
points = approx.reshape(-1, 2).tolist()
else:
points = []
return points, score
import torch
import torch.nn as nn
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
class SAM2ImageEncoder(nn.Module):
"""图像编码器部分 - 用于提取图像特征"""
def __init__(self, sam2_model):
super().__init__()
self.image_encoder = sam2_model.image_encoder
def forward(self, x):
# x: [batch_size, 3, height, width]
return self.image_encoder(x)
class SAM2PromptDecoder(nn.Module):
"""提示解码器部分 - 使用图像特征和提示生成掩码"""
def __init__(self, predictor):
super().__init__()
self.predictor = predictor
# 修正属性名
self.mask_decoder = predictor.model.sam_mask_decoder
self.prompt_encoder = predictor.model.sam_prompt_encoder
def forward(self, image_embeddings, point_coords, point_labels, box=None):
"""
image_embeddings: 图像特征 [batch_size, embedding_dim, H, W]
point_coords: 点坐标 [batch_size, num_points, 2]
point_labels: 点标签 [batch_size, num_points]
box: 边界框 [batch_size, 4] 或 None
"""
# 编码提示
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=(point_coords, point_labels) if point_coords is not None else None,
boxes=box,
masks=None,
)
# 低分辨率掩码预测
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=image_embeddings,
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
)
return low_res_masks, iou_predictions
class SAM2CompleteModel(nn.Module):
"""完整模型 - 图像编码 + 提示解码"""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, image, point_coords, point_labels, box=None):
image_embeddings = self.encoder(image)
masks, scores = self.decoder(image_embeddings, point_coords, point_labels, box)
return masks, scores
def export_sam2_onnx():
"""导出 SAM2 模型为 ONNX 格式"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. 加载原始模型
model_cfg = "configs/samurai/sam2.1_hiera_l.yaml" # 根据您的实际路径调整
model_path = r"D:\data\models\sam2.1_hiera_large.pt" # 您的模型路径
sam2_model = build_sam2(model_cfg, model_path, device=device)
predictor = SAM2ImagePredictor(sam2_model)
# 2. 首先探索模型结构,了解可用的组件
print("探索模型结构...")
print("模型类型:", type(sam2_model))
print("模型属性:")
for name in dir(sam2_model):
if not name.startswith('_'):
print(f" {name}: {type(getattr(sam2_model, name))}")
# 3. 创建导出模型 - 只导出图像编码器(最稳定的部分)
print("导出图像编码器...")
image_encoder = SAM2ImageEncoder(sam2_model)
image_encoder.eval()
dummy_image = torch.randn(1, 3, 1024, 1024, device=device)
torch.onnx.export(
image_encoder,
dummy_image,
"sam2_image_encoder.onnx",
input_names=["image"],
output_names=["image_embeddings"],
dynamic_axes={
"image": {0: "batch_size", 2: "height", 3: "width"},
"image_embeddings": {0: "batch_size", 2: "height", 3: "width"}
},
opset_version=17,
do_constant_folding=True
)
print("图像编码器导出完成!")
# 4. 尝试导出完整的提示解码部分
try:
print("尝试导出提示解码器...")
prompt_decoder = SAM2PromptDecoder(predictor)
prompt_decoder.eval()
# 创建测试输入
dummy_embedding = image_encoder(dummy_image)
dummy_points = torch.tensor([[[500, 500], [600, 600]]], dtype=torch.float32, device=device)
dummy_labels = torch.tensor([[1, 0]], dtype=torch.float32, device=device)
# 测试前向传播
with torch.no_grad():
masks, scores = prompt_decoder(dummy_embedding, dummy_points, dummy_labels)
print(f"掩码形状: {masks.shape}, 分数形状: {scores.shape}")
# 导出提示解码器
torch.onnx.export(
prompt_decoder,
(dummy_embedding, dummy_points, dummy_labels, None),
"sam2_prompt_decoder.onnx",
input_names=["image_embeddings", "point_coords", "point_labels", "box"],
output_names=["masks", "scores"],
dynamic_axes={
"image_embeddings": {0: "batch_size"},
"point_coords": {0: "batch_size", 1: "num_points"},
"point_labels": {0: "batch_size", 1: "num_points"},
"masks": {0: "batch_size"},
"scores": {0: "batch_size"}
},
opset_version=17,
do_constant_folding=True
)
print("提示解码器导出完成!")
# 5. 导出完整模型
print("导出完整模型...")
complete_model = SAM2CompleteModel(image_encoder, prompt_decoder)
complete_model.eval()
torch.onnx.export(
complete_model,
(dummy_image, dummy_points, dummy_labels, None),
"sam2_complete_model.onnx",
input_names=["image", "point_coords", "point_labels", "box"],
output_names=["masks", "scores"],
dynamic_axes={
"image": {0: "batch_size"},
"point_coords": {0: "batch_size", 1: "num_points"},
"point_labels": {0: "batch_size", 1: "num_points"},
"masks": {0: "batch_size"},
"scores": {0: "batch_size"}
},
opset_version=17,
do_constant_folding=True
)
print("完整模型导出完成!")
except Exception as e:
print(f"导出提示解码器或完整模型时出错: {e}")
print("只导出了图像编码器,这是最稳定的部分")
print("ONNX 导出过程完成!")
# 简化的导出方法 - 只导出图像编码器
def export_simple_encoder():
"""只导出图像编码器,这是最稳定的部分"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载模型
model_cfg = "configs/samurai/sam2.1_hiera_l.yaml"
model_path = r"D:\data\models\sam2.1_hiera_large.pt"
sam2_model = build_sam2(model_cfg, model_path, device=device)
# 创建并导出图像编码器
image_encoder = SAM2ImageEncoder(sam2_model)
image_encoder.eval()
dummy_image = torch.randn(1, 3, 1024, 1024, device=device)
torch.onnx.export(
image_encoder,
dummy_image,
"sam2_image_encoder_simple.onnx",
input_names=["image"],
output_names=["image_embeddings"],
opset_version=17,
do_constant_folding=True
)
print("简单图像编码器导出完成!")
# 运行导出
if __name__ == "__main__":
# 先尝试完整导出
export_sam2_onnx()
# 如果完整导出失败,只导出图像编码器
# export_simple_encoder()