import argparse
import os
import platform
import sys
from pathlib import Path
import numpy as np
import torch
from utils.augmentations import letterbox
import cv2
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from ultralytics.utils.plotting import Annotator, colors, save_one_box
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.torch_utils import select_device, smart_inference_mode
import ctypes
import platform
from pathlib import Path
import cv2
import uvicorn
from fastapi import FastAPI
def load_mode(device:str, weights: str, data: str, imgsz=(640, 640)):
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False)
stride = model.stride
imgsz = check_img_size(imgsz, s=stride)
return model, imgsz, device
def transform_images(im0, img_size, stride, auto):
assert im0 is not None, f'Image is None'
im = letterbox(im0, img_size, stride=stride, auto=auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
return im
def process_predict(pred, im0, im, names, img_name):
line_thickness = 5
results = []
for i, det in enumerate(pred): # per image
# gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Write results
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
for *xyxy, conf, cls in reversed(det):
c = cls.item() # integer class
label = names[c]
x1, y1, x2, y2 = xyxy
results.append((label, (x1.item(), y1.item(), x2.item(), y2.item()), conf.item()))
# results.append((label, (x1.item(), y1.item(), x2.item(), y2.item()), conf.item()))
# results.append((label, x1.item(), y1.item(), x2.item(), y2.item(), conf.item())
annotator.box_label(xyxy, label, color=colors(c, True))
im0 = annotator.result()
cv2.imwrite('/nfs/wms/profiles/predict_img/' + str(img_name) + '_predict.jpg', im0)
results.append('/predict_img/' + str(img_name) + '_predict.jpg')
print('results:', results)
return results
# @smart_inference_mode()
def main2( weights, data, imgsz, conf_thres, iou_thres, max_det, device, img_source):
print(weights, data,device, img_source,'在模型函数里面')
model, imgsz, device = load_mode(device, weights, data, imgsz)
# im0 = cv2.imread("/home/data/yuzhong/image_detection/images/HBNR009X-HP0006-20190703-00308.jpg")
im0 = cv2.imread(img_source)
img_name = os.path.split(img_source)[-1]
im = transform_images(im0, imgsz, model.stride, model.pt)
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
pred = model(im, augment=False, visualize=False)
pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=max_det)
# return process_predict(pred, im0, im, model.names)
predict = process_predict(pred, im0, im, model.names, img_name)
return predict
# return {"result": f"{predict}"}
app = FastAPI()
# 绑定路由和视图函数
@app.get("/get")
def index_get(weights, data, img_source, device):
# print(weights, data, img_source, device, '看看收到了没有')
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / str(weights), #yolov5s.pt
help='model path or triton URL')
parser.add_argument('--data', type=str, default=ROOT / str(data), help='(optional) dataset.yaml path') # data/coco128.yaml
parser.add_argument('--img_source', type=str,
default=ROOT / str(img_source)) # /home/data/yuzhong/image_detection/images/HBNR009X-HP0006-20190703-00308.jpg
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640],
help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
parser.add_argument('--device', default=device, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1
print(opt.weights, opt.data, opt.imgsz, opt.conf_thres, opt.iou_thres, opt.max_det, opt.device, opt.img_source,
'看看收到了没有')
result=main2(opt.weights, opt.data, opt.imgsz, opt.conf_thres, opt.iou_thres, opt.max_det, opt.device, opt.img_source)
result = list(result)
return result
# http://127.0.0.1:7755/docs自带调试接口
if __name__ == "__main__":
uvicorn.run(f"{Path(__file__).stem}:app", host="0.0.0.0", port=7757)
【野生动物数据检测筛选】
最新推荐文章于 2024-09-10 15:58:02 发布