YOLOv11小白的进击之路(四)从model.py看代码运行逻辑...

话接上文,YOLOv11小白的进击之路(三)从YOLO类-DetectionTrainer类出发看YOLO代码运行逻辑...-优快云博客,继续来探索YOLOv11的代码到底是怎么运行起来的~

在我们自己写的train.py中,会指定模型配置文件的路径,例如

 model = YOLO(model='ultralytics/cfg/models/11/my.yaml')

以及,通过关键词参数定制训练过程,例如

model.train(data='./data.yaml',epochs=3,batch=-1,imgsz=640)   

那一个是YOLO类,另一个是model.train函数,可以作为我们探寻YOLO运行逻辑的两个起点。上一篇文章中介绍了前者,这次我们主要来分析一下model.train函数。

model.train函数的输入输出

按住Ctrl键+鼠标左键可以跳转到model.train,发现其位于ultralytics-main/ultralytics/engine/model.py处,model.py的Model类里的一个函数。

def train(self, trainer=None, **kwargs):

先看函数的输入和输出,输入包括:

  • self: 一个特殊的参数,通常用于类的方法中,表示该方法所属的对象实例。这里代表当前YOLO模型。

  • trainer=None: 自定义的训练器实例(可选),其默认值为 None

  • **kwargs: 任意关键字参数,允许我们传入训练配置,如数据路径、批量大小、训练轮次等。这里接收的关键字参数(关键字参数是以“键=值”形式传递的参数)是任意数量的。

输出为self.metrics,返回一个包含训练指标的字典(如损失、精确度等),如果训练失败则返回 None

return self.metrics

模型检查

再来看具体实现的部分,前面是一些检查的代码,我们略过,注释见如下代码

self._check_is_pytorch_model()#检查模型是否为PyTorch模型
        # 检查Ultralytics HUB会话
        if hasattr(self.session, "model") and self.session.model.id:  # 如果当前存在一个已经加载的模型的 HUB 会话
            if any(kwargs): # 且用户提供了本地参数
                LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
            kwargs = self.session.train_args  # 发出警告并使用 HUB 训练参数,覆盖本地提供的参数
        # 检查Pip更新
        checks.check_pip_update_available()

读取配置文件

接下来是很重要的部分,读取配置文件没错,我们进击之路第一篇的yaml文件就从这里输入到模型的YOLOv11小白的进击之路(一)从yaml文件说起...-优快云博客

overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
  • 输入kwargs["cfg"](我们输入的配置文件路径,当然,前提是路径存在)。
  • 输出: 返回一个包含配置的字典(如果有配置文件路径)或者使用当前的进行覆盖配置。
  • 作用: 如果我们提供了配置文件路径,则加载该文件并检查它的有效性;否则,使用当前的self.overrides覆盖配置。

构建训练参数,并合并配置参数

 custom = {  
            # NOTE: handle the case when 'cfg' includes 'data'.
            "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task],
            "model": self.overrides["model"],
            "task": self.task,
        }  # method defaults
        args = {**overrides, **custom, **kwargs, "mode": "train"}  # highest priority args on the right

值得注意的是,在合并配置参数时,输出的是一个合并后的字典 args,其包含所有参数,优先级按输入顺序逆序(也就是右边的优先级最高)。非常合理哈哈哈

处理恢复训练

if args.get("resume"): args["resume"] = self.ckpt_path
  • 输入args(合并后的参数字典)。
  • 输出args 中的 "resume" 按检查点路径更新。
  • 作用: 如果我们希望从某个检查点恢复训练,设置恢复路径(对于租用算力平台的uu非常友好的一个功能,不会出现训练到一半没钱了导致重头来过的问题)

初始化训练器,并设置模型

self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
if not args.get("resume"):  # manually set model only if not resuming
            self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
            self.model = self.trainer.model

在初始化训练器后,检查一下这次训练是不是上次没训练完的(也就是,仅在不是恢复训练的情况下,才手动设置训练模型)这样可以确保模型初始化正确。

附加HUB会话,并开始训练

self.trainer.hub_session = self.session  # attach optional HUB session
self.trainer.train()

从这里开始,我们的模型真正开始了训练过程~

更新模型和配置,并返回训练指标

if RANK in {-1, 0}:  
    ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last  
    self.model, _ = attempt_load_one_weight(ckpt)  
    self.overrides = self.model.args  
    self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
return self.metrics  

这里我们留意RANK,它就是一个常量,通常用于标识当前进程的编号,在多进程训练中常用。这一步确保只有主进程会执行接下来的操作,这是为了避免多个进程同时对模型进行更新,从而导致潜在的错误或冲突。

  • 输入self.trainer(检索最佳或最后的检查点),以及模型和覆盖配置。
  • 输出: 更新后的模型和超参数。
  • 作用: 训练完成后,更新模型和覆盖信息,同时提取评估指标,最后返回训练指标。

至此,对YOLOv11代码运行的过程初步分析完毕,还比较浅显;后续会慢慢更新丰富,更进一步对YOLO的运行逻辑进行探索。

文章持续更新中...

最后

欢迎一起交流探讨 ~ 砥砺奋进,共赴山海!
文章推荐

YOLOv11小白的进击之路(五)BaseModel类下_predict_once函数源码分析-优快云博客

YOLOv11小白的进击之路(六)创新YOLO的iou及损失函数时的源码分析-优快云博客

YOLOv11小白的进击之路(七)训练输出日志解读以及训练OOM报错解决办法-优快云博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值