【野生动物数据检测筛选】

import argparse
import os
import platform
import sys
from pathlib import Path
import numpy as np
import torch
from utils.augmentations import letterbox
import cv2
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

from ultralytics.utils.plotting import Annotator, colors, save_one_box

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
                           increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.torch_utils import select_device, smart_inference_mode



import ctypes
import platform
from pathlib import Path
import cv2
import uvicorn
from fastapi import FastAPI





def load_mode(device:str, weights: str, data: str, imgsz=(640, 640)):
    device = select_device(device)
    model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False)
    stride = model.stride
    imgsz = check_img_size(imgsz, s=stride)
    return model, imgsz, device

def transform_images(im0, img_size, stride, auto):
    assert im0 is not None, f'Image is None'
    im = letterbox(im0, img_size, stride=stride, auto=auto)[0]  # padded resize
    im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    im = np.ascontiguousarray(im)  # contiguous
    return im


def process_predict(pred, im0, im, names, img_name):
    line_thickness = 5
    results = []
    for i, det in enumerate(pred):  # per image
        # gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
        if len(det):
            # Rescale boxes from img_size to im0 size
            det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
            # Write results
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
            for *xyxy, conf, cls in reversed(det):
                c = cls.item()  # integer class
                label = names[c]
                x1, y1, x2, y2 = xyxy
	        results.append((label, (x1.item(), y1.item(), x2.item(), y2.item()), conf.item()))			
               # results.append((label, (x1.item(), y1.item(), x2.item(), y2.item()), conf.item()))
               # results.append((label, x1.item(), y1.item(), x2.item(), y2.item(), conf.item())
                annotator.box_label(xyxy, label, color=colors(c, True))
            im0 = annotator.result()
    cv2.imwrite('/nfs/wms/profiles/predict_img/' + str(img_name) + '_predict.jpg', im0)
    results.append('/predict_img/' + str(img_name) + '_predict.jpg')
    print('results:', results)
    return results



# @smart_inference_mode()
def main2( weights, data, imgsz, conf_thres, iou_thres, max_det, device, img_source):

    print(weights, data,device, img_source,'在模型函数里面')
    model, imgsz, device = load_mode(device, weights, data, imgsz)
    # im0 = cv2.imread("/home/data/yuzhong/image_detection/images/HBNR009X-HP0006-20190703-00308.jpg")
    im0 = cv2.imread(img_source)
    img_name = os.path.split(img_source)[-1]
    im = transform_images(im0, imgsz, model.stride, model.pt)
    im = torch.from_numpy(im).to(model.device)
    im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
    im /= 255  # 0 - 255 to 0.0 - 1.0
    if len(im.shape) == 3:
        im = im[None]  # expand for batch dim
    pred = model(im, augment=False, visualize=False)

    pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=max_det)

    # return process_predict(pred, im0, im, model.names)
    predict = process_predict(pred, im0, im, model.names, img_name)
    return predict
    # return {"result": f"{predict}"}



app = FastAPI()


# 绑定路由和视图函数
@app.get("/get")
def index_get(weights, data, img_source, device):
   # print(weights, data, img_source, device, '看看收到了没有')
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / str(weights),  #yolov5s.pt
                        help='model path or triton URL')
    parser.add_argument('--data', type=str, default=ROOT / str(data), help='(optional) dataset.yaml path')  # data/coco128.yaml
    parser.add_argument('--img_source', type=str,
                        default=ROOT / str(img_source))                                                           # /home/data/yuzhong/image_detection/images/HBNR009X-HP0006-20190703-00308.jpg
    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640],
                        help='inference size h,w')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
    parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
    parser.add_argument('--device', default=device, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    opt = parser.parse_args()
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1

    print(opt.weights, opt.data, opt.imgsz, opt.conf_thres, opt.iou_thres, opt.max_det, opt.device, opt.img_source,
          '看看收到了没有')

    result=main2(opt.weights, opt.data, opt.imgsz, opt.conf_thres, opt.iou_thres, opt.max_det, opt.device, opt.img_source)
    result = list(result)
    return result



    # http://127.0.0.1:7755/docs自带调试接口
if __name__ == "__main__":
    uvicorn.run(f"{Path(__file__).stem}:app", host="0.0.0.0", port=7757)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值