解决在训练完的基于mmcv框架的模型基础上继续训练时初始lr不正确的问题

        假设基于mmcv框架的模型按照设置的max_iteration或max_epoch训练完了,然后我想在它的权重基础上做resume训练, 设置更大的max_iteration或max_epoch后启动训练你会发现训练可以跑起来,但是初始的lr还是从0开始训练时设置的比较大的值,而不是上次训练完时的lr值,这会导致新启动的训练的Loss经常总是在那里震荡下不去做训练多了N轮都基本是无用功,这可以理解,loss已经不大了,还使用比较大的lr就容易发生震荡。

       那怎么解决再次启动训练时初始lr值不正确的问题呢?需要解决mmcv代码的一个缺陷。

mmcv/runner/iter_based_runner.py
    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._epoch = data_loader.epoch
        data_batch = next(data_loader)
        self.data_batch = data_batch
        self.call_hook('before_train_iter')   ###
        outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('model.train_step() must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs
        self.call_hook('after_train_iter')
        del self.data_batch
        self._inner_iter += 1
        self._iter += 1

mmcv/runner/base_runner.py

    def call_hook(self, fn_name: str) -> None:
        """Call all hooks.

        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
        for hook in self._hooks:          # #1 hook == mmcv.runner.hooks.lr_updater.CosineAnnealingLrUpdaterHook
            getattr(hook, fn_name)(self)  #fn_name == 'before_train_iter'

train开始前,调用self.call_hook('before_train_iter')

call_hook('before_train_iter')调用各个注册了的hook的'before_train_iter'方法,第一个hook就是更新学习率lr的hook,例如mmcv.runner.hooks.lr_updater.CosineAnnealingLrUpdaterHook,它继承自LrUpdaterHook,调用的其实是LrUpdaterHook.before_train_iter()

class LrUpdaterHook(Hook):
    ...

    def _set_lr(self, runner, lr_groups):
        if isinstance(runner.optimizer, dict):
            for k, optim in runner.optimizer.items():
                for param_group, lr in zip(optim.param_groups, lr_groups[k]):
                    param_group['lr'] = lr
        else:
            for param_group, lr in zip(runner.optimizer.param_groups,
                                       lr_groups):
                param_group['lr'] = lr
                
    def get_regular_lr(self, runner: 'runner.BaseRunner'):                        ###
        if isinstance(runner.optimizer, dict):
            lr_groups = {}
            for k in runner.optimizer.keys():
                _lr_group = [
                    self.get_lr(runner, _base_lr)
                    for _base_lr in self.base_lr[k]
                ]
                lr_groups.update({k: _lr_group})

            return lr_groups
        else:
            return [self.get_lr(runner, _base_lr) for _base_lr in self.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Arnold-FY-Chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值