假设基于mmcv框架的模型按照设置的max_iteration或max_epoch训练完了,然后我想在它的权重基础上做resume训练, 设置更大的max_iteration或max_epoch后启动训练你会发现训练可以跑起来,但是初始的lr还是从0开始训练时设置的比较大的值,而不是上次训练完时的lr值,这会导致新启动的训练的Loss经常总是在那里震荡下不去做训练多了N轮都基本是无用功,这可以理解,loss已经不大了,还使用比较大的lr就容易发生震荡。
那怎么解决再次启动训练时初始lr值不正确的问题呢?需要解决mmcv代码的一个缺陷。
mmcv/runner/iter_based_runner.py
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._epoch = data_loader.epoch
data_batch = next(data_loader)
self.data_batch = data_batch
self.call_hook('before_train_iter') ###
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('model.train_step() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
del self.data_batch
self._inner_iter += 1
self._iter += 1
mmcv/runner/base_runner.py
def call_hook(self, fn_name: str) -> None:
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks: # #1 hook == mmcv.runner.hooks.lr_updater.CosineAnnealingLrUpdaterHook
getattr(hook, fn_name)(self) #fn_name == 'before_train_iter'
train开始前,调用self.call_hook('before_train_iter')
call_hook('before_train_iter')调用各个注册了的hook的'before_train_iter'方法,第一个hook就是更新学习率lr的hook,例如mmcv.runner.hooks.lr_updater.CosineAnnealingLrUpdaterHook,它继承自LrUpdaterHook,调用的其实是LrUpdaterHook.before_train_iter()
class LrUpdaterHook(Hook):
...
def _set_lr(self, runner, lr_groups):
if isinstance(runner.optimizer, dict):
for k, optim in runner.optimizer.items():
for param_group, lr in zip(optim.param_groups, lr_groups[k]):
param_group['lr'] = lr
else:
for param_group, lr in zip(runner.optimizer.param_groups,
lr_groups):
param_group['lr'] = lr
def get_regular_lr(self, runner: 'runner.BaseRunner'): ###
if isinstance(runner.optimizer, dict):
lr_groups = {}
for k in runner.optimizer.keys():
_lr_group = [
self.get_lr(runner, _base_lr)
for _base_lr in self.base_lr[k]
]
lr_groups.update({k: _lr_group})
return lr_groups
else:
return [self.get_lr(runner, _base_lr) for _base_lr in self.

最低0.47元/天 解锁文章
3213

被折叠的 条评论
为什么被折叠?



