图像样本增广,yoloV5扩展

本文介绍了如何在YOLOv5训练过程中集成Albumentations库进行数据增强,包括各种图像变换如亮度、对比度调整、裁剪、旋转等。同时,展示了如何处理边界框,确保在增强后仍保持有效,并提供了数据集缓存和读取的实现。此外,还提及了MosaicMixup方法和MixUp技术,用于进一步增强训练数据的多样性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

albumentation

https://albumentations.ai/docs/#getting-started-with-albumentations

minu area和minu visibility参数控制相册在增强后大小发生变化时应该对增强的边界框执行的操作。如果应用空间增强(例如,裁剪图像的一部分或调整图像大小),边界框的大小可能会更改

minu area是以像素为单位的值。如果扩展后边界框的面积小于最小面积,则Albumentations将删除该框。因此,返回的扩展边界框列表将不包含该边界框

minu visibility是介于0和1之间的值。如果增强后的边界框面积与增强前的边界框面积之比小于最小可见性,则Albumentations将删除该框。因此,如果增强过程剪切了大部分边界框,那么该框将不会出现在返回的增强边界框列表中

label_fields=['class_labels', 'class_categories']

class_labels 一个类别中的细分,例如class_categories 是动物,class_labels 则可以是狗、猫、牛等

集成 albumentation 及 yoloV5 中的mosaic mixup

class LoadImagesAndLabels_1(Dataset):  # for training/testing
    def __init__(self,
                 path,
                 img_size=640,
                 batch_size=16,
                 augment=False,
                 hyp=None,
                 rect=False,
                 image_weights=False,
                 cache_images=False,
                 single_cls=False,
                 stride=32,
                 pad=0.0,
                 prefix=''):
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
        self.mosaic_border = [-img_size // 2, -img_size // 2]
        self.stride = stride
        self.path = path
        self.transform = A.Compose([
            # ---------------------
            A.RandomBrightness(limit=2, p=0.5),  # 随机调整亮度
            A.RandomBrightnessContrast(p=0.2),  # 随机调整亮度
            A.RandomContrast(limit=2.3, p=0.5),  # 随机对比度
            # A.RGBShift(r_shift_limit=133, g_shift_limit=146, b_shift_limit=26, p=0.5),  # 色度偏移
            # A.RandomGamma(gamma_limit=148, p=0.5),  # 伽马变换
            # A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=20, val_shift_limit=27, p=0.5),  # HSV偏移
            # A.JpegCompression(quality_lower=80, quality_upper=100, p=0.5),  # 图片压缩
            # A.Blur(blur_limit=7, p=1),  # 模糊处理
            # A.MedianBlur(blur_limit=3, p=0.1),  # 中值模糊
            # A.MotionBlur(p=0.2),  # 动态模糊
            # A.GaussianBlur(blur_limit=7, p=1),  # 高斯模糊
            # A.GlassBlur(sigma=0.7, max_delta=4),  # 棱镜模糊
            # A.Cutout(num_holes=8, max_h_size=8, max_w_size=8),  # 挖小洞
            # A.RandomSnow(p=1),  # 加雪花
            # A.RandomRain(p=1),  # 加雨滴
            # A.RandomFog(p=1),  # 加雾
            # A.RandomSunFlare(p=1),  # 加阳光
            # A.RandomShadow(p=1),  # 加阴影
            # A.ISONoise(p=1),  # 传感器噪声
            # A.IAAAdditiveGaussianNoise(p=1),  # 高斯噪声
            # A.MultiplicativeNoise(p=1),  # 多层偏移噪声
            # A.IAAEmboss(p=1),  # 浮雕
            # A.IAASuperpixels(p=1),  # 超像素
            # A.IAASharpen(p=1),  # 锐化
            # # ---------------------
            # A.RandomCrop(width=450, height=450),
            # A.CenterCrop(height=480, width=480, p=0.5),  # 中心裁剪
            # # ---------------------
            # A.Rotate(limit=89, p=0.5),  # 随机旋转
            # A.VerticalFlip(p=0.5),  # 垂直翻转
            # A.HorizontalFlip(p=0.5),  # 水平翻转
            # A.RandomRotate90(p=0.5),  # 随机90度旋转
            # A.ShiftScaleRotate(shift_limit=0.8, scale_limit=1, rotate_limit=118, p=0.5),  # 旋转平移缩放
            # A.IAAPerspective(p=1),  # 透视变换
            # # ---------------------
            # A.ElasticTransform(alpha=155, sigma=210, alpha_affine=157, p=0.5),  # 弹性变形, 文本图像分析最好的变形
            # A.OpticalDistortion(distort_limit=0.25, shift_limit=0.2, p=0.5),  # 光学变形,
            # A.GridDistortion(num_steps=5, distort_limit=0.3, p=1),  # 栅格变形
            # A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8)),  # 自适应直方图均衡化
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))

        try:
            f = []  # image files
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)  # os-agnostic
                if p.is_dir():  # dir
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                    # f = list(p.rglob('**/*.*'))  # pathlib
                elif p.is_file():  # file
                    with open(p, 'r') as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
                else:
                    raise Exception(f'{prefix}{p} does not exist')
            self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats])  # pathlib
            assert self.img_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')

        # Check cache
        self.label_files = img2label_paths(self.img_files)  # labels
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')  # cached labels
        if cache_path.is_file():
            cache, exists = torch.load(cache_path), True  # load
            if cache['hash'] != get_hash(self.label_files + self.img_files):  # changed
                cache, exists = self.cache_labels(cache_path, prefix), False  # re-cache
        else:
            cache, exists = self.cache_labels(cache_path, prefix), False  # cache

        # Display cache
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupted, total
        if exists:
            d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
            tqdm(None, desc=prefix + d, total=n, initial=n)  # display cache results
        assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'

        # Read cache
        cache.pop('hash')  # remove hash
        cache.pop('version')  # remove version
        labels, shapes, self.segments = zip(*cache.values())
        self.labels = list(labels)
        self.shapes = np.array(shapes, dtype=np.float64)
        self.img_files = list(cache.keys())  # update
        self.label_files = img2label_paths(cache.keys())  # update
        if single_cls:
            for x in self.labels:
                x[:, 0] = 0

        n = len(shapes)  # number of images
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index
        nb = bi[-1] + 1  # number of batches
        self.batch = bi  # batch index of image
        self.n = n
        self.indices = range(n)

        # Rectangular Training
        if self.rect:
            # Sort by aspect ratio
            s = self.shapes  # wh
            ar = s[:, 1] / s[:, 0]  # aspect ratio
            irect = ar.argsort()
            self.img_files = [self.img_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.shapes = s[irect]  # wh
            ar = ar[irect]

            # Set training image shapes
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]

            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride

        # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
        self.imgs = [None] * n
        if cache_images:
            gb = 0  # Gigabytes of cached images
            self.img_hw0, self.img_hw = [None] * n, [None] * n
            results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
            pbar = tqdm(enumerate(results), total=n)
            for i, x in pbar:
                self.imgs[i], self.img_hw0[i], self.img_hw[i] = x  # img, hw_original, hw_resized = load_image(self, i)
                gb += self.imgs[i].nbytes
                pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
            pbar.close()

    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
        # Cache dataset labels, check images and read shapes
        x = {}  # dict
        nm, nf, ne, nc = 0, 0, 0, 0  # number missing, found, empty, corrupt
        desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
        with Pool(num_threads) as pool:
            pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
                        desc=desc, total=len(self.img_files))
            for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f in pbar:
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                if im_file:
                    x[im_file] = [l, shape, segments]
                pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"

        pbar.close()
        if nf == 0:
            logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
        x['hash'] = get_hash(self.label_files + self.img_files)
        x['results'] = nf, nm, ne, nc, len(self.img_files)
        x['version'] = 0.2  # cache version
        try:
            torch.save(x, path)  # save cache for next time
            logging.info(f'{prefix}New cache created: {path}')
        except Exception as e:
            logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}')  # path not writeable
        return x

    def __len__(self):
        return len(self.img_files)

    # def __iter__(self):
    #     self.count = -1
    #     print('ran dataset iter')
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
    #     return self

    def __getitem__(self, index):
        index = self.indices[index]  # linear, shuffled, or image_weights

        hyp = self.hyp
        mosaic = self.mosaic and random.random() < hyp['mosaic']
        if mosaic:
            # Load mosaic
            # mosaic 方法集成投射投影变换,裁剪等,使用时后续不用再做这方面的增强
            img, labels = load_mosaic(self, index)
            shapes = None

            # MixUp https://arxiv.org/pdf/1710.09412.pdf
            if random.random() < hyp['mixup']:
                img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
                r = np.random.beta(32.0, 32.0)  # mixup ratio, alpha=beta=32.0
                img = (img * r + img2 * (1 - r)).astype(np.uint8)
                labels = np.concatenate((labels, labels2), 0)
        else:
            # Load image
            img, (h0, w0), (h, w) = load_image(self, index)

            # Letterbox
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

            labels = self.labels[index].copy()
            if labels.size:  # normalized xywh to pixel xyxy format, 对应pascal_voc标注类型
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

        # 图像增强
        if self.augment:
            bboxes = labels[:, 1:]
            class_labels = labels[:, 0]

            transformed = self.transform(image=img, bboxes=bboxes, class_labels=class_labels)
            img = transformed['image']
            transformed_bboxes = transformed['bboxes']
            transformed_class_labels = transformed['class_labels']

            labels = np.zeros((len(transformed_class_labels), 5))
            labels[:, 0] = transformed_class_labels
            labels[:, 1:] = transformed_bboxes

        nL = len(labels)  # number of labels
        if nL:
            labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])  # convert xyxy to xywh
            labels[:, [2, 4]] /= img.shape[0]  # normalized height 0-1
            labels[:, [1, 3]] /= img.shape[1]  # normalized width 0-1

        labels_out = torch.zeros((nL, 6))
        if nL:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        return torch.from_numpy(img), labels_out, self.img_files[index], shapes

    @staticmethod
    def collate_fn(batch):
        img, label, path, shapes = zip(*batch)  # transposed
        for i, l in enumerate(label):
            l[:, 0] = i  # add target image index for build_targets()
        return torch.stack(img, 0), torch.cat(label, 0), path, shapes

    @staticmethod
    def collate_fn4(batch):
        img, label, path, shapes = zip(*batch)  # transposed
        n = len(shapes) // 4
        img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

        ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
        wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
        s = torch.tensor([[1, 1, .5, .5, .5, .5]])  # scale
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
            i *= 4
            if random.random() < 0.5:
                im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
                    0].type(img[i].type())
                l = label[i]
            else:
                im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
                l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
            img4.append(im)
            label4.append(l)

        for i, l in enumerate(label4):
            l[:, 0] = i  # add target image index for build_targets()

        return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值