detectron2整体构建
1. 各文件夹功能
- configs:各种网络的 yaml 配置文件
- datasets:存放数据集
- detectron2:核心组件
- tools:代码运行入口 及 可视化代码
- tests:一些测试代码
- projects:真实项目代码示例,自己的代码结构可以参照这些projects
2. 代码逻辑
config
detectron2/config文件夹
defaults.py: 定义了参数默认值
config.py:定义了一个 CfgNode 类,还有一个 get_cfg() 方法,该方法会返回一个包含 defaults.py 中默认配置的 CfgNode
2.1 train_net.py
(1)setup(args)
通过 merge_from_file 将 args.config_file 的参数覆盖 defaults.py 的默认配置;通过 merge_from_list 将命令行的配置参数覆盖defaults的超参。default_setup 进行了logger设置、环境/命令行/config的log记录、输出目的中config文件的记录,定义在detectron2/engine/default.py中。
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg() # 获取已经配置好默认参数的 cfg
cfg.merge_from_file(args.config_file) # args.config_file 是指定的 yaml 配置文件,通过 merge_from_file 方法,将默认配置中的超参用 yaml 文件中指定的超参进行覆盖
cfg.merge_from_list(args.opts) # merge_from_list,将命令行指定的超参进行覆盖
cfg.freeze() # cfg文件冻结
default_setup(cfg, args) # 开始时的一些基本配置,包括:1. 设置detectron2的logger;2. log关于环境、命令行参数、config信息;3. 将config信息backup到输出目录
return cfg
(2)class Trainer(DefaultTrainer)
DefaultTrainer 定义在 detectron2/engine/default.py 中, DefaultTrainer又继承了来自 detectron2/engine/train_loop.py 中的 TrainerBase。
(2.1)detectron2/engine
先来分析下detectron2/engine下的4个python文件:defaults.py,hooks.py,launch.py,train_loop.py。
(2.1.1)train_loop.py
主要实现了两个类 HookBase 和 TrainerBase,SimpleTrainer 继承 TrainerBase 以及 AMPTrainer 继承 SimpleTrainer。
(1)HookBase:Hook的基类,依据需求实现4种方法:before_train,after_train,before_step,after_step。(HookBase中都是pass)
(2)TrainerBase:
register_hook 方法:将用户自定义的hook注册到trainer中,按照注册顺序执行hook(放到一个list,遍历list依次执行)通过weakref.proxy(self)弱引用调用:defore_train 、after_train 、before_step 、after_step 、run_step 方法,该类中run_step未进行定义(raise NotImplementedError)。(run_step方法即训练过程中读数据、训模型、求loss、BP等)。
train方法:有 start_iter 和 max_iter 两个参数,按顺序执行range(start_iter, max_iter)次hook(before_step、run_step、after_step),在循环前后分别执行 before_train 和 after_train。
此外还有 state_dict、load_state_dict方法,可以使得hook checkpointable。
(3)SimpleTrainer:继承TrainerBase,写的一个简单的trainer,适用于大部分通用任务。(假设已经有了model、dataloader、optimizer),在run_step中写了一次 iter 的计算(获得loss,optimize.zero_grad,loss.backward、optimizer.step).
(4)AMPTrainer:继承SimpleTrainer,进行混合精度训练。
(2.1.2)hooks.py
一些hook的定义,都继承了来自 train_loop.py的 HookBase。
(2.1.3)defaults.py
主要就两个类:DefaultTrainer、DefaultPredictor。
DefaultTrainer :继承TrainerBase ,用于简化 standard model training workflow,(1)使用 config 文件定义的 model、optimizer、dataloader创建一个 SimpleTrainer 类(build_model、build_optimizer、build_train_dataloader、build_test_loader),根据 config 文件定义创建一个 LR scheduler (build_lr_scheduler);(2)调用 resume_or_load 加载上一个 checkpoint 或 cfg.MODEL.WEIGHTS;(3)注册一个由config定义的 common hook。
(2.1.4)launch.py
是否进行分布式训练。
【Trainer总结】
train_net.py中的关键是Trainer,Trainer继承了detectron2/engine/defaults.py中的DefaultTrainer,DefaultTrainer又继承了train_loop.py中的TrainerBase,但其中引用的Trainer是SimpleTrainer。detectron2/engine/train_loop.py 定义了 HookBase、TrainerBase、SimpleTrainer。
- HookBase:定义了 Hook 的基类,四种方法(pass)——before_train、after_train、before_step、after_step;
- TrainerBase:基本trainer,假设model、data_loader、optimizer已经有了。(1)在register_hooks中将自定义的hook放到一个list,然后遍历list按序执行,完成trainer弱引用;(2)在before_train、after_train、before_step、after_step中调用相应的hook方法;(3)新增 run_step(NotImplemented);(4)定义了方法 train,参数为(star_iter,max_iter),完成一个训练逻辑;
def before_train(self):
for h in self._hooks:
h.before_train()
def run_step(self):
raise NotImplementedError
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
# self.iter == max_iter can be used by `after_train` to
# tell whether the training successfully finished or failed
# due to exceptions.
self.iter += 1
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
- SimpleTrainer(TrainerBase):继承TrainerBase,假设data_loader、model、optimizer已经有了,完成TrainerBase中NotImplemented的run_step方法,即dataloader中一个 iter 数据的 loss计算、optimizer.zero_grad()、loss.backward()、optimizer.step();
def __init__(self, model, data_loader, optimizer):
super().__init__()
model.train()
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
def run_step(self):
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
loss_dict = self.model(data)
if isinstance(loss_dict, torch.Tensor