多目标跟踪——SORT算法

 Simple Online and Realtime Tracking (SORT) 是一种简洁、高效且实用的多目标跟踪算法。它以目标检测方法(如 YOLO)为基础,结合卡尔曼滤波与匈牙利算法,极大地提升了多目标跟踪的速度。然而,仅依靠 IOU 进行匹配虽能实现快速处理,却会导致严重的 ID switch 问题。SORT 的核心优势在于其算法架构,卡尔曼滤波负责预测目标位置,匈牙利算法则优化检测结果与跟踪目标的匹配,二者协同工作,确保了算法的高效运行。

import os
import numpy as np
import cv2
from filterpy.kalman import KalmanFilter
from ultralytics import YOLO

# ---------------------- SORT 算法及辅助函数 ---------------------- #
 
def linear_assignment(cost_matrix):
    try:
        import lap
        _, x, y = lap.lapjv(cost_matrix, extend_cost=True)
        return np.array([[y[i], i] for i in x if i >= 0])
    except ImportError:
        from scipy.optimize import linear_sum_assignment
        x, y = linear_sum_assignment(cost_matrix)
        return np.array(list(zip(x, y)))
 
def iou_batch(bb_test, bb_gt):
    bb_gt = np.expand_dims(bb_gt, 0)
    bb_test = np.expand_dims(bb_test, 1)
    xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
    yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
    xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
    yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
    w = np.maximum(0., xx2 - xx1)
    h = np.maximum(0., yy2 - yy1)
    wh = w * h
    o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
              + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
    return o
 
def convert_bbox_to_z(bbox):
    # bbox: [x1, y1, x2, y2]
    w = bbox[2] - bbox[0]
    h = bbox[3] - bbox[1]
    x = bbox[0] + w / 2.
    y = bbox[1] + h / 2.
    s = w * h
    r = w / float(h)
    return np.array([x, y, s, r]).reshape((4, 1))
 
def convert_x_to_bbox(x, score=None):
    # x: 状态向量 (7,1)
    w = np.sqrt(x[2] * x[3])
    h = x[2] / w
    if score is None:
        return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4))
    else:
        return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5))
 
class KalmanBoxTracker(object):
    count = 0
    def __init__(self, bbox):
        self.kf = KalmanFilter(dim_x=7, dim_z=4)
        # 状态:[center_x, center_y, s, r, vx, vy, vs]
        self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
                              [0, 1, 0, 0, 0, 1, 0],
                              [0, 0, 1, 0, 0, 0, 1],
                              [0, 0, 0, 1, 0, 0, 0],
                              [0, 0, 0, 0, 1, 0, 0],
                              [0, 0, 0, 0, 0, 1, 0],
                              [0, 0, 0, 0, 0, 0, 1]])
        self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
                              [0, 1, 0, 0, 0, 0, 0],
                              [0, 0, 1, 0, 0, 0, 0],
                              [0, 0, 0, 1, 0, 0, 0]])
        self.kf.R[2:, 2:] *= 10.
        self.kf.P[4:, 4:] *= 1000.
        self.kf.P *= 10.
        self.kf.Q[-1, -1] *= 0.01
        self.kf.Q[4:, 4:] *= 0.01
        self.kf.x[:4] = convert_bbox_to_z(bbox)
        self.time_since_update = 0
        self.id = KalmanBoxTracker.count
        KalmanBoxTracker.count += 1
        self.history = []
        self.hits = 0
        self.hit_streak = 0
        self.age = 0
 
    def update(self, bbox):
        self.time_since_update = 0
        self.history = []
        self.hits += 1
        self.hit_streak += 1
        self.kf.update(convert_bbox_to_z(bbox))
 
    def predict(self):
        # 若预测后的面积小于0则置零
        if ((self.kf.x[6] + self.kf.x[2]) <= 0):
            self.kf.x[6] *= 0.0
        self.kf.predict()
        self.age += 1
        if (self.time_since_update > 0):
            self.hit_streak = 0
        self.time_since_update += 1
        self.history.append(convert_x_to_bbox(self.kf.x))
        return self.history[-1]
 
    def get_state(self):
        return convert_x_to_bbox(self.kf.x)
 
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
    if len(trackers) == 0:
        return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
    iou_matrix = iou_batch(detections, trackers)
    if min(iou_matrix.shape) > 0:
        a = (iou_matrix > iou_threshold).astype(np.int32)
        if a.sum(1).max() == 1 and a.sum(0).max() == 1:
            matched_indices = np.stack(np.where(a), axis=1)
        else:
            matched_indices = linear_assignment(-iou_matrix)
    else:
        matched_indices = np.empty(shape=(0, 2))
    unmatched_detections = []
    for d, det in enumerate(detections):
        if d not in matched_indices[:, 0]:
            unmatched_detections.append(d)
    unmatched_trackers = []
    for t, trk in enumerate(trackers):
        if t not in matched_indices[:, 1]:
            unmatched_trackers.append(t)
    matches = []
    for m in matched_indices:
        if iou_matrix[m[0], m[1]] < iou_threshold:
            unmatched_detections.append(m[0])
            unmatched_trackers.append(m[1])
        else:
            matches.append(m.reshape(1, 2))
    if len(matches) == 0:
        matches = np.empty((0, 2), dtype=int)
    else:
        matches = np.concatenate(matches, axis=0)
    return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
 
