keras-yolo3项目之训练文件train.py注释

本文详细注释了keras-yolo3项目的训练文件train.py,通过阅读可以深入理解模型的训练流程和逻辑。通过训练代码,你可以掌握YOLOv3在Keras中的实现细节。

要想知道一个模型的流程,个人认为,看训练代码是能最快了解模型前后逻辑关系的方法之一,keras-yolo3项目训练源码如下:

"""
Retrain the YOLO model for your own dataset.
使用自己的数据训练YOLO模型
"""

import numpy as np
import keras.backend as K
from keras.layers import Input, Lambda
from keras.models import Model
from keras.optimizers import Adam
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

from yolo3.model import preprocess_true_boxes, yolo_body, tiny_yolo_body, yolo_loss
from yolo3.utils import get_random_data

import tensorflow as tf


def _main():
    annotation_path = '2007_train.txt'
    log_dir = 'logs/000/'
    classes_path = 'model_data/smoking_classes.txt'
    anchors_path = 'model_data/smoking_anchors.txt'
    class_names = get_classes(classes_path)
    num_classes = len(class_names)
    anchors = get_anchors(anchors_path)

    input_shape = (416,416) # multiple of 32, hw

    is_tiny_version = len(anchors)==6 # default setting
    if is_tiny_version:
        model = create_tiny_model(input_shape, anchors, num_classes,
            freeze_body=2, weights_path='model_data/tiny_yolo_weights.h5')    #构建初始模型,即模型初始化
    else:
        # 创建初始模型,并冻结两层网络层
        model = create_model(input_shape, anchors, num_classes,
            freeze_body=2, weights_path='model_data/yolo_weights.h5') # make sure you know what you freeze

    logging = TensorBoard(log_dir=log_dir)
    checkpoint = ModelCheckpoint(log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
        monitor='val_loss', save_weights_only=True, save_best_only=True, period=3)    #保存权重
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)    #学习率
    early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=100, verbose=1)   #早停设置

    val_split = 0.1    #验证集比重
    # 读取标注数据
    with open(annotation_path) as f:
        lines = f.readlines()
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
    num_val = int(len(lines)*val_split)    # 验证集数量
    num_train = len(lines) - num_val    #训练集数量

    # Train with frozen layers first, to get a stable loss.
    # Adjust num epochs to your dataset. This step is enough to obtain a not bad model.
    if True:
        model.
D:\miniconda\envs\yolo8\python.exe C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\train_v8.py WARNING no model scale passed. Assuming scale=&#39;n&#39;. from n params module arguments 0 -1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2] 1 -1 1 2470 ultralytics.nn.modules.AKConv.AKConv [16, 32, 3, 2] 2 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True] 3 -1 1 8006 ultralytics.nn.modules.AKConv.AKConv [32, 64, 3, 2] 4 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True] 5 -1 1 28294 ultralytics.nn.modules.AKConv.AKConv [64, 128, 3, 2] 6 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True] 7 -1 1 105734 ultralytics.nn.modules.AKConv.AKConv [128, 256, 3, 2] 8 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True] 9 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5] 10 -1 1 0 ultralytics.nn.SimAM.SimAM [1024] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, &#39;nearest&#39;] 12 [-1, 6] 1 0 ultralytics.nn.modules.conv.Concat [1] 13 -1 1 148224 ultralytics.nn.modules.block.C2f [384, 128, 1] 14 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, &#39;nearest&#39;] 15 [-1, 4] 1 0 ultralytics.nn.modules.conv.Concat [1] 16 -1 1 37248 ultralytics.nn.modules.block.C2f [192, 64, 1] 17 -1 1 15878 ultralytics.nn.modules.AKConv.AKConv [64, 64, 3, 2] 18 [-1, 12] 1 0 ultralytics.nn.modules.conv.Concat [1] 19 -1 1 156416 ultralytics.nn.modules.block.C2f [448, 128, 1] 20 -1 1 56326 ultralytics.nn.modules.AKConv.AKConv [128, 128, 3, 2] 21 [-1, 9] 1 0 ultralytics.nn.modules.conv.Concat [1] 22 -1 1 493056 ultralytics.nn.modules.block.C2f [384, 256, 1] 23 [15, 18, 21] 1 3481363 ultralytics.nn.modules.head.Detect [1, [192, 448, 384]] D:\miniconda\envs\yolo8\lib\site-packages\torch\functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:3596.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] AKn summary: 245 layers, 5413031 parameters, 5413015 gradients AKn summary: 245 layers, 5413031 parameters, 5413015 gradients New https://pypi.org/project/ultralytics/8.3.203 available Update with &#39;pip install -U ultralytics&#39; Ultralytics YOLOv8.0.151 Python-3.10.0 torch-2.5.1+cu118 CUDA:0 (NVIDIA GeForce RTX 3080 Laptop GPU, 16383MiB) engine\trainer: task=detect, mode=train, model=ultralytics/cfg/models/v8/AKn.yaml, data=mosquito.yaml, epochs=500, patience=50, batch=64, imgsz=640, save=True, save_period=-1, cache=False, device=None, workers=8, project=runs, name=AKConv+simam, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, show=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_conf=True, vid_stride=1, line_width=None, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, boxes=True, format=rknn, keras=False, optimize=False, int8=False, dynamic=False, simplify=False, opset=None, workspace=4, nms=False, lr0=0.01, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, pose=12.0, kobj=1.0, label_smoothing=0.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, mosaic=1.0, mixup=0.0, copy_paste=0.0, cfg=None, tracker=botsort.yaml, save_dir=runs\AKConv+simam3 WARNING no model scale passed. Assuming scale=&#39;n&#39;. from n params module arguments 0 -1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2] 1 -1 1 2470 ultralytics.nn.modules.AKConv.AKConv [16, 32, 3, 2] 2 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True] 3 -1 1 8006 ultralytics.nn.modules.AKConv.AKConv [32, 64, 3, 2] 4 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True] 5 -1 1 28294 ultralytics.nn.modules.AKConv.AKConv [64, 128, 3, 2] 6 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True] 7 -1 1 105734 ultralytics.nn.modules.AKConv.AKConv [128, 256, 3, 2] 8 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True] 9 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5] 10 -1 1 0 ultralytics.nn.SimAM.SimAM [1024] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, &#39;nearest&#39;] 12 [-1, 6] 1 0 ultralytics.nn.modules.conv.Concat [1] 13 -1 1 148224 ultralytics.nn.modules.block.C2f [384, 128, 1] 14 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, &#39;nearest&#39;] 15 [-1, 4] 1 0 ultralytics.nn.modules.conv.Concat [1] 16 -1 1 37248 ultralytics.nn.modules.block.C2f [192, 64, 1] 17 -1 1 15878 ultralytics.nn.modules.AKConv.AKConv [64, 64, 3, 2] 18 [-1, 12] 1 0 ultralytics.nn.modules.conv.Concat [1] 19 -1 1 156416 ultralytics.nn.modules.block.C2f [448, 128, 1] 20 -1 1 56326 ultralytics.nn.modules.AKConv.AKConv [128, 128, 3, 2] 21 [-1, 9] 1 0 ultralytics.nn.modules.conv.Concat [1] 22 -1 1 493056 ultralytics.nn.modules.block.C2f [384, 256, 1] 23 [15, 18, 21] 1 3481363 ultralytics.nn.modules.head.Detect [1, [192, 448, 384]] AKn summary: 245 layers, 5413031 parameters, 5413015 gradients AMP: running Automatic Mixed Precision (AMP) checks with YOLOv8n... C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\nn\tasks.py:565: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don&#39;t have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. return torch.load(file, map_location=&#39;cpu&#39;), file # load [ WARN:0@10.900] global loadsave.cpp:268 cv::findDecoder imread_(&#39;C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\assets\bus.jpg&#39;): can&#39;t open/read file: check file path/integrity Traceback (most recent call last): File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\train_v8.py", line 36, in <module> main(opt) File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\train_v8.py", line 17, in main results = model.train(data=&#39;mosquito.yaml&#39;, # 训练参数均可以重新设置 File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\engine\model.py", line 377, in train self.trainer.train() File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\engine\trainer.py", line 192, in train self._do_train(world_size) File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\engine\trainer.py", line 276, in _do_train self._setup_train(world_size) File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\engine\trainer.py", line 219, in _setup_train self.amp = torch.tensor(check_amp(self.model), device=self.device) File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\utils\checks.py", line 465, in check_amp assert amp_allclose(YOLO(&#39;yolov8n.pt&#39;), im) File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\utils\checks.py", line 452, in amp_allclose a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\engine\model.py", line 98, in __call__ return self.predict(source, stream, **kwargs) File "D:\miniconda\envs\yolo8\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\engine\model.py", line 246, in predict return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\engine\predictor.py", line 197, in __call__ return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one File "D:\miniconda\envs\yolo8\lib\site-packages\torch\utils\_contextlib.py", line 36, in generator_context response = gen.send(None) File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\engine\predictor.py", line 242, in stream_inference for batch in self.dataset: File "C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\data\loaders.py", line 240, in __next__ raise FileNotFoundError(f&#39;Image Not Found {path}&#39;) FileNotFoundError: Image Not Found C:\Users\Administrator\Desktop\正点原子rk3588\2.pt转onnx\ultralytics_yolov8-rk_opt_v1.6\ultralytics\assets\bus.jpg 进程已结束,退出代码为 1
最新发布
09-25
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

great-wind

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值