在YOLOv5/YOLOv8数据集上应用数据增强的方法与代码实现结果

数据增强简介

数据增强是机器学习或深度学习中的一种技术,通过应用各种变换(如翻转、旋转、改变亮度/对比度等)从现有数据创建新数据。它通常用于计算机视觉任务,但也适用于自然语言处理和语音识别等领域。

数据增强的重要性

数据增强可以帮助防止过拟合,通过增加训练数据的多样性,帮助模型学习泛化到新的、未见过的数据。它还可以通过让模型接触到更广泛的数据变化来提高模型的准确性和鲁棒性。此外,当可用的训练数据量有限时,数据增强有助于增加数据集的大小,从而提高模型的性能。

数据增强的常见方法

常见的数据增强方法包括但不限于以下几种:

  • 图像翻转:水平或垂直翻转图像。
  • 图像旋转:将图像旋转一定角度。
  • 颜色变换:调整图像的亮度、对比度、饱和度等。
  • 随机裁剪:从图像中随机裁剪出一个子区域。
  • 缩放:将图像放大或缩小。
  • 添加噪声:在图像中添加随机噪声。

YOLOv5和YOLOv8的数据集格式

YOLOv8和YOLOv5的数据集格式相同,主要包含两个目录:

  • Images目录:包含图像文件。
  • labels目录:包含**.txt文件。每个.txt**文件包含归一化的边界框,格式如下:
[class_label, x_center, y_center, width, height]

数据集格式示例

假设有一个图像文件image1.jpg,对应的标签文件image1.txt内容如下:

0 0.5 0.5 0.2 0.3
1 0.3 0.4 0.1 0.2

这表示图像中有两个目标,第一个目标的类别标签为0,中心点坐标为(0.5, 0.5),宽度和高度分别为0.2和0.3;第二个目标的类别标签为1,中心点坐标为(0.3, 0.4),宽度和高度分别为0.1和0.2。

如何使用Albumentations库?

Albumentations是一个用于图像增强的Python库,提供了一种简单灵活的方式来执行各种图像变换。以下是使用该库的详细步骤:

安装Albumentations库

首先,需要安装Albumentations库。可以通过以下命令进行安装:

pip install -U albumentations

最小增强流程

以下是一个简单的增强流程示例:

import albumentations as A
import cv2
import numpy as np

# 定义增强操作
transform = A.Compose([
    A.RandomCrop(width=450, height=450),  # 随机裁剪
    A.HorizontalFlip(p=0.5),  # 水平翻转,概率为0.5
    A.RandomBrightnessContrast(p=0.2),  # 随机调整亮度和对比度,概率为0.2
], bbox_params=A.BboxParams(format='yolo'))

# 加载图像和边界框
image = cv2.imread('image1.jpg')
bboxes = [[0.5, 0.5, 0.2, 0.3, 0], [0.3, 0.4, 0.1, 0.2, 1]]

# 应用增强操作
transformed = transform(image=image, bboxes=bboxes)
transformed_image = transformed['image']
transformed_bboxes = transformed['bboxes']

# 显示增强后的图像
cv2.imshow('Transformed Image', transformed_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

代码说明

  1. 导入库:首先需要导入albumentations库。
  2. 定义增强操作:在A.Compose对象中定义要应用的增强操作。可以添加多种增强操作,每种操作都可以设置相应的概率。
  3. 转换YOLO格式:需要将YOLO输入边界框标签转换为Albumentations格式。Albumentations格式为[x_center, y_center, width, height, class_name]
  4. 应用增强操作:将YOLO输入图像和边界框列表传递给transform对象,返回的增强结果存储在transformed字典中。

保存增强结果

如果需要保存增强后的图像和标签,可以使用以下方法:

# 保存增强后的图像
cv2.imwrite('transformed_image.jpg', transformed_image)

# 保存增强后的标签
with open('transformed_labels.txt', 'w') as f:
    for bbox in transformed_bboxes:
        class_label = int(bbox[4])
        x_center, y_center, width, height = bbox[:4]
        f.write(f'{class_label} {x_center} {y_center} {width} {height}\n')

结果展示

输入

  • 输入图像:

输入图像

  • 输入标签:

输入标签

输出

  • 增强后的图像:

增强后的图像

  • 增强后的文本文件:

增强后的文本文件

  • 通过在增强后的图像上绘制标签来可视化增强输出:

可视化增强输出

完整代码

import albumentations as A
import cv2
import os
import yaml
import pybboxes as pbx


with open("contants.yaml", 'r') as stream:
    CONSTANTS = yaml.safe_load(stream)


def is_image_by_extension(file_name):
    """
    Check if the given file has a recognized image extension.

    Args:
        file_name (str): Name of the file.

    Returns:
        bool: True if the file has a recognized image extension, False otherwise.

    """
    # List of common image extensions
    image_extensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff', 'webp']
    # Get the file extension
    file_extension = file_name.lower().split('.')[-1]
    # Check if the file has a recognized image extension
    return file_extension in image_extensions


def get_inp_data(img_file):
    """
    Get input data for image processing.

    Args:
        img_file (str): Name of the input image file.

    Returns:
        tuple: A tuple containing the image, ground truth bounding boxes, and augmented file name.

    """
    file_name = os.path.splitext(img_file)[0]
    aug_file_name = f"{file_name}_{CONSTANTS['transformed_file_name']}"
    image = cv2.imread(os.path.join(CONSTANTS["inp_img_pth"], img_file))
    lab_pth = os.path.join(CONSTANTS["inp_lab_pth"], f"{file_name}.txt")
    gt_bboxes = get_bboxes_list(lab_pth, CONSTANTS['CLASSES'])
    return image, gt_bboxes, aug_file_name


def get_album_bb_list(yolo_bbox, class_names):
    """
    Extracts bounding box information for a single object from YOLO format.

    Args:
        yolo_bbox (str): YOLO format string representing bounding box information.
        class_names (list): List of class names corresponding to class numbers.

    Returns:
        list: A list containing [x_center, y_center, width, height, class_name].
    """
    str_bbox_list = yolo_bbox.split()
    class_number = int(str_bbox_list[0])
    class_name = class_names[class_number]
    bbox_values = list(map(float, str_bbox_list[1:]))
    album_bb = bbox_values + [class_name]
    return album_bb


def get_album_bb_lists(yolo_str_labels, classes):
    """
    Extracts bounding box information for multiple objects from YOLO format.

    Args:
        yolo_str_labels (str): YOLO format string containing bounding box information for multiple objects.
        classes (list): List of class names corresponding to class numbers.

    Returns:
        list: A list of lists, each containing [x_center, y_center, width, height, class_name].
    """
    album_bb_lists = []
    yolo_list_labels = yolo_str_labels.split('\n')
    for yolo_str_label in yolo_list_labels:
        if yolo_str_label:
            album_bb_list = get_album_bb_list(yolo_str_label, classes)
            album_bb_lists.append(album_bb_list)
    return album_bb_lists


def get_bboxes_list(inp_lab_pth, classes):
    """
    Reads YOLO format labels from a file and returns bounding box information.

    Args:
        inp_lab_pth (str): Path to the YOLO format labels file.
        classes (list): List of class names corresponding to class numbers.

    Returns:
        list: A list of lists, each containing [x_center, y_center, width, height, class_name].
    """
    yolo_str_labels = open(inp_lab_pth, "r").read()

    if not yolo_str_labels:
        print("No object")
        return []

    lines = [line.strip() for line in yolo_str_labels.split("\n") if line.strip()]
    album_bb_lists = get_album_bb_lists("\n".join(lines), classes) if len(lines) > 1 else [get_album_bb_list("\n".join(lines), classes)]

    return album_bb_lists


def single_obj_bb_yolo_conversion(transformed_bboxes, class_names):
    """
    Convert bounding boxes for a single object to YOLO format.

    Parameters:
    - transformed_bboxes (list): Bounding box coordinates and class name.
    - class_names (list): List of class names.

    Returns:
    - list: Bounding box coordinates in YOLO format.
    """
    if transformed_bboxes:
        class_num = class_names.index(transformed_bboxes[-1])
        bboxes = list(transformed_bboxes)[:-1]
        bboxes.insert(0, class_num)
    else:
        bboxes = []
    return bboxes


def multi_obj_bb_yolo_conversion(aug_labs, class_names):
    """
    Convert bounding boxes for multiple objects to YOLO format.

    Parameters:
    - aug_labs (list): List of bounding box coordinates and class names.
    - class_names (list): List of class names.

    Returns:
    - list: List of bounding box coordinates in YOLO format for each object.
    """
    yolo_labels = [single_obj_bb_yolo_conversion(aug_lab, class_names) for aug_lab in aug_labs]
    return yolo_labels


def save_aug_lab(transformed_bboxes, lab_pth, lab_name):
    """
    Save augmented bounding boxes to a label file.

    Args:
        transformed_bboxes (list): List of augmented bounding boxes.
        lab_pth (str): Path to the output label directory.
        lab_name (str): Name of the label file.

    """
    lab_out_pth = os.path.join(lab_pth, lab_name)
    with open(lab_out_pth, 'w') as output:
        for bbox in transformed_bboxes:
            updated_bbox = str(bbox).replace(',', ' ').replace('[', '').replace(']', '')
            output.write(updated_bbox + '\n')


def save_aug_image(transformed_image, out_img_pth, img_name):
    """
    Save augmented image to an output directory.

    Args:
        transformed_image (numpy.ndarray): Augmented image.
        out_img_pth (str): Path to the output image directory.
        img_name (str): Name of the image file.

    """
    out_img_path = os.path.join(out_img_pth, img_name)
    cv2.imwrite(out_img_path, transformed_image)


def draw_yolo(image, labels, file_name):
    """
    Draw bounding boxes on an image based on YOLO format.

    Args:
        image (numpy.ndarray): Input image.
        labels (list): List of bounding boxes in YOLO format.

    """
    H, W = image.shape[:2]
    for label in labels:
        yolo_normalized = label[1:]
        box_voc = pbx.convert_bbox(tuple(yolo_normalized), from_type="yolo", to_type="voc", image_size=(W, H))
        cv2.rectangle(image, (box_voc[0], box_voc[1]),
                      (box_voc[2], box_voc[3]), (0, 0, 255), 1)
    cv2.imwrite(f"bb_image/{file_name}.png", image)
    # cv2.imshow(f"{file_name}.png", image)
    # cv2.waitKey(0)


def has_negative_element(lst):
    """
    Check if the given list contains any negative element.

    Args:
        lst (list): List of elements.

    Returns:
        bool: True if there is any negative element, False otherwise.
    """
    return any(x < 0 for x in lst)


def get_augmented_results(image, bboxes):
    """
    Apply data augmentation to an input image and bounding boxes.

    Parameters:
    - image (numpy.ndarray): Input image.
    - bboxes (list): List of bounding boxes in YOLO format [x_center, y_center, width, height, class_name].

    Returns:
    - tuple: A tuple containing the augmented image and the transformed bounding boxes.
    """
    # Define the augmentations
    transform = A.Compose([
        A.RandomCrop(width=300, height=300),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0),
        A.CLAHE(clip_limit=(0, 1), tile_grid_size=(8, 8), always_apply=True),
        A.Resize(300, 300)
    ], bbox_params=A.BboxParams(format='yolo'))

    # Apply the augmentations
    transformed = transform(image=image, bboxes=bboxes)
    transformed_image, transformed_bboxes = transformed['image'], transformed['bboxes']
    
    return transformed_image, transformed_bboxes


def has_negative_element(matrix):
    """
    Check if there is a negative element in the 2D list of augmented bounding boxes.

    Args:
        matrix (list[list]): The 2D list.

    Returns:
        bool: True if a negative element is found, False otherwise.

    """
    return any(element < 0 for row in matrix for element in row)


def save_augmentation(trans_image, trans_bboxes, trans_file_name):
    """
    Saves the augmented label and image if no negative elements are found in the transformed bounding boxes.

    Parameters:
        trans_image (numpy.ndarray): The augmented image.
        trans_bboxes (list): The transformed bounding boxes.
        trans_file_name (str): The name for the augmented output.

    Returns:
        None
    """
    tot_objs = len(trans_bboxes)
    if tot_objs:
        # Convert bounding boxes to YOLO format
        trans_bboxes = multi_obj_bb_yolo_conversion(trans_bboxes, CONSTANTS['CLASSES']) if tot_objs > 1 else [single_obj_bb_yolo_conversion(trans_bboxes[0], CONSTANTS['CLASSES'])]
        if not has_negative_element(trans_bboxes):
            # Save augmented label and image
            save_aug_lab(trans_bboxes, CONSTANTS["out_lab_pth"], trans_file_name + ".txt")
            save_aug_image(trans_image, CONSTANTS["out_img_pth"], trans_file_name + ".png")
            # Draw bounding boxes on the augmented image
            draw_yolo(trans_image, trans_bboxes, trans_file_name)
        else:
            print("Found Negative element in Transformed Bounding Box...")
    else:
        print("Label file is empty")
from utils import *


def run_yolo_augmentor():
    """
    Run the YOLO augmentor on a set of images.

    This function processes each image in the input directory, applies augmentations,
    and saves the augmented images and labels to the output directories.

    """
    imgs = [img for img in os.listdir(CONSTANTS["inp_img_pth"]) if is_image_by_extension(img)]

    for img_num, img_file in enumerate(imgs):
        print(f"{img_num+1}-image is processing...\n")
        image, gt_bboxes, aug_file_name = get_inp_data(img_file)
        aug_img, aug_label = get_augmented_results(image, gt_bboxes)
        if len(aug_img) and len(aug_label):
            save_augmentation(aug_img, aug_label, aug_file_name)


if __name__ == "__main__":
    run_yolo_augmentor()

总结

本文主要讨论了如何为YOLOv5和YOLOv8的目标检测任务获取增强后的数据集。我们也可以使用相同的Python库为其他格式的数据集获取增强后的数据集。该方法不仅适用于目标检测,还可以用于分类和分割任务。

### 使用Albumentations库增强YOLO目标检测模型的数据 为了提高YOLO目标检测模型的效果,在数据预处理阶段应用图像增广技术至关重要。Albumentations是一个高效且易于使用的Python库,专门用于计算机视觉任务中的数据增广。 #### 安装依赖包 首先安装必要的软件包: ```bash pip install albumentations==1.0.7 imgaug opencv-python matplotlib numpy pandas torch torchvision ``` #### 导入所需模块并定义转换函数 接着导入所需的Python模块,并创建自定义的变换类来集成到YOLO工作流中[^1]。 ```python import cv2 from albumentations import Compose, BboxParams, RandomBrightnessContrast, ShiftScaleRotate, HorizontalFlip import numpy as np class YoloAugmentationPipeline: def __init__(self): self.transform = Compose([ HorizontalFlip(p=0.5), ShiftScaleRotate(scale_limit=0.2, rotate_limit=20, p=0.5), RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5) ], bbox_params=BboxParams(format='yolo', label_fields=['category_ids'])) def apply(self, image: np.ndarray, bboxes=None, category_ids=None): transformed = self.transform(image=image, bboxes=bboxes or [], category_ids=category_ids or []) return transformed['image'], transformed.get('bboxes', []), transformed.get('category_ids', []) ``` 此代码片段展示了如何构建一个简单的数据增广管道,其中包括水平翻转、随机缩放旋转和平移以及亮度对比度调整等操作。 #### 应用增广方法于训练集样本上 当准备好了上述工具之后就可以将其应用于实际的数据集中了。对于每一个输入图片及其对应的边界框坐标和类别标签,调用`apply()`方法即可获得经过变换后的版本。 通过这种方式可以在不增加额外标注成本的情况下显著扩充可用作训练的数据量,从而有助于提升最终模型的表现性能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

知来者逆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值