YOLOv5-4.0-general.py 源代码导读(通用工具)

本文介绍了YOLOv5的通用工具文件`general.py`,包括日志设置、随机种子初始化、模型路径获取、网络连接检查、依赖包验证等功能。此外,还详细讲解了坐标转换、IOU计算、非极大值抑制等关键操作,是深入理解YOLOv5实现的重要参考资料。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

YOLOv5介绍

YOLOv5为兼顾速度与性能的目标检测算法。笔者将在近期更新一系列YOLOv5的代码导读博客。YOLOv5为2021.1.5日发布的4.0版本。
YOLOv5开源项目github网址
源代码导读汇总网址
本博客导读的代码为utils文件夹下的general.py,取自1.27日更新的版本。

general.py

该文件提供了模型多个过程中用到的通用方法,每个功能以函数的方式进行定义。

以下为该文件必须导入的模块,其中部分文件来源于其他项目子文件。

# General utils
import glob      # 仅支持部分通配符得文件搜索模块
import logging   # 日志模块
import math      # 数学公式模块
import os        # 与操作系统进行交互的模块
import platform  # 获得操作系统相关信息的模块
import random    # 生成随机数的模块
import re        # 用来匹配字符串(动态、模糊)的模块
import subprocess# 创建子进程的模块
import time      # 用来获取系统时间的模块
from pathlib import Path  #Path对象 简便对path进行操作

import cv2           # opencv库
import numpy as np   # numpy矩阵处理函数库
import torch         # pytorch框架
import torchvision   # 为pytorch 提供一些辅助工具
import yaml          # yaml配置文件模块

# 以下调用三个函数具体注释见 源代码导读的其他文件
from utils.google_utils import gsutil_getsize # 用于返回网站链接对应文件的大小
from utils.metrics import fitness # 返回指标的加权值得行向量
from utils.torch_utils import init_torch_seeds # 功能为初始化随机种子

以下为运行相关的一些基本的设置

# 下两行为设置tensor和numpy array的打印格式 linewidth为每一行字符上限 precision 为精度
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={
   'float_kind': '{:11.5g}'.format})  # format short g, %precision=5
cv2.setNumThreads(0)  # 阻止opencv参与多线程(与 Pytorch的 Dataloader不兼容)
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8))  # 确定最大的线程数 这里被限制在了8

该函数为对日志的设置进行初始化 rank为-1或0 时设置输出级别为WARN

def set_logging(rank=-1):
    logging.basicConfig(
        format="%(message)s",
        level=logging.INFO if rank in [-1, 0] else logging.WARN)

该函数为初始化随机种子 统一random numpy torch 种子

def init_seeds(seed=0):
    # 初始化随机种子生成器
    random.seed(seed)
    np.random.seed(seed)
    init_torch_seeds(seed)

该函数返回最近的模型 'last.pt’对应的路径

def get_latest_run(search_dir='.'):
    # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
    # 从Python版本3.5开始,glob模块支持该"**"指令(仅当传递recursive标志时才会解析该指令
    # glob.glob函数匹配所有的符合条件的文件,并将其以list的形式返回
    last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
    # os.getctime 返回路径对应文件的创建时间 
    # 也就是返回所有文件中创建时间最晚的路径
    return max(last_list, key=os.path.getctime) if last_list else ''

用socket模块 检查当前主机网络连接是否可用

def check_online():
    # Check internet connectivity
    import socket
    try:
        socket.create_connection(("1.1.1.1", 53))  # check host accesability 该单词拼错 应为accessbility
        return True
    except OSError:
        return False

检查当前代码是否是最新版 如果不是最新的 会提示使用git pull命令进行升级

def check_git_status():
    # Recommend 'git pull' if code is out of date
    # 彩色显示github单词 colorstr()函数后续介绍
    print(colorstr('github: '), end='')
    try:
        # 检查以git结尾的路径存在
        assert Path('.git').exists(), 'skipping check (not a git repository)'
        # 但是包含"/workspace"的路径不存在
        assert not Path('/workspace').exists(), 'skipping check (Docker image)'  # not Path('/.dockerenv').exists()
        # 保证主机网络是可用的
        assert check_online(), 'skipping check (offline)'
        # 这里是创建cmd命令 并创建子进程进行执行
        cmd = 'git fetch && git config --get remote.origin.url'
        url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git')  # github repo url
        branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip()  # checked out
        n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True))  # commits behind
        # n大于0说明 当前版本之后还有commit 因此当前版本不是最新的 s为输出的相关提示
        if n > 0:
            s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
                f"Use 'git pull' to update or 'git clone {url}' to download latest."
        else:
            s = f'up to date with {url} ✅'
        # 通过.encode().decode()的组合忽略掉无法用ascii编码的内容
        print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s)
    except Exception as e:
        print(e)

用于检查已经安装的包是否满足file对应txt文件的要求

def check_requirements(file='requirements.txt', exclude=()):

    import pkg_resources #此为管理安装包信息相关模块
    # x在被读取的时候转换为pkg_resources.Requirement类
    # x.name 返回安装包对应名称 x.specifier 返回符号后面内容 即安装包对应版本
    # requirements 负责将上述pkg类转换为列表对象 以供require函数使用
    requirements = [f'{x.name}{x.specifier}' for x in pkg_resources.parse_requirements(Path(file).open())
                    if x.name not in exclude]
    pkg_resources.require(requirements)  #判断已经安装的包是否符合要求

检查img_size是否能被s整除,这里默认s为32

def check_img_size(img_size, s=32):
    # make_divisible函数见下文注释
    new_size = make_divisible(img_size, int(s))  # 返回大于等于img_size的最小能被s整除的值
    if new_size != img_size: #新size和旧size不同时 打印出信息 
        print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
    return new_size # 返回能被s整除的new_size

检查相关文件路径是否能找到文件,如果不能则在全局路径中寻找,匹配到多个时返回第一个

def check_file(file):

    if os.path.isfile(file) or file == '':
        return file
    else:
        files = glob.glob('./**/' + file, recursive=True)  # 寻找文件
        assert len(files), 'File Not Found: %s' % file  # assert 文件没有找到
        assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files)  # assert unique
        return files[0]  # 返回文件

检查数据集 如果本地没有则从torch库中下载并解压数据集

#          dict内容需要实际调试(暂无法逐行调试) 该方法只能适用某几个数据集
def check_dataset(dict):

    val, s = dict.get('val'), dict.get('download')
    if val and len(val):
        val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path
        if not all(x.exists() for x in val):
            print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
            if s and len(s):  # download script
                print('Downloading %s ...' % s)
                if s.startswith('http') and s.endswith('.zip'):  # URL
                    f = Path(s).name  # filename
                    torch.hub.download_url_to_file(s, f)
                    r = os.system('unzip -q %s -d ../ && rm %s' % (f, f))  # unzip
                else:  # bash script
                    r = os.system(s)
                print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure'))  # analyze return value
            else:
                raise Exception('Dataset not found.')

取大于等于x的最小值,该值能被divisor整除

def make_divisible(x, divisor
评论 34
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值