deepsort修改(2):多类别追踪,id根据类别从1编号

背景

原始deepsort主要用于行人、车辆追踪,因此不区分类别,只给出跟踪id。自己使用的场景中,存在不同类别的物体需要跟踪,希望能够对每一类物体的跟踪id,分别从0开始编号。

修改思路

拓展deepsort修改(1):id跳变优化的思路。对检测到的每个类别重新编号。

1.建立一个类别对象的数据结构类

该类别主要用于对deepsort返回结果的封装,这里因为有一些场景的先验知识,对追踪结果进行了二次简单的更新。

class BankDetObject:
    """
    返回目标的 [检测框, 置信度, 追踪ID]
    """
    def __init__(self, idx_frame, track_id, xyxy, label, confidence, min_confidence=0.5, max_history=5):
        self.last_update_frame = idx_frame  # 最后更新帧
        self.track_id = track_id            # 当前跟踪ID
        self.xyxy = xyxy                    # 检测框
        self.label = label                  # 标签
        self.confidence = confidence        # 置信度
        self.max_history = max_history      # 最大保留的历史帧数
        self.desc = [idx_frame, None]       # 描述信息
        self.min_confidence = min(min_confidence, confidence)       # 可加入历史信息的置信度下限
        self.history = deque([[idx_frame, xyxy, label, confidence]], maxlen=self.max_history)     # 保存的历史信息

    def update(self, idx_frame=None, xyxy=None, label=None, confidence=None):
        if not (idx_frame is None or confidence < self.min_confidence):    # 置信度高,历史信息作用小,当前结果加入历史信息
            self.add_to_history(idx_frame, xyxy, label, confidence)
            self.last_update_frame = idx_frame

        return self.xyxy, self.label, self.confidence, self.track_id                    # 返回最大置信度对应的检测结果

    def add_to_history(self, idx_frame, xyxy, label, confidence):      # 删除超过15帧的历史信息,相当于记录1秒内最有用的几帧
        self.history = deque([info for info in self.history if idx_frame - info[0] <= 15], maxlen=self.max_history)

        if len(self.history) >= self.max_history:
            min_confidence_index = self.get_min_confidence_index()      # 找到置信度最小的历史信息索引
            if confidence > self.history[min_confidence_index][-1]:     # 如果最小置信度也比当前的大,当前帧的置信度就无效,不加入
                self.history[min_confidence_index] = [idx_frame, xyxy, label, confidence]
        else:
            self.history.append([idx_frame, xyxy, label, confidence])

        max_confidence_index = self.get_max_confidence_index()
        self.xyxy = self.history[max_confidence_index][1]
        self.label = self.history[max_confidence_index][2]
        self.confidence = self.history[max_confidence_index][3]

    def get_max_confidence_index(self):  # 返回最大置信度索引
        return max(range(len(self.history)), key=lambda i: self.history[i][-1])

    def get_min_confidence_index(self):  # 返回最小置信度索引
        return min(range(len(self.history)), key=lambda i: self.history[i][-1])

    def add_desc(self, desc):
        self.desc = desc

2.建立一个追踪类别的类

同样建立起原始deepsort的track_id和实际id之间的映射。

