话接上文,YOLOv11小白的进击之路(三)从YOLO类-DetectionTrainer类出发看YOLO代码运行逻辑...-优快云博客,继续来探索YOLOv11的代码到底是怎么运行起来的~
在我们自己写的train.py中,会指定模型配置文件的路径,例如
model = YOLO(model='ultralytics/cfg/models/11/my.yaml')
以及,通过关键词参数定制训练过程,例如
model.train(data='./data.yaml',epochs=3,batch=-1,imgsz=640)
那一个是YOLO类,另一个是model.train函数,可以作为我们探寻YOLO运行逻辑的两个起点。上一篇文章中介绍了前者,这次我们主要来分析一下model.train函数。
model.train函数的输入输出
按住Ctrl键+鼠标左键可以跳转到model.train,发现其位于ultralytics-main/ultralytics/engine/model.py处,model.py的Model类里的一个函数。
def train(self, trainer=None, **kwargs):
先看函数的输入和输出,输入包括:
-
self
: 一个特殊的参数,通常用于类的方法中,表示该方法所属的对象实例。这里代表当前YOLO模型。 -
trainer=None
: 自定义的训练器实例(可选),其默认值为None
。 -
**kwargs
: 任意关键字参数,允许我们传入训练配置,如数据路径、批量大小、训练轮次等。这里接收的关键字参数(关键字参数是以“键=值”形式传递的参数)是任意数量的。
输出为self.metrics,返回一个包含训练指标的字典(如损失、精确度等),如果训练失败则返回 None
。
return self.metrics
模型检查
再来看具体实现的部分,前面是一些检查的代码,我们略过,注释见如下代码
self._check_is_pytorch_model()#检查模型是否为PyTorch模型
# 检查Ultralytics HUB会话
if hasattr(self.session, "model") and self.session.model.id: # 如果当前存在一个已经加载的模型的 HUB 会话
if any(kwargs): # 且用户提供了本地参数
LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
kwargs = self.session.train_args # 发出警告并使用 HUB 训练参数,覆盖本地提供的参数
# 检查Pip更新
checks.check_pip_update_available()
读取配置文件
接下来是很重要的部分,读取配置文件(没错,我们进击之路第一篇的yaml文件就从这里输入到模型的)YOLOv11小白的进击之路(一)从yaml文件说起...-优快云博客
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
- 输入:
kwargs["cfg"]
(我们输入的配置文件路径,当然,前提是路径存在)。 - 输出: 返回一个包含配置的字典(如果有配置文件路径)或者使用当前的进行覆盖配置。
- 作用: 如果我们提供了配置文件路径,则加载该文件并检查它的有效性;否则,使用当前的self.overrides覆盖配置。
构建训练参数,并合并配置参数
custom = {
# NOTE: handle the case when 'cfg' includes 'data'.
"data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task],
"model": self.overrides["model"],
"task": self.task,
} # method defaults
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
值得注意的是,在合并配置参数时,输出的是一个合并后的字典 args
,其包含所有参数,优先级按输入顺序逆序(也就是右边的优先级最高)。非常合理哈哈哈
处理恢复训练
if args.get("resume"): args["resume"] = self.ckpt_path
- 输入:
args
(合并后的参数字典)。 - 输出:
args
中的"resume"
按检查点路径更新。 - 作用: 如果我们希望从某个检查点恢复训练,设置恢复路径(对于租用算力平台的uu非常友好的一个功能,不会出现训练到一半没钱了导致重头来过的问题)
初始化训练器,并设置模型
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
if not args.get("resume"): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
在初始化训练器后,检查一下这次训练是不是上次没训练完的(也就是,仅在不是恢复训练的情况下,才手动设置训练模型)这样可以确保模型初始化正确。
附加HUB会话,并开始训练
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
从这里开始,我们的模型真正开始了训练过程~
更新模型和配置,并返回训练指标
if RANK in {-1, 0}:
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
self.model, _ = attempt_load_one_weight(ckpt)
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
return self.metrics
这里我们留意RANK,它就是一个常量,通常用于标识当前进程的编号,在多进程训练中常用。这一步确保只有主进程会执行接下来的操作,这是为了避免多个进程同时对模型进行更新,从而导致潜在的错误或冲突。
- 输入:
self.trainer
(检索最佳或最后的检查点),以及模型和覆盖配置。 - 输出: 更新后的模型和超参数。
- 作用: 训练完成后,更新模型和覆盖信息,同时提取评估指标,最后返回训练指标。
至此,对YOLOv11代码运行的过程初步分析完毕,还比较浅显;后续会慢慢更新丰富,更进一步对YOLO的运行逻辑进行探索。
文章持续更新中...
最后
欢迎一起交流探讨 ~ 砥砺奋进,共赴山海!
文章推荐
YOLOv11小白的进击之路(五)BaseModel类下_predict_once函数源码分析-优快云博客