Deepsort目标跟踪代码

import argparse
import os
import platform
import shutil
import time
from pathlib import Path
import cv2
import numpy as np

import torch
import torch.backends.cudnn as cudnn

from ultralytics.utils.downloads import attempt_download_asset
from ultralytics.utils.checks import check_imgsz, check_imshow
from ultralytics.utils.torch_utils import select_device, time_sync
from ultralytics.data.loaders import LoadStreams, LoadImagesAndVideos
from ultralytics.data.augment import LetterBox
from ultralytics.utils.ops import non_max_suppression
from ultralytics.nn.tasks import attempt_load_weights
from deep_sort_pytorch.utils.parser import get_config
from deep_sort_pytorch.deep_sort import DeepSort
from utils import scale_coords

palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'


def xyxy_to_xywh(*xyxy):
    """" Calculates the relative bounding box from absolute pixel values. """
    bbox_left = min([xyxy[0].item(), xyxy[2].item()])
    bbox_top = min([xyxy[1].item(), xyxy[3].item()])
    bbox_w = abs(xyxy[0].item() - xyxy[2].item())
    bbox_h = abs(xyxy[1].item() - xyxy[3].item())
    x_c = (bbox_left + bbox_w / 2)
    y_c = (bbox_top + bbox_h / 2)
    w = bbox_w
    h = bbox_h
    return x_c, y_c, w, h


def xyxy_to_tlwh(bbox_xyxy):
    tlwh_bboxs = []
    for i, box in enumerate(bbox_xyxy):
        x1, y1, x2, y2 = [int(i) for i in box]
        top = x1
        left = y1
        w = int(x2 - x1)
        h = int(y2 - y1)
        tlwh_obj = [top, left, w, h]
        tlwh_bboxs.append(tlwh_obj)
    return tlwh_bboxs


def compute_color_for_labels(label):
    """
    Simple function that adds fixed color depending on the class
    """
    color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
    return tuple(color)


