【YOLOv5】源码(train.py)

train.pyYOLOv5中用于模型训练的脚本文件,其主要功能是读取配置文件、设置训练参数、构建模型结构、加载数据、训练/验证模型、保存模型权重文件、输出日志等。本文梳理了train.py的整个流程,并在第3节举例如何使用train.py进行训练,并对训练结果目录中的文件作了详细解释

参考笔记:

【YOLOv3】源码(train.py)_yolo原始代码-优快云博客

【yolov5】 train.py详解_yolov5 train.py-优快云博客

学习视频:

模型训练之train.py文件_哔哩哔哩_bilibili


目录

1. train.py的主要功能

2. 主要模块

2.1 参数解析与初始化

2.2 加载模型与数据

2.3 优化器与学习率更新器设置

2.4 训练循环(核心)

2.5 最终的模型验证

3. 使用train.py训练案例

3.1 准备工作

3.2 训练结果解析

4. 完整train.py代码


1. train.py的主要功能

  • 读取命令选项、训练参数配置文件:train.py通过argparse库读取指定的命令行参数,例如batch_size、epoch、weights等;读取yaml文件中的各种训练参数,例如learing_rate、momentum、weight_decay、IoU阈值、高宽比阈值anchor_t

  • 构建模型结构:train.py中要么通过命令行参数weights指定权重文件构建模型结构,并加载参数,如果没有使用命令行参数weights,则通过命令行参数cfg指定YOLOv5模型结构构建一个新的初始化模型

  • 数据加载和预处理:train.py中定义了create_dataloader函数,用于加载训练数据和验证数据,并对其进行预处理。其中,预处理过程包括:自适应图像缩放、图像增强、标签转换等操作

  • 训练和验证过程:train.pytrain函数用于进行模型的训练和验证过程。训练过程中,train.py会对训练数据进行多次迭代,在每个epoch结束时,会对模型在验证集上的表现进行评估,记录的指标有:P,R,mAP@.5,mAP@.5-.95,val_loss(box,obj,cls)

  • 模型保存和日志输出:train.py在每个epoch结束时,保存当前epoch的模型权重(允许覆盖)last.pt,通过模型在验证集上的(P,R,mAP@.5,mAP@.5-.95)计算出模型的适应度,利用适应度保存训练过程中的最佳模型权重best.pt,并将训练和验证过程中的各种指标输出到日志文件中

注意:模型训练结束后,train.py会利用best.pt在验证集上作最后一次验证,控制台输出验证结果

2. 主要模块

2.1 参数解析与初始化

常用参数说明

  • --weights:模型初始权重存放路径
  • --cfg:模型结构的YAML配置文件路径,例如models/yolov5l.yaml
  • --data:数据集YAML配置文件路径,在YAML文件中中定义训练/验证数据集的存放路径和类别等
  • --hyp:训练超参数YAML配置文件路径,控制learning_rate、weight_decay、momentum等训练超参数
  • --epochs:训练的总轮数
  • --batch-size:批量大小
  • --imgsz:将输入图像自适应缩放到imgsz尺寸大小,其作用阶段是在加载数据时,具体的代码实现在utils/agumentations.py下的letterbox函数

  •  --resume:断点续训,有时候服务器崩溃了导致训练中断,就可以利用这个参数继续训练。训练过程保存的pt存放在runs/train/expn/weights下,分别是last.pth、best.pth,使用该参数时,指定--resume runs/train/expn/weights/last.pth即可继续上次中断的训练
  • --noautoanchor:是否不作自适应锚框计算,该功能默认开启,代码位置如下:

  •  --multi-scale:多尺度训练;该功能的执行阶段在自适应图像缩放之后模型训练过程之中,自适应图像缩放之后,在训练过程中要对图片作处理时,再对图像尺寸作随机缩放或扩充,但必须确保仍然是最大下采样倍数(通常是32)的倍数。多尺度训练能够使得模型对不同尺度的目标具有更强的鲁棒性。代码的具体位置如下:

  •  --label-smoothing:标签平滑;具体的代码实现是utils/loss.py下的smooth_BCE函数:

eps是在训练超参数YAML配置文件中定义的

  • --patience:早停机制;如果模型在指定的epoch之内仍然没有性能提升,则训练提前终止,代码的具体位置如下:

  •  --cos-lr:是否通过余弦函数更新学习率,不开启该功能则使用线性函数更新学习率,代码的具体代码位置如下:

  • --freeze:冻结某些层的所有参数不参与训练;具体代码位置如下:

 常用参数和其他参数作用可参考代码注释:

def parse_opt(known=False):
    parser = argparse.ArgumentParser()

    #权重文件路径
    parser.add_argument('--weights', type=str, default=ROOT /'weights/yolov5s.pt',
                        help='initial weights path (初始权重文件路径)')

    #模型结构yaml配置文件路径
    parser.add_argument('--cfg', type=str, default='',
                        help='model.yaml path (模型结构配置文件路径)')

    #数据集yaml配置文件路径
    parser.add_argument('--data', type=str, default=ROOT / 'data/VOC-hat.yaml',
                        help='dataset.yaml path (数据集配置文件路径)')

    #训练超参数yaml配置文件路径
    parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml',
                        help='hyperparameters path (超参数配置文件路径)')

    #训练轮数
    parser.add_argument('--epochs', type=int, default=300)

    #批量大小
    parser.add_argument('--batch-size', type=int, default=16,
                        help='total batch size for all GPUs, -1 for autobatch')

    #imgsz指定训练和验证时将输入图片自适应缩放到相应的尺寸
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640,
                        help='指定训练、验证时将输入图片自适应缩放到相应的尺寸')

    #是否使用矩形训练
    parser.add_argument('--rect', action='store_true',
                        help='rectangular training(是否使用矩形训练)')

    '''断点续训,有时候服务器崩溃了导致训练中断,就可以利用这个参数继续训练,训练过
    程保存的pth存放在runs/train/expxx/weights下,分别是last.pth和best.pth,使用该参数时,
    指定--resume runs/train/expxx/weights/last.pth即可继续上次中断的训练'''
    parser.add_argument('--resume', nargs='?', const=True, default=False,
                        help='resume most recent training(断点续训)')

    '''该参数指定到了最后一个epoch才保存模型的last.pt和best.pt,
    如果你只想获取最后一个epoch的模型权重文件,那可以开启这个功能,默认是不开启的'''
    parser.add_argument('--nosave', action='store_true',
                        help='only save final checkpoint(只保存最后一个epoch的last.pt和best.pt)')

    #仅验证最终epoch,该参数指定到了最后一个epoch才去计算模型在验证集上的性能指标,默认不开启
    parser.add_argument('--noval', action='store_true',
                        help='only validate final epoch(仅验证最终epoch)')

    #是否禁用自适应锚框策略,该功能默认开启
    parser.add_argument('--noautoanchor', action='store_true',
                        help='disable AutoAnchor')

    #不保存训练中生成的图标文件,默认是保存的
    parser.add_argument('--noplots', action='store_true',
                        help='save no plot files')

    #使用遗传算法优化超参数,可指定优化代数,默认不开启该功能
    parser.add_argument('--evolve', type=int, nargs='?', const=300,
                        help='evolve hyperparameters for x generations')

    ##谷歌云盘bucket,一般用不到
    parser.add_argument('--bucket', type=str, default='',
                        help='gsutil bucket')

    #是否缓存数据集到RAM或磁盘
    parser.add_argument('--cache', type=str, nargs='?', const='ram',
                        help='--cache images in "ram" (default) or "disk"(缓存数据集)')

    #是否使用加权的图像进行训练,对于那些训练不好的图片,会在下一个epoch中增加一些权重
    parser.add_argument('--image-weights', action='store_true',
                        help='use weighted image selection for training')
    #指定训练的设备
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

    #是否使用多尺度训练(即随机将自适应缩放之后的尺寸再作随机缩放或增加,但必须确保是最大下采样倍数(通常是32)的倍数)
    #该功能使模型对不同尺度的目标具有更强的鲁棒性
    parser.add_argument('--multi-scale', action='store_true',
                        help='vary img-size +/- 50%%(多尺度训练)')

    #不知有什么用,默认不开启
    parser.add_argument('--single-cls', action='store_true',
                        help='train multi-class data as single-class')

    #优化器,默认是SGD
    parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'],
                        default='SGD', help='optimizer')

    #是否启用同步BatchNorm(只在多GPU中使用),默认不启用
    parser.add_argument('--sync-bn', action='store_true',
                        help='use SyncBatchNorm, only available in DDP mode')

    #数据加载器的最大工作线程数
    parser.add_argument('--workers', type=int, default=8,
                        help='max dataloader workers (per RANK in DDP mode)')

    #训练结果存放的根目录,默认是runs/trian
    parser.add_argument('--project', default=ROOT / 'runs/train',
                        help='save to project/name(根目录)')

    #训练结果存放的子目录,默认是runs/train/expn
    parser.add_argument('--name', default='exp',
                        help='save to project/name(子目录)')

    #是否用当前的训练目录覆盖以前的expn训练目录,默认不开启
    parser.add_argument('--exist-ok', action='store_true',
                        help='existing project/name ok, do not increment(允许覆盖以前的训练结果)')

    #是否使用四元数据加载器,默认不开启
    parser.add_argument('--quad', action='store_true', help='quad dataloader')

    #是否使用余弦函数更新学习率,默认是线性函数更新学习率
    parser.add_argument('--cos-lr', action='store_true',
                        help='cosine LR scheduler')

    #是否启用标签平滑
    parser.add_argument('--label-smoothing', type=float, default=0.0,
                        help='Label smoothing epsilon(标签平滑)')

    #早停机制,设置早停的epoch数
    parser.add_argument('--patience', type=int, default=100,
                        help='EarlyStopping patience (epochs without improvement)')

    #指定冻结不进行训练的层索引
    parser.add_argument('--freeze', nargs='+', type=int, default=[0],
                        help='Freeze layers: backbone=10, first3=0 1 2')

    #设置多少个epoch保存模型权重文件,保存路径是runs/train/weights/epochx.pt
    #该功能模型不开启,一般只保存last.pt和best.pt
    parser.add_argument('--save-period', type=int, default=-1,
                        help='Save checkpoint every x epochs (disabled if < 1)')

    #本地进程排名(DDP模式用)
    parser.add_argument('--local_rank', type=int, default=-1,
                        help='DDP parameter, do not modify(DDP模式的进程排名)')

    #---------------------------- W&B(Weights & Biases)参数配置 ----------------------------
    parser.add_argument('--entity', default=None,
                        help='W&B: Entity (W&B 实体名称)')
    parser.add_argument('--upload_dataset', nargs='?', const=True, default=False,
                        help='W&B: Upload dataset as artifact table (上传数据集到 W&B Artifact Table)')
    parser.add_argument('--bbox_interval', type=int, default=-1,
                        help='W&B: Set bounding-box image logging interval (设置目标框日志记录间隔)')
    parser.add_argument('--artifact_alias', type=str, default='latest',
                        help='W&B: Version of dataset artifact to use (使用的数据集版本别名)')

    #解析参数
    opt = parser.parse_known_args()[0] if known else parser.parse_args()

    return opt

2.2 加载模型与数据

加载模型权重和配置文件,设置模型参数,加载训练数据

分析:该部分代码属于yolov5结构中的哪个阶段?

主要发生在训练前的准备工作,即还没有进入模型的前向传播和反向传播阶段

运行逻辑分析

  • 模型加载与构建
    • 如果提供了预训练权重,加载模型结构和权重参数
    • 如果没有提供权重,则根据模型结构配置文件yolov5xx.yaml构建新模型
  • 冻结层设置
    • 通过命令行的freeze参数指定需要冻结哪些层(例如Backbone层),以适应迁移学习或微调场景
  • 训练数据准备
    • 加载数据集,创建数据迭代器Dataloader
    '-----------------------------模型部分---------------------------'
    #Model,模型加载
    check_suffix(weights, '.pt')  #检查权重文件的后缀是否为.pt
    pretrained = weights.endswith('.pt') #判断权重文件是否为以pt结尾,如果是的话则为true,说明在命令行指定了pt文件路径

    if pretrained:
        #如果本地找不到权重文件,则去YOLOv5的官方仓库中下载权重文件
        with torch_distributed_zero_first(LOCAL_RANK):
            weights = attempt_download(weights)

        #加载权重文件
        ckpt = torch.load(weights, map_location='cpu')

        #创建YOLOv5模型
        model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
        '''
        cfg为模型结构的yaml说明书文件,pt权重文件里也是会保存模型结构的yaml说明书文件的,所以这两个任选一个即可
        ch:为输入通道数(一般为3) nc:数据集类别数 anchors:先验anchor尺寸
        '''

        exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []#定义需要排除的键
        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
        csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) #交集,获取匹配的参数
        model.load_state_dict(csd, strict=False)  #给创建的YOLOv5模型加载权重参数
        LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')  #输出转移的参数数量

    else:
        #如果在命令行的weights参数没有指定使用预训练权重,则使用给定的cfg创建新模型
        model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create

    
    '-----------------------------冻结层部分---------------------------'
    
    #Freeze,冻结一些层的参数,训练过程中不更新
    #保存需要冻结的层,eg:[model.0,model.1,model.6..]
    freeze = [f'model.{x}.' for x in (freeze if len(freeze)>1 else range(freeze[0]))]
    for k, v in model.named_parameters():
        v.requires_grad = True  #默认所有参数参与训练
        if any(x in k for x in freeze):#如果当前参数所在的层是否在冻结列表中
            LOGGER.info(f'freezing {k}')
            v.requires_grad = False#冻结参数,不进行训练


    '-----------------------------数据加载部分---------------------------'
    
    #Trainloader,加载训练集迭代器
    train_loader, dataset = create_dataloader(path=train_path,#训练集存放路径
                                              imgsz=imgsz,#自适应缩放图片大小
                                              batch_size=batch_size // WORLD_SIZE,#批量大小
                                              stride=gs,#最大下采样倍数
                                              single_cls=single_cls,
                                              hyp=hyp,#训练超参数存放路径
                                              augment=True,#数据增强
                                              cache=None if opt.cache == 'val' else opt.cache,
                                              rect=opt.rect,#是否使用矩形训练
                                              rank=LOCAL_RANK,
                                              workers=workers,#线程数
                                              image_weights=opt.image_weights,#是否使用加权的图像进行训练
                                              quad=opt.quad,#是否启用四元数据加载方式
                                              prefix
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值