Simple Online and Realtime Tracking (SORT) 是一种简洁、高效且实用的多目标跟踪算法。它以目标检测方法(如 YOLO)为基础,结合卡尔曼滤波与匈牙利算法,极大地提升了多目标跟踪的速度。然而,仅依靠 IOU 进行匹配虽能实现快速处理,却会导致严重的 ID switch 问题。SORT 的核心优势在于其算法架构,卡尔曼滤波负责预测目标位置,匈牙利算法则优化检测结果与跟踪目标的匹配,二者协同工作,确保了算法的高效运行。
import os
import numpy as np
import cv2
from filterpy.kalman import KalmanFilter
from ultralytics import YOLO
# ---------------------- SORT 算法及辅助函数 ---------------------- #
def linear_assignment(cost_matrix):
try:
import lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
return np.array([[y[i], i] for i in x if i >= 0])
except ImportError:
from scipy.optimize import linear_sum_assignment
x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y)))
def iou_batch(bb_test, bb_gt):
bb_gt = np.expand_dims(bb_gt, 0)
bb_test = np.expand_dims(bb_test, 1)
xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
return o
def convert_bbox_to_z(bbox):
# bbox: [x1, y1, x2, y2]
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w / 2.
y = bbox[1] + h / 2.
s = w * h
r = w / float(h)
return np.array([x, y, s, r]).reshape((4, 1))
def convert_x_to_bbox(x, score=None):
# x: 状态向量 (7,1)
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if score is None:
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4))
else:
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5))
class KalmanBoxTracker(object):
count = 0
def __init__(self, bbox):
self.kf = KalmanFilter(dim_x=7, dim_z=4)
# 状态:[center_x, center_y, s, r, vx, vy, vs]
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000.
self.kf.P *= 10.
self.kf.Q[-1, -1] *= 0.01
self.kf.Q[4:, 4:] *= 0.01
self.kf.x[:4] = convert_bbox_to_z(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.history = []
self.hits = 0
self.hit_streak = 0
self.age = 0
def update(self, bbox):
self.time_since_update = 0
self.history = []
self.hits += 1
self.hit_streak += 1
self.kf.update(convert_bbox_to_z(bbox))
def predict(self):
# 若预测后的面积小于0则置零
if ((self.kf.x[6] + self.kf.x[2]) <= 0):
self.kf.x[6] *= 0.0
self.kf.predict()
self.age += 1
if (self.time_since_update > 0):
self.hit_streak = 0
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x))
return self.history[-1]
def get_state(self):
return convert_x_to_bbox(self.kf.x)
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
if len(trackers) == 0:
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
iou_matrix = iou_batch(detections, trackers)
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-iou_matrix)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if d not in matched_indices[:, 0]:
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if t not in matched_indices[:, 1]:
unmatched_trackers.append(t)
matches = []
for m in matched_indices:
if iou_matrix[m[0], m[1]] < iou_threshold:
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if len(matches) == 0:
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
class Sort(object):
def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.trackers = []
self.frame_count = 0
def update(self, dets=np.empty((0, 5))):
self.frame_count += 1
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(self.trackers):
pos = trk.predict()[0]
trks[t] = [pos[0], pos[1], pos[2], pos[3], 0]
if np.any(np.isnan(pos)):
to_del.append(t)
if len(to_del) > 0:
for t in reversed(to_del):
self.trackers.pop(t)
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :])
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
d = trk.get_state()[0]
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1))
i -= 1
if trk.time_since_update > self.max_age:
self.trackers.pop(i)
if len(ret) > 0:
return np.concatenate(ret)
return np.empty((0, 5))
# ---------------------- YOLOv8 检测模块 ---------------------- #
def yolo_detect(frame, model, conf_threshold=0.5):
"""
使用 YOLOv8 检测图像中的目标(例如 person)。
返回格式为:[[x1, y1, x2, y2, score], ...]
"""
results = model(frame)
# 检查检测结果
if results and results[0].boxes is not None and len(results[0].boxes) > 0:
# 获取检测框(xyxy 格式)和置信度
boxes = results[0].boxes.xyxy.cpu().numpy() if hasattr(results[0].boxes.xyxy, 'cpu') else results[0].boxes.xyxy.numpy()
confs = results[0].boxes.conf.cpu().numpy() if hasattr(results[0].boxes.conf, 'cpu') else results[0].boxes.conf.numpy()
# 若需要筛选特定类别,可以参考下面的代码(此处默认所有检测均有效)
# cls = results[0].boxes.cls.cpu().numpy() if hasattr(results[0].boxes.cls, 'cpu') else results[0].boxes.cls.numpy()
# mask = (cls == 0) # 假设只保留类别为 person(COCO 中 class_id == 0)
# boxes = boxes[mask]
# confs = confs[mask]
dets = np.hstack((boxes, confs.reshape(-1, 1)))
else:
dets = np.empty((0, 5))
return dets # x1 y1 x2 y2 score
# ---------------------- 主函数 ---------------------- #
def main():
# 加载 YOLOv8 模型
yolo_weights = r"yolov8n.pt" # YOLO 权重文件路径
model = YOLO(yolo_weights)
# 初始化 SORT 跟踪器
max_age, min_hits, iou_threshold = 3, 3, 0.3
mot_tracker = Sort(max_age=max_age, min_hits=min_hits, iou_threshold=iou_threshold)
colours = 255 * np.random.rand(32, 3)
# 打开视频流,这里使用摄像头;若使用视频文件,将参数改为文件路径
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("无法打开视频流")
return
while True:
ret, frame = cap.read()
if not ret:
break
# 使用 YOLOv8 检测
dets = yolo_detect(frame, model)
# 更新 SORT 跟踪器
tracked_objects = mot_tracker.update(dets)
# 绘制跟踪结果
for d in tracked_objects:
# d: [x1, y1, x2, y2, track_id]
x1, y1, x2, y2, track_id = d.astype(int)
color = colours[int(track_id) % 32].tolist()
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, str(int(track_id)), (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
cv2.imshow("YOLOv8 + SORT Tracking", frame)
key = cv2.waitKey(1)
if key == 27: # 按 Esc 键退出
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()