YOLACT++代码分析1——数据增强

本文深入解析了YOLACT++目标检测模型中的数据增强技术,包括光度扭曲、随机对比度调整、颜色空间转换等操作,以及随机裁剪、翻转和旋转等几何变换,旨在提高模型的泛化能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

**

数据增强

**
本系列的博客都是从yolact++源码(https://github.com/dbolya/yolact)的trian.py文件开始讲解。

1.首先加载数据

    #image_path训练图片文件夹
    #info_file标签文件夹
    dataset = COCODetection(image_path=cfg.dataset.train_images,
                            info_file=cfg.dataset.train_info,
                            transform=SSDAugmentation(MEANS))

主要讲一下transform=SSDAugmentation(MEANS)

class SSDAugmentation(object):
    """ Transform to be used when training. """

    def __init__(self, mean=MEANS, std=STD):
        //构造Compose的实例对象,但是传入的是一个列表
        self.augment = Compose([ 
            ConvertFromInts(),  //将图片数据转为np.float32的数据类型:image.astype(np.float32)
            ToAbsoluteCoords(), //计算bbox的绝对坐标,注意这里只是定义一个临时对象。
            enable_if(cfg.augment_photometric_distort, PhotometricDistort()),  //光度扭曲
            enable_if(cfg.augment_expand, Expand(mean)),//扩张
            enable_if(cfg.augment_random_sample_crop, RandomSampleCrop()),//随机裁剪
            enable_if(cfg.augment_random_mirror, RandomMirror()),//随机镜像
            enable_if(cfg.augment_random_flip, RandomFlip()),//随机翻转
            enable_if(cfg.augment_random_flip, RandomRot90()), //随机旋转
            Resize(),
            enable_if(not cfg.preserve_aspect_ratio, Pad(cfg.max_size, cfg.max_size, mean)),
            ToPercentCoords(),
            PrepareMasks(cfg.mask_size, cfg.use_gt_bboxes),
            BackboneTransform(cfg.backbone.transform, mean, std, 'BGR')
        ])

我们下面来看下该代码用了那些图像增强:
1 . PhotometricDistort 光度扭曲

class PhotometricDistort(object):
    def __init__(self):
        self.pd = [
            RandomContrast(),
            ConvertColor(transform='HSV'),
            RandomSaturation(),
            RandomHue(),
            ConvertColor(current='HSV', transform='BGR'),
            RandomContrast()
        ]
        self.rand_brightness = RandomBrightness()
        self.rand_light_noise = RandomLightingNoise()
    def __call__(self, image, masks, boxes, labels):
        im = image.copy()
        im, masks, boxes, labels = self.rand_brightness(im, masks, boxes, labels)
        if random.randint(2):
            distort = Compose(self.pd[:-1])
        else:
            distort = Compose(self.pd[1:])
        im, masks, boxes, labels = distort(im, masks, boxes, labels)
        return self.rand_light_noise(im, masks, boxes, labels)    

2.RandomContrast() 随机对比度

class RandomContrast(object):
    def __init__(self, lower=0.5, upper=1.5):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "contrast upper must be >= lower." //随机选取值的约束声明
        assert self.lower >= 0, "contrast lower must be non-negative."

    # expects float image
    def __call__(self, image, masks=None, boxes=None, labels=None):
        if random.randint(2):
            alpha = random.uniform(self.lower, self.upper)
            image *= alpha   //在self.lower, self.upper的正态分布中选一个随机值,对图片整体进行加权。
        return image, masks, boxes, labels

3.ConvertColor(transform=‘HSV’) 颜色通道转换

class ConvertColor(object):
    def __init__(self, current='BGR', transform='HSV'):
        self.transform = transform
        self.current = current

    def __call__(self, image, masks=None, boxes=None, labels=None):
        if self.current == 'BGR' and self.transform == 'HSV':
            image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)  //将原图像中BGR颜色通道转为HSV颜色通道
        elif self.current == 'HSV' and self.transform == 'BGR':
            image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
        else:
            raise NotImplementedError
        return image, masks, boxes, labels

4.RandomSaturation()

class RandomSaturation(object):
    def __init__(self, lower=0.5, upper=1.5):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "contrast upper must be >= lower."
        assert self.lower >= 0, "contrast lower must be non-negative."

    def __call__(self, image, masks=None, boxes=None, labels=None):
        if random.randint(2):
            image[:, :, 1] *= random.uniform(self.lower, self.upper) //对图片数据的第2个通道进行随机加权

        return image, masks, boxes, labels

5.RandomHue()

class RandomHue(object):
    def __init__(self, delta=18.0):
        assert delta >= 0.0 and delta <= 360.0
        self.delta = delta

    def __call__(self, image, masks=None, boxes=None, labels=None):
        if random.randint(2):
            image[:, :, 0] += random.uniform(-self.delta, self.delta) //对图像数据第一个通道的像素点 加上一个随机值
            image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0  //对图像数据第一个通道中大于360的像素点减上一个360
            image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 
        return image, masks, boxes, labels

6.RandomBrightness() 随机亮度

class RandomBrightness(object):
    def __init__(self, delta=32):
        assert delta >= 0.0
        assert delta <= 255.0
        self.delta = delta

    def __call__(self, image, masks=None, boxes=None, labels=None):
        if random.randint(2):
            delta = random.uniform(-self.delta, self.delta)
            image += delta 
        return image, masks, boxes, labels

7.随机裁剪RandomSampleCrop()

class RandomSampleCrop(object):
    """Crop
    Arguments:
        img (Image): the image being input during training

    Return:
        img (Image): the cropped image

    """
    def __init__(self):
        self.sample_options = (
            # using entire original input image
            None,
            # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
            (0.1, None),
            (0.3, None),
            (0.7, None),
            (0.9, None),
            # randomly sample a patch
            (None, None),
        )

    def __call__(self,image):
        height, width, _ = image.shape
        crop_list = []
        while True:
            # randomly choose a mode
            mode = random.choice(self.sample_options)
            if mode is None:
                return image, masks, boxes, labels

            min_iou, max_iou = mode
            if min_iou is None:
                min_iou = float('-inf')
            if max_iou is None:
                max_iou = float('inf')
            
            
            # max trails (50)
            for _ in range(50):
                current_image = image

                w = random.uniform(0.3 * width, width)
                h = random.uniform(0.3 * height, height)

                # aspect ratio constraint b/t .5 & 2
                if h / w < 0.5 or h / w > 2:
                    continue
                left = random.uniform(width - w)
                top = random.uniform(height - h)
                # convert to integer rect x1,y1,x2,y2
                rect = np.array([int(left), int(top), int(left+w), int(top+h)])

                # cut the crop from the image
                current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
                                              :]
                crop_list.append(current_image)

            return crop_list

8.随机翻转,随机旋转90度:

class RandomFlip(object):
    def __call__(self, image):
        height , _ , _ = image.shape
        if random.randint(2):
            image = image[::-1, :]   
        return image
        
class RandomRot90(object):
    def __call__(self, image, masks=None, boxes=None, labels=None):
        old_height , old_width , _ = image.shape
        k = random.randint(4)
        image = np.rot90(image,k)
#         masks = np.array([np.rot90(mask,k) for mask in masks])
#         boxes = boxes.copy()
#         for _ in range(k):
#             boxes = np.array([[box[1], old_width - 1 - box[2], box[3], old_width - 1 - box[0]] for box in boxes])
#             old_width, old_height = old_height, old_width
        return image, masks, boxes, labels

9.image_Expand

class Expand(object):
    def __init__(self, mean):
        self.mean = mean

    def __call__(self, image):
        if random.randint(2):
            return image

        height, width, depth = image.shape
        ratio = random.uniform(1, 4)
        left = random.uniform(0, width*ratio - width)
        top = random.uniform(0, height*ratio - height)

        expand_image = np.zeros(
            (int(height*ratio), int(width*ratio), depth),
            dtype=image.dtype)
   
        expand_image[:, :, :] = self.mean
        expand_image[int(top):int(top + height),
                     int(left):int(left + width)] = image
        image = expand_image
        return image

我们实际用一下上面的数据增强
输入图片:
在这里插入图片描述
1.光度扭曲
代码调用:

image_path = './images/dog.jpg'
image = cv2.imread(image_path) #uint8数据类型
image= image.astype(np.float32)
aug_ = PhotometricDistort()
image_aug_whole = aug_(image,None,None,None)
cv2.imshow('dog_aug_whole',image_aug_whole[0])
cv2.waitKey(0)
cv2.destroyAllWindows()

在这里插入图片描述
2.RandomContrast() 随机对比度
在这里插入图片描述
3.ConvertColor(transform=‘HSV’) 颜色通道转换
在这里插入图片描述
4.RandomSaturation()
在这里插入图片描述
5.RandomHue()
在这里插入图片描述
6.RandomBrightness() 随机亮度
在这里插入图片描述
7.随机裁剪RandomSampleCrop()

代码调用

image_path = './images/dog.jpg'
image = cv2.imread(image_path) #uint8数据类型
cv2.imshow('dog',image)
cv2.waitKey(0)
cv2.destroyAllWindows()

images= image.astype(np.float32)
rand_crop = RandomSampleCrop()
current_image=rand_crop(image)
cv2.imshow('dog',current_image[2])  
cv2.waitKey(0)
cv2.destroyAllWindows()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

8.随机翻转,随机旋转90度
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
9.image_Expand
代码调用

image_path = './images/dog.jpg'
image = cv2.imread(image_path) #uint8数据类型
cv2.imshow('dog',image)
cv2.waitKey(0)
cv2.destroyAllWindows()
print(image.shape)

MEANS = (103.94, 116.78, 123.68)
image_Expand = Expand(MEANS)

expand_image = image_Expand(image)
expand_image.shape

cv2.imshow('dog',expand_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

在这里插入图片描述
关于数据增强部分,到此讲解完毕,后面会继续讲解train.py
YOLACT++代码分析2——Yolact模型

Reference:
YOLACT++ Better Real-time Instance Segmentation:https://arxiv.org/abs/1912.06218
YOLACT++源码:https://github.com/dbolya/yolact

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值