深度学习-mmcv中build_runner实例化全流程详解

1、实例化的例子

#运行的类型是IterBasedRunne
runner = dict(
type="IterBasedRunner",
max_iters=num_iters_per_epoch * num_epochs,
)
#实例化运行器,会返回一个IterBasedRunner的实例化对象
runner = build_runner(
    cfg.runner,
    default_args=dict(
        model=model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta,
    ),
)

2、实例化的完整调用流程及关键注释
~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/runner/builder.py:

#调用~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/utils/registry.py中的Registry类,实例化注册器,'runner'、'runner builder'对应__init__的形参name
RUNNERS = Registry('runner')
RUNNER_BUILDERS = Registry('runner builder')

def build_runner_constructor(cfg: dict):
		#实例化~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/runner/default_constructor.py中的DefaultRunnerConstructor
    return RUNNER_BUILDERS.build(cfg)

def build_runner(cfg: dict, default_args: Optional[dict] = None):
    runner_cfg = copy.deepcopy(cfg)
    #参数配置文件中没提供'constructor',这里采用默认运行构造器'DefaultRunnerConstructor'
    constructor_type = runner_cfg.pop('constructor',
                                      'DefaultRunnerConstructor')
    runner_constructor = build_runner_constructor(
        dict(
            type=constructor_type,
            runner_cfg=runner_cfg,
            default_args=default_args))
    runner = runner_constructor()#调用构造器的__call__()方法
    return runner

~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/runner/default_constructor.py

导入~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/runner/builder.py中的RUNNER_BUILDERS, RUNNERS
from .builder import RUNNER_BUILDERS, RUNNERS

@RUNNER_BUILDERS.register_module()
class DefaultRunnerConstructor:
   
    def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None):
        if not isinstance(runner_cfg, dict):
            raise TypeError('runner_cfg should be a dict',
                            f'but got {type(runner_cfg)}')
        self.runner_cfg = runner_cfg
        self.default_args = default_args
    #~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/runner/builder.py中的runner = runner_constructor()会调用__call__
    因为self.runner_cfg中的type为IterBasedRunner,所以会实例化~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/runner/iter_based_runner.py中的class IterBasedRunner(BaseRunner):
    def __call__(self):
        return RUNNERS.build(self.runner_cfg, default_args=self.default_args)

~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/runner/iter_based_runner.py:

@RUNNERS.register_module()
#继承父类的__init__
class IterBasedRunner(BaseRunner):

~/anaconda3/envs/sparsedrive/lib/python3.7/site-packages/mmcv/runner/base_runner.py:

class BaseRunner(metaclass=ABCMeta):
    def __init__(self,
                 model: torch.nn.Module,
                 batch_processor: Optional[Callable] = None,
                 optimizer: Union[Dict, torch.optim.Optimizer, None] = None,
                 work_dir: Optional[str] = None,
                 logger: Optional[logging.Logger] = None,
                 meta: Optional[Dict] = None,
                 max_iters: Optional[int] = None,
                 max_epochs: Optional[int] = None) -> None:
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError('batch_processor must be callable, '
                                f'but got {type(batch_processor)}')
            warnings.warn(
                'batch_processor is deprecated, please implement '
                'train_step() and val_step() in the model instead.',
                DeprecationWarning)
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
            if is_module_wrapper(model):
                _model = model.module
            else:
                _model = model
            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
                raise RuntimeError(
                    'batch_processor and model.train_step()/model.val_step() '
                    'cannot be both available.')
        else:
            assert hasattr(model, 'train_step')

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f'optimizer must be a dict of torch.optim.Optimizers, '
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            raise TypeError(
                f'optimizer must be a torch.optim.Optimizer object '
                f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f'logger must be a logging.Logger object, '
                            f'but got {type(logger)}')

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')

        self.model = model
        self.batch_processor = batch_processor
        self.optimizer = optimizer
        self.logger = logger
        self.meta = meta
        # create work_dir
        if isinstance(work_dir, str):
            self.work_dir: Optional[str] = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        self.mode: Optional[str] = None
        self._hooks: List[Hook] = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                'Only one of `max_epochs` or `max_iters` can be set.')

        self._max_epochs = max_epochs
        self._max_iters = max_iters
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值