yolov5 hubconf的模型部署与使用, 含视频、摄像头处理
模型加载流程
PyTorch Hub是一个简易API和工作流程,为复现研究提供了基本构建模块
在项目的根目录下, 实现 hubconf.py
文件, 这个项目就可以使用pytorch hub进行加载模型
我们可以看到, yolov5 下就实现了这个文件
我们使用一行代码就可以将模型加载出来
# repo_or_dir : hubconf.py在哪个目录下
# model : 调用哪个函数, 如果是自定义模型, 需要传入模型地址, 如果是预训练模型, 需要传入模型名称
# source : 模型在本地还是在github 上, 如果在本地repo_or_dir就写项目地址
model = torch.hub.load(repo_or_dir='./', model='yolov5s', source='local')
# 也可以 model = torch.hub.load(repo_or_dir='./', 'custom', './yolov5x.pt', source='local')
对应我们的写法, 调用到了 hubconf.py
的 yolov5s
方法
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
return _create("yolov5s", pretrained, channels, classes, autoshape, _verbose, device)
我们继续看 _create
方法
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
"""
Creates or loads a YOLOv5 model.
Arguments:
name (str): model name 'yolov5s' or path 'path/to/best.pt'
pretrained (bool): load pretrained weights into the model
channels (int): number of input channels
classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen
device (str, torch.device, None): device to use for model parameters
Returns:
YOLOv5 model
"""
# ...
check_requirements(ROOT / "requirements.txt", exclude=("opencv-python", "tensorboard", "thop"))
name = Path(name)
path = name.with_suffix(".pt") if name.suffix == "" and not name.is_dir() else name # checkpoint path
try:
device = select_device(device)
if pretrained and channels == 3 and classes == 80:
try:
model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
# 无效代码...
else:
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
except Exception:
model = attempt_load(path, device=device, fuse=False) # arbitrary model
# ...
return model.to(device)
except Exception as e:
help_url = "https://docs.ultralytics.com/yolov5/tutorials/pytorch_hub_model_loading"
s = f"{e}. Cache may be out of date, try `force_reload=True` or see {help_url} for help."
raise Exception(s) from e
去除了无效代码, 我们可以看出, 首先将模型加载成DetectMultiBackend
对象, 这也是 detect.py
中用到的检测模型, 然后, 再封装成AutoShape
对象, 这个对象可以帮我们完成从letterbox 图像预处理,
到预检测框生成
, NMS过滤
, 再 scale_boxes恢复图像原大小
的全过程
我们可以从其forward
方法可以看出
def forward(self, ims, size=640, augment=False, profile=False):
"""
Performs inference on inputs with optional augment & profiling.
Supports various formats including file, URI, OpenCV, PIL, numpy, torch.
"""
# For size(height=640, width=1280), RGB images example inputs are:
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
dt = (Profile(), Profile(), Profile())
with dt[0]:
if isinstance(size, int): # expand
size = (size, size)
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != "cpu") # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast):
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
# Pre-process
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
for i, im in enumerate(ims):
f = f"image{i}" # filename
if isinstance(im, (str, Path)): # filename or uri
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im), im
im = np.asarray(exif_transpose(im))
elif isinstance(im, Image.Image): # PIL Image
im, f = np.asarray(exif_transpose(im)), getattr(im, "filename", f) or f
files.append(Path(f).with_suffix(".jpg").name)
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = max(size) / max(s) # gain
shape1.append([int(y * g) for y in s])
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
with amp.autocast(autocast):
# Inference
with dt[1]:
y = self.model(x, augment=augment) # forward
# Post-process
with dt[2]:
y = non_max_suppression(
y if self.dmb else y[0],
self.conf,
self.iou,
self.classes,
self.agnostic,
self.multi_label,
max_det=self.max_det,
) # NMS
for i in range(n):
scale_boxes(shape1, y[i][:, :4], shape0[i])
return Detections(ims, y, files, dt, self.names, x.shape)
最后返回一个Detections
对象, 这个对象可以帮我们得到检测框位置, 检测框图像等等多种数据
模型基本使用
- 图片可视化:results.show()
- 图片结果:results.render()
- 数据结果:results,pandas()
- 裁切结果:results.crops()
- 文本结果:results.print()
import torch
from models.common import AutoShape, Detections
def test1():
# 导入模型
'''
repo_or_dir: 项目地址, 该地址下需要有一个 hubconf.py 文件
model: 调用哪个函数, 如果是自定义模型, 需要传入模型地址, 如果是预训练模型, 需要传入模型名称, 这里是自定义
source: 是在本地还是在github 上, 如果在本地repo_or_dir就写项目地址
'''
model:AutoShape = torch.hub.load('./', 'custom', './weights/yolov5x.pt', source='local')
# model = torch.hub.load(repo_or_dir='./', model='yolov5x', source='local')
print(type(model))
# 图片地址
image_path = './data/images/bus.jpg'
# 预测
results:Detections = model(image_path)
# print(results.render()[0].shape) # (1080, 810, 3)
print(results)
# results.show() # 展示图片
print(results.pandas().xyxy[0])
'''
xmin ymin xmax ymax confidence class name
0 667.667297 393.297791 809.253662 880.357971 0.937885 0 person
1 15.557645 233.723160 799.038452 732.006592 0.925769 5 bus
2 49.887505 396.291016 245.754440 904.234375 0.922690 0 person
3 221.138351 409.338318 344.786987 860.277771 0.898539 0 person
4 0.250775 550.910889 78.686089 870.606018 0.711578 0 person
'''
print(results.pandas().xyxyn[0])
'''
xmin ymin xmax ymax confidence class name
0 0.824281 0.364165 0.999079 0.815146 0.937885 0 person
1 0.019207 0.216410 0.986467 0.677784 0.925769 5 bus
2 0.061590 0.366936 0.303401 0.837254 0.922690 0 person
3 0.273010 0.379017 0.425663 0.796553 0.898539 0 person
4 0.000310 0.510103 0.097143 0.806117 0.711578 0 person
'''
print(results.pandas().xywh[0])
'''
xcenter ycenter width height confidence class name
0 738.460449 636.827881 141.586365 487.060181 0.937885 0 person
1 407.298035 482.864868 783.480835 498.283447 0.925769 5 bus
2 147.820969 650.262695 195.866943 507.943359 0.922690 0 person
3 282.962677 634.808044 123.648636 450.939453 0.898539 0 person
4 39.468433 710.758423 78.435310 319.695129 0.711578 0 person
'''
print(results.pandas().xywhn[0])
'''
xcenter ycenter width height confidence class name
0 0.911680 0.589655 0.174798 0.450982 0.937885 0 person
1 0.502837 0.447097 0.967260 0.461374 0.925769 5 bus
2 0.182495 0.602095 0.241811 0.470318 0.922690 0 person
3 0.349337 0.587785 0.152653 0.417537 0.898539 0 person
4 0.048726 0.658110 0.096834 0.296014 0.711578 0 person
'''
print(results.crop(save=False)) # 表示预测到的每个物体图像, 是个列表, 每个元素是字典, 是主要信息
'''
[{'box': [tensor(0.25077, device='cuda:0'), tensor(550.91089, device='cuda:0'), tensor(78.68609, device='cuda:0'), tensor(870.60602, device='cuda:0')], 'conf': tensor(0.71158, device='cuda:0'), 'cls': tensor(0., device='cuda:0'), 'label': 'person 0.71', 'im': array([[[128, 119, 116],
[123, 117, 112],
[125, 119, 112],
...,
[122, 152, 179],
....
....
]
'''
print(results.crop(save=False)[0]["im"])
from PIL import Image
# 展示裁剪后的图片
Image.fromarray(results.crop(save=False)[3]["im"][:, :, ::-1]).show()
print(str(results))
'''
image 1/1: 1080x810 4 persons, 1 bus
Speed: 19.0ms pre-process, 73.3ms inference, 5.0ms NMS per image at shape (1, 3, 640, 480)
'''
对视频, 摄像头等多种类型的数据进行推理
我们封装好了一套对图像, 视频, 摄像头数据进行推理的方法供大伙参考
# YOLOv5 PyTorch HUB Inference (DetectionModels only)
import math
import time
from threading import Thread
import cv2
import numpy as np
import torch
import platform
from utils.augmentations import letterbox
model = torch.hub.load(repo_or_dir='./', model='yolov5s', source='local') # or yolov5n - yolov5x6 or custom
im = r'C:\Users\产品研发部-xxx\Desktop\yolov5-master\data\images\bus.jpg'
startTime = time.time()
results = model(im) # inference
print(f'单张图像用时{((time.time() - startTime) * 1000):.2f}ms')
save_path = r'C:\Users\产品研发部-xxx\Desktop\yolov5-master\data\video\test.mp4'
vid = r'C:\Users\产品研发部-xxx\Desktop\yolov5-master\data\video\WIN_20240702_15_07_22_Pro.mp4'
def get_video(video_path: str, save_path: str, img_size=640):
"""
对Video类型数据进行解析
:param video_path: video路径
:param save_path: 保存路径
:param img_size: letterbox压缩的尺寸, 默认为 (640, 640)
:return:
"""
vid_stride = 1 # 帧跨度, 默认1
cap = cv2.VideoCapture(video_path) # 视频捕获器
fps = cap.get(cv2.CAP_PROP_FPS) # 视频帧率
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 视频宽度
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 视频高度
frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) / vid_stride) # 321 / 1
videoWriter = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) # 定义视频写入器, 第二个参数是编码方式
print(f'一共{frames:.2f}帧, 帧率: {fps:.2f} 宽: {w} 高: {h}')
for i in range(frames):
for _ in range(vid_stride): # 跨一下视频偏置, 比如说录制的视频2帧算一帧
cap.grab() # 尝试获取视频当前帧 , cap.read() = cap.grab() + cap.retrieve()
ret_val, im0 = cap.retrieve()
origin_im = im0.copy() # 保留一下源图像数据, 可以进行别的处理
while not ret_val: # 失败的情况, 或者视频结束了
cap.release() # 释放资源
ret_val, im0 = cap.read() # 读取帧数据
startTime = time.time()
results = model(im0,size=img_size) # 得到结果
print(f'第{i + 1}/{frames}帧用时{((time.time() - startTime) * 1000):.2f}ms')
videoWriter.write(results.render()[0])
videoWriter.release()
cap.release()
def get_img(im, img_size=640):
"""
基本用法, 返回对应渲染后的图像数据(PIL)
# 输入可以是下列各项 For size(height=640, width=1280), RGB images example inputs are:
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
:return:
"""
results = model(im, size=img_size) # inference
return results.render()[0]
# get_video(vid, save_path)
def get_stream(source='0', save_path='./test.mp4', view_img=True, img_size=640):
"""
处理摄像头, 流视频等数据
:param source: 默认为 0(摄像头) 也可以是 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
:param save_path: 保存路径
:param view_img: 是否观看实时处理结果
:param img_size: letterbox压缩的尺寸, 默认为 (640, 640)
:return:
"""
class LoadingStream:
def __init__(self, sources="0", vid_stride=1):
"""
:param sources: 默认摄像头
:param vid_stride:
"""
self.vid_stride = vid_stride
self.sources = eval(sources)
cap = cv2.VideoCapture(self.sources)
self.cap = cap
self.w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 640
self.h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 480
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan 30
self.frames = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float("inf") # infinite stream fallback inf
self.fps = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback 30
_, self.img = cap.read() # guarantee first frame
# 定义当前流媒体对应的线程, 且是守护线程
print(f"1/1 {self.sources} Success ({self.frames} frames {self.w}x{self.h} at {self.fps:.2f} FPS)")
self.thread = Thread(target=self.update, args=([cap, self.sources]), daemon=True)
# 启动
self.thread.start()
def update(self, cap, stream):
"""
:param cap: 捕获的媒体对象
:param stream: 流媒体或者摄像头url, 摄像头的话为 0
:return:
"""
n, f = 0, self.frames # frame number, frame array
while cap.isOpened() and n < f:
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
if n % self.vid_stride == 0:
success, im = cap.retrieve()
if success:
self.img = im
else:
print("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
self.img = np.zeros_like(self.img)
cap.open(stream) # re-open stream if signal was lost
time.sleep(0.0) # wait time
def __iter__(self):
self.count = -1
return self
def __next__(self):
self.count += 1
# 按 q 退出
if not self.thread.is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
cv2.destroyAllWindows()
self.cap.release()
raise StopIteration
im0 = self.img.copy()
return im0
def __len__(self):
"""Returns the number of sources in the dataset, supporting up to 32 streams at 30 FPS over 30 years."""
return 1 # 1E12 frames = 32 streams at 30 FPS for 30 years
loadingStream = LoadingStream(sources=source)
videoWriter = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), loadingStream.fps,
(loadingStream.w, loadingStream.h)) # 定义视频写入器, 第二个参数是编码方式
current_frame = 0
for img in loadingStream:
current_frame += 1
origin_im = img.copy() # 保留一下源图像数据, 可以进行别的处理
startTime = time.time()
results = model(img, size=img_size) # 得到结果
print(f'第{current_frame}帧用时{((time.time() - startTime) * 1000):.2f}ms')
process_img = results.render()[0]
videoWriter.write(process_img)
if view_img:
name = str('Camera' if eval(source) == 0 else source)
if platform.system() == "Linux":
cv2.namedWindow(name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(name, loadingStream.w, loadingStream.h)
cv2.imshow(name, process_img)
cv2.waitKey(1) # 1 millisecondimg_size
videoWriter.release()
# get_stream(source='0', save_path=save_path, view_img=True)
不给大伙看我的样子了哈, 上班摸鱼ing