背景
原始deepsort主要用于行人、车辆追踪,因此不区分类别,只给出跟踪id。自己使用的场景中,存在不同类别的物体需要跟踪,希望能够对每一类物体的跟踪id,分别从0开始编号。
修改思路
拓展deepsort修改(1):id跳变优化的思路。对检测到的每个类别重新编号。
1.建立一个类别对象的数据结构类
该类别主要用于对deepsort返回结果的封装,这里因为有一些场景的先验知识,对追踪结果进行了二次简单的更新。
class BankDetObject:
"""
返回目标的 [检测框, 置信度, 追踪ID]
"""
def __init__(self, idx_frame, track_id, xyxy, label, confidence, min_confidence=0.5, max_history=5):
self.last_update_frame = idx_frame # 最后更新帧
self.track_id = track_id # 当前跟踪ID
self.xyxy = xyxy # 检测框
self.label = label # 标签
self.confidence = confidence # 置信度
self.max_history = max_history # 最大保留的历史帧数
self.desc = [idx_frame, None] # 描述信息
self.min_confidence = min(min_confidence, confidence) # 可加入历史信息的置信度下限
self.history = deque([[idx_frame, xyxy, label, confidence]], maxlen=self.max_history) # 保存的历史信息
def update(self, idx_frame=None, xyxy=None, label=None, confidence=None):
if not (idx_frame is None or confidence < self.min_confidence): # 置信度高,历史信息作用小,当前结果加入历史信息
self.add_to_history(idx_frame, xyxy, label, confidence)
self.last_update_frame = idx_frame
return self.xyxy, self.label, self.confidence, self.track_id # 返回最大置信度对应的检测结果
def add_to_history(self, idx_frame, xyxy, label, confidence): # 删除超过15帧的历史信息,相当于记录1秒内最有用的几帧
self.history = deque([info for info in self.history if idx_frame - info[0] <= 15], maxlen=self.max_history)
if len(self.history) >= self.max_history:
min_confidence_index = self.get_min_confidence_index() # 找到置信度最小的历史信息索引
if confidence > self.history[min_confidence_index][-1]: # 如果最小置信度也比当前的大,当前帧的置信度就无效,不加入
self.history[min_confidence_index] = [idx_frame, xyxy, label, confidence]
else:
self.history.append([idx_frame, xyxy, label, confidence])
max_confidence_index = self.get_max_confidence_index()
self.xyxy = self.history[max_confidence_index][1]
self.label = self.history[max_confidence_index][2]
self.confidence = self.history[max_confidence_index][3]
def get_max_confidence_index(self): # 返回最大置信度索引
return max(range(len(self.history)), key=lambda i: self.history[i][-1])
def get_min_confidence_index(self): # 返回最小置信度索引
return min(range(len(self.history)), key=lambda i: self.history[i][-1])
def add_desc(self, desc):
self.desc = desc
2.建立一个追踪类别的类
同样建立起原始deepsort的track_id和实际id之间的映射。
class BoxTracker:
def __init__(self):
self.nums = 0 # 历史追踪到的箱子数量
self.track_id_map = dict() # 将追踪ID映射成箱子的序号,字典:{track_id: real_id}
self.boxes = dict() # 箱子的实际ID,字典: {real_id: BankDetObject类}
self.id_set = set() # 当前检测到的ID,集合:{real_id}
def update(self, idx_frame, kx_results, ren_result, ioa_threshold=0.5):
for kx in kx_results:
x1, y1, x2, y2, label, track_id, conf = kx
xyxy = [x1, y1, x2, y2]
if track_id in self.track_id_map and self.track_id_map[track_id] in self.id_set:
real_id = self.track_id_map[track_id]
old_label = self.boxes[real_id].label
self.boxes[real_id].update(idx_frame, xyxy, label, conf)
self.check_state_change(idx_frame, old_label, real_id)
else:
found = False
for real_id, box in self.boxes.items():
if ioa(box.xyxy, xyxy) > ioa_threshold or ioa(xyxy, box.xyxy) > ioa_threshold: # 打开和关闭两种情况
self.track_id_map[track_id] = real_id
box.update(idx_frame, xyxy, label, conf)
found = True
old_label = box.label
box.update(idx_frame, xyxy, label, conf)
self.check_state_change(idx_frame, old_label, real_id)
# print(f"标签变化:{old_label}, {box.label}")
if not found: # 新的箱子
if track_id in self.track_id_map:
real_id = self.track_id_map[track_id]
else:
self.nums += 1
real_id = self.nums
self.track_id_map[track_id] = real_id
self.boxes[real_id] = BankDetObject(idx_frame, track_id, xyxy, label, conf)
self.id_set.add(real_id)
self.cleanup(idx_frame, ren_result, ioa_threshold)
valid_boxes = [] # 返回所有有效的箱子
valid_desc = [] # 箱子的描述
for real_id, box in self.boxes.items():
valid_boxes.append([*box.xyxy, box.label, real_id, box.confidence])
valid_desc.append(box.desc[1])
if len(valid_boxes) == 0:
return np.empty((0, 7), dtype=np.int32), []
return np.array(valid_boxes, dtype=np.int32), valid_desc
def cleanup(self, idx_frame, ren_result, ioa_threshold=0.5):
to_remove = []
for real_id, box in self.boxes.items():
if idx_frame - box.last_update_frame > 15:
is_occluded = False
for ren in ren_result:
ren_xyxy = ren[:4]
if ioa(ren_xyxy, box.xyxy) > ioa_threshold: # 假设IOA大于0.5认为被遮挡
is_occluded = True
break
if not is_occluded:
to_remove.append(real_id)
for real_id in to_remove:
del self.boxes[real_id]
self.id_set.discard(real_id)
# 清理track_id_map中对应的track_id,由于间隔问题,历史信息可能有用,所以不删除。
# track_ids_to_remove = [track_id for track_id, real_id in self.track_id_map.items() if real_id in to_remove]
# for track_id in track_ids_to_remove:
# del self.track_id_map[track_id]
def check_state_change(self, idx_frame, old_label, real_id):
if old_label - self.boxes[real_id].label < 0: # 款箱的标签发生变化,说明进行了打开或者关闭
self.boxes[real_id].add_desc([idx_frame, "打开款箱"])
return
if old_label - self.boxes[real_id].label > 0:
self.boxes[real_id].add_desc([idx_frame, "关闭款箱"])
return
if self.boxes[real_id].desc[1] and idx_frame - self.boxes[real_id].desc[0] > 15:
self.boxes[real_id].add_desc([idx_frame, None])
return
def ioa(bbox1, bbox2):
"""
计算两个检测框的交集面积比上当前检测框的面积(IOA)
"""
x1, y1, x2, y2 = bbox1
x1_, y1_, x2_, y2_ = bbox2
inter_x1 = np.maximum(x1, x1_)
inter_y1 = np.maximum(y1, y1_)
inter_x2 = np.minimum(x2, x2_)
inter_y2 = np.minimum(y2, y2_)
inter_area = np.maximum(0, inter_x2 - inter_x1) * np.maximum(0, inter_y2 - inter_y1)
bbox2_area = (x2_ - x1_) * (y2_ - y1_)
return inter_area / bbox2_area
3.对deepsort输出结果修改
将追踪结果分类别的传入对应的追踪类
# 修改的核心代码,获取deepsort结果后,做后处理
deepsort_outputs = self.deepsort.update(bbox_xywh_d, confs_d, img, cls_d) # x1,y1,x2,y2,label,track_ID,confs
# print(f"bbox_xywh: {bbox_xywh}, confs: {confs}, cls: {cls}, outputs: {outputs}")
if len(deepsort_outputs) > 0:
split_index = split_indices_deepsort(deepsort_outputs, (0, 1, 2, 3, 4)) # 钱没放入检测
res_det = [np.empty((0, 7)) for i in range(5)]
for i in range(5):
if split_index[i] is not None:
res_det[i] = deepsort_outputs[split_index[i]]
money_counter = self.money_counter_tracker.update(idx_frame, res_det[0], res_det[4])
kx_all = np.concatenate([res_det[1], res_det[2]], axis=0)
kx = self.kx_tracker.update(idx_frame, kx_all, res_det[4])
deepsort_outputs = np.concatenate([money_counter, kx[0], res_det[4]], axis=0)
修改结果
人、验钞机和款箱分别追踪:
video_plot_2024-06-05-10-54