def draw_boxes(img, bbox, identities=None, offset=(0, 0)):
    for i, box in enumerate(bbox):
        x1, y1, x2, y2 = [int(i) for i in box]
        x1 += offset[0]
        x2 += offset[0]
        y1 += offset[1]
        y2 += offset[1]
        # box text and bar
        id = int(identities[i]) if identities is not None else 0
        color = compute_color_for_labels(id)
        label = '{}{:d}'.format("", id)
        t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
        cv2.rectangle(
            img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
        cv2.putText(img, label, (x1, y1 +
                                 t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
    return img


def detect(opt):
    out, source, yolo_weights, deep_sort_weights, show_vid, save_vid, save_txt, imgsz, evaluate = \
        opt.output, opt.source, opt.yolo_weights, opt.deep_sort_weights, opt.show_vid, opt.save_vid, \
            opt.save_txt, opt.img_size, opt.evaluate
    webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')

    # initialize deepsort
    cfg = get_config()
    cfg.merge_from_file(opt.config_deepsort)
    attempt_download_asset(deep_sort_weights, repo='mikel-brostrom/Yolov5_DeepSort_Pytorch')
    deepsort = DeepSort(cfg.DEEPSORT.REID_CKPT,
                        max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
                        nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
                        max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
                        use_cuda=True)

    # Initialize
    device = select_device(opt.device)

    # The MOT16 evaluation runs multiple inference streams in parallel, each one writing to
    # its own .txt file. Hence, in that case, the output folder is not restored
    if not evaluate:
        if os.path.exists(out):
            pass
            shutil.rmtree(out)  # delete output folder
        os.makedirs(out)  # make new output folder
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    model = attempt_load_weights(yolo_weights, device=device)  # load FP32 model
    stride = int(model.stride.max())  # model stride
    imgsz = check_imgsz(imgsz, stride=stride)  # check img_size
    if half:
        model.half()  # to FP16

    # Set Dataloader
    vid_path, vid_writer = None, None
    # Check if environment supports image displays
    if show_vid:
        show_vid = check_imshow()

    if webcam:
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source)
    else:
        dataset = LoadImagesAndVideos(source)

    # Get names and colors
    names = model.module.names if hasattr(model, 'module') else model.names

    # Run inference
    if device.type != 'cpu':
        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once
    t0 = time.time()

    save_path = str(Path(out))
    # extract what is in between the last '/' and last '.'
    txt_file_name = source.split('/')[-1].split('.')[0]
    txt_path = str(Path(out)) + '/' + txt_file_name + '.txt'
    letterbox = LetterBox(imgsz, auto=True, stride=stride)

    for frame_idx, (path, im0s, _) in enumerate(dataset):
        vid_cap = dataset.cap
        path = str(path[0])
        im0s = np.array(im0s).squeeze()
        img = letterbox(image=im0s)
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = time_sync()
        pred = model(img, augment=opt.augment)[0]

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t2 = time_sync()

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
            else:
                p, s, im0 = path, '', im0s

            s += '%gx%g ' % img.shape[2:]  # print string
            save_path = str(Path(out) / Path(p).name)

            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string

                xywh_bboxs = []
                confs = []

                # Adapt detections to deep sort input format
                for *xyxy, conf, cls in det:
                    # to deep sort format
                    x_c, y_c, bbox_w, bbox_h = xyxy_to_xywh(*xyxy)
                    xywh_obj = [x_c, y_c, bbox_w, bbox_h]
                    xywh_bboxs.append(xywh_obj)
                    confs.append([conf.item()])

                xywhs = torch.Tensor(xywh_bboxs)
                confss = torch.Tensor(confs)

                # pass detections to deepsort
                outputs = deepsort.update(xywhs, confss, im0)

                # draw boxes for visualization
                if len(outputs) > 0:
                    bbox_xyxy = outputs[:, :4]
                    identities = outputs[:, -1]
                    draw_boxes(im0, bbox_xyxy, identities)
                    # to MOT format
                    tlwh_bboxs = xyxy_to_tlwh(bbox_xyxy)

                    # Write MOT compliant results to file
                    if save_txt:
                        for j, (tlwh_bbox, output) in enumerate(zip(tlwh_bboxs, outputs)):
                            bbox_top = tlwh_bbox[0]
                            bbox_left = tlwh_bbox[1]
                            bbox_w = tlwh_bbox[2]
                            bbox_h = tlwh_bbox[3]
                            identity = output[-1]
                            with open(txt_path, 'a') as f:
                                f.write(('%g ' * 10 + '\n') % (frame_idx, identity, bbox_top,
                                                               bbox_left, bbox_w, bbox_h, -1, -1, -1,
                                                               -1))  # label format

            else:
                deepsort.increment_ages()

            # Print time (inference + NMS)
            # print('%sDone. (%.3fs)' % (s, t2 - t1))

            # Stream results
            if show_vid:
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # q to quit
                    raise StopIteration

            # Save results (image with detections)
            if save_vid:
                if vid_path != save_path:  # new video
                    vid_path = save_path
                    if isinstance(vid_writer, cv2.VideoWriter):
                        vid_writer.release()  # release previous video writer
                    if vid_cap:  # video
                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    else:  # stream
                        fps, w, h = 30, im0.shape[1], im0.shape[0]
                        save_path += '.mp4'

                    vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                vid_writer.write(im0)

    if save_txt or save_vid:
        print('Results saved to %s' % os.getcwd() + os.sep + out)
        if platform == 'darwin':  # MacOS
            os.system('open ' + save_path)

    print('Done. (%.3fs)' % (time.time() - t0))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--yolo_weights', type=str, default='weights/yolo11s.pt', help='model.pt path')
    parser.add_argument('--deep_sort_weights', type=str, default='weights/ckpt.t7', help='ckpt.t7 path')
    # file/folder, 0 for webcam
    parser.add_argument('--source', type=str, default='test.mp4', help='source')
    parser.add_argument('--output', type=str, default='inference/output', help='output folder')  # output folder
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
    parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec (verify ffmpeg support)')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--show-vid', action='store_true', help='display tracking video results')
    parser.add_argument('--save-vid', default=True, action='store_true', help='save video tracking results')
    parser.add_argument('--save-txt', action='store_true', help='save MOT compliant results to *.txt')
    # class 0 is person, 1 is bycicle, 2 is car... 79 is oven
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 16 17')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--evaluate', action='store_true', help='augmented inference')
    parser.add_argument("--config_deepsort", type=str, default="deep_sort_pytorch/configs/deep_sort.yaml")
    args = parser.parse_args()
    args.img_size = check_imgsz(args.img_size)

    with torch.no_grad():
        detect(args)
