YOLOv11-ultralytics-8.3.67部分代码阅读笔记-predict.py

predict.py

ultralytics\models\yolo\detect\predict.py

目录

predict.py

1.所需的库和模块

2.class DetectionPredictor(BasePredictor): 


1.所需的库和模块

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops

2.class DetectionPredictor(BasePredictor): 

# 这段代码定义了一个名为 DetectionPredictor 的类,它继承自 BasePredictor ,用于基于检测模型进行预测。
# 这行代码定义了一个名为 DetectionPredictor 的类,它继承自 BasePredictor 。 BasePredictor 是一个基础类,包含一些通用的预测功能,而 DetectionPredictor 则在此基础上扩展了针对目标检测模型的特定功能。
class DetectionPredictor(BasePredictor):
    # 扩展 BasePredictor 类的类,用于基于检测模型进行预测。
    """
    A class extending the BasePredictor class for prediction based on a detection model.

    Example:
        ```python
        from ultralytics.utils import ASSETS
        from ultralytics.models.yolo.detect import DetectionPredictor

        args = dict(model="yolo11n.pt", source=ASSETS)
        predictor = DetectionPredictor(overrides=args)
        predictor.predict_cli()
        ```
    """

    # 这段代码定义了 DetectionPredictor 类中的 postprocess 方法,用于对目标检测模型的预测结果进行后处理。
    # 定义了一个名为 postprocess 的方法,属于 DetectionPredictor 类。该方法接收以下参数 :
    # 1.preds :模型的原始预测结果,通常是包含边界框、置信度和类别信息的张量。
    # 2.img :经过预处理的图像张量,用于后续的边界框缩放等操作。
    # 3.orig_imgs :原始图像数据,用于最终结果的可视化或保存。
    # 4.**kwargs :可选的额外参数,用于支持灵活的扩展。
    def postprocess(self, preds, img, orig_imgs, **kwargs):
        # 后期处理预测并返回结果对象列表。
        """Post-processes predictions and returns a list of Results objects."""
        # 调用了 ops.non_max_suppression 函数,对预测结果 preds 进行非极大值抑制(Non-Maximum Suppression, NMS)。NMS 是目标检测中常用的步骤,用于去除冗余的边界框,保留最优的检测结果。
        preds = ops.non_max_suppression(
            # 模型的预测结果。
            preds,
            # 置信度阈值,低于此值的预测结果将被忽略。
            self.args.conf,
            # 交并比(IoU)阈值,用于判断边界框之间的重叠程度。
            self.args.iou,
            # 指定检测的类别,可以用于过滤特定类别的检测结果。
            self.args.classes,
            # 是否进行类别无关的 NMS,即不区分类别进行边界框筛选。
            self.args.agnostic_nms,
            # 最大检测数量,限制最终返回的边界框数量。
            max_det=self.args.max_det,
            # 类别总数,通过 len(self.model.names) 获取。
            nc=len(self.model.names),
            # 是否启用端到端模式,通过 getattr(self.model, "end2end", False) 动态获取。
            end2end=getattr(self.model, "end2end", False),
            # 是否处理旋转边界框,根据 self.args.task == "obb" 判断。
            rotated=self.args.task == "obb",
        )

        # 检查 orig_imgs 是否为一个列表。如果不是(例如是一个 torch.Tensor ),则调用 ops.convert_torch2numpy_batch 函数将其转换为 NumPy 数组的列表。这种转换是为了确保后续操作能够正确处理原始图像数据,因为某些操作可能需要基于 Python 列表的灵活性。
        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            # def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
            # -> 用于将批量的 PyTorch 张量转换为 NumPy 数组,并将其值缩放到 [0, 255] 范围内,同时确保数据类型为 uint8 。这通常用于将图像数据从 PyTorch 格式转换为 NumPy 格式,以便进行可视化或其他处理。
            # -> return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        # 调用了 self.construct_results 方法,将 经过 NMS 处理后的预测结果 preds 、 预处理后的图像 img 和 原始图像列表 orig_imgs 传递给该方法,并返回最终的结果对象。 **kwargs 用于传递额外的参数,保持方法的灵活性。
        return self.construct_results(preds, img, orig_imgs, **kwargs)
    # 这段代码定义了 DetectionPredictor 类中的 postprocess 方法,其主要功能是。非极大值抑制(NMS):通过 ops.non_max_suppression 对预测结果进行筛选,去除冗余的边界框,保留置信度高且不重叠的检测结果。数据格式转换:如果原始图像数据不是列表格式,则将其从 torch.Tensor 转换为 NumPy 数组的列表,以便后续操作。结果构建:将处理后的预测结果传递给 construct_results 方法,生成最终的结果对象。整体来看, postprocess 方法是目标检测流程中的关键步骤,它确保了预测结果的准确性和可用性,为后续的可视化、保存或进一步处理提供了基础。

    # 这段代码定义了 DetectionPredictor 类中的 construct_results 方法,其主要功能是将预测结果封装为结果对象的列表。
    # 定义了一个名为 construct_results 的方法,属于 DetectionPredictor 类。它接收以下参数 :
    # 1.preds :经过后处理(如非极大值抑制)的预测结果,通常是一个包含多个预测边界框的列表。
    # 2.img :预处理后的图像张量,用于后续操作。
    # 3.orig_imgs :原始图像的列表,用于生成最终结果对象。
    def construct_results(self, preds, img, orig_imgs):
        # 根据预测构建结果对象列表。
        """
        Constructs a list of result objects from the predictions.

        Args:
            preds (List[torch.Tensor]): List of predicted bounding boxes and scores.
            img (torch.Tensor): The image after preprocessing.
            orig_imgs (List[np.ndarray]): List of original images before preprocessing.

        Returns:
            (list): List of result objects containing the original images, image paths, class names, and bounding boxes.
        """
        # 使用 Python 的列表推导式,逐个处理 预测结果 、 原始图像 和 图像路径 ,生成 结果对象的列表 。
        # zip(preds, orig_imgs, self.batch[0]) : zip 函数将 preds (预测结果)、 orig_imgs (原始图像列表)和 self.batch[0] (图像路径列表)打包成一个迭代器。每次迭代返回一个元组,包含一个预测结果、一个原始图像和一个图像路径。
        # self.construct_result(pred, img, orig_img, img_path) :
        # 对于每个元组中的数据,调用 self.construct_result 方法,将单个预测结果 pred  、预处理后的图像 img 、原始图像 orig_img 和图像路径 img_path 传递给该方法。
        # self.construct_result 方法会根据这些输入生成一个结果对象,通常包含检测到的目标信息(如边界框、类别名称等)。
        # 列表推导式 :列表推导式会将 self.construct_result 的返回值收集到一个列表中,最终返回一个 包含所有结果对象的列表 。
        return [
            self.construct_result(pred, img, orig_img, img_path)
            for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
        ]
    # 这段代码定义了 construct_results 方法,其主要功能是将预测结果封装为结果对象的列表。具体步骤包括。数据打包:通过 zip 函数将预测结果、原始图像和图像路径打包成一个迭代器。结果对象生成:对每个打包后的数据调用 self.construct_result 方法,生成单个结果对象。结果列表构建:使用列表推导式将所有结果对象收集到一个列表中并返回。 construct_results 方法的作用是将批量的预测结果和原始图像数据转换为易于使用的格式,为后续的可视化、保存或进一步处理提供方便。

    # 这段代码定义了 DetectionPredictor 类中的 construct_result 方法,其主要功能是将单个预测结果封装为一个结果对象。
    # 定义了一个名为 construct_result 的方法,属于 DetectionPredictor 类。它接收以下参数 :
    # 1.pred :单个预测结果,通常是一个二维张量,包含边界框坐标、置信度和类别索引。
    # 2.img :预处理后的图像张量,用于调整边界框的尺寸。
    # 3.orig_img :原始图像的 NumPy 数组,用于生成最终结果对象。
    # 4.img_path :原始图像的路径,用于记录结果来源。
    def construct_result(self, pred, img, orig_img, img_path):
        # 根据预测构建结果对象。
        """
        Constructs the result object from the prediction.

        Args:
            pred (torch.Tensor): The predicted bounding boxes and scores.
            img (torch.Tensor): The image after preprocessing.
            orig_img (np.ndarray): The original image before preprocessing.
            img_path (str): The path to the original image.

        Returns:
            (Results): The result object containing the original image, image path, class names, and bounding boxes.
        """
        # 调用 ops.scale_boxes 函数,将预测的边界框坐标从预处理后的图像尺寸缩放到原始图像尺寸。
        # pred[:, :4] :表示 预测结果中边界框的坐标部分 (通常是 [x1, y1, x2, y2] 或 [cx, cy, w, h] )。
        # img.shape[2:] :获取预处理后图像的 高和宽 (假设 img 是一个四维张量,形状为 [batch_size, channels, height, width] )。
        # orig_img.shape :获取原始图像的高和宽。
        # ops.scale_boxes 一个工具函数,用于将边界框从一个尺寸缩放到另一个尺寸。缩放后的边界框坐标会直接更新到 pred[:, :4] 中。
        # def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
        # -> 用于将边界框从一个图像尺寸( img1_shape )缩放到另一个图像尺寸( img0_shape )。它考虑了图像缩放比例( gain )和填充( padding ),并最终将边界框裁剪到目标图像的范围内。调用 clip_boxes 函数,将缩放后的边界框裁剪到原始图像的范围内,确保边界框的坐标不会超出原始图像的边界。
        # -> return clip_boxes(boxes, img0_shape)
        pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        # 创建并返回一个 Results 对象,封装了 检测结果的所有相关信息 。
        # orig_img :原始图像的 NumPy 数组。
        # path=img_path :原始图像的路径。
        # names=self.model.names :模型的类别名称列表,用于将类别索引映射为类别名称。
        # boxes=pred[:, :6] :预测的边界框和相关信息,通常包括 :前 4 列 :边界框坐标( [x1, y1, x2, y2] 或 [cx, cy, w, h] )。第 5 列 :置信度(confidence score)。第 6 列 :类别索引(class index)。
        # class Results(SimpleClass):
        # -> 用于存储和操作从 YOLO 模型中得到的推理结果。该类封装了检测、分割、姿态估计和分类任务的结果处理功能。
        # -> def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None, speed=None) -> None:
        return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
    # 这段代码定义了 construct_result 方法,其主要功能是将单个预测结果封装为一个结果对象。具体步骤包括。边界框缩放:通过 ops.scale_boxes 函数,将预测的边界框从预处理后的图像尺寸缩放到原始图像尺寸。结果对象构建:创建一个 Results 对象,包含以下信息:原始图像。图像路径。类别名称列表。缩放后的边界框和相关信息(坐标、置信度、类别索引)。 construct_result 方法的作用是将单个预测结果转换为一个结构化的对象,便于后续的可视化、分析或保存。
# 这段代码实现了一个目标检测预测器 DetectionPredictor ,它继承自 BasePredictor ,并扩展了目标检测的后处理和结果构建功能。主要功能包括。非极大值抑制(NMS):通过 ops.non_max_suppression 对预测结果进行筛选,去除冗余的边界框。边界框缩放:将预测的边界框从预处理后的图像尺寸缩放到原始图像尺寸。结果对象构建:将预测结果封装为 Results 对象,包含原始图像、图像路径、类别名称和边界框信息。整体来看, DetectionPredictor 是一个用于目标检测任务的预测器,能够处理预测结果并生成易于使用的输出格式。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

红色的山茶花

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值