Bounding boxes augmentation for object detection

本文介绍了PascalVOC、Albumentations、COCO和YOLO等不同标注格式的boundingboxes,展示了如何在图像处理中使用这些格式,并详细讲解了boundingbox的坐标表示和augmentation过程中对大小、可见度和类标签的处理。
部署运行你感兴趣的模型镜像

Different annotations formats

Bounding boxes are rectangles that mark objects on an image. There are multiple formats of bounding boxes annotations. Each format uses its specific representation of bouning boxes coordinates 每种格式都使用其特定的边界框坐标表示。. Albumentations supports four formats: pascal_vocalbumentationscoco, and yolo .

Let's take a look at each of those formats and how they represent coordinates 坐标 of bounding boxes.

As an example, we will use an image from the dataset named Common Objects in Context. It contains one bounding box that marks a cat. The image width is 640 pixels, and its height is 480 pixels. The width of the bounding box is 322 pixels, and its height is 117 pixels.

 An example image with a bounding box from the COCO dataset

pascal_voc

pascal_voc is a format used by the Pascal VOC dataset. Coordinates of a bounding box are encoded with four values in pixels: [x_min, y_min, x_max, y_max]. x_min and y_min are coordinates of the top-left corner of the bounding box. x_max and y_max are coordinates of bottom-right corner of the bounding box.

Coordinates of the example bounding box in this format are [98, 345, 420, 462].

albumentations

albumentations is similar to pascal_voc, because it also uses four values [x_min, y_min, x_max, y_max] to represent a bounding box. But unlike pascal_vocalbumentations uses normalized values. To normalize values, we divide coordinates in pixels for the x- and y-axis by the width and the height of the image.

Coordinates of the example bounding box in this format are [98 / 640, 345 / 480, 420 / 640, 462 / 480] which are [0.153125, 0.71875, 0.65625, 0.9625].

Albumentations uses this format internally 内部 to work with bounding boxes and augment them.

coco

coco is a format used by the Common Objects in Context COCOCOCO dataset.

In coco, a bounding box is defined by four values in pixels [x_min, y_min, width, height]. They are coordinates of the top-left corner along with the width and height of the bounding box.

Coordinates of the example bounding box in this format are [98, 345, 322, 117].

yolo

In yolo, a bounding box is represented by four values [x_center, y_center, width, height]x_center and y_center are the normalized coordinates of the center of the bounding box. To make coordinates normalized, we take pixel values of x and y, which marks the center of the bounding box on the x- and y-axis. Then we divide the value of x by the width of the image and value of y by the height of the image. width and height represent the width and the height of the bounding box. They are normalized as well.

Coordinates of the example bounding box in this format are [((420 + 98) / 2) / 640, ((462 + 345) / 2) / 480, 322 / 640, 117 / 480] which are [0.4046875, 0.840625, 0.503125, 0.24375].

Bounding boxes augmentation

Just like with images and masks augmentation, the process of augmenting bounding boxes consists of 4 steps.

  1. You import the required libraries.
  2. You define an augmentation pipeline.
  3. You read images and bounding boxes from the disk.
  4. You pass an image and bounding boxes to the augmentation pipeline and receive augmented images and boxes.

Note

Some transforms in Albumentation don't support bounding boxes. If you try to use them you will get an exception. Please refer to this article to check whether a transform can augment bounding boxes.

Step 1. Import the required libraries.

import albumentations as A
import cv2

Step 2. Define an augmentation pipeline.

Here an example of a minimal declaration of an augmentation pipeline that works with bounding boxes.

transform = A.Compose([
    A.RandomCrop(width=450, height=450),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
], bbox_params=A.BboxParams(format='coco'))

Note that unlike image and masks augmentation, Compose now has an additional parameter bbox_params. You need to pass an instance of A.BboxParams to that argument. A.BboxParams specifies settings for working with bounding boxes. format sets the format for bounding boxes coordinates.

It can either be pascal_vocalbumentationscoco or yolo. This value is required because Albumentation needs to know the coordinates' source format for bounding boxes to apply augmentations correctly.

Besides formatA.BboxParams supports a few more settings.

Here is an example of Compose that shows all available settings with A.BboxParams:

transform = A.Compose([
    A.RandomCrop(width=450, height=450),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
], bbox_params=A.BboxParams(format='coco', min_area=1024, min_visibility=0.1, label_fields=['class_labels']))

min_area and min_visibility

min_area and min_visibility parameters control what Albumentations should do to the augmented bounding boxes if their size has changed after augmentation. The size of bounding boxes could change if you apply spatial augmentations 空间增强 , for example, when you crop 裁剪 a part of an image or when you resize an image.

min_area is a value in pixels 是以像素为单位的值. If the area of a bounding box after augmentation becomes smaller than min_area, Albumentations will drop that box. So the returned list of augmented bounding boxes won't contain that bounding box.

min_visibility is a value between 0 and 1. If the ratio of the bounding box area after augmentation to the area of the bounding box before augmentation becomes smaller than min_visibility, Albumentations will drop that box. So if the augmentation process cuts the most of the bounding box, that box won't be present in the returned list of the augmented bounding boxes.

Here is an example image that contains two bounding boxes. Bounding boxes coordinates are declared using the coco format.

 An example image with two bounding boxes

First, we apply the CenterCrop augmentation without declaring parameters min_area and min_visibility. The augmented image contains two bounding boxes.

 An example image with two bounding boxes after applying augmentation

Next, we apply the same CenterCrop augmentation, but now we also use the min_area parameter. Now, the augmented image contains only one bounding box, because the other bounding box's area after augmentation became smaller than min_area, so Albumentations dropped that bounding box.

 An example image with one bounding box after applying augmentation with 'min_area'

Finally, we apply the CenterCrop augmentation with the min_visibility. After that augmentation, the resulting image doesn't contain any bounding box, because visibility of all bounding boxes after augmentation are below threshold set by min_visibility.

An example image with zero bounding boxes after applying augmentation with 'min_visibility' 

Class labels for bounding boxes

Besides coordinates, each bounding box should have an associated class label that tells which object lies inside the bounding box. There are two ways to pass a label for a bounding box.

Let's say you have an example image with three objects: dogcat, and sports ball. Bounding boxes coordinates in the coco format for those objects are [23, 74, 295, 388][377, 294, 252, 161], and [333, 421, 49, 49].

An example image with 3 bounding boxes from the COCO dataset

1. You can pass labels along with bounding boxes coordinates by adding them as additional values to the list of coordinates.

For the image above, bounding boxes with class labels will become [23, 74, 295, 388, 'dog'][377, 294, 252, 161, 'cat'], and [333, 421, 49, 49, 'sports ball'].

Class labels could be of any type: integer, string, or any other Python data type. For example, integer values as class labels will look the following: [23, 74, 295, 388, 18][377, 294, 252, 161, 17], and [333, 421, 49, 49, 37].

 Also, you can use multiple class values for each bounding box, for example [23, 74, 295, 388, 'dog', 'animal'][377, 294, 252, 161, 'cat', 'animal'], and [333, 421, 49, 49, 'sports ball', 'item'].

2.You can pass labels for bounding boxes as a separate list (the preferred way).

 For example, if you have three bounding boxes like [23, 74, 295, 388][377, 294, 252, 161], and [333, 421, 49, 49] you can create a separate list with values like ['cat', 'dog', 'sports ball'], or [18, 17, 37] that contains class labels for those bounding boxes. Next, you pass that list with class labels as a separate argument to the transform function. Albumentations needs to know the names of all those lists with class labels to join them with augmented bounding boxes correctly. Then, if a bounding box is dropped after augmentation because it is no longer visible, Albumentations will drop the class label for that box as well. Use label_fields parameter to set names for all arguments in transform that will contain label descriptions for bounding boxes (more on that in Step 4).

Step 3. Read images and bounding boxes from the disk.

Read an image from the disk.

image = cv2.imread("/path/to/image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

Bounding boxes can be stored on the disk in different serialization formats: JSON, XML, YAML, CSV, etc. So the code to read bounding boxes depends on the actual format of data on the disk.

After you read the data from the disk, you need to prepare bounding boxes for Albumentations.

Albumentations expects that bounding boxes will be represented 表示 as a list of lists. Each list contains information about a single bounding box. A bounding box definition should have at list four elements that represent the coordinates of that bounding box. The actual meaning of those four values depends on the format of bounding boxes (either pascal_vocalbumentationscoco, or yolo). Besides four coordinates, each definition of a bounding box may contain one or more extra values. You can use those extra values to store additional information about the bounding box, such as a class label of the object inside the box. During augmentation, Albumentations will not process those extra values. The library will return them as is along with the updated coordinates of the augmented bounding box 库将按原样返回它们以及增强边界框的更新坐标.

Step 4. Pass an image and bounding boxes to the augmentation pipeline and receive augmented images and boxes.

As discussed in Step 2, there are two ways of passing class labels along with bounding boxes coordinates:

1. Pass class labels along with coordinates.

So, if you have coordinates of three bounding boxes that look like this:

bboxes = [
    [23, 74, 295, 388],
    [377, 294, 252, 161],
    [333, 421, 49, 49],
]

you can add a class label for each bounding box as an additional element of the list along with four coordinates. So now a list with bounding boxes and their coordinates will look the following:

bboxes = [
    [23, 74, 295, 388, 'dog'],
    [377, 294, 252, 161, 'cat'],
    [333, 421, 49, 49, 'sports ball'],
]

or with multiple labels per each bounding box:

bboxes = [
    [23, 74, 295, 388, 'dog', 'animal'],
    [377, 294, 252, 161, 'cat', 'animal'],
    [333, 421, 49, 49, 'sports ball', 'item'],
]

You can use any data type for declaring class labels. It can be string, integer, or any other Python data type.

Next, you pass an image and bounding boxes for it to the transform function and receive the augmented image and bounding boxes.

transformed = transform(image=image, bboxes=bboxes)
transformed_image = transformed['image']
transformed_bboxes = transformed['bboxes']

 Example input and output data for bounding boxes augmentation

2. Pass class labels in a separate argument to transform (the preferred way).

 Let's say you have coordinates of three bounding boxes

bboxes = [
    [23, 74, 295, 388],
    [377, 294, 252, 161],
    [333, 421, 49, 49],
]

You can create a separate list that contains class labels for those bounding boxes:

class_labels = ['cat', 'dog', 'parrot']

Then you pass both bounding boxes and class labels to transform. Note that to pass class labels, you need to use the name of the argument that you declared in label_fields when creating an instance of Compose in step 2. In our case, we set the name of the argument to class_labels.

transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)
transformed_image = transformed['image']
transformed_bboxes = transformed['bboxes']
transformed_class_labels = transformed['class_labels']

Example input and output data for bounding boxes augmentation with a separate argument for class labels 

Note that label_fields expects a list, so you can set multiple fields that contain labels for your bounding boxes. So if you declare Compose like

transform = A.Compose([
    A.RandomCrop(width=450, height=450),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
], bbox_params=A.BboxParams(format='coco', label_fields=['class_labels', 'class_categories'])))

you can use those multiple arguments to pass info about class labels, like

class_labels = ['cat', 'dog', 'parrot']
class_categories = ['animal', 'animal', 'item']

transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels, class_categories=class_categories)
transformed_image = transformed['image']
transformed_bboxes = transformed['bboxes']
transformed_class_labels = transformed['class_labels']
transformed_class_categories = transformed['class_categories']

您可能感兴趣的与本文相关的镜像

Yolo-v5

Yolo-v5

Yolo

YOLO(You Only Look Once)是一种流行的物体检测和图像分割模型,由华盛顿大学的Joseph Redmon 和Ali Farhadi 开发。 YOLO 于2015 年推出,因其高速和高精度而广受欢迎

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license """ Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy. References: https://arxiv.org/pdf/2304.08069.pdf """ from ultralytics.engine.model import Model from ultralytics.nn.tasks import RTDETRDetectionModel from ultralytics.utils.torch_utils import TORCH_1_11 from .predict import RTDETRPredictor from .train import RTDETRTrainer from .val import RTDETRValidator class RTDETR(Model): """ Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector. This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed. Attributes: model (str): Path to the pre-trained model. Methods: task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes. Examples: Initialize RT-DETR with a pre-trained model >>> from ultralytics import RTDETR >>> model = RTDETR("rtdetr-l.pt") >>> results = model("image.jpg") """ def __init__(self, model: str = "rtdetr-l.pt") -> None: """ Initialize the RT-DETR model with the given pre-trained model file. Args: model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats. """ assert TORCH_1_11, "RTDETR requires torch>=1.11" super().__init__(model=model, task="detect") @property def task_map(self) -> dict: """ Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes. Returns: (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model. """ return { "detect": { "predictor": RTDETRPredictor, "validator": RTDETRValidator, "trainer": RTDETRTrainer, "model": RTDETRDetectionModel, } } # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license import torch from ultralytics.data.augment import LetterBox from ultralytics.engine.predictor import BasePredictor from ultralytics.engine.results import Results from ultralytics.utils import ops class RTDETRPredictor(BasePredictor): """ RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions. This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection. Attributes: imgsz (int): Image size for inference (must be square and scale-filled). args (dict): Argument overrides for the predictor. model (torch.nn.Module): The loaded RT-DETR model. batch (list): Current batch of processed inputs. Methods: postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores. pre_transform: Pre-transform input images before feeding them into the model for inference. Examples: >>> from ultralytics.utils import ASSETS >>> from ultralytics.models.rtdetr import RTDETRPredictor >>> args = dict(model="rtdetr-l.pt", source=ASSETS) >>> predictor = RTDETRPredictor(overrides=args) >>> predictor.predict_cli() """ def postprocess(self, preds, img, orig_imgs): """ Postprocess the raw predictions from the model to generate bounding boxes and confidence scores. The method filters detections based on confidence and class if specified in `self.args`. It converts model predictions to Results objects containing properly scaled bounding boxes. Args: preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes and scores. img (torch.Tensor): Processed input images with shape (N, 3, H, W). orig_imgs (list | torch.Tensor): Original, unprocessed images. Returns: results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, and class labels. """ if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference preds = [preds, None] nd = preds[0].shape[-1] bboxes, scores = preds[0].split((4, nd - 4), dim=-1) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) results = [] for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4) bbox = ops.xywh2xyxy(bbox) max_score, cls = score.max(-1, keepdim=True) # (300, 1) idx = max_score.squeeze(-1) > self.args.conf # (300, ) if self.args.classes is not None: idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det] oh, ow = orig_img.shape[:2] pred[..., [0, 2]] *= ow # scale x coordinates to original width pred[..., [1, 3]] *= oh # scale y coordinates to original height results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) return results def pre_transform(self, im): """ Pre-transform input images before feeding them into the model for inference. The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square (640) and scale_filled. Args: im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list. Returns: (list): List of pre-transformed images ready for model inference. """ letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True) return [letterbox(image=x) for x in im] # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license from __future__ import annotations from copy import copy from ultralytics.models.yolo.detect import DetectionTrainer from ultralytics.nn.tasks import RTDETRDetectionModel from ultralytics.utils import RANK, colorstr from .val import RTDETRDataset, RTDETRValidator class RTDETRTrainer(DetectionTrainer): """ Trainer class for the RT-DETR model developed by Baidu for real-time object detection. This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference speed. Attributes: loss_names (tuple): Names of the loss components used for training. data (dict): Dataset configuration containing class count and other parameters. args (dict): Training arguments and hyperparameters. save_dir (Path): Directory to save training results. test_loader (DataLoader): DataLoader for validation/testing data. Methods: get_model: Initialize and return an RT-DETR model for object detection tasks. build_dataset: Build and return an RT-DETR dataset for training or validation. get_validator: Return a DetectionValidator suitable for RT-DETR model validation. Notes: - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument. - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching. Examples: >>> from ultralytics.models.rtdetr.train import RTDETRTrainer >>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3) >>> trainer = RTDETRTrainer(overrides=args) >>> trainer.train() """ def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True): """ Initialize and return an RT-DETR model for object detection tasks. Args: cfg (dict, optional): Model configuration. weights (str, optional): Path to pre-trained model weights. verbose (bool): Verbose logging if True. Returns: (RTDETRDetectionModel): Initialized model. """ model = RTDETRDetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1) if weights: model.load(weights) return model def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None): """ Build and return an RT-DETR dataset for training or validation. Args: img_path (str): Path to the folder containing images. mode (str): Dataset mode, either 'train' or 'val'. batch (int, optional): Batch size for rectangle training. Returns: (RTDETRDataset): Dataset object for the specific mode. """ return RTDETRDataset( img_path=img_path, imgsz=self.args.imgsz, batch_size=batch, augment=mode == "train", hyp=self.args, rect=False, cache=self.args.cache or None, single_cls=self.args.single_cls or False, prefix=colorstr(f"{mode}: "), classes=self.args.classes, data=self.data, fraction=self.args.fraction if mode == "train" else 1.0, ) def get_validator(self): """Return a DetectionValidator suitable for RT-DETR model validation.""" self.loss_names = "giou_loss", "cls_loss", "l1_loss" return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license from __future__ import annotations from pathlib import Path from typing import Any import torch from ultralytics.data import YOLODataset from ultralytics.data.augment import Compose, Format, v8_transforms from ultralytics.models.yolo.detect import DetectionValidator from ultralytics.utils import colorstr, ops __all__ = ("RTDETRValidator",) # tuple or list class RTDETRDataset(YOLODataset): """ Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class. This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for real-time detection and tracking tasks. Attributes: augment (bool): Whether to apply data augmentation. rect (bool): Whether to use rectangular training. use_segments (bool): Whether to use segmentation masks. use_keypoints (bool): Whether to use keypoint annotations. imgsz (int): Target image size for training. Methods: load_image: Load one image from dataset index. build_transforms: Build transformation pipeline for the dataset. Examples: Initialize an RT-DETR dataset >>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640) >>> image, hw = dataset.load_image(0) """ def __init__(self, *args, data=None, **kwargs): """ Initialize the RTDETRDataset class by inheriting from the YOLODataset class. This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking) model, building upon the base YOLODataset functionality. Args: *args (Any): Variable length argument list passed to the parent YOLODataset class. data (dict | None): Dictionary containing dataset information. If None, default values will be used. **kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class. """ super().__init__(*args, data=data, **kwargs) def load_image(self, i, rect_mode=False): """ Load one image from dataset index 'i'. Args: i (int): Index of the image to load. rect_mode (bool, optional): Whether to use rectangular mode for batch inference. Returns: im (torch.Tensor): The loaded image. resized_hw (tuple): Height and width of the resized image with shape (2,). Examples: Load an image from the dataset >>> dataset = RTDETRDataset(img_path="path/to/images") >>> image, hw = dataset.load_image(0) """ return super().load_image(i=i, rect_mode=rect_mode) def build_transforms(self, hyp=None): """ Build transformation pipeline for the dataset. Args: hyp (dict, optional): Hyperparameters for transformations. Returns: (Compose): Composition of transformation functions. """ if self.augment: hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0 transforms = v8_transforms(self, self.imgsz, hyp, stretch=True) else: # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)]) transforms = Compose([]) transforms.append( Format( bbox_format="xywh", normalize=True, return_mask=self.use_segments, return_keypoint=self.use_keypoints, batch_idx=True, mask_ratio=hyp.mask_ratio, mask_overlap=hyp.overlap_mask, ) ) return transforms class RTDETRValidator(DetectionValidator): """ RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for the RT-DETR (Real-Time DETR) object detection model. The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for post-processing, and updates evaluation metrics accordingly. Attributes: args (Namespace): Configuration arguments for validation. data (dict): Dataset configuration dictionary. Methods: build_dataset: Build an RTDETR Dataset for validation. postprocess: Apply Non-maximum suppression to prediction outputs. Examples: Initialize and run RT-DETR validation >>> from ultralytics.models.rtdetr import RTDETRValidator >>> args = dict(model="rtdetr-l.pt", data="coco8.yaml") >>> validator = RTDETRValidator(args=args) >>> validator() Notes: For further details on the attributes and methods, refer to the parent DetectionValidator class. """ def build_dataset(self, img_path, mode="val", batch=None): """ Build an RTDETR Dataset. Args: img_path (str): Path to the folder containing images. mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for each mode. batch (int, optional): Size of batches, this is for `rect`. Returns: (RTDETRDataset): Dataset configured for RT-DETR validation. """ return RTDETRDataset( img_path=img_path, imgsz=self.args.imgsz, batch_size=batch, augment=False, # no augmentation hyp=self.args, rect=False, # no rect cache=self.args.cache or None, prefix=colorstr(f"{mode}: "), data=self.data, ) def postprocess( self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor] ) -> list[dict[str, torch.Tensor]]: """ Apply Non-maximum suppression to prediction outputs. Args: preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores. Returns: (list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing: - 'bboxes': Tensor of shape (N, 4) with bounding box coordinates - 'conf': Tensor of shape (N,) with confidence scores - 'cls': Tensor of shape (N,) with class indices """ if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference preds = [preds, None] bs, _, nd = preds[0].shape bboxes, scores = preds[0].split((4, nd - 4), dim=-1) bboxes *= self.args.imgsz outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs for i, bbox in enumerate(bboxes): # (300, 4) bbox = ops.xywh2xyxy(bbox) score, cls = scores[i].max(-1) # (300, ) pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter # Sort by confidence to correctly get internal metrics pred = pred[score.argsort(descending=True)] outputs[i] = pred[score > self.args.conf] return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs] def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None: """ Serialize YOLO predictions to COCO json format. Args: predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with bounding box coordinates, confidence scores, and class predictions. pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'. """ path = Path(pbatch["im_file"]) stem = path.stem image_id = int(stem) if stem.isnumeric() else stem box = predn["bboxes"].clone() box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred box = ops.xyxy2xywh(box) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()): self.jdict.append( { "image_id": image_id, "file_name": path.name, "category_id": self.class_map[int(c)], "bbox": [round(x, 3) for x in b], "score": round(s, 5), } )
09-27
题目型问题 ### 题目重述 请对提供的 RT-DETR 模块代码进行解析,包括其类结构、核心方法与功能实现。 --- ### 详解 该代码实现了 **RT-DETR**(Real-Time Detection Transformer)模型在 Ultralytics 框架中的完整集成。RT-DETR 是由百度提出的一种基于 Vision Transformer 的实时目标检测器,在保持高精度的同时显著降低推理延迟,尤其适用于 CUDA + TensorRT 加速场景。 项目包含四个文件: - `__init__.py`:主模型接口 `RTDETR` - `predict.py`:推理逻辑 `RTDETRPredictor` - `train.py`:训练流程 `RTDETRTrainer` - `val.py`:验证与数据集处理 `RTDETRDataset`, `RTDETRValidator` 下面逐层解析。 #### 1. `RTDETR(Model)` —— 主入口类 继承自 `ultralytics.engine.model.Model`,提供统一 API 接口。 ##### 构造函数 ```python def __init__(self, model: str = "rtdetr-l.pt") -> None ``` - 强制要求 PyTorch ≥ 1.11: ```python assert TORCH_1_11, "RTDETR requires torch>=1.11" ``` 因为内部使用了高级张量操作和动态图特性。 - 支持 `.pt`, `.yaml`, `.yml` 格式初始化,但实际仅允许加载预训练权重或配置模型结构。 ##### `task_map` 属性 定义任务映射字典: ```python "detect": { "predictor": RTDETRPredictor, "validator": RTDETRValidator, "trainer": RTDETRTrainer, "model": RTDETRDetectionModel, } ``` 说明 RT-DETR 专用于目标检测任务,并指定了全流程组件。 --- #### 2. `RTDETRPredictor(BasePredictor)` 负责推理阶段的前处理与后处理。 ##### `pre_transform(im)` - 使用 `LetterBox` 对图像进行缩放填充,确保输入为正方形且边缘对齐。 - 设置 `scale_fill=True` 实现“拉伸填满”模式,避免信息丢失。 ##### `postprocess(preds, img, orig_imgs)` 这是关键后处理步骤: 1. 输入 `preds[0]` 形状为 `(B, N, C+4)`,其中: - 前 4 列是归一化的 `xywh` 框(范围 [0,1]) - 后 C 列是类别得分 2. 分离框与得分: ```python bboxes, scores = preds[0].split((4, nd - 4), dim=-1) ``` 3. 转换坐标格式: ```python bbox = ops.xywh2xyxy(bbox) # 转为 xyxy ``` 4. 筛选结果: - 按置信度过滤:`max_score > self.args.conf` - 按类别筛选:`cls in args.classes` - 按最大数量限制:`[:self.args.max_det]` 5. 将预测框还原到原始图像尺寸: ```python pred[..., [0,2]] *= ow # 宽度方向 pred[..., [1,3]] *= oh # 高度方向 ``` 6. 封装为 `Results` 对象返回,支持可视化与分析。 --- #### 3. `RTDETRTrainer(DetectionTrainer)` 定制化训练流程以适配 RT-DETR 架构。 ##### `get_model(cfg, weights, verbose)` - 创建 `RTDETRDetectionModel` 实例(来自 `ultralytics.nn.tasks`)。 - 若传入 `weights`,则加载预训练参数。 ##### `build_dataset(img_path, mode, batch)` - 返回 `RTDETRDataset` 类实例。 - 训练时启用增强(`augment=True`),验证时不启用。 ##### `get_validator()` - 设置损失名称用于日志显示: ```python self.loss_names = "giou_loss", "cls_loss", "l1_loss" ``` - 返回专用验证器 `RTDETRValidator`。 > ⚠️ 注意事项提示: > - `F.grid_sample` 不支持 `deterministic=True` > - AMP(自动混合精度)可能导致 NaN 输出 --- #### 4. `RTDETRDataset(YOLODataset)` 基于 YOLO 数据集扩展,支持 RT-DETR 特性。 ##### `build_transforms(hyp)` - 训练模式下应用 Mosaic、MixUp 等增强(当 `not rect` 时)。 - 测试/验证模式下仅添加格式化操作: ```python Format(bbox_format="xywh", normalize=True, ...) ``` - 所有边界框以归一化 `xywh` 格式存储,符合 DETR 系列模型输入要求。 --- #### 5. `RTDETRValidator(DetectionValidator)` 验证阶段专用处理器。 ##### `postprocess(preds)` - 将输出框从 `[0,1]` 范围乘以 `imgsz` 还原为像素坐标。 - 转换 `xywh → xyxy` 并按置信度排序。 - 输出格式为列表,每个元素是一个字典: ```python {"bboxes": (N,4), "conf": (N,), "cls": (N,)} ``` ##### `pred_to_json(...)` 将预测结果转换为 COCO 格式 JSON,用于评估 mAP 指标: - 提取图像 ID(支持数字命名) - 将 `xyxy` 转回 `xywh`(COCO 格式要求) - 减去宽高一半得到左上角坐标 - 写入 `self.jdict` 缓存,最终用于 `cocoapi` 评估 --- ### 知识点 - **Vision Transformer 检测架构**:RT-DETR 使用 ViT 主干 + 解码器结构,摆脱 NMS 依赖,实现端到端检测,推理更稳定。 - **IoU-aware Query Selection**:在查询生成阶段引入 IoU 感知机制,提升定位质量,增强大物体检测能力。 - **高效混合编码器(Hybrid Encoder)**:结合 CNN 局部感知与 Transformer 全局建模优势,在低延迟下保持高精度表现。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值