from ultralytics import YOLO
import cv2
import numpy as np
from sklearn.cluster import DBSCAN
class LightStripProcessor:
"""
Attributes:
model (YOLO): 加载的YOLO分割模型
color_map (list): BGR格式颜色配置表
dbscan_eps (int): DBSCAN聚类半径参数
dbscan_min_samples (int): DBSCAN最小样本数
iou_threshold (float): 掩膜去重IOU阈值
"""
def __init__(self, model_path='light_seg.pt'):
"""初始化处理对象
Args:
model_path (str): YOLO模型文件路径
"""
self.model = YOLO(model_path)
self.color_map = [
(100, 100, 100), (0, 255, 255), (255, 0, 0),
(255, 255, 0), (255, 0, 255), (0, 128, 0), (0, 0, 255)
]
self.dbscan_eps = 5
self.dbscan_min_samples = 3
self.iou_threshold = 0.5
def _mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
"""计算两个掩膜的交并比
Args:
mask1: 二进制掩膜数组 (H,W)
mask2: 二进制掩膜数组 (H,W)
Returns:
float: IOU数值 [0.0-1.0]
"""
intersection = np.logical_and(mask1, mask2)
union = np.logical_or(mask1, mask2)
return np.sum(intersection) / np.sum(union)
def _filter_outliers(self, coords: np.ndarray) -> np.ndarray:
"""DBSCAN离群点过滤
Args:
coords: 坐标点数组 (N,2)
Returns:
np.ndarray: 过滤后的坐标数组
"""
if len(coords) < 10:
return coords
clustering = DBSCAN(eps=self.dbscan_eps,
min_samples=self.dbscan_min_samples).fit(coords)
return coords[clustering.labels_ != -1]
def _refine_masks(self, results) -> list:
"""掩膜预处理
Args:
results: YOLO预测结果对象
Returns:
list: 有效掩膜索引列表
"""
valid_masks = []
masks = results.masks.data.cpu().numpy()
for i in range(len(masks)):
# 轮廓分析
contours, _ = cv2.findContours(masks[i].astype(np.uint8),
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
continue
# 几何过滤
rect = cv2.minAreaRect(contours[0])
w, h = rect[1]
if min(w, h) < 1e-5 or max(w, h)/min(w, h) < 5:
continue
# IOU去重
if not any(self._mask_iou(masks[i], masks[vm]) > self.iou_threshold
for vm in valid_masks):
valid_masks.append(i)
return valid_masks
def _process_light_strip(self, mask: np.ndarray) -> list:
"""单条光带中心线提取
Args:
mask: 二进制掩膜数组 (H,W)
Returns:
list: 光带中心点坐标列表 [(x,y),...]
"""
coords = np.column_stack(np.where(mask))[:, [1, 0]] # 转换(y,x)->(x,y)
if not coords.size:
return []
filtered = self._filter_outliers(coords)
y_dict = {}
# 纵向分组统计
for x, y in filtered:
y_dict.setdefault(y, []).append(x)
centers = []
for y in sorted(y_dict.keys()):
xs = y_dict[y]
if len(xs) > 1:
q1, q3 = np.percentile(xs, [25, 75])
valid_xs = [x for x in xs if q1-1.5*(q3-q1) <= x <= q3+1.5*(q3-q1)]
if valid_xs:
centers.append((int(np.median(valid_xs)), y))
return centers
def process_image(self, image_path: str, output_path: str = 'output.jpg') -> None:
"""完整图像处理流程
Args:
image_path: 输入图像路径
output_path: 输出图像保存路径
"""
img = cv2.imread(image_path)
results = self.model(image_path)[0]
for idx in self._refine_masks(results):
mask = results.masks.data[idx].cpu().numpy()
centers = self._process_light_strip(mask)
if len(centers) > 4:
# 高斯平滑
centers = np.array(centers)
x_smooth = cv2.GaussianBlur(centers[:, 0].astype(float), (5, 5), 0)
smoothed = np.column_stack([x_smooth.round(), centers[:, 1]]).astype(int)
# 绘制光带
cv2.polylines(img, [smoothed.reshape(-1, 1, 2)], False,
self.color_map[int(results.boxes.cls[idx])], 1)
cv2.imwrite(output_path, img)
# 使用示例
if __name__ == "__main__":
processor = LightStripProcessor()
processor.process_image('C:/project/light_segment/cl_color_image.bmp', 'final_output3.jpg')
image 1/1 C:\project\light_segment\cl_color_image.bmp: 544x640 6 whites, 11 blues, 14 yellows, 7 reds, 4 greens, 11 purples, 8 cyans, 52.2ms
Speed: 3.3ms preprocess, 52.2ms inference, 90.3ms postprocess per image at shape (1, 3, 544, 640)
Traceback (most recent call last):
File "C:\project\light_segment\main_mid.py", line 141, in <module>
processor.process_image('C:/project/light_segment/cl_color_image.bmp', 'final_output3.jpg')
File "C:\project\light_segment\main_mid.py", line 122, in process_image
for idx in self._refine_masks(results):
File "C:\project\light_segment\main_mid.py", line 78, in _refine_masks
if not any(self._mask_iou(masks[i], masks[vm]) > self.iou_threshold
File "C:\project\light_segment\main_mid.py", line 78, in <genexpr>
if not any(self._mask_iou(masks[i], masks[vm]) > self.iou_threshold
TypeError: LightStripProcessor._mask_iou() takes 2 positional arguments but 3 were given
最新发布