Train.py代码阅读笔记
在这个py文件中,一共包含三个模块,
1、头文件导入模块,import
导入各种各样的库;
2、main文件,程序的主要进入口;
3、train函数模块,用来写训练的主要逻辑;
1 头文件导入模块
导入各种各样的支持包
import argparse # 解析命令行参数模块
import logging # logging模块,用于记录日志
import math # 数学计算模块
import os # 与操作系统交互的模块,包含文件路径操作和解析
import random # 生成随机数模块,用于设置种子
import time # 时间模块,用于记录时间
from copy import deepcopy # 深度拷贝模块,用于复制模型参数
from pathlib import Path # 路径操作模块,将str对象转为path对象,使字符串路径易于操作的模块
from threading import Thread # 多线程模块,用于多线程处理,在代码中使用多线程可以提高处理速度
import numpy as np # numpy模块,用于处理数组
import torch.distributed as dist # 用于分布式训练的模块,用于多机多卡训练
import torch.nn as nn # 神经网络模块,用于构建神经网络
import torch.nn.functional as F # 神经网络函数模块,用于构建神经网络
import torch.optim as optim # 优化器模块,用于优化神经网络
import torch.optim.lr_scheduler as lr_scheduler # 学习率调度器模块,用于调整学习率
import torch.utils.data # 数据集模块,用于加载数据集
import yaml # yaml模块,用于解析yaml文件
from torch.cuda import amp # 半精度训练模块,用于加速训练
from torch.nn.parallel import DistributedDataParallel as DDP # 用于分布式训练的模块,用于多机多卡训练
from torch.utils.tensorboard import SummaryWriter # 用于tensorboard可视化模块
from tqdm import tqdm # 进度条模块,用于显示进度条
import test # import test.py to get mAP after each epoch
from models.experimental import attempt_load # 导入模型模块,用于加载预训练模型
from models.yolo import Model # 导入模型模块,用于构建yolo模型
首先,导入一下常用的python库:
- *argparse:* 它是一个用于命令项选项与参数解析的模块,通过在程序中定义好我们需要的参数,argparse 将会从 sys.argv 中解析出这些参数,并自动生成帮助和使用信息
- *math:* 调用这个库进行数学运算
- *os:* 它提供了多种操作系统的接口。通过os模块提供的操作系统接口,我们可以对操作系统里文件、终端、进程等进行操作
- *random:* 是使用随机数的Python标准库。random库主要用于生成随机数
- *sys:* 它是与python解释器交互的一个接口,该模块提供对解释器使用或维护的一些变量的访问和获取,它提供了许多函数和变量来处理 Python 运行时环境的不同部分
- *time:* Python中处理时间的标准库,是最基础的时间处理库
- *copy:* Python 中赋值语句不复制对象,而是在目标和对象之间创建绑定 (bindings) 关系。copy模块提供了通用的浅层复制和深层复制操作
- *datetime:* 是Python常用的一个库,主要用于时间解析和计算
- *pathlib:* 这个库提供了一种面向对象的方式来与文件系统交互,可以让代码更简洁、更易读
然后再导入一些****pytorch库****:
- *numpy:* 科学计算库,提供了矩阵,线性代数,傅立叶变换等等的解决方案, 最常用的是它的N维数组对象
- *torch:* 这是主要的Pytorch库。它提供了构建、训练和评估神经网络的工具
- *torch.distributed:* torch.distributed包提供Pytorch支持和通信基元,对多进程并行,在一个或多个机器上运行的若干个计算阶段
- *torch.nn:* torch下包含用于搭建神经网络的modules和可用于继承的类的一个子包
- *yaml:* yaml是一种直观的能够被电脑识别的的数据序列化格式,容易被人类阅读,并且容易和脚本语言交互。一般用于存储配置文件
- *torch.cuda.amp:* 自动混合精度训练 —— 节省显存并加快推理速度
- *torch.nn.parallel:* 构建分布式模型,并行加速程度更高,且支持多节点多gpu的硬件拓扑结构
- *torch.optim:* 优化器 Optimizer。主要是在模型训练阶段对模型可学习参数进行更新,常用优化器有 SGD,RMSprop,Adam等
- *tqdm:* 就是我们看到的训练时进度条显示
from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path,\
labels_to_image_weights, init_seeds, \
fitness, strip_optimizer, get_latest_run,\
check_dataset, check_file, check_git_status, check_img_size,check_requirements,\
print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download # 下载模型模块,用于下载预训练模型
from utils.loss import ComputeLoss # 损失函数模块,用于计算损失函数
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution # 绘图模块,用于绘制图像
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel # 工具模块,用于处理pytorch相关操作
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume # 导入wandb模块,用于记录日志到wandb平台
logger = logging.getLogger(__name__) # 获取logger对象,用于记录日志
这些都是用户自定义的库,由于上一步已经把路径加载上了,所以现在可以导入,这个顺序不可以调换。具体来说,代码从如下几个文件中导入了部分函数和类:
- *val:* 这个是测试集,我们下一篇再具体讲
- *models.experimental:* 实验性质的代码,包括MixConv2d、跨层权重Sum等
- *models.yolo:* yolo的特定模块,包括BaseModel,DetectionModel,ClassificationModel,parse_model等
- *utils.autoanchor:* 定义了自动生成锚框的方法
- *utils.autobatch:* 定义了自动生成批量大小的方法
- *utils.callbacks:* 定义了回调函数,主要为logger服务
- *utils.datasets:* dateset和dateloader定义代码
- *utils.downloads:* 谷歌云盘内容下载
- *utils.general:* 定义了一些常用的工具函数,比如检查文件是否存在、检查图像大小是否符合要求、打印命令行参数等等
- *utils.loggers :* 日志打印
- *utils.loss:* 存放各种损失函数
- *utils.metrics:* 模型验证指标,包括ap,混淆矩阵等
- *utils.plots.py:* 定义了Annotator类,可以在图像上绘制矩形框和标注信息
- *utils.torch_utils.py:* 定义了一些与PyTorch有关的工具函数,比如选择设备、同步时间等
通过导入这些模块,可以更方便地进行目标检测的相关任务,并且减少了代码的复杂度和冗余。
2 main()函数模块
2.1 main函数入口
通过__main__
开启主函数入口,
if __name__ == '__main__':
2.2 输入参数解析
在main函数中,首先是将外界输入的参数进行解释,
如外界命令为
python train.pu --weights yolov5s.py --cfg yolov5s.yaml --data coco.yaml --epoch 20 --batch-size 16
将参数后面的值解释到对应的参数中去,
代码如下
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
# 这里是权重文件(或者说是预训练模型),如果有权重文件,将基于这个权重文件进行训练,
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
# 模型配置文件
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
# 数据文件,包含训练数据,测试数据,评估数据
parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
# 超参数文件
parser.add_argument('--epochs', type=int, default=300)
# 训练的epochs次数,默认设置为300次
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
# batch-size的数目
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
# 图像尺寸信息
parser.add_argument('--rect', action='store_true', help='rectangular training')
# 使用rectangular训练方式,这个我也并不是很明白,
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
# 是否接着上一次的训练结果
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
# 是否保存checkpoint开关
parser.add_argument('--notest', action='store_true', help='only test final epoch')
# 是否只需要测试最后的epoch
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
# envolve是是否使用遗传算法进行调参
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
# bucket谷歌优盘
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
# 是否提前缓存图片到内存,以加快训练速度,默认为flase
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
# 是否开启图片采样策略,也就是图像权重,这个有点难理解,但是在后面代码中会有解释
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
# 设备的选择,GPU CPU...
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
# 是否开启多尺度训练模式
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
# 数据集是否多类/默认True
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
# 是否开启adam优化器
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
# sync-bn是否使用跨卡同步BN,在DDP模式中使用
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
# local-rank:本地的线程数,rank=-1单卡,rank=0主卡,rank=1,2,3... 在第rank个节点上
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
# workers: 线程数
parser.add_argument('--project', default='runs/train', help='save to project/name')
# 项目的输出地址
parser.add_argument('--entity', default=None, help='W&B entity')
# 在线可视化工具,类似于Tensorboard
parser.add_argument('--name', default='exp', help='save to project/name')
# 项目结果的输出名称,输出为exp文件夹
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
# 判断项目位置是否存在,
parser.add_argument('--quad', action='store_true', help='quad dataloader')
# quad:四元数据加载器,允许在较低的--imag尺寸下进行更高--img尺寸训练的一些好处
parser.add_argument('--linear-lr', action='store_true', help='linear LR')
# 余弦学习率,线性学习率
parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
# 标签平滑/默认不增强,用户可以根据自己标签的实际情况设置这个参数
parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
# 是否上传数据集到wandb table
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
# 设置界框图像记录间隔
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
# 多少个epoch保存一下checkpoint
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
# 使用数据的版本
opt = parser.parse_args()
# 作用就是当京获取到基本设置时,如果运行命令传入了其他的配置,就替换,否之,保持默认选项
2.2.1 rectangular training
2.2.2 check-point
2.2.3 evolve
2.2.4 image-weights
2.2.5 multi-scale
2.2.6 local-rank
2.2.7 workers
2.2.8 quad
2.2.9 label_smoothing
2.2.10 box-interval
2.3 参数配置
2.3.1 world_size与global_rank本地参数的获取
# Set DDP variables
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
# 这段代码的作用是从环境变量中获取“WORD_SIZE”这个变量的值,并将这个值转换为整数类型赋值给opt.world,
# 如果没有这个变量,则默认为1。
# 这段代码通常用于多进程/多节点分布式训练,用于指定训练的进程/节点数量
# WORLD_SIZE在不同进程中是唯一的
# WORLD_SIZE由torch.distributed.launch.py产生,具体数值为 nproc_per_node*node(主机数,这里为1),
# 这里的 opt.world_size指进程总数,在这里就是我们使用的卡数
代码的获取作用是,从环境变量中获取“WORD_SIZE”这个变量的值,并将这个值转换为整数类型赋值给opt.world
。
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
'''
rank指进程序号,local_rank指本地序号,两者的区别在于前者用于进程间通讯,后者用于本地设备分配,
这里的local_rank是由torch.distributed.launch.py产生的,具体数值为local_rank*nproc_per_node+global_rank
在分布式训练和多 GPU 训练的上下文中,进程编号(rank)通常用于标识不同的训练进程。
进程编号为 -1 通常表示单机单卡模式,即训练在单个 GPU 上进行,没有分布式训练。以下是一些常见的进程编号含义:
rank = -1:表示单机单卡模式,训练在单个 GPU 上进行。
rank = 0:表示分布式训练中的主节点(master node),通常负责协调和管理其他节点。
rank = 1, 2, 3, ...:表示分布式训练中的其他节点(worker nodes),每个节点负责一部分训练任务。
在代码中,进程编号通常用于条件判断,以决定是否启用某些功能,例如数据并行、同步批归一化等。
'''
2.3.2 日志目录的配置
set_logging(opt.global_rank) # 设置日志级别
if opt.global_rank in [-1, 0]:
check_git_status() # 检查你的代码版本是否为最新的(不适用于windows系统),如果不是最新的,会给出警告提示
check_requirements() # 检查当前环境是否满足了依赖库
2.3.3 Resume与wandb配置
# Resume
# Wandb(Weights & Biases)是一个用于机器学习实验跟踪和可视化的工具和平台。
# 它可以帮助你跟踪你的模型的训练过程,记录超参数、指标、模型图、数据集、图像、日志、检查点等。
# 它还可以帮助你分享你的模型、结果、分析你的模型,并与他人协作。
# 这里的wandb_run变量是wandb_logger.wandb_run,它是一个wandb.run实例,用于记录当前训练的状态。
# 如果opt.resume为True,则会检查是否有wandb_run,如果有,则从wandb_run中恢复训练,否则会从opt.resume中恢复训练
wandb_run = check_wandb_resume(opt) # 检查是否有wandb_run,如果有,则从wandb_run中恢复训练
if opt.resume and not wandb_run: # 如果opt.resume为True,并且wandb_run不存在,则从opt.resume中恢复训练
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # ckpt是恢复的模型路径
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' # 断言ckpt文件是否存在
apriori = opt.global_rank, opt.local_rank # 保存apriori信息,主要是多卡训练时,需要保存apriori信息
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: # 加载opt.yaml文件,opt是训练参数
opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
opt.cfg, opt.weights, opt.resume,opt.batch_size,\
opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # 更新opt参数
logger.info('Resuming training from %s' % ckpt) # 打印恢复信息
else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
# opt.data是一个yaml文件,里面包含了训练数据集的路径、类别数、以及训练集、验证集、测试集的比例
# opt.hyp是一个yaml文件,里面包含了训练超参数,如学习率、优化器、学习率衰减策略、正则化策略等
opt.data, opt.cfg, opt.hyp = check_file(opt.data),check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' # # 如果模型文件和权重文件为空,弹出警告
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # 扩展img_size
opt.name = 'evolve' if opt.evolve else opt.name
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
这段代码的主要功能是处理训练过程中的恢复逻辑,包括从 WandB 恢复和从本地检查点恢复。以下是对代码的详细解释:
-
检查 WandB 恢复:
wandb_run = check_wandb_resume(opt)
这行代码调用
check_wandb_resume
函数,检查是否存在 WandB 运行记录(wandb_run
)。如果存在,则可以从 WandB 恢复训练。 -
从本地检查点恢复:
if opt.resume and not wandb_run: # 如果opt.resume为True,并且wandb_run不存在,则从opt.resume中恢复训练 ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # ckpt是恢复的模型路径 assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' # 断言ckpt文件是否存在 apriori = opt.global_rank, opt.local_rank # 保存apriori信息,主要是多卡训练时,需要保存apriori信息 with open(Path(ckpt).parent.parent / 'opt.yaml') as f: # 加载opt.yaml文件,opt是训练参数 opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace opt.cfg, opt.weights, opt.resume,opt.batch_size,\ opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # 更新opt参数 logger.info('Resuming training from %s' % ckpt) # 打印恢复信息
这段代码处理从本地检查点恢复训练的逻辑:
- 如果
opt.resume
为True
且wandb_run
不存在,则从opt.resume
指定的路径恢复训练。 ckpt
是恢复的模型路径,如果opt.resume
是字符串,则直接使用;否则调用get_latest_run
获取最新运行记录。- 断言
ckpt
文件存在,确保恢复路径正确。 - 保存
apriori
信息,包括global_rank
和local_rank
,这些信息在多卡训练时需要。 - 从
opt.yaml
文件中加载训练参数,并更新opt
对象。 - 更新
opt
对象的其他参数,如cfg
、weights
、resume
、batch_size
、global_rank
和local_rank
。 - 打印恢复信息。
- 如果
-
初始化新训练:
else: # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') # opt.data是一个yaml文件,里面包含了训练数据集的路径、类别数、以及训练集、验证集、测试集的比例 # opt.hyp是一个yaml文件,里面包含了训练超参数,如学习率、优化器、学习率衰减策略、正则化策略等 opt.data, opt.cfg, opt.hyp = check_file(opt.data),check_file(opt.cfg), check_file(opt.hyp) # check files assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' # # 如果模型文件和权重文件为空,弹出警告 opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # 扩展img_size opt.name = 'evolve' if opt.evolve else opt.name opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
这段代码处理初始化新训练的逻辑:
- 检查并加载
opt.data
、opt.cfg
和opt.hyp
文件。 - 断言
opt.cfg
或opt.weights
必须指定,确保模型文件或权重文件存在。 - 扩展
opt.img_size
,确保其长度为 2。 - 根据
opt.evolve
设置opt.name
。 - 调用
increment_path
函数,生成新的保存路径。
- 检查并加载
总结:
- 这段代码首先检查是否可以从 WandB 恢复训练。
- 如果不能从 WandB 恢复,则检查是否可以从本地检查点(checkpoint)恢复训练。
- 如果既不能从 WandB 恢复也不能从本地检查点恢复,则初始化新训练。
2.3.3.1 wandb
2.3.3.2 本地checkpoint
3.3.5 DDP mode
# DDP mode
opt.total_batch_size = opt.batch_size # batch_size is per GPU
device = select_device(opt.device, batch_size=opt.batch_size) # 选择GPU设备
if opt.local_rank != -1: # 如果rank不是-1,即不是单卡模式,则进入DDP模式
assert torch.cuda.device_count() > opt.local_rank # 断言,GPU数量大于local_rank
# torch.cuda.set_device函数用于设置当前使用的cuda设备,
# 在当拥有多个可用的GPU且能被pytorch识别的cuda设备情况下(环境变量CUDA_VISIBLE_DEVICES可以影响GPU设备到cuda设备的映射)。
# 由于有些情况下,可以不显式写出cuda设备的编号,此时指的是当前使用的cuda设备,默认为cuda0设备,如下所示
torch.cuda.set_device(opt.local_rank) # 设置
# torch.device函数用于创建设备对象,其参数为cuda设备的编号,返回一个设备对象。
device = torch.device('cuda', opt.local_rank) # 设置当前GPU
# torch.distributed.init_process_group函数用于初始化分布式训练,
# 其参数为backend、init_method、world_size、rank、group_name等,
# 其中backend为分布式后端,init_method为初始化方法,world_size为进程总数,rank为当前进程号,group_name为分布式组名。
# 这里的backend为nccl,init_method为env://,即通过环境变量初始化,world_size为opt.world_size,
# rank为opt.global_rank,group_name为None。
# 这里的init_method为env://,即通过环境变量初始化,这意味着每个进程都需要知道其他进程的IP地址和端口号,
# 因此需要在每个进程的命令行中设置CUDA_VISIBLE_DEVICES、MASTER_ADDR、MASTER_PORT等环境变量。
# 这里的init_method也可以设置为tcp://IP:PORT,即通过IP地址和端口号初始化,但需要在每个进程的命令行中设置相同的IP地址和端口号。
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
# 断言,batch_size必须是cuda设备数量的整数倍
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
# opt.batch_size是每个GPU的batch_size,这里的batch_size是所有GPU的batch_size,因此需要除以world_size
opt.batch_size = opt.total_batch_size // opt.world_size
这段代码主要用于在分布式训练环境下初始化和设置GPU设备,确保批量大小在各个GPU上均匀分布,并初始化分布式训练环境。
这段代码是用于在分布式数据并行(DDP)模式下设置和初始化训练环境的。以下是对代码的详细解释:
-
设置全局批量大小:
opt.total_batch_size = opt.batch_size # batch_size is per GPU
这里将
opt.total_batch_size
设置为opt.batch_size
,表示每个GPU上的批量大小。 -
选择GPU设备:
device = select_device(opt.device, batch_size=opt.batch_size) # 选择GPU设备
调用
select_device
函数选择GPU设备。opt.device
可能是类似'cuda:0'
的字符串,表示使用哪个GPU设备。 -
进入DDP模式:
if opt.local_rank != -1: # 如果rank不是-1,即不是单卡模式,则进入DDP模式
如果
opt.local_rank
不是 -1,表示当前处于分布式训练模式。 -
断言GPU数量大于local_rank:
assert torch.cuda.device_count() > opt.local_rank # 断言,GPU数量大于local_rank
确保系统中GPU的数量大于
opt.local_rank
,以避免超出GPU数量范围的错误。 -
设置当前GPU设备:
torch.cuda.set_device(opt.local_rank) # 设置 device = torch.device('cuda', opt.local_rank) # 设置当前GPU
设置当前使用的GPU设备为
opt.local_rank
对应的GPU。 -
初始化分布式训练环境:
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
使用
nccl
后端和env://
初始化方法初始化分布式训练环境。nccl
是NVIDIA的集合通信库,适用于多GPU环境。 -
断言批量大小是GPU数量的倍数:
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
确保每个GPU上的批量大小是总GPU数量的倍数,以保证数据在各个GPU上均匀分布。
-
计算每个GPU上的批量大小:
opt.batch_size = opt.total_batch_size // opt.world_size
计算每个GPU上的批量大小,即全局批量大小除以GPU数量。
2.3.4 超参数设置
# Hyperparameters
with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
# Train
logger.info(opt)
if not opt.evolve: # 如果不使用遗传算法,则直接训练
tb_writer = None # init loggers
if opt.global_rank in [-1, 0]: # 如果globle_rank是-1或者0,则初始化wandb_logger
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer) # 进入训练过程
# Evolve hyperparameters (optional) 遗传进化算法,边进化边训练
else:
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
# 超参数(1,1e-5,1e-1)==> mutation scale=1, lower_limit=1e-5, upper_limit=1e-1
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3),初始化学习率
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
'box': (1, 0.02, 0.2), # box loss gain
'cls': (1, 0.2, 4.0), # cls loss gain
'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (0, 0.1, 0.7), # IoU training threshold
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
'translate': (1, 0.0, 0.9),# image translation (+/- fraction)
'scale': (1, 0.0, 0.9), # image scale (+/- gain)
'shear': (1, 0.0, 10.0), # image shear (+/- deg)
'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' # DDP模式下不能执行evolve
opt.notest, opt.nosave = True, True # only test/save final epoch
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
if opt.bucket:
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
# 默认迭代演化300次, 每次变异都需要训练一次要查看效果result, 所以计算量呈倍数递增
# 官方建议至少进行 300 代进化以获得最佳结果, 而基础场景被训练了数百次, 可能需要数百或数千个GPU小时
for _ in range(300): # generations to evolve 300次迭代进化,遗传算法
if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
# Select parent(s)
parent = 'single' # parent selection method: 'single' or 'weighted'
x = np.loadtxt('evolve.txt', ndmin=2) # 加载evolve文件
# 加载evolve.csv文件
n = min(5, len(x)) # number of previous results to consider
# 最多选择5个最好的变异结果来挑选
x = x[np.argsort(-fitness(x))][:n] # top n mutations
# np.argsort只能从小到大排序, 添加负号实现从大到小排序, 算是排序的一个代码技巧
w = fitness(x) - fitness(x).min() # weights fitness在metrics.py中
# 根据(mp, mr, map50, map)的加权和来作为权重
# 根据不同进化方式获得base hyp
# single方式: 根据每个hyp的权重随机选择一个之前的hyp作为base hyp
# weighted方式: 根据每个hyp的权重对之前所有的hyp进行融合获得一个base hyp
if parent == 'single' or len(x) == 1:
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
elif parent == 'weighted':
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
# Mutate 超参数进化
mp, s = 0.8, 0.2 # mutation probability, sigma 设置突变概率和方差
npr = np.random
npr.seed(int(time.time()))
# 获取突变初始值, 也就是meta三个值的第一个数据
# 三个数值分别对应着: 变异初始概率, 最低限值, 最大限值(mutation scale 0-1, lower_limit, upper_limit)
g = np.array([x[0] for x in meta.values()]) # gains 0-1
ng = len(meta)
v = np.ones(ng)
while all(v == 1): # mutate until a change occurs (prevent duplicates)
v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
hyp[k] = float(x[i + 7] * v[i]) # mutate
# Constrain to limits 限制草参数范围
for k, v in meta.items():
hyp[k] = max(hyp[k], v[1]) # lower limit
hyp[k] = min(hyp[k], v[2]) # upper limit
hyp[k] = round(hyp[k], 5) # significant digits
# Train mutation
# Train mutation: result{tuple 7}: mp, mr, map50, map, box, obj, cls
# 具体返回的是: 'metrics/precision', 'metrics/recall',
# 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
# 'val/box_loss', 'val/obj_loss', 'val/cls_loss'
results = train(hyp.copy(), opt, device)
# Write mutation results
# 将每一代的演化结果与训练结果记录在evolve.csv文件中
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
# Plot results
# 绘图: 每个超参数有一个子图, 显示适应度(y 轴)与超参数值(x 轴).黄色表示更高的浓度
plot_evolution(yaml_file)
print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
fitness函数的代码如下
# 适应度的计算(用来挑选变异结果)
def fitness(x):
# Model fitness as a weighted combination of metrics
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
return (x[:, :4] * w).sum(1)
这段代码主要涉及超参数的加载、训练过程以及可选的超参数进化(遗传算法)。以下是对代码的详细解释:
2.3.4.1 超参数加载
with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
这行代码从指定的文件中加载超参数,并将其存储在 hyp
字典中。
2.3.4.2 训练过程
logger.info(opt)
if not opt.evolve: # 如果不使用遗传算法,则直接训练
tb_writer = None # init loggers
if opt.global_rank in [-1, 0]: # 如果globle_rank是-1或者0,则初始化wandb_logger
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer) # 进入训练过程
- 如果
opt.evolve
为False
,则直接进行训练。 - 初始化写入器
tb_writer
,并调用train
函数进行训练。
3.4.3 在else中
- 如果
opt.evolve
为True
,则进行超参数进化。 - 定义超参数进化的元数据
meta
,包括每个超参数的变异范围和限制。 - 确保
opt.local_rank
为-1
,即不支持分布式训练模式下的超参数进化。 - 设置
opt.notest
和opt.nosave
为True
,表示只在最终 epoch 进行测试和保存。 - 定义保存最佳结果的 YAML 文件路径
yaml_file
。 - 如果存在
evolve.txt
,则从中选择最佳超参数并进行变异。 - 进行 300 代的进化过程,每代选择最佳超参数并进行变异,然后训练模型并记录结果。
- 最终绘制进化结果并输出最佳超参数的训练命令。
2.3.4.4 meta:(mutation scale 0-1, lower_limit, upper_limit)
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3),初始化学习率
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
... ...
在超参数进化(遗传算法)的上下文中,mutation scale 0-1, lower_limit, upper_limit
是指每个超参数的变异范围和限制。具体来说:
- mutation scale 0-1: 变异初始概率,这是一个介于 0 和 1 之间的值,表示变异的比例或强度。在遗传算法中,变异是为了引入新的遗传物质,以避免陷入局部最优解。这个比例决定了变异的程度,通常是通过随机数生成器来实现的。
- lower_limit: 这是每个超参数的下限值。在进行变异时,生成的新的超参数值不能低于这个下限。
- upper_limit: 这是每个超参数的上限值。在进行变异时,生成的新的超参数值不能高于这个上限。
举个例子,假设我们有一个超参数 learning_rate
,其元数据定义如下:
'learning_rate': (0.5, 1e-5, 1e-1)
这意味着:
- mutation scale 0.5: 变异的比例是 0.5,表示变异的程度是 50%。
- lower_limit 1e-5: 学习率的下限是 1e-5。
- upper_limit 1e-1: 学习率的上限是 1e-1。
在进行变异时,新的学习率值将会在 1e-5 和 1e-1 之间,并且变异的程度会根据 0.5 的比例进行调整。
这种定义方式确保了在超参数进化过程中,每个超参数的变异都在合理的范围内进行,从而避免生成无效或不合理的新超参数值
可以参考这个大佬写的博客
YOLOv5系列(三十八) 解读遗传算法实现超参数进化(Hyperparameter Evolution)(详尽)_yolov5中的进化算法-优快云博客