YOLOv5介绍
YOLOv5为兼顾速度与性能的目标检测算法。笔者将在近期更新一系列YOLOv5的代码导读博客。YOLOv5为2021.1.5日发布的4.0版本。
YOLOv5开源项目github网址
源代码导读汇总网址
本博客导读的代码为utils文件夹下的general.py,取自1.27日更新的版本。
general.py
该文件提供了模型多个过程中用到的通用方法,每个功能以函数的方式进行定义。
以下为该文件必须导入的模块,其中部分文件来源于其他项目子文件。
# General utils
import glob # 仅支持部分通配符得文件搜索模块
import logging # 日志模块
import math # 数学公式模块
import os # 与操作系统进行交互的模块
import platform # 获得操作系统相关信息的模块
import random # 生成随机数的模块
import re # 用来匹配字符串(动态、模糊)的模块
import subprocess# 创建子进程的模块
import time # 用来获取系统时间的模块
from pathlib import Path #Path对象 简便对path进行操作
import cv2 # opencv库
import numpy as np # numpy矩阵处理函数库
import torch # pytorch框架
import torchvision # 为pytorch 提供一些辅助工具
import yaml # yaml配置文件模块
# 以下调用三个函数具体注释见 源代码导读的其他文件
from utils.google_utils import gsutil_getsize # 用于返回网站链接对应文件的大小
from utils.metrics import fitness # 返回指标的加权值得行向量
from utils.torch_utils import init_torch_seeds # 功能为初始化随机种子
以下为运行相关的一些基本的设置
# 下两行为设置tensor和numpy array的打印格式 linewidth为每一行字符上限 precision 为精度
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={
'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
cv2.setNumThreads(0) # 阻止opencv参与多线程(与 Pytorch的 Dataloader不兼容)
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # 确定最大的线程数 这里被限制在了8
该函数为对日志的设置进行初始化 rank为-1或0 时设置输出级别为WARN
def set_logging(rank=-1):
logging.basicConfig(
format="%(message)s",
level=logging.INFO if rank in [-1, 0] else logging.WARN)
该函数为初始化随机种子 统一random numpy torch 种子
def init_seeds(seed=0):
# 初始化随机种子生成器
random.seed(seed)
np.random.seed(seed)
init_torch_seeds(seed)
该函数返回最近的模型 'last.pt’对应的路径
def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
# 从Python版本3.5开始,glob模块支持该"**"指令(仅当传递recursive标志时才会解析该指令
# glob.glob函数匹配所有的符合条件的文件,并将其以list的形式返回
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
# os.getctime 返回路径对应文件的创建时间
# 也就是返回所有文件中创建时间最晚的路径
return max(last_list, key=os.path.getctime) if last_list else ''
用socket模块 检查当前主机网络连接是否可用
def check_online():
# Check internet connectivity
import socket
try:
socket.create_connection(("1.1.1.1", 53)) # check host accesability 该单词拼错 应为accessbility
return True
except OSError:
return False
检查当前代码是否是最新版 如果不是最新的 会提示使用git pull命令进行升级
def check_git_status():
# Recommend 'git pull' if code is out of date
# 彩色显示github单词 colorstr()函数后续介绍
print(colorstr('github: '), end='')
try:
# 检查以git结尾的路径存在
assert Path('.git').exists(), 'skipping check (not a git repository)'
# 但是包含"/workspace"的路径不存在
assert not Path('/workspace').exists(), 'skipping check (Docker image)' # not Path('/.dockerenv').exists()
# 保证主机网络是可用的
assert check_online(), 'skipping check (offline)'
# 这里是创建cmd命令 并创建子进程进行执行
cmd = 'git fetch && git config --get remote.origin.url'
url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
# n大于0说明 当前版本之后还有commit 因此当前版本不是最新的 s为输出的相关提示
if n > 0:
s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
f"Use 'git pull' to update or 'git clone {url}' to download latest."
else:
s = f'up to date with {url} ✅'
# 通过.encode().decode()的组合忽略掉无法用ascii编码的内容
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s)
except Exception as e:
print(e)
用于检查已经安装的包是否满足file对应txt文件的要求
def check_requirements(file='requirements.txt', exclude=()):
import pkg_resources #此为管理安装包信息相关模块
# x在被读取的时候转换为pkg_resources.Requirement类
# x.name 返回安装包对应名称 x.specifier 返回符号后面内容 即安装包对应版本
# requirements 负责将上述pkg类转换为列表对象 以供require函数使用
requirements = [f'{x.name}{x.specifier}' for x in pkg_resources.parse_requirements(Path(file).open())
if x.name not in exclude]
pkg_resources.require(requirements) #判断已经安装的包是否符合要求
检查img_size是否能被s整除,这里默认s为32
def check_img_size(img_size, s=32):
# make_divisible函数见下文注释
new_size = make_divisible(img_size, int(s)) # 返回大于等于img_size的最小能被s整除的值
if new_size != img_size: #新size和旧size不同时 打印出信息
print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
return new_size # 返回能被s整除的new_size
检查相关文件路径是否能找到文件,如果不能则在全局路径中寻找,匹配到多个时返回第一个
def check_file(file):
if os.path.isfile(file) or file == '':
return file
else:
files = glob.glob('./**/' + file, recursive=True) # 寻找文件
assert len(files), 'File Not Found: %s' % file # assert 文件没有找到
assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique
return files[0] # 返回文件
检查数据集 如果本地没有则从torch库中下载并解压数据集
# dict内容需要实际调试(暂无法逐行调试) 该方法只能适用某几个数据集
def check_dataset(dict):
val, s = dict.get('val'), dict.get('download')
if val and len(val):
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
if s and len(s): # download script
print('Downloading %s ...' % s)
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
torch.hub.download_url_to_file(s, f)
r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
else: # bash script
r = os.system(s)
print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
else:
raise Exception('Dataset not found.')
取大于等于x的最小值,该值能被divisor整除
def make_divisible(x, divisor