import argparse  # 用于解析命令行参数
import os  # 操作系统接口
import platform  # 获取底层平台信息
import shutil  # 高级文件操作
import time  # 时间相关函数
from pathlib import Path  # 面向对象的文件系统路径
import cv2  # OpenCV,用于图像和视频处理
import numpy as np  # 数组和数值计算

import torch  # PyTorch 深度学习框架
import torch.backends.cudnn as cudnn  # cuDNN 后端控制

# 从 ultralytics YOLO 实现中导入工具
from ultralytics.utils.downloads import attempt_download_asset  # 如果缺少资源则下载
from ultralytics.utils.checks import check_imgsz, check_imshow  # 检查图像尺寸和显示支持
from ultralytics.utils.torch_utils import select_device, time_sync  # 设备选择和计时
from ultralytics.data.loaders import LoadStreams, LoadImagesAndVideos  # 数据加载类
from ultralytics.data.augment import LetterBox  # 保持宽高比的缩放
from ultralytics.utils.ops import non_max_suppression  # 非极大值抑制
from ultralytics.nn.tasks import attempt_load_weights  # 加载 YOLO 权重

# 从 DeepSORT 实现中导入工具
from deep_sort_pytorch.utils.parser import get_config  # 解析 DeepSORT 配置
from deep_sort_pytorch.deep_sort import DeepSort  # DeepSORT 跟踪器类
from utils import scale_coords  # 将坐标缩放回原图尺寸

# 颜色调色板,用于给不同跟踪 ID 分配不同颜色
palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
# 允许重复加载 OpenMP 库,避免 KMP 冲突错误
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'


def xyxy_to_xywh(*xyxy):
    """ 将 [x1, y1, x2, y2] 绝对坐标转换为 [中心x, 中心y, 宽, 高] """
    bbox_left = min([xyxy[0].item(), xyxy[2].item()])  # 左边界
    bbox_top = min([xyxy[1].item(), xyxy[3].item()])  # 上边界
    bbox_w = abs(xyxy[0].item() - xyxy[2].item())  # 宽度
    bbox_h = abs(xyxy[1].item() - xyxy[3].item())  # 高度
    x_c = (bbox_left + bbox_w / 2)  # 中心 x
    y_c = (bbox_top + bbox_h / 2)  # 中心 y
    return x_c, y_c, bbox_w, bbox_h  # 返回 (中心x, 中心y, 宽, 高)


def xyxy_to_tlwh(bbox_xyxy):
    """ 将 [x1, y1, x2, y2] 转换为 [top, left, width, height] 格式,用于 MOT 格式 """
    tlwh_bboxs = []
    for box in bbox_xyxy:
        x1, y1, x2, y2 = [int(i) for i in box]  # 强制转换为整数
        top = x1  # top 坐标
        left = y1  # left 坐标
        w = int(x2 - x1)  # 宽度
        h = int(y2 - y1)  # 高度
        tlwh_bboxs.append([top, left, w, h])  # 添加到列表
    return tlwh_bboxs


def compute_color_for_labels(label):
    """ 根据跟踪 ID 生成唯一颜色 """
    color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
    return tuple(color)