class Sort(object):
    def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
        self.max_age = max_age
        self.min_hits = min_hits
        self.iou_threshold = iou_threshold
        self.trackers = []
        self.frame_count = 0
 
    def update(self, dets=np.empty((0, 5))):
        self.frame_count += 1
        trks = np.zeros((len(self.trackers), 5))
        to_del = []
        ret = []
        for t, trk in enumerate(self.trackers):
            pos = trk.predict()[0]
            trks[t] = [pos[0], pos[1], pos[2], pos[3], 0]
            if np.any(np.isnan(pos)):
                to_del.append(t)
        if len(to_del) > 0:
            for t in reversed(to_del):
                self.trackers.pop(t)
        matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
        for m in matched:
            self.trackers[m[1]].update(dets[m[0], :])
        for i in unmatched_dets:
            trk = KalmanBoxTracker(dets[i, :])
            self.trackers.append(trk)
        i = len(self.trackers)
        for trk in reversed(self.trackers):
            d = trk.get_state()[0]
            if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
                ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1))
            i -= 1
            if trk.time_since_update > self.max_age:
                self.trackers.pop(i)
        if len(ret) > 0:
            return np.concatenate(ret)
        return np.empty((0, 5))
 
# ---------------------- YOLOv8 检测模块 ---------------------- #
def yolo_detect(frame, model, conf_threshold=0.5):
    """
    使用 YOLOv8 检测图像中的目标(例如 person)。
    返回格式为:[[x1, y1, x2, y2, score], ...]
    """
    results = model(frame)
    # 检查检测结果
    if results and results[0].boxes is not None and len(results[0].boxes) > 0:
        # 获取检测框(xyxy 格式)和置信度
        boxes = results[0].boxes.xyxy.cpu().numpy() if hasattr(results[0].boxes.xyxy, 'cpu') else results[0].boxes.xyxy.numpy()
        confs = results[0].boxes.conf.cpu().numpy() if hasattr(results[0].boxes.conf, 'cpu') else results[0].boxes.conf.numpy()
        # 若需要筛选特定类别,可以参考下面的代码(此处默认所有检测均有效)
        # cls = results[0].boxes.cls.cpu().numpy() if hasattr(results[0].boxes.cls, 'cpu') else results[0].boxes.cls.numpy()
        # mask = (cls == 0)  # 假设只保留类别为 person(COCO 中 class_id == 0)
        # boxes = boxes[mask]
        # confs = confs[mask]
        dets = np.hstack((boxes, confs.reshape(-1, 1)))
    else:
        dets = np.empty((0, 5))
    return dets # x1 y1 x2 y2 score
 
# ---------------------- 主函数 ---------------------- #
def main():
    # 加载 YOLOv8 模型
    yolo_weights = r"yolov8n.pt"  # YOLO 权重文件路径
    model = YOLO(yolo_weights)
 
    # 初始化 SORT 跟踪器
    max_age, min_hits, iou_threshold = 3, 3, 0.3
    mot_tracker = Sort(max_age=max_age, min_hits=min_hits, iou_threshold=iou_threshold)
    colours = 255 * np.random.rand(32, 3)

    # 打开视频流,这里使用摄像头;若使用视频文件,将参数改为文件路径
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        print("无法打开视频流")
        return
 
    while True:
        ret, frame = cap.read()
        if not ret:
            break
 
        # 使用 YOLOv8 检测
        dets = yolo_detect(frame, model)
 
        # 更新 SORT 跟踪器
        tracked_objects = mot_tracker.update(dets)
        
        # 绘制跟踪结果
        for d in tracked_objects:
            # d: [x1, y1, x2, y2, track_id]
            x1, y1, x2, y2, track_id = d.astype(int)
            color = colours[int(track_id) % 32].tolist()
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            cv2.putText(frame, str(int(track_id)), (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
        cv2.imshow("YOLOv8 + SORT Tracking", frame)
        key = cv2.waitKey(1)
        if key == 27:  # 按 Esc 键退出
            break
 
    cap.release()
    cv2.destroyAllWindows()
 
if __name__ == '__main__':
    main()

### DeepSORT多目标跟踪算法概述 DeepSORT是一种先进的计算机视觉目标跟踪算法,旨在为每个对象分配唯一ID并保持其身份一致性。作为SORT算法的增强版,该算法不仅继承了原版的优点——即简单高效的数据关联策略和实时处理能力,还通过集成深度学习组件来改善长期跟踪表现。 #### 原理 DeepSORT利用深度神经网络提取目标外观特征向量,并将其融入到传统的基于检测框位置的状态估计框架之中。具体而言,在每一帧图像中获得的对象边界框会先经过一个预训练好的卷积神经网络(CNN),从而得到表征个体特性的嵌入(embedding)[^1]。这些高维空间里的表示随后会被用来计算不同时间戳下同一实体间的相似度得分矩阵;与此同时,Kalman滤波器负责维护各个轨迹的历史位移趋势并向未来时刻做出预测。最终,借助于匈牙利算法完成当前观测与已有轨迹之间的最优配对决策过程[^2]。 #### 实现 对于实际部署来说,DeepSORT的设计允许使用者灵活定制化不同的组成部分: - **目标检测模型**:可以根据特定任务需求选用合适的架构(如YOLOv3、Faster R-CNN等),只要能提供可靠的候选区域即可满足输入要求; - **重识别(Re-ID)子网**:通常采用Market1501数据集上预先训练过的ResNet变体或其他适合的人脸/车辆再认专用结构; - **参数调整**:诸如最大连续丢失次数(max_age)、最小可见比例(min_hits)之类的超参可根据应用场景特点适当调节优化性能指标[^4]。 此外,官方开源项目提供了Python接口封装良好的`Tracker`类实例,便于快速搭建原型系统或开展实验验证工作。 ```python from deep_sort import nn_matching from deep_sort.detection import Detection from deep_sort.tracker import Tracker metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance=0.2) tracker = Tracker(metric) for frame_idx, detections in enumerate(detections_sequence): # Convert raw detection results into the format expected by tracker.update() dets = [Detection(bbox, score, feature) for bbox, score, feature in detections] # Update tracks based on new observations (detections). tracker.predict() # Predict positions of existing tracked objects. matches, unmatched_detections, unmatched_tracks = tracker.match(dets) ``` #### 应用 得益于强大的泛化能力和出色的鲁棒性,DeepSORT广泛应用于智慧城市监控、自动驾驶辅助感知等多个领域内涉及大量移动物体交互分析的任务当中。特别是在人群密集场所的安全防范方面表现出色,能够有效应对遮挡干扰等问题带来的挑战,确保长时间稳定可靠地锁定感兴趣的目标个体而不发生漂移现象[^3]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

张飞飞飞飞飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值