sam.py
import cv2
import onnxruntime
import numpy as np
class SAMImageEncoder:
def __init__(self, model_path):
self.session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.input_name = self.session.get_inputs()[0].name
self.input_shape = self.session.get_inputs()[0].shape
def transform(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = (img - np.array([123.675, 116.28, 103.53])) / np.array([58.395, 57.12, 57.375])
h, w, c = img.shape
size = max(h, w)
img = np.pad(img, ((0, size - h), (0, size - w), (0, 0)), 'constant', constant_values=(0, 0))
img = cv2.resize(img, self.input_shape[2:])
img = np.expand_dims(img, axis=0)
img = np.transpose(img, axes=[0, 3, 1, 2]).astype(np.float32)
return img
def run(self, img):
tensor = self.transform(img)
feature = self.session.run(None, {self.input_name: tensor})[0]
return feature
class SAMImageDecoder:
def __init__(self, model_path):
self.session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.img_size = (1024, 1024)
self.mask_threshold = 0.0
def apply_coords(self, coords, original_size, target_length):
old_h, old_w = original_size
scale = target_length * 1.0 / max(old_h, old_w)
new_h, new_w = old_h * scale, old_w * scale
new_w = int(new_w + 0.5)
new_h = int(new_h + 0.5)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes(self, boxes, original_size, target_length):
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size, target_length)
return boxes.reshape(-1, 4)
def run(self, img_embeddings, origin_image_size, point_coords, point_labels, boxes) :
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
has_mask_input = np.zeros(1, dtype=np.float32)
if len(point_coords) > 0:
point_coords = np.array(point_coords, dtype=np.float32)
point_labels = np.array(point_labels, dtype=np.float32)
point_coords = self.apply_coords(point_coords, origin_image_size, self.img_size[0]).astype(np.float32)
point_coords = np.expand_dims(point_coords, axis=0)
point_labels = np.expand_dims(point_labels, axis=0)
if len(boxes) > 0:
boxes = np.array(boxes, dtype=np.float32)
boxes = self.apply_boxes(boxes, origin_image_size, self.img_size[0]).reshape((1, -1, 2)).astype(np.float32)
box_label = np.array([[2, 3] for i in range(boxes.shape[1] // 2)], dtype=np.float32).reshape((1, -1))
if len(point_coords) > 0:
point_coords = np.concatenate([point_coords, boxes], axis=1)
point_labels = np.concatenate([point_labels, box_label], axis=1)
else:
point_coords = boxes
point_labels = box_label
input_dict = {"image_embeddings": img_embeddings,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": mask_input,
"has_mask_input": has_mask_input,
"orig_im_size": np.array(origin_image_size, dtype=np.float32)}
res = self.session.run(None, input_dict)
result_dict = dict()
for i in range(len(res)):
out_name = self.session.get_outputs()[i].name
if out_name == "masks":
mask = (res[i] > self.mask_threshold).astype(np.int32)
result_dict[out_name] = mask
else:
result_dict[out_name] = res[i]
return result_dict["masks"][0][0][:, :, None]
class SamImage:
def __init__(self, vit_model_path, decoder_model_path):
self.encoder = SAMImageEncoder(vit_model_path)
self.decoder = SAMImageDecoder(decoder_model_path)
def set_image(self, img):
self.origin_image_size = img.shape
self.features = self.encoder.run(img)
def set_points(self, point_coords):
point_labels = [1 for i in range(len(point_coords))]
return self.decoder.run(self.features, self.origin_image_size[:2], point_coords, point_labels, [])
def set_box(self, box):
return self.decoder.run(self.features, self.origin_image_size[:2], [], [], box)
inference.py
import cv2
import numpy as np
from sam import SamImage
image = cv2.imread("dog.jpg")
sam = SamImage("./model/vit_b.onnx", "./model/sam_vit_b.onnx")
sam.set_image(image)
points = [[200, 500], [600, 240]]
box = [0, 219, 417, 528]
mask = sam.set_points(points) #(534, 800, 1)
#mask = sam.set_box(box)
result = np.uint8(np.ones_like(image) * 255 * mask)
result = cv2.addWeighted(image, 0.5, result, 0.5, 0)
for p in points:
cv2.circle(result, (p[0], p[1]), 2, (255, 0, 0), -1)
if len(box) == 4:
cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 1)
cv2.imwrite("result.jpg", result)
分割结果:
模型文件在网盘链接:
https://pan.baidu.com/s/1tLM3uzUKoTaB2Fyxftr3GA?pwd=p6cw 提取码: p6cw


被折叠的 条评论
为什么被折叠?