def draw_boxes(img, bbox, identities=None, offset=(0, 0)):
    """ 在图像上绘制检测框和 ID 标签 """
    for i, box in enumerate(bbox):
        x1, y1, x2, y2 = [int(i) for i in box]  # 坐标转换为整数
        x1 += offset[0]; x2 += offset[0]  # 应用偏移
        y1 += offset[1]; y2 += offset[1]
        track_id = int(identities[i]) if identities is not None else 0  # 获取跟踪 ID
        color = compute_color_for_labels(track_id)  # 计算颜色
        label = f"{track_id}"  # 文本标签
        # 获取文本尺寸
        t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
        # 绘制边框
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
        # 绘制标签背景
        cv2.rectangle(img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
        # 绘制标签文字
        cv2.putText(img, label, (x1, y1 + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, (255, 255, 255), 2)
    return img


def detect(opt):
    """ 主检测和跟踪循环 """
    # 解包参数
    out, source, yolo_weights, deep_sort_weights, show_vid, save_vid, save_txt, imgsz, evaluate = (
        opt.output, opt.source, opt.yolo_weights, opt.deep_sort_weights,
        opt.show_vid, opt.save_vid, opt.save_txt, opt.img_size, opt.evaluate
    )
    # 判断输入源是否为摄像头或流
    webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')

    # 初始化 DeepSORT
    cfg = get_config()  # 加载默认配置
    cfg.merge_from_file(opt.config_deepsort)  # 合并用户配置
    attempt_download_asset(deep_sort_weights, repo='mikel-brostrom/Yolov5_DeepSort_Pytorch')  # 下载权重
    deepsort = DeepSort(
        cfg.DEEPSORT.REID_CKPT,  # reid 特征提取模型
        max_dist=cfg.DEEPSORT.MAX_DIST,  # 最大余弦距离
        min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,  # 最小置信度
        nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP,  # NMS 重叠阈值
        max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,  # 最大 IOU 距离
        max_age=cfg.DEEPSORT.MAX_AGE,  # 最大存活帧数
        n_init=cfg.DEEPSORT.N_INIT,  # 确认跟踪所需最小命中数
        nn_budget=cfg.DEEPSORT.NN_BUDGET,  # 特征存储预算
        use_cuda=True  # 使用 GPU
    )

    # 选择运行设备
    device = select_device(opt.device)  # CUDA 或 CPU

    # 准备输出文件夹
    if not evaluate:
        if os.path.exists(out):
            shutil.rmtree(out)  # 删除已有输出
        os.makedirs(out)  # 创建新输出文件夹
    half = device.type != 'cpu'  # GPU 上使用半精度

    # 加载 YOLO 模型
    model = attempt_load_weights(yolo_weights, device=device)  # 加载模型权重
    stride = int(model.stride.max())  # 获取最大步幅
    imgsz = check_imgsz(imgsz, stride=stride)  # 检查图像尺寸
    if half:
        model.half()  # 转为 FP16

    # 初始化数据加载器
    vid_path, vid_writer = None, None  # 视频保存路径和写入器
    if show_vid:
        show_vid = check_imshow()  # 检查显示支持
    if webcam:
        cudnn.benchmark = True  # 优化固定尺寸推理
        dataset = LoadStreams(source)  # 流加载
    else:
        dataset = LoadImagesAndVideos(source)  # 文件加载

    # 获取类别名称
    names = model.module.names if hasattr(model, 'module') else model.names

    # 预热模型(仅 GPU)
    if device.type != 'cpu':
        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))
    t0 = time.time()  # 记录开始时间

    # 准备保存路径和文本文件
    save_path = str(Path(out))
    txt_file_name = source.split('/')[-1].split('.')[0]  # 文本文件名
    txt_path = f"{save_path}/{txt_file_name}.txt"  # 文本路径
    letterbox = LetterBox(imgsz, auto=True, stride=stride)  # letterbox 缩放

    # 主循环:逐帧处理
    for frame_idx, (path, im0s, _) in enumerate(dataset):
        vid_cap = dataset.cap  # 视频捕获对象
        path = str(path[0])  # 当前帧路径
        im0s = np.array(im0s).squeeze()  # 原始图像
        img = letterbox(image=im0s)  # letterbox 缩放
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR->RGB, HWC->CHW
        img = np.ascontiguousarray(img)  # 保持连续内存
        img = torch.from_numpy(img).to(device)  # 转为张量并移动到设备
        img = img.half() if half else img.float()  # 转为 FP16/FP32
        img /= 255.0  # 归一化到 [0,1]
        if img.ndimension() == 3:
            img = img.unsqueeze(0)  # 添加批次维度

        # 推理
        t1 = time_sync()  # 推理前计时
        pred = model(img, augment=opt.augment)[0]  # 前向推理

        # 非极大抑制
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
                                   classes=opt.classes, agnostic=opt.agnostic_nms)
        t2 = time_sync()  # 推理后计时

        # 处理每张图像的检测结果
        for i, det in enumerate(pred):
            if webcam:
                p, s, im0 = path[i], f"{i}: ", im0s[i].copy()  # 多流处理
            else:
                p, s, im0 = path, '', im0s  # 单流处理

            s += f"{img.shape[2]}x{img.shape[3]} "  # 添加尺寸信息
            save_path = str(Path(out) / Path(p).name)  # 更新保存路径

            if det is not None and len(det):  # 若有检测到目标
                # 将坐标从缩放图像映射回原图
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # 统计各类别检测数量
                for c in det[:, -1].unique():
                    n = int((det[:, -1] == c).sum())  # 计数
                    s += f"{n} {names[int(c)]}s, "

                xywh_bboxs, confs = [], []  # 准备 DeepSORT 输入
                for *xyxy, conf, cls in det:
                    x_c, y_c, w, h = xyxy_to_xywh(*xyxy)  # 转为中心格式
                    xywh_bboxs.append([x_c, y_c, w, h])
                    confs.append([conf.item()])

                xywhs = torch.Tensor(xywh_bboxs)  # 转张量
                confss = torch.Tensor(confs)

                # 更新跟踪器
                outputs = deepsort.update(xywhs, confss, im0)

                if len(outputs) > 0:  # 若有跟踪结果
                    bbox_xyxy = outputs[:, :4]  # 获取边框
                    identities = outputs[:, -1]  # 获取 ID
                    draw_boxes(im0, bbox_xyxy, identities)  # 绘制

                    # 转为 MOT txt 格式并保存
                    tlwhs = xyxy_to_tlwh(bbox_xyxy)
                    if save_txt:
                        with open(txt_path, 'a') as f:
                            for idx, tlwh in enumerate(tlwhs):
                                frame_id = frame_idx
                                track_id = int(outputs[idx, -1])
                                top, left, w, h = tlwh
                                f.write(f"{frame_id} {track_id} {top} {left} {w} {h} -1 -1 -1 -1\n")
            else:
                deepsort.increment_ages()  # 若无检测,则更新 track 年龄

            # 显示结果
            if show_vid:
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # 按 'q' 键退出
                    raise StopIteration

            # 保存视频
            if save_vid:
                if vid_path != save_path:  # 若切换到新文件
                    vid_path = save_path
                    if isinstance(vid_writer, cv2.VideoWriter):
                        vid_writer.release()  # 释放旧写入器
                    if vid_cap:
                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    else:
                        fps, w, h = 30, im0.shape[1], im0.shape[0]
                        save_path += '.mp4'
                    vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*opt.fourcc), fps, (w, h))
                vid_writer.write(im0)

    # 处理完成后的提示
    if save_txt or save_vid:
        print(f"结果已保存至 {os.getcwd()}/{out}")
        if platform.system() == 'Darwin':
            os.system(f"open {save_path}")  # MacOS 自动打开

    print(f"完成,总耗时 {(time.time() - t0):.3f} 秒")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()  # 创建命令行参数解析器
    parser.add_argument('--yolo_weights', type=str, default='weights/yolo11s.pt', help='YOLO 模型权重路径')
    parser.add_argument('--deep_sort_weights', type=str, default='weights/ckpt.t7', help='DeepSORT 权重路径')
    parser.add_argument('--source', type=str, default='test.mp4', help='输入源(文件/摄像头/流)')
    parser.add_argument('--output', type=str, default='inference/output', help='输出文件夹')
    parser.add_argument('--img-size', type=int, default=640, help='推理图像尺寸(像素)')
    parser.add_argument('--conf-thres', type=float, default=0.4, help='目标置信度阈值')
    parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IOU 阈值')
    parser.add_argument('--fourcc', type=str, default='mp4v', help='输出视频编码格式')
    parser.add_argument('--device', default='', help='设备 (CUDA 设备号 或 cpu)')
    parser.add_argument('--show-vid', action='store_true', help='显示推理结果窗口')
    parser.add_argument('--save-vid', action='store_true', default=True, help='保存输出视频')
    parser.add_argument('--save-txt', action='store_true', help='保存 MOT txt 文件')
    parser.add_argument('--classes', nargs='+', type=int, help='仅检测指定类别,例如 --classes 0 2')
    parser.add_argument('--agnostic-nms', action='store_true', help='类别无关的 NMS')
    parser.add_argument('--augment', action='store_true', help='使用增强推理')
    parser.add_argument('--evaluate', action='store_true', help='评估模式,不清空输出')
    parser.add_argument('--config_deepsort', type=str, default='deep_sort_pytorch/configs/deep_sort.yaml', help='DeepSORT 配置文件路径')
    args = parser.parse_args()  # 解析参数
    args.img_size = check_imgsz(args.img_size)  # 检查图像尺寸合法性

    with torch.no_grad():  # 禁用梯度计算
        detect(args)  # 运行检测
import numpy as np  # 导入 NumPy,用于数值计算和数组操作
import torch  # 导入 PyTorch,用于深度学习张量操作

# 从当前包中导入特征提取器 Extractor
from .deep.feature_extractor import Extractor  # 提取目标图像特征
# 导入最近邻距离度量,用于度量特征相似度
from .sort.nn_matching import NearestNeighborDistanceMetric  # 最近邻距离度量
# 导入 Detection 类,用于封装检测框、置信度和特征
from .sort.detection import Detection  # 检测结果封装
# 导入 Tracker 类,实现基于卡尔曼滤波和匈牙利匹配的跟踪器
from .sort.tracker import Tracker  # 跟踪器

__all__ = ['DeepSort']  # 定义模块对外导出 DeepSort 类


class DeepSort(object):  # 定义 DeepSort 跟踪类
    def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7,
                 max_age=70, n_init=3, nn_budget=100, use_cuda=True):
        # 初始化最小置信度阈值
        self.min_confidence = min_confidence  # 小于该置信度的检测将被过滤
        # 初始化 NMS 最大重叠阈值
        self.nms_max_overlap = nms_max_overlap  # NMS 时允许的最大重叠

        # 创建特征提取器,加载预训练模型
        self.extractor = Extractor(model_path, use_cuda=use_cuda)  # 用于提取检测框对应的特征向量

        # 设置最大余弦距离,用于度量特征匹配相似度
        max_cosine_distance = max_dist  # 最大余弦距离阈值
        # 初始化度量器,使用余弦距离并设置预算
        metric = NearestNeighborDistanceMetric(
            "cosine", max_cosine_distance, nn_budget)  # 特征匹配度量器
        # 创建跟踪器,传入度量器和各项超参数
        self.tracker = Tracker(
            metric,
            max_iou_distance=max_iou_distance,  # IOU 匹配阈值
            max_age=max_age,  # 最大存活帧数
            n_init=n_init  # 确认跟踪所需最小命中数
        )

    def update(self, bbox_xywh, confidences, ori_img):  # 更新跟踪器并返回跟踪结果
        # 获取原始图像的高和宽
        self.height, self.width = ori_img.shape[:2]  # 图像尺寸
        # 提取当前所有检测框的特征
        features = self._get_features(bbox_xywh, ori_img)  # 调用特征提取器
        # 将 [x_center, y_center, w, h] 转为 [top, left, w, h]
        bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)  # 坐标转换
        # 构造 Detection 对象列表,只保留置信度大于阈值的检测
        detections = [
            Detection(bbox_tlwh[i], conf, features[i])
            for i, conf in enumerate(confidences)
            if conf > self.min_confidence
        ]  # 封装检测框、置信度、特征

        # 运行非极大值抑制(可选,此处假设已在外部完成)
        boxes = np.array([d.tlwh for d in detections])  # 提取 tlwh 坐标列表
        scores = np.array([d.confidence for d in detections])  # 提取置信度列表

        # 跟踪器预测下一帧位置
        self.tracker.predict()  # 卡尔曼滤波预测
        # 用当前检测结果更新跟踪器
        self.tracker.update(detections)  # 关联并更新轨迹

        # 输出跟踪结果:列表形式 [x1, y1, x2, y2, track_id]
        outputs = []  # 存储最终跟踪结果
        for track in self.tracker.tracks:  # 遍历所有轨迹
            # 仅保留已确认且刚更新过的轨迹
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            # 获取 tlwh 格式边框
            box = track.to_tlwh()  # 转换为 [top, left, w, h]
            # 转为 [x1, y1, x2, y2]
            x1, y1, x2, y2 = self._tlwh_to_xyxy(box)  # 坐标转换
            track_id = track.track_id  # 获取轨迹 ID
            outputs.append(
                np.array([x1, y1, x2, y2, track_id], dtype=int)
            )  # 添加到结果列表
        # 如果有结果,则堆叠为数组返回
        if len(outputs) > 0:
            outputs = np.stack(outputs, axis=0)  # 转为 NumPy 数组
        return outputs  # 返回跟踪结果

    @staticmethod
    def _xywh_to_tlwh(bbox_xywh):  # 静态方法:中心坐标转左上宽高
        if isinstance(bbox_xywh, np.ndarray):
            bbox_tlwh = bbox_xywh.copy()  # NumPy 数组复制
        elif isinstance(bbox_xywh, torch.Tensor):
            bbox_tlwh = bbox_xywh.clone()  # PyTorch 张量克隆
        # x_center - w/2 => left,y_center - h/2 => top
        bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2.
        bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2.
        return bbox_tlwh  # 返回 tlwh 格式

    def _xywh_to_xyxy(self, bbox_xywh):  # 将中心格式转为角点格式
        x, y, w, h = bbox_xywh  # 解包
        x1 = max(int(x - w / 2), 0)  # 计算左上角 x,保证 >= 0
        x2 = min(int(x + w / 2), self.width - 1)  # 计算右下角 x,保证 <= 图像宽-1
        y1 = max(int(y - h / 2), 0)  # 计算左上角 y,保证 >= 0
        y2 = min(int(y + h / 2), self.height - 1)  # 计算右下角 y,保证 <= 图像高-1
        return x1, y1, x2, y2  # 返回角点坐标

    def _tlwh_to_xyxy(self, bbox_tlwh):  # 将 tlwh 转为角点格式
        x, y, w, h = bbox_tlwh  # 解包
        x1 = max(int(x), 0)  # 左上角 x
        x2 = min(int(x + w), self.width - 1)  # 右下角 x
        y1 = max(int(y), 0)  # 左上角 y
        y2 = min(int(y + h), self.height - 1)  # 右下角 y
        return x1, y1, x2, y2  # 返回角点坐标

    def increment_ages(self):  # 当一帧没有检测到目标时,更新轨迹年龄
        self.tracker.increment_ages()  # 调用跟踪器方法

    def _xyxy_to_tlwh(self, bbox_xyxy):  # 将角点格式转为 tlwh
        x1, y1, x2, y2 = bbox_xyxy  # 解包
        t = x1  # top
        l = y1  # left
        w = int(x2 - x1)  # 宽度
        h = int(y2 - y1)  # 高度
        return t, l, w, h  # 返回 tlwh

    def _get_features(self, bbox_xywh, ori_img):  # 提取检测框特征
        im_crops = []  # 存储裁剪图像块
        for box in bbox_xywh:  # 遍历所有框
            x1, y1, x2, y2 = self._xywh_to_xyxy(box)  # 转为角点坐标
            im = ori_img[y1:y2, x1:x2]  # 从原图裁剪
            im_crops.append(im)  # 添加到列表
        if im_crops:  # 如果有裁剪图像
            features = self.extractor(im_crops)  # 提取特征
        else:
            features = np.array([])  # 无检测时返回空数组
        return features  # 返回特征数组

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值