【detectron2】整体构建

本文详细介绍了Detectron2的整体构建,包括各文件夹功能、代码逻辑和核心组件。从train_net.py的setup过程到Trainer类的实现,分析了DefaultTrainer的构建模型、数据加载器、优化器和学习率调度器的方法。同时,讲解了如何注册和使用通用数据集,以及数据加载器的工作原理。内容涵盖了Detectron2在深度学习、计算机视觉、目标检测和人工智能领域的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. 各文件夹功能

  • configs:各种网络的 yaml 配置文件
  • datasets:存放数据集
  • detectron2:核心组件
  • tools:代码运行入口 及 可视化代码
  • tests:一些测试代码
  • projects:真实项目代码示例,自己的代码结构可以参照这些projects

2. 代码逻辑

config
detectron2/config文件夹
defaults.py: 定义了参数默认值
config.py:定义了一个 CfgNode 类,还有一个 get_cfg() 方法,该方法会返回一个包含 defaults.py 中默认配置的 CfgNode

2.1 train_net.py

(1)setup(args)

通过 merge_from_file 将 args.config_file 的参数覆盖 defaults.py 的默认配置;通过 merge_from_list 将命令行的配置参数覆盖defaults的超参。default_setup 进行了logger设置、环境/命令行/config的log记录、输出目的中config文件的记录,定义在detectron2/engine/default.py中。

def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()		# 获取已经配置好默认参数的 cfg
    cfg.merge_from_file(args.config_file)	# args.config_file 是指定的 yaml 配置文件,通过 merge_from_file 方法,将默认配置中的超参用 yaml 文件中指定的超参进行覆盖
    cfg.merge_from_list(args.opts) # merge_from_list,将命令行指定的超参进行覆盖
    cfg.freeze()		# cfg文件冻结
    default_setup(cfg, args)	# 开始时的一些基本配置,包括:1. 设置detectron2的logger;2. log关于环境、命令行参数、config信息;3. 将config信息backup到输出目录
    return cfg
(2)class Trainer(DefaultTrainer)

DefaultTrainer 定义在 detectron2/engine/default.py 中, DefaultTrainer又继承了来自 detectron2/engine/train_loop.py 中的 TrainerBase


(2.1)detectron2/engine

先来分析下detectron2/engine下的4个python文件:defaults.py,hooks.py,launch.py,train_loop.py。

(2.1.1)train_loop.py

主要实现了两个类 HookBase 和 TrainerBase,SimpleTrainer 继承 TrainerBase 以及 AMPTrainer 继承 SimpleTrainer。
(1)HookBase:Hook的基类,依据需求实现4种方法:before_train,after_train,before_step,after_step。(HookBase中都是pass)
(2)TrainerBase
register_hook 方法:将用户自定义的hook注册到trainer中,按照注册顺序执行hook(放到一个list,遍历list依次执行)通过weakref.proxy(self)弱引用调用:defore_trainafter_trainbefore_stepafter_steprun_step 方法,该类中run_step未进行定义(raise NotImplementedError)。(run_step方法即训练过程中读数据、训模型、求loss、BP等)。
train方法:有 start_iter 和 max_iter 两个参数,按顺序执行range(start_iter, max_iter)次hook(before_step、run_step、after_step),在循环前后分别执行 before_train 和 after_train。
此外还有 state_dict、load_state_dict方法,可以使得hook checkpointable。
(3)SimpleTrainer:继承TrainerBase,写的一个简单的trainer,适用于大部分通用任务。(假设已经有了model、dataloader、optimizer),在run_step中写了一次 iter 的计算(获得loss,optimize.zero_grad,loss.backward、optimizer.step).
(4)AMPTrainer:继承SimpleTrainer,进行混合精度训练。

(2.1.2)hooks.py

一些hook的定义,都继承了来自 train_loop.py的 HookBase。

(2.1.3)defaults.py

主要就两个类:DefaultTrainerDefaultPredictor
DefaultTrainer :继承TrainerBase ,用于简化 standard model training workflow,(1)使用 config 文件定义的 model、optimizer、dataloader创建一个 SimpleTrainer 类(build_model、build_optimizer、build_train_dataloader、build_test_loader),根据 config 文件定义创建一个 LR scheduler (build_lr_scheduler);(2)调用 resume_or_load 加载上一个 checkpoint 或 cfg.MODEL.WEIGHTS;(3)注册一个由config定义的 common hook。

(2.1.4)launch.py

是否进行分布式训练。

【Trainer总结】

train_net.py中的关键是Trainer,Trainer继承了detectron2/engine/defaults.py中的DefaultTrainer,DefaultTrainer又继承了train_loop.py中的TrainerBase,但其中引用的Trainer是SimpleTrainer。detectron2/engine/train_loop.py 定义了 HookBase、TrainerBase、SimpleTrainer。

  • HookBase:定义了 Hook 的基类,四种方法(pass)——before_train、after_train、before_step、after_step
  • TrainerBase:基本trainer,假设model、data_loader、optimizer已经有了。(1)在register_hooks中将自定义的hook放到一个list,然后遍历list按序执行,完成trainer弱引用;(2)在before_train、after_train、before_step、after_step中调用相应的hook方法;(3)新增 run_step(NotImplemented);(4)定义了方法 train,参数为(star_iter,max_iter),完成一个训练逻辑;
    def before_train(self):
        for h in self._hooks:
            h.before_train()
    def run_step(self):
        raise NotImplementedError
try:
    self.before_train()
    for self.iter in range(start_iter, max_iter):
        self.before_step()
        self.run_step()
        self.after_step()
    # self.iter == max_iter can be used by `after_train` to
                # tell whether the training successfully finished or failed
                # due to exceptions.
     self.iter += 1
except Exception:
    logger.exception("Exception during training:")
    raise
finally:
    self.after_train()
  • SimpleTrainer(TrainerBase)继承TrainerBase,假设data_loader、model、optimizer已经有了,完成TrainerBase中NotImplemented的run_step方法,即dataloader中一个 iter 数据的 loss计算、optimizer.zero_grad()、loss.backward()、optimizer.step()
    def __init__(self, model, data_loader, optimizer):
    	super().__init__()
   
        model.train()
        
        self.model = model
        self.data_loader = data_loader
        self._data_loader_iter = iter(data_loader)
        self.optimizer = optimizer    
    def run_step(self):
        data = next(self._data_loader_iter)
        data_time = time.perf_counter() - start

        loss_dict = self.model(data)
        if isinstance(loss_dict, torch.Tensor
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值