在mmseg框架下是默认用iters进行迭代训练的,但是可以自行设定使用epoch进行训练。
循环控制器指的是训练, 验证和测试时的执行流程, 在配置文件里面使用 train_cfg, val_cfg 和 test_cfg 来构建这些流程。
例如,使用基于迭代次数的训练循环 (IterBasedTrainLoop) 去训练 80,000 个迭代次数, 并且每 8,000 iteration 做一次验证, 可以如下设置:
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
在这里就可以调整,例如需要使用epoch进行训练,调整配置的类别信息。所有类别都在mmengine/runner/loops.py中。整体文件可以在官网查看:mmengine/mmengine/runner/loops.py at main · open-mmlab/mmengine · GitHub
mmengine/runner/loops.py中的EpochBasedTrainLoop类用于使用epoch进行训练:
@LOOPS.register_module() # 用于实现基于 epoch(轮次)的训练循环 class EpochBasedTrainLoop(BaseLoop): """Loop for epoch-based training. Args: runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to build a dataloader. max_epochs (int): Total training epochs. val_begin (int): The epoch that begins validating. Defaults to 1. val_interval (int): Validation interval. Defaults to 1. dynamic_intervals (List[Tuple[int, int]], optional): The first element in the tuple is a milestone and the second element is a interval. The interval is used after the corresponding milestone. Defaults to None. """ def __init__( self, runner, # Runner 实例的引用,用于管理整个训练过程 dataloader: Union[DataLoader, Dict], # 可以是 DataLoader 对象或包含数据加载参数的字典 max_epochs: int, # 训练的总轮次数 val_begin: int = 1, # 指定开始验证的轮次和验证的频率 val_interval: int = 1, dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: # 个可选的里程碑列表,每个里程碑(元组)包含一个轮次和一个间隔,用于在达到该轮次后更改验证频率 super().__init__(runner, dataloader) self._max_epochs = int(max_epochs) assert self._max_epochs == max_epochs, \ f'`max_epochs` should be a integer number, but get {max_epochs}.' self._max_iters = self._max_epochs * len(self.dataloader) # 计算得到的最大迭代次数 iters = pochs * data批次数量 self._epoch = 0 self._iter = 0 # 用于跟踪当前轮次和迭代次数 self.val_begin = val_begin self.val_interval = val_interval # This attribute will be updated by `EarlyStoppingHook` # when it is enabled. self.stop_training = False # 用于控制提前停止训练的标志,可以通过 EarlyStoppingHook 将其设为 True,实现早停功能。 if hasattr(self.dataloader.dataset, 'metainfo'): self.runner.visualizer.dataset_meta = \ self.dataloader.dataset.metainfo else: print_log( f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' 'metainfo. ``dataset_meta`` in visualizer will be ' 'None.', logger='current', level=logging.WARNING) self.dynamic_milestones, self.dynamic_intervals = \ calc_dynamic_intervals( self.val_interval, dynamic_intervals)
在使用时,就需要更改训练设置的配置,不同的类需要传入的参数不同。
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=200, val_interval=20)