class BoxTracker:
    def __init__(self):
        self.nums = 0                   # 历史追踪到的箱子数量
        self.track_id_map = dict()      # 将追踪ID映射成箱子的序号,字典:{track_id: real_id}
        self.boxes = dict()             # 箱子的实际ID,字典: {real_id: BankDetObject类}
        self.id_set = set()             # 当前检测到的ID,集合:{real_id}

    def update(self, idx_frame, kx_results, ren_result, ioa_threshold=0.5):
        for kx in kx_results:
            x1, y1, x2, y2, label, track_id, conf = kx
            xyxy = [x1, y1, x2, y2]

            if track_id in self.track_id_map and self.track_id_map[track_id] in self.id_set:
                real_id = self.track_id_map[track_id]
                old_label = self.boxes[real_id].label
                self.boxes[real_id].update(idx_frame, xyxy, label, conf)
                self.check_state_change(idx_frame, old_label, real_id)
            else:
                found = False
                for real_id, box in self.boxes.items():
                    if ioa(box.xyxy, xyxy) > ioa_threshold or ioa(xyxy, box.xyxy) > ioa_threshold:  # 打开和关闭两种情况
                        self.track_id_map[track_id] = real_id
                        box.update(idx_frame, xyxy, label, conf)
                        found = True

                        old_label = box.label
                        box.update(idx_frame, xyxy, label, conf)
                        self.check_state_change(idx_frame, old_label, real_id)
                        # print(f"标签变化:{old_label}, {box.label}")

                if not found:  # 新的箱子
                    if track_id in self.track_id_map:
                        real_id = self.track_id_map[track_id]
                    else:
                        self.nums += 1
                        real_id = self.nums
                    self.track_id_map[track_id] = real_id
                    self.boxes[real_id] = BankDetObject(idx_frame, track_id, xyxy, label, conf)
                    self.id_set.add(real_id)

        self.cleanup(idx_frame, ren_result, ioa_threshold)

        valid_boxes = []    # 返回所有有效的箱子
        valid_desc = []     # 箱子的描述
        for real_id, box in self.boxes.items():
            valid_boxes.append([*box.xyxy, box.label, real_id, box.confidence])
            valid_desc.append(box.desc[1])

        if len(valid_boxes) == 0:
            return np.empty((0, 7), dtype=np.int32), []

        return np.array(valid_boxes, dtype=np.int32), valid_desc

    def cleanup(self, idx_frame, ren_result, ioa_threshold=0.5):
        to_remove = []
        for real_id, box in self.boxes.items():
            if idx_frame - box.last_update_frame > 15:
                is_occluded = False
                for ren in ren_result:
                    ren_xyxy = ren[:4]
                    if ioa(ren_xyxy, box.xyxy) > ioa_threshold:  # 假设IOA大于0.5认为被遮挡
                        is_occluded = True
                        break
                if not is_occluded:
                    to_remove.append(real_id)

        for real_id in to_remove:
            del self.boxes[real_id]
            self.id_set.discard(real_id)

        # 清理track_id_map中对应的track_id,由于间隔问题,历史信息可能有用,所以不删除。
        # track_ids_to_remove = [track_id for track_id, real_id in self.track_id_map.items() if real_id in to_remove]
        # for track_id in track_ids_to_remove:
        #     del self.track_id_map[track_id]

    def check_state_change(self, idx_frame, old_label, real_id):
        if old_label - self.boxes[real_id].label < 0:  # 款箱的标签发生变化,说明进行了打开或者关闭
            self.boxes[real_id].add_desc([idx_frame, "打开款箱"])
            return
        if old_label - self.boxes[real_id].label > 0:
            self.boxes[real_id].add_desc([idx_frame, "关闭款箱"])
            return
        if self.boxes[real_id].desc[1] and idx_frame - self.boxes[real_id].desc[0] > 15:
            self.boxes[real_id].add_desc([idx_frame, None])
        return


def ioa(bbox1, bbox2):
    """
    计算两个检测框的交集面积比上当前检测框的面积(IOA)
    """
    x1, y1, x2, y2 = bbox1
    x1_, y1_, x2_, y2_ = bbox2
    inter_x1 = np.maximum(x1, x1_)
    inter_y1 = np.maximum(y1, y1_)
    inter_x2 = np.minimum(x2, x2_)
    inter_y2 = np.minimum(y2, y2_)
    inter_area = np.maximum(0, inter_x2 - inter_x1) * np.maximum(0, inter_y2 - inter_y1)
    bbox2_area = (x2_ - x1_) * (y2_ - y1_)
    return inter_area / bbox2_area

3.对deepsort输出结果修改

将追踪结果分类别的传入对应的追踪类

# 修改的核心代码,获取deepsort结果后,做后处理
deepsort_outputs = self.deepsort.update(bbox_xywh_d, confs_d, img, cls_d)   # x1,y1,x2,y2,label,track_ID,confs
            # print(f"bbox_xywh: {bbox_xywh}, confs: {confs}, cls: {cls}, outputs: {outputs}")

            if len(deepsort_outputs) > 0:
                split_index = split_indices_deepsort(deepsort_outputs, (0, 1, 2, 3, 4))    # 钱没放入检测

                res_det = [np.empty((0, 7)) for i in range(5)]

                for i in range(5):
                    if split_index[i] is not None:
                        res_det[i] = deepsort_outputs[split_index[i]]

                money_counter = self.money_counter_tracker.update(idx_frame, res_det[0], res_det[4])

                kx_all = np.concatenate([res_det[1], res_det[2]], axis=0)
                kx = self.kx_tracker.update(idx_frame, kx_all, res_det[4])

                deepsort_outputs = np.concatenate([money_counter, kx[0], res_det[4]], axis=0)

修改结果

人、验钞机和款箱分别追踪:

video_plot_2024-06-05-10-54

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值