mmlab之自定义数据处理

本文介绍了如何在mmlab中自定义数据变换类,如MyFlip,并展示了如何使用RandomChoice和RandomApply进行随机选择和执行数据增强。同时,详细讨论了利用getattr与Albumentations库结合,实现自定义数据处理的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

系列文章目录

例如:第一章mmlab数据处理模块



前言

mmcv是mmlab系列的底层支持架构之一,主要进行数据读取、数据变换、卷积神经网络、以及一些nms之类的基础算子。其主要采用register实现相应模块的继承。下面将以文档中的案例进行展开介绍


一、自定义数据变换类

要实现一个新的数据变换类,需要继承 BaseTransform,并实现 transform 方法。这里,我们使用一个简单的翻转变换(MyFlip)作为示例:

import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS

@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
    def __init__(self, direction: str):
        super().__init__()
        self.direction = direction

    def transform(self, results: dict) -> dict:
        img = results['img']
        results['img'] = mmcv.imflip(img, direction=self.direction)
        return results

从而,我们可以实例化一个 MyFlip 对象,并将之作为一个可调用对象,来处理我们的数据字典。

import numpy as np

transform = MyFlip(direction='horizontal')
data_dict = {'img': np.random.rand(224, 224, 3)}
data_dict = transform(data_dict)
processed_img = data_dict['img']

又或者,在配置文件的 pipeline 中使用 MyFlip 变换

pipeline = [
    ...
    dict(type='MyFlip', direction='horizontal'),
    ...
]

二、随机选择和随机执行

随机选择包装(RandomChoice)用于从一系列数据变换组合中随机应用一个数据变换组合。利用这一包装,我们可以简单地实现一些数据增强功能,比如 AutoAugment。

如果配合注册器和配置文件使用的话,在配置文件中数据集的 pipeline 中如下例使用随机选择包装:

代码如下(示例):

pipeline = [
    ...
    dict(type='RandomChoice',
        transforms=[
            [
                dict(type='Posterize', bits=4),
                dict(type='Rotate', angle=30.)
            ],  # 第一种随机变化组合
            [
                dict(type='Equalize'),
                dict(type='Rotate', angle=30)
            ],  # 第二种随机变换组合
        ],
        prob=[0.4, 0.6]  # 两种随机变换组合各自的选用概率
        )
    ...
]

随机执行包装(RandomApply)用于以指定概率随机执行数据变换组合。例如:

pipeline = [
    ...
    dict(type='RandomApply',
        transforms=[dict(type='Rotate', angle=30.)],
        prob=0.3)  # 以 0.3 的概率执行被包装的数据变换
    ...
]

albumentation

自定义Albu方法

@TRANSFORMS.register_module(['Albumentations', 'Albu'])
class Albumentations(BaseTransform):
    """Wrapper to use augmentation from albumentations library.

    **Required Keys:**

    - img

    **Modified Keys:**

    - img
    - img_shape

    Adds custom transformations from albumentations library.
    More details can be found in
    `Albumentations <https://albumentations.readthedocs.io>`_.
    An example of ``transforms`` is as followed:

    .. code-block::

        [
            dict(
                type='ShiftScaleRotate',
                shift_limit=0.0625,
                scale_limit=0.0,
                rotate_limit=0,
                interpolation=1,
                p=0.5),
            dict(
                type='RandomBrightnessContrast',
                brightness_limit=[0.1, 0.3],
                contrast_limit=[0.1, 0.3],
                p=0.2),
            dict(type='ChannelShuffle', p=0.1),
            dict(
                type='OneOf',
                transforms=[
                    dict(type='Blur', blur_limit=3, p=1.0),
                    dict(type='MedianBlur', blur_limit=3, p=1.0)
                ],
                p=0.1),
        ]

    Args:
        transforms (List[Dict]): List of albumentations transform configs.
        keymap (Optional[Dict]): Mapping of mmpretrain to albumentations
            fields, in format {'input key':'albumentation-style key'}.
            Defaults to None.

    Example:
        >>> import mmcv
        >>> from mmpretrain.datasets import Albumentations
        >>> transforms = [
        ...     dict(
        ...         type='ShiftScaleRotate',
        ...         shift_limit=0.0625,
        ...         scale_limit=0.0,
        ...         rotate_limit=0,
        ...         interpolation=1,
        ...         p=0.5),
        ...     dict(
        ...         type='RandomBrightnessContrast',
        ...         brightness_limit=[0.1, 0.3],
        ...         contrast_limit=[0.1, 0.3],
        ...         p=0.2),
        ...     dict(type='ChannelShuffle', p=0.1),
        ...     dict(
        ...         type='OneOf',
        ...         transforms=[
        ...             dict(type='Blur', blur_limit=3, p=1.0),
        ...             dict(type='MedianBlur', blur_limit=3, p=1.0)
        ...         ],
        ...         p=0.1),
        ... ]
        >>> albu = Albumentations(transforms)
        >>> data = {'img': mmcv.imread('./demo/demo.JPEG')}
        >>> data = albu(data)
        >>> print(data['img'].shape)
        (375, 500, 3)
    """

    def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None):
        if albumentations is None:
            raise RuntimeError('albumentations is not installed')
        else:
            from albumentations import Compose as albu_Compose

        assert isinstance(transforms, list), 'transforms must be a list.'
        if keymap is not None:
            assert isinstance(keymap, dict), 'keymap must be None or a dict. '

        self.transforms = transforms

        self.aug = albu_Compose(
            [self.albu_builder(t) for t in self.transforms])

        if not keymap:
            self.keymap_to_albu = dict(img='image')
        else:
            self.keymap_to_albu = keymap
        self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}

    def albu_builder(self, cfg: Dict):
        """Import a module from albumentations.

        It inherits some of :func:`build_from_cfg` logic.
        Args:
            cfg (dict): Config dict. It should at least contain the key "type".
        Returns:
            obj: The constructed object.
        """

        assert isinstance(cfg, dict) and 'type' in cfg, 'each item in ' \
            "transforms must be a dict with keyword 'type'."
        args = cfg.copy()

        obj_type = args.pop('type')
        if mmengine.is_str(obj_type):
            obj_cls = getattr(albumentations, obj_type)
        elif inspect.isclass(obj_type):
            obj_cls = obj_type
        else:
            raise TypeError(
                f'type must be a str or valid type, but got {type(obj_type)}')

        if 'transforms' in args:
            args['transforms'] = [
                self.albu_builder(transform)
                for transform in args['transforms']
            ]

        return obj_cls(**args)

    @staticmethod
    def mapper(d, keymap):
        """Dictionary mapper.

        Renames keys according to keymap provided.
        Args:
            d (dict): old dict
            keymap (dict): {'old_key':'new_key'}
        Returns:
            dict: new dict.
        """

        updated_dict = {}
        for k, v in zip(d.keys(), d.values()):
            new_k = keymap.get(k, k)
            updated_dict[new_k] = d[k]
        return updated_dict

[docs]    def transform(self, results: Dict) -> Dict:
        """Transform function to perform albumentations transforms.

        Args:
            results (dict): Result dict from loading pipeline.

        Returns:
            dict: Transformed results, 'img' and 'img_shape' keys are
                updated in result dict.
        """
        assert 'img' in results, 'No `img` field in the input.'

        # dict to albumentations format
        results = self.mapper(results, self.keymap_to_albu)
        results = self.aug(**results)

        # back to the original format
        results = self.mapper(results, self.keymap_back)
        results['img_shape'] = results['img'].shape[:2]

        return results


    def __repr__(self):
        """Print the basic information of the transform.

        Returns:
            str: Formatted string.
        """
        repr_str = self.__class__.__name__
        repr_str += f'(transforms={repr(self.transforms)})'
        return repr_str
Albu-transformer = [
            dict(
                type='ShiftScaleRotate',
                shift_limit=0.0625,
                scale_limit=0.0,
                rotate_limit=0,
                interpolation=1,
                p=0.5),
            dict(
                type='RandomBrightnessContrast',
                brightness_limit=[0.1, 0.3],
                contrast_limit=[0.1, 0.3],
                p=0.2),
            dict(type='ChannelShuffle', p=0.1),
            dict(
                type='OneOf',
                transforms=[
                    dict(type='Blur', blur_limit=3, p=1.0),
                    dict(type='MedianBlur', blur_limit=3, p=1.0)
                ],
                p=0.1),
        ]


train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', prob=0.5),
    dict(type='Albumentations', transforms = Albu-transformer),
    dict(type='PackDetInputs')
]

mmsegmentation

支持调用Albu数据流,调用代码如下

albu_train_transforms = [
    dict(
        type='ShiftScaleRotate',
        shift_limit=0.0625,
        scale_limit=0.0,
        rotate_limit=0,
        interpolation=1,
        p=0.5),
    dict(
        type='RandomBrightnessContrast',
        brightness_limit=[0.1, 0.5],
        contrast_limit=[0.1, 0.5],
        p=0.3),
    dict(
        type='OneOf',
        transforms=[
            dict(
                type='RGBShift',
                r_shift_limit=10,
                g_shift_limit=10,
                b_shift_limit=10,
                p=1.0),
            dict(
                type='HueSaturationValue',
                hue_shift_limit=20,
                sat_shift_limit=30,
                val_shift_limit=20,
                p=1.0)
        ],
        p=0.2),
    # dict(type='JpegCompression', quality_lower=85, quality_upper=95, p=0.2),
    dict(type='ChannelShuffle', p=0.3),
    dict(
        type='OneOf',
        transforms=[
            dict(type='Blur', blur_limit=3, p=1.0),
            dict(type='MedianBlur', blur_limit=3, p=1.0)
        ],
        p=0.3),
]


train_pipeline = [
    dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
    dict(type='LoadAnnotations'),
    dict(
        type='Albu',
        transforms=albu_train_transforms,
        
        keymap={
            'img': 'image',
            'gt_masks': 'masks',
        }),
    dict(
        type='RandomChoiceResize',
        scales=[384, 448, 512, 576, 640, 704, 768],
        resize_type='ResizeShortestEdge',
        max_size=1536),
    dict(type='RandomCrop', crop_size=crop_size),
     dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='RandomFlip', prob=0.5, direction='vertical'),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]

keymap参数将默认数据流的关键字映射为Albu支持的关键字,mmsegmentaion中设置了相应的默认值,在此处可以省略。

– 知识点:
在Python中,getattr 是一个内建函数,它用于获取一个对象的属性值。其基本用法是 getattr(object, name[, default]),其中 object 是要访问的对象,name 是要以字符串形式提供的对象属性名。

  • 对于表达式 getattr(albumentations, obj_type):

albumentations 应该是一个模块或类,它包含了一系列的属性和方法。
obj_type 是一个字符串,表示要从 albumentations 中获取的属性或方法的名称。如果 obj_type 存在于 albumentations 中,getattr 将返回该属性的值。如果不存在,并且没有提供 default 参数,则会抛出一个 AttributeError。
举个例子,如果 albumentations 是一个库,其中有一个名为 HorizontalFlip 的类,那么 getattr(albumentations, “HorizontalFlip”) 将返回 HorizontalFlip 类。

总结

mmlab其实具有比较详细的文档介绍,但是各个项目中会对一些特定的功能进行介绍,因此在使用mmlab系列时,需要对mmcv以及mmengine的相关文档进行阅读,便于根据自己的需求进行拓展。
链接:
mmcv参考内容
mmpretrain

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云朵不吃雨

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值