一、四大注册器类DATASETS、PIPELINES、MODELS、HOOKS
Registry 用于提供全局类注册器功能,Registry 类内部其实维护的是一个全局 key-value 对。通过 Registry 类,用户可以通过字符串方式实例化任何想要的模块。
#先构建全局的 DATASETS 注册器类
##只是设置了self._name = 'dataset',self.build_func = build_from_cfg
DATASETS = Registry('dataset')
#通过装饰器方式作用在想要加入注册器的具体类中
@DATASETS.register_module()
class KittiTinyDataset(CustomDataset):
##只是设置了self._name = 'pipeline',self.build_func = build_from_cfg
PIPELINES = Registry('pipeline')
#mmdet/models/builder.py
MODELS = Registry('models', parent=MMCV_MODELS)
#mmcv/runner/hooks/hook.py
HOOKS = Registry('hook')
二、四大注册器类实例化过程
通过register_module装饰器, 把key-value添加到字典中self._module_dict[name] = module_class
def register_module(self, name=None, force=False, module=None):
'省略。。。'
#register_module是用作一种方法
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
#register_module是用作一个装饰器
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
def _register_module(self, module_class, module_name=None, force=False):
if module_name is None:
module_name = module_class.__name__
self._module_dict[name] = module_class
下面是数据集类的实例化过程
datasets = [build_dataset(cfg.data.train)]
def build_dataset(cfg, default_args=None):
dataset = build_from_cfg(cfg, DATASETS, default_args)
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
args = cfg.copy()
#获取类和类名
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
#在这里进行实例化,我定义的类KittiTinyDataset并没有初始化函数,他的父类CustomDataset有对应初始化函数
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
#先构建全局的 MODELS 注册器类
MODELS = Registry('models', parent=MMCV_MODELS)
这里的MMCV_MODELS就是下面的MODELS
MODELS = Registry('model', build_func=build_model_from_cfg)
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
cfg是list类型的情况下,循环调用build_from_cfg
def build_model_from_cfg(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
下面时配置文件中的部分代码,根据type内容,去实例化对应的模型
backbone=dict(
type='ResNet',...),
neck=dict(
type='FPN',...),
rpn_head=dict(
type='RPNHead',...),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',...),
loss_cls=dict(
type='CrossEntropyLoss', ...),
roi_head=dict(
type='StandardRoIHead',...)
这里需要注意,Models通过register_module装饰器, 把key-value添加到字典中的函数都在mmdet/models/文件路径下,下面代码在mmdet/models/faster_rcnn.py中
@DETECTORS.register_module()
class FasterRCNN(TwoStageDetector):
总结:1、Registry创建多个全局注册器类,比如DATASETS = Registry(‘dataset’), PIPELINES = Registry(‘pipeline’),MODELS = Registry(‘models’, parent=MMCV_MODELS)
2、通过@DETECTORS.register_module()这种类似的装饰器,把类名和类添加到Registry._module_dict中
3、用配置文件内容去实例化对应的类