之前了解了视频数据加载和处理的过程。mmaction2-master/mmaction/tools/train.py代码中对应的部分进行分析来看一下视频训练的过程。
1、首先在tools/train.py的main()函数中调用train_model()函数进行训练模型
2、然后train_model()函数中
该函数中最重要的部分为,生成runner,运行runner.run()开始执行训练过程。runner在单gpus训练过程中为EpochBasedRunner的子类。
Runner = OmniSourceRunner if cfg.omnisource else EpochBasedRunner # 采用EpochBasedRunner
runner = Runner( # 将数据模型和优化器全部装载到runner中
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs, **runner_kwargs)
train_model()的其他代码部分为模型运行配置和一些策略的设置。
3、runner.run()函数在其子类EpochBaseRunner中实现
下面给出run()中有关训练的关键代码,其中执行epoch_runnner()函数进行训练
class EpochBasedRunner(BaseRunner):
def run(self,
data_loaders:List[DtaLoarder],
workflow: List[Tuple[str, int]],
max_epochs: Optional[int] = None,
**kwargs):
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
# 执行训练的迭代过程
while self.epoch < self._max_epochs: # 执行一次循环self.epochs会加一
for i, flow in enumerate(workflow): # workflow记录当前训练流信息 状态+epoch信息
mode, epochs = flow # epoch为一个常量