系列文章目录
例如:第一章mmlab数据处理模块
前言
mmcv是mmlab系列的底层支持架构之一,主要进行数据读取、数据变换、卷积神经网络、以及一些nms之类的基础算子。其主要采用register实现相应模块的继承。下面将以文档中的案例进行展开介绍
一、自定义数据变换类
要实现一个新的数据变换类,需要继承 BaseTransform,并实现 transform 方法。这里,我们使用一个简单的翻转变换(MyFlip)作为示例:
import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS
@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
def __init__(self, direction: str):
super().__init__()
self.direction = direction
def transform(self, results: dict) -> dict:
img = results['img']
results['img'] = mmcv.imflip(img, direction=self.direction)
return results
从而,我们可以实例化一个 MyFlip 对象,并将之作为一个可调用对象,来处理我们的数据字典。
import numpy as np
transform = MyFlip(direction='horizontal')
data_dict = {'img': np.random.rand(224, 224, 3)}
data_dict = transform(data_dict)
processed_img = data_dict['img']
又或者,在配置文件的 pipeline 中使用 MyFlip 变换
pipeline = [
...
dict(type='MyFlip', direction='horizontal'),
...
]
二、随机选择和随机执行
随机选择包装(RandomChoice)用于从一系列数据变换组合中随机应用一个数据变换组合。利用这一包装,我们可以简单地实现一些数据增强功能,比如 AutoAugment。
如果配合注册器和配置文件使用的话,在配置文件中数据集的 pipeline 中如下例使用随机选择包装:
代码如下(示例):
pipeline = [
...
dict(type='RandomChoice',
transforms=[
[
dict(type='Posterize', bits=4),
dict(type='Rotate', angle=30.)
], # 第一种随机变化组合
[
dict(type='Equalize'),
dict(type='Rotate', angle=30)
], # 第二种随机变换组合
],
prob=[0.4, 0.6] # 两种随机变换组合各自的选用概率
)
...
]
随机执行包装(RandomApply)用于以指定概率随机执行数据变换组合。例如:
pipeline = [
...
dict(type='RandomApply',
transforms=[dict(type='Rotate', angle=30.)],
prob=0.3) # 以 0.3 的概率执行被包装的数据变换
...
]
albumentation
自定义Albu方法
@TRANSFORMS.register_module(['Albumentations', 'Albu'])
class Albumentations(BaseTransform):
"""Wrapper to use augmentation from albumentations library.
**Required Keys:**
- img
**Modified Keys:**
- img
- img_shape
Adds custom transformations from albumentations library.
More details can be found in
`Albumentations <https://albumentations.readthedocs.io>`_.
An example of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (List[Dict]): List of albumentations transform configs.
keymap (Optional[Dict]): Mapping of mmpretrain to albumentations
fields, in format {'input key':'albumentation-style key'}.
Defaults to None.
Example:
>>> import mmcv
>>> from mmpretrain.datasets import Albumentations
>>> transforms = [
... dict(
... type='ShiftScaleRotate',
... shift_limit=0.0625,
... scale_limit=0.0,
... rotate_limit=0,
... interpolation=1,
... p=0.5),
... dict(
... type='RandomBrightnessContrast',
... brightness_limit=[0.1, 0.3],
... contrast_limit=[0.1, 0.3],
... p=0.2),
... dict(type='ChannelShuffle', p=0.1),
... dict(
... type='OneOf',
... transforms=[
... dict(type='Blur', blur_limit=3, p=1.0),
... dict(type='MedianBlur', blur_limit=3, p=1.0)
... ],
... p=0.1),
... ]
>>> albu = Albumentations(transforms)
>>> data = {'img': mmcv.imread('./demo/demo.JPEG')}
>>> data = albu(data)
>>> print(data['img'].shape)
(375, 500, 3)
"""
def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None):
if albumentations is None:
raise RuntimeError('albumentations is not installed')
else:
from albumentations import Compose as albu_Compose
assert isinstance(transforms, list), 'transforms must be a list.'
if keymap is not None:
assert isinstance(keymap, dict), 'keymap must be None or a dict. '
self.transforms = transforms
self.aug = albu_Compose(
[self.albu_builder(t) for t in self.transforms])
if not keymap:
self.keymap_to_albu = dict(img='image')
else:
self.keymap_to_albu = keymap
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
def albu_builder(self, cfg: Dict):
"""Import a module from albumentations.
It inherits some of :func:`build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg, 'each item in ' \
"transforms must be a dict with keyword 'type'."
args = cfg.copy()
obj_type = args.pop('type')
if mmengine.is_str(obj_type):
obj_cls = getattr(albumentations, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if 'transforms' in args:
args['transforms'] = [
self.albu_builder(transform)
for transform in args['transforms']
]
return obj_cls(**args)
@staticmethod
def mapper(d, keymap):
"""Dictionary mapper.
Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""
updated_dict = {}
for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
return updated_dict
[docs] def transform(self, results: Dict) -> Dict:
"""Transform function to perform albumentations transforms.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Transformed results, 'img' and 'img_shape' keys are
updated in result dict.
"""
assert 'img' in results, 'No `img` field in the input.'
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
results = self.aug(**results)
# back to the original format
results = self.mapper(results, self.keymap_back)
results['img_shape'] = results['img'].shape[:2]
return results
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(transforms={repr(self.transforms)})'
return repr_str
Albu-transformer = [
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='Albumentations', transforms = Albu-transformer),
dict(type='PackDetInputs')
]
mmsegmentation
支持调用Albu数据流,调用代码如下
albu_train_transforms = [
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.5],
contrast_limit=[0.1, 0.5],
p=0.3),
dict(
type='OneOf',
transforms=[
dict(
type='RGBShift',
r_shift_limit=10,
g_shift_limit=10,
b_shift_limit=10,
p=1.0),
dict(
type='HueSaturationValue',
hue_shift_limit=20,
sat_shift_limit=30,
val_shift_limit=20,
p=1.0)
],
p=0.2),
# dict(type='JpegCompression', quality_lower=85, quality_upper=95, p=0.2),
dict(type='ChannelShuffle', p=0.3),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.3),
]
train_pipeline = [
dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
dict(type='LoadAnnotations'),
dict(
type='Albu',
transforms=albu_train_transforms,
keymap={
'img': 'image',
'gt_masks': 'masks',
}),
dict(
type='RandomChoiceResize',
scales=[384, 448, 512, 576, 640, 704, 768],
resize_type='ResizeShortestEdge',
max_size=1536),
dict(type='RandomCrop', crop_size=crop_size),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='RandomFlip', prob=0.5, direction='vertical'),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
keymap参数将默认数据流的关键字映射为Albu支持的关键字,mmsegmentaion中设置了相应的默认值,在此处可以省略。
– 知识点:
在Python中,getattr 是一个内建函数,它用于获取一个对象的属性值。其基本用法是 getattr(object, name[, default]),其中 object 是要访问的对象,name 是要以字符串形式提供的对象属性名。
- 对于表达式 getattr(albumentations, obj_type):
albumentations 应该是一个模块或类,它包含了一系列的属性和方法。
obj_type 是一个字符串,表示要从 albumentations 中获取的属性或方法的名称。如果 obj_type 存在于 albumentations 中,getattr 将返回该属性的值。如果不存在,并且没有提供 default 参数,则会抛出一个 AttributeError。
举个例子,如果 albumentations 是一个库,其中有一个名为 HorizontalFlip 的类,那么 getattr(albumentations, “HorizontalFlip”) 将返回 HorizontalFlip 类。
总结
mmlab其实具有比较详细的文档介绍,但是各个项目中会对一些特定的功能进行介绍,因此在使用mmlab系列时,需要对mmcv以及mmengine的相关文档进行阅读,便于根据自己的需求进行拓展。
链接:
mmcv参考内容
mmpretrain