sam2 导出onnx

export_2_onnx2.py

import numpy as np
import torch
import torch.nn as nn
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


import threading

import numpy as np
import torch
import os
import sys
import glob
import json
import os
import time
import numpy as np
import cv2
import torch

import gc

current_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(current_dir)
print('current_dir', current_dir)
paths = [current_dir, current_dir + '/../']
paths.append(current_dir + '/../sam2')
for path in paths:
    sys.path.insert(0, path)
    os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')


from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


class SAM2ImageApi:
    def __init__(self, device: str = "cuda:0"):
        self.device = device
        self.model_configs = {}
        self.predictors = {}
        self.color = [(255, 0, 0)]
        if sys.platform.startswith('win'):
            self.model_path = r"D:\data\models\sam2.1_hiera_large.pt"
        else:
            self.model_path = "/home/common/models/seg/sam2.1_hiera_large.pt"

        self.initialize_model()

    def determine_model_cfg(self, model_path: str):
        if "large" in model_path:
            return "configs/samurai/sam2.1_hiera_l.yaml"
        elif "base_plus" in model_path:
            return "configs/samurai/sam2.1_hiera_b+.yaml"
        elif "small" in model_path:
            return "configs/samurai/sam2.1_hiera_s.yaml"
        else:
            raise ValueError("Unknown model size in path!")

    def initialize_model(self):
        """加载并缓存模型"""
        if self.model_path in self.predictors:
            return self.predictors[self.model_path]

        model_cfg = self.determine_model_cfg(self.model_path)

        self.img_predictor = SAM2ImagePredictor(
            build_sam2(model_cfg, self.model_path, device=self.device)
        )

    def img_segmentation(self, params):
        """执行图像分割(单线程安全版)"""
        self.img_predictor.set_image(params.frame)

        if params.mode == "points":
            point_coords = torch.Tensor(params.points_coords).unsqueeze(0) if params.points_coords is not None else None
            point_labels = torch.Tensor(params.points_labels).unsqueeze(0) if params.points_labels is not None else None
            box = None
        elif params.mode == "box":
            point_coords = None
            point_labels = None
            box = torch.Tensor(params.box) if params.box is not None else None

        masks, scores, _ = self.img_predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            box=box,
            multimask_output=True
        )

        #  锁释放后再处理 CPU 的部分
        best_mask_idx = np.argmax(scores)
        mask = masks[best_mask_idx]
        score = scores[best_mask_idx]

        mask_uint8 = (mask * 255).astype(np.uint8)
        contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        max_contour = max(contours, key=cv2.contourArea)

        # 面积过滤,比如最小100
        if cv2.contourArea(max_contour) > 100:
            epsilon = 0.002 * cv2.arcLength(max_contour, True)
            approx = cv2.approxPolyDP(max_contour, epsilon, True)

            # 转换为点列表
            points = approx.reshape(-1, 2).tolist()
        else:
            points = []
        return points, score


import torch
import torch.nn as nn
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


class SAM2ImageEncoder(nn.Module):
    """图像编码器部分 - 用于提取图像特征"""

    def __init__(self, sam2_model):
        super().__init__()
        self.image_encoder = sam2_model.image_encoder

    def forward(self, x):
        # x: [batch_size, 3, height, width]
        return self.image_encoder(x)


class SAM2PromptDecoder(nn.Module):
    """提示解码器部分 - 使用图像特征和提示生成掩码"""

    def __init__(self, predictor):
        super().__init__()
        self.predictor = predictor
        # 修正属性名
        self.mask_decoder = predictor.model.sam_mask_decoder
        self.prompt_encoder = predictor.model.sam_prompt_encoder

    def forward(self, image_embeddings, point_coords, point_labels, box=None):
        """
        image_embeddings: 图像特征 [batch_size, embedding_dim, H, W]
        point_coords: 点坐标 [batch_size, num_points, 2]
        point_labels: 点标签 [batch_size, num_points]
        box: 边界框 [batch_size, 4] 或 None
        """
        # 编码提示
        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=(point_coords, point_labels) if point_coords is not None else None,
            boxes=box,
            masks=None,
        )

        # 低分辨率掩码预测
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
        )

        return low_res_masks, iou_predictions


class SAM2CompleteModel(nn.Module):
    """完整模型 - 图像编码 + 提示解码"""

    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, image, point_coords, point_labels, box=None):
        image_embeddings = self.encoder(image)
        masks, scores = self.decoder(image_embeddings, point_coords, point_labels, box)
        return masks, scores


def export_sam2_onnx():
    """导出 SAM2 模型为 ONNX 格式"""
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 1. 加载原始模型
    model_cfg = "configs/samurai/sam2.1_hiera_l.yaml"  # 根据您的实际路径调整
    model_path = r"D:\data\models\sam2.1_hiera_large.pt"  # 您的模型路径

    sam2_model = build_sam2(model_cfg, model_path, device=device)
    predictor = SAM2ImagePredictor(sam2_model)

    # 2. 首先探索模型结构,了解可用的组件
    print("探索模型结构...")
    print("模型类型:", type(sam2_model))
    print("模型属性:")
    for name in dir(sam2_model):
        if not name.startswith('_'):
            print(f"  {name}: {type(getattr(sam2_model, name))}")

    # 3. 创建导出模型 - 只导出图像编码器(最稳定的部分)
    print("导出图像编码器...")
    image_encoder = SAM2ImageEncoder(sam2_model)
    image_encoder.eval()

    dummy_image = torch.randn(1, 3, 1024, 1024, device=device)

    torch.onnx.export(
        image_encoder,
        dummy_image,
        "sam2_image_encoder.onnx",
        input_names=["image"],
        output_names=["image_embeddings"],
        dynamic_axes={
            "image": {0: "batch_size", 2: "height", 3: "width"},
            "image_embeddings": {0: "batch_size", 2: "height", 3: "width"}
        },
        opset_version=17,
        do_constant_folding=True
    )
    print("图像编码器导出完成!")

    # 4. 尝试导出完整的提示解码部分
    try:
        print("尝试导出提示解码器...")
        prompt_decoder = SAM2PromptDecoder(predictor)
        prompt_decoder.eval()

        # 创建测试输入
        dummy_embedding = image_encoder(dummy_image)
        dummy_points = torch.tensor([[[500, 500], [600, 600]]], dtype=torch.float32, device=device)
        dummy_labels = torch.tensor([[1, 0]], dtype=torch.float32, device=device)

        # 测试前向传播
        with torch.no_grad():
            masks, scores = prompt_decoder(dummy_embedding, dummy_points, dummy_labels)
            print(f"掩码形状: {masks.shape}, 分数形状: {scores.shape}")

        # 导出提示解码器
        torch.onnx.export(
            prompt_decoder,
            (dummy_embedding, dummy_points, dummy_labels, None),
            "sam2_prompt_decoder.onnx",
            input_names=["image_embeddings", "point_coords", "point_labels", "box"],
            output_names=["masks", "scores"],
            dynamic_axes={
                "image_embeddings": {0: "batch_size"},
                "point_coords": {0: "batch_size", 1: "num_points"},
                "point_labels": {0: "batch_size", 1: "num_points"},
                "masks": {0: "batch_size"},
                "scores": {0: "batch_size"}
            },
            opset_version=17,
            do_constant_folding=True
        )
        print("提示解码器导出完成!")

        # 5. 导出完整模型
        print("导出完整模型...")
        complete_model = SAM2CompleteModel(image_encoder, prompt_decoder)
        complete_model.eval()

        torch.onnx.export(
            complete_model,
            (dummy_image, dummy_points, dummy_labels, None),
            "sam2_complete_model.onnx",
            input_names=["image", "point_coords", "point_labels", "box"],
            output_names=["masks", "scores"],
            dynamic_axes={
                "image": {0: "batch_size"},
                "point_coords": {0: "batch_size", 1: "num_points"},
                "point_labels": {0: "batch_size", 1: "num_points"},
                "masks": {0: "batch_size"},
                "scores": {0: "batch_size"}
            },
            opset_version=17,
            do_constant_folding=True
        )
        print("完整模型导出完成!")

    except Exception as e:
        print(f"导出提示解码器或完整模型时出错: {e}")
        print("只导出了图像编码器,这是最稳定的部分")

    print("ONNX 导出过程完成!")


# 简化的导出方法 - 只导出图像编码器
def export_simple_encoder():
    """只导出图像编码器,这是最稳定的部分"""
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 加载模型
    model_cfg = "configs/samurai/sam2.1_hiera_l.yaml"
    model_path = r"D:\data\models\sam2.1_hiera_large.pt"

    sam2_model = build_sam2(model_cfg, model_path, device=device)

    # 创建并导出图像编码器
    image_encoder = SAM2ImageEncoder(sam2_model)
    image_encoder.eval()

    dummy_image = torch.randn(1, 3, 1024, 1024, device=device)

    torch.onnx.export(
        image_encoder,
        dummy_image,
        "sam2_image_encoder_simple.onnx",
        input_names=["image"],
        output_names=["image_embeddings"],
        opset_version=17,
        do_constant_folding=True
    )
    print("简单图像编码器导出完成!")


# 运行导出
if __name__ == "__main__":
    # 先尝试完整导出
    export_sam2_onnx()

    # 如果完整导出失败,只导出图像编码器
    # export_simple_encoder()

SAM2模型转换为ONNX格式是一个相对技术性的过程,涉及多个步骤,包括模型加载、结构解析、转换以及最终的验证。以下是实现这一目标的一般性指导流程: ### 准备工作 在开始之前,请确保已安装必要的依赖库,包括PyTorch、ONNX、以及SAM2模型相关的代码和权重文件。可以通过以下命令安装基础依赖: ```bash pip install torch onnx ``` ### 模型导出步骤 1. **加载SAM2模型** 根据官方文档或代码库加载预训练的SAM2模型。通常,模型的加载可以通过如下方式完成: ```python from sam2 import build_sam2 # 指定模型配置文件和权重路径 model_cfg = "sam2_hiera_l.yaml" checkpoint = "sam2_hiera_large.pt" # 构建并加载模型 sam2_model = build_sam2(model_cfg, checkpoint) ``` 2. **模型设置** 将模型切换为评估模式,并确保所有必要的组件(如图像编码器和提示编码器)都已正确配置。 ```python sam2_model.eval() ``` 3. **定义输入张量** 根据模型的输入要求定义一个或多个示例输入张量。对于图像编码器,通常需要一个表示图像的张量。 ```python import torch # 假设输入图像尺寸为 3x1024x1024 dummy_input = torch.randn(1, 3, 1024, 1024) ``` 4. **导出ONNX格式** 使用`torch.onnx.export`函数将模型导出ONNX格式。注意需要指定输入名称、输出名称以及动态轴(如果需要支持动态输入尺寸)。 ```python import torch.onnx # 导出模型 torch.onnx.export( sam2_model, dummy_input, "sam2_model.onnx", export_params=True, # 存储训练参数 opset_version=13, # ONNX算子集版本 do_constant_folding=True, # 优化常量 input_names=['input'], # 模型输入名称 output_names=['output'], # 模型输出名称 dynamic_axes={ 'input': {0: 'batch_size'}, # 动态维度 'output': {0: 'batch_size'} } ) ``` 5. **验证ONNX模型** 使用ONNX运行时加载并运行模型,以验证导出是否成功。 ```python import onnx import onnxruntime as ort # 加载ONNX模型 onnx_model = onnx.load("sam2_model.onnx") onnx.checker.check_model(onnx_model) # 创建推理会话 ort_session = ort.InferenceSession("sam2_model.onnx") # 运行模型 outputs = ort_session.run( None, {'input': dummy_input.numpy()} ) ``` ### 注意事项 - **模型结构复杂性**:SAM2模型可能包含复杂的控制流或条件分支,这可能需要额外的处理才能成功导出。 - **自定义操作**:如果模型中包含自定义操作或未被ONNX支持的操作,可能需要编写自定义转换逻辑或使用替代实现。 - **性能优化**:导出后可以使用ONNX Runtime或其他工具对模型进行进一步的优化,以提升推理速度。 ###
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值