interface.py 修改
import cv2
import numpy as np
from mmdeploy_runtime import Detector, Classifier
import utils
INFER_WINDOW_SIZE = [1024, 1024]
INFER_WINDOW_STRIDES = (960, 960)
OVERLAP_NMS_THRESHOLD = 0.5
BBOX_DISPLAY_CONFIDENCE = 0.5
category_color = {
'papule':(140,240,255),
'nevus':(162,247,223),
'nodule':(152,192,250),
'open_comedo':(86,143,128),
'closed_comedo':(60,57,166),
'atrophic_scar':(103,49,132),
'hypertrophic_scar':(249,233,160),
'melasma':(164,93,222),
'pustule':(235,177,161),
'other':(233,222,252)
}
caregory_index = {
0: 'papule',
1: 'nevus',
2: 'nodule',
3: 'open_comedo',
4: 'closed_comedo',
5: 'atrophic_scar',
6: 'hypertrophic_scar',
7: 'melasma',
8: 'pustule',
9: 'other'
}
def create_acne_inference(cfg):
detector = Detector(model_path=cfg['detector']['model_path'],
device_name=cfg['detector']['device_name'],
device_id=cfg['detector']['device_id'])
classifier = Classifier(model_path=cfg['classifier']['model_path'],
device_name=cfg['classifier']['device_name'],
device_id=cfg['classifier']['device_id'])
return detector, classifier
def acne_instance_seg(img: np.ndarray, detector: Detector):
win_gen = utils.WindowGenerator(img.shape[0], img.shape[1],
INFER_WINDOW_SIZE[0], INFER_WINDOW_SIZE[1],
INFER_WINDOW_STRIDES[0], INFER_WINDOW_STRIDES[1])
p_bboxes, p_labels, p_masks = [], [], []
for h_slice, w_slice in win_gen:
img_patch = img[h_slice, w_slice, :]
offset_x, offset_y = w_slice.start, h_slice.start
b, l, m = detector(img_patch)
b[:, [0, 2]] += offset_x
b[:, [1, 3]] += offset_y
p_bboxes.append(b)
p_labels.append(l)
p_masks.extend(m)
p_bboxes = np.concatenate(p_bboxes, axis=0)
p_labels = np.concatenate(p_labels, axis=0)
p_masks = np.array(p_masks, dtype=object)
bboxes, labels, masks = [], [], []
cls = np.unique(p_labels)
for c in cls:
idx = p_labels == c
b = p_bboxes[idx]
l = p_labels[idx]
m = p_masks[idx]
keep = utils.nms(b, OVERLAP_NMS_THRESHOLD)
bboxes.append(b[keep])
labels.append(l[keep])
masks.append(m[keep])
bboxes = np.concatenate(bboxes, axis=0)
labels = np.concatenate(labels, axis=0)
masks = np.concatenate(masks, axis=0)
return bboxes, labels, masks
def acne_severity_grading(img: np.ndarray, classifier: Classifier):
result = classifier(img)
return result
if __name__ == '__main__':
from infer_config import acne_infer_config
import math
image = cv2.imread('infer/test.jpg')
detector, classifier = create_acne_inference(acne_infer_config)
bboxes, labels, masks = acne_instance_seg(image, detector)
cls_result = acne_severity_grading(image, classifier)
print("bbox:",bboxes)
print("labels:",labels)
# print(masks)
print("cls_result:",cls_result)
count = 0
indices = [i for i in range(len(bboxes))]
for index, bbox, label_id in zip(indices, bboxes, labels):
count = count + 1
# print(count,index,bbox,label_id)
[left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
if score < BBOX_DISPLAY_CONFIDENCE:
continue
color = category_color[caregory_index[label_id]]
cv2.rectangle(image, (left, top), (right, bottom), color)
label_text = f"{caregory_index[label_id]} | {int(100 * score)}%" # 使用痤疮的标签作为标注信息
cv2.putText(image, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) # 添加标注信息
if masks[index].size:
mask = masks[index]
# blue, green, red = cv2.split(image)
# if mask.shape == image.shape[:2]: # rtmdet-inst
# mask_img = blue
# else: # maskrcnn
# x0 = int(max(math.floor(bbox[0]) - 1, 0))
# y0 = int(max(math.floor(bbox[1]) - 1, 0))
# mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
# cv2.bitwise_or(mask, mask_img, mask_img)
# image = cv2.merge([blue, green, red])
if mask.shape == image.shape[:2]: # rtmdet-inst
mask_img = image.copy()
else: # maskrcnn
x0 = int(max(math.floor(bbox[0]) - 1, 0))
y0 = int(max(math.floor(bbox[1]) - 1, 0))
mask_img = image[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
mask_img[mask.astype(bool)] = color
text = f'Prediction: {cls_result[0][0]}, {cls_result[0][1]:.4f} ' \
f'({acne_infer_config["classifier"]["classes"][cls_result[0][0]]})'
cv2.putText(image, text, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
cv2.imwrite('output_detection.png', image)