SAM模型onnxruntime推理

sam.py

import cv2
import onnxruntime
import numpy as np


class SAMImageEncoder:
    def __init__(self, model_path):
        self.session = onnxruntime.InferenceSession(model_path,  providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
        self.input_name = self.session.get_inputs()[0].name
        self.input_shape = self.session.get_inputs()[0].shape

    def transform(self, img):
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = (img - np.array([123.675, 116.28, 103.53])) / np.array([58.395, 57.12, 57.375])
        h, w, c = img.shape
        size = max(h, w)
        img = np.pad(img, ((0, size - h), (0, size - w), (0, 0)), 'constant', constant_values=(0, 0))
        img = cv2.resize(img, self.input_shape[2:])
        img = np.expand_dims(img, axis=0)
        img = np.transpose(img, axes=[0, 3, 1, 2]).astype(np.float32)
        return img

    def run(self, img):
        tensor = self.transform(img)
        feature = self.session.run(None, {self.input_name: tensor})[0]
        return feature


class SAMImageDecoder:
    def __init__(self, model_path):       
        self.session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
        self.img_size = (1024, 1024)
        self.mask_threshold = 0.0

    def apply_coords(self, coords, original_size, target_length):
        old_h, old_w = original_size
        scale = target_length * 1.0 / max(old_h, old_w)
        new_h, new_w = old_h * scale, old_w * scale
        new_w = int(new_w + 0.5)
        new_h = int(new_h + 0.5)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords
    
    def apply_boxes(self, boxes, original_size, target_length):
        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size, target_length)
        return boxes.reshape(-1, 4)

    def run(self, img_embeddings, origin_image_size, point_coords, point_labels, boxes) :
        mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
        has_mask_input = np.zeros(1, dtype=np.float32)

        if len(point_coords) > 0:
            point_coords = np.array(point_coords, dtype=np.float32)
            point_labels = np.array(point_labels, dtype=np.float32)
            point_coords = self.apply_coords(point_coords, origin_image_size, self.img_size[0]).astype(np.float32)
            point_coords = np.expand_dims(point_coords, axis=0)
            point_labels = np.expand_dims(point_labels, axis=0)

        if len(boxes) > 0:    
            boxes = np.array(boxes, dtype=np.float32)                
            boxes = self.apply_boxes(boxes, origin_image_size, self.img_size[0]).reshape((1, -1, 2)).astype(np.float32)
            box_label = np.array([[2, 3] for i in range(boxes.shape[1] // 2)], dtype=np.float32).reshape((1, -1))
            if len(point_coords) > 0:
                point_coords = np.concatenate([point_coords, boxes], axis=1)
                point_labels = np.concatenate([point_labels, box_label], axis=1)
            else:
                point_coords = boxes
                point_labels = box_label

        input_dict = {"image_embeddings": img_embeddings,
                      "point_coords": point_coords,
                      "point_labels": point_labels,
                      "mask_input": mask_input,
                      "has_mask_input": has_mask_input,
                      "orig_im_size": np.array(origin_image_size, dtype=np.float32)}
        res = self.session.run(None, input_dict)

        result_dict = dict()
        for i in range(len(res)):
            out_name = self.session.get_outputs()[i].name
            if out_name == "masks":
                mask = (res[i] > self.mask_threshold).astype(np.int32)
                result_dict[out_name] = mask
            else:
                result_dict[out_name] = res[i]
        return result_dict["masks"][0][0][:, :, None]


class SamImage:
    def __init__(self, vit_model_path, decoder_model_path):
        self.encoder = SAMImageEncoder(vit_model_path)
        self.decoder = SAMImageDecoder(decoder_model_path)

    def set_image(self, img):
        self.origin_image_size = img.shape
        self.features = self.encoder.run(img)

    def set_points(self, point_coords):
        point_labels = [1 for i in range(len(point_coords))]
        return self.decoder.run(self.features, self.origin_image_size[:2], point_coords, point_labels, [])
    
    def set_box(self, box):
        return self.decoder.run(self.features, self.origin_image_size[:2], [], [], box)

inference.py

import cv2
import numpy as np
from sam import SamImage


image = cv2.imread("dog.jpg")
sam = SamImage("./model/vit_b.onnx", "./model/sam_vit_b.onnx")
sam.set_image(image)

points = [[200, 500], [600, 240]]
box = [0, 219, 417, 528] 
mask = sam.set_points(points)   #(534, 800, 1)
#mask = sam.set_box(box)

result = np.uint8(np.ones_like(image) * 255 * mask)
result = cv2.addWeighted(image, 0.5, result, 0.5, 0)
for p in points:
    cv2.circle(result, (p[0], p[1]), 2, (255, 0, 0), -1)
if len(box) == 4:
    cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 1)
cv2.imwrite("result.jpg", result)

分割结果:在这里插入图片描述

模型文件在网盘链接:
https://pan.baidu.com/s/1tLM3uzUKoTaB2Fyxftr3GA?pwd=p6cw 提取码: p6cw

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

给算法爸爸上香

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值