用YOLOV8官方代码,且采用预训练权重训练自己修改过的网络模型的方法(可以确保训练到了自己的网络架构)

另外一种方法训练自己的aa.yaml文件并用预训练权重bb.pt来监督:

第一步,重组YOLO的函数。

路径:路径ultralytics\yolo\engine\model.py

# Ultralytics YOLO 🚀, GPL-3.0 license
import sys
from pathlib import Path
from ultralytics import yolo  # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
                                  guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
                                    is_git_dir, is_pip_package, yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode,intersect_dicts
# Map head to model, trainer, validator, and predictor classes
TASK_MAP = {
    'classify': [
        ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
        yolo.v8.classify.ClassificationPredictor],
    'detect': [
        DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
        yolo.v8.detect.DetectionPredictor]}
class YOLO:
    def __init__(self, model='yolov8n.yaml', weights='yolov8n.pt', task=None, session=None) -> None:
        self._reset_callbacks()
        self.predictor = None  # reuse predictor
        self.model = None  # model object
        self.trainer = None  # trainer object
        self.task = None  # task type
        self.ckpt = None  # if loaded from *.pt
        self.cfg = None  # if loaded from *.yaml
        self.ckpt_path = None # path to *.pt weights
        self.overrides = {}  # overrides for trainer object
        self.metrics = None  # validation/training metrics
        self.session = session  # HUB session
        self.pretrained_weights = weights # 预训练模型
        self.new(model,weights,task)

    def __call__(self, source=None, stream=False, **kwargs):
        return self.predict(source, stream, **kwargs)

    def __getattr__(self, attr):
        name = self.__class__.__name__
        raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

    def new(self,cfg: str,weights,  task=None, verbose=True):
        self.cfg = check_yaml(cfg)  # check YAML
        cfg_dict = yaml_load(self.cfg, append_filename=True)  # model dict
        self.task = task or guess_model_task(cfg_dict)
        self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1)  # build model
        self.overrides['model'] = self.cfg
        # Below added to allow export from yamls
        args = {**DEFAULT_CFG_DICT, **self.overrides}  # combine model and default args, preferring model args
        self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # attach args to model
        self.model.task = self.task
        if isinstance(weights, (str, Path)):
            weights, self.ckpt = attempt_load_one_
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值