使用SAHI提升YOLO模型对小目标的检测效果

请添加图片描述

SAHI(Slicing Aided Hyper Inference)是一个用于对大型图像/小物体执行切片推理的计算机视觉库。它主要用于解决目标检测任务中,特别是小目标检测的问题。以下是关于 Python 的 SAHI 软件包的详细介绍:

SAHI 的主要功能和特点:

2. SAHI的主要特点

  • 支持多种目标检测模型(YOLO、MMDetection等)
  • 自动切片和合并结果
  • 可配置的切片大小和重叠率
  • 支持批量处理
  • 结果可视化

默认支持的权重格式

MODEL_TYPE_TO_MODEL_CLASS_NAME = {
    "ultralytics": "UltralyticsDetectionModel",
    "rtdetr": "RTDetrDetectionModel",
    "mmdet": "MmdetDetectionModel",
    "yolov5": "Yolov5DetectionModel",
    "detectron2": "Detectron2DetectionModel",
    "huggingface": "HuggingfaceDetectionModel",
    "torchvision": "TorchVisionDetectionModel",
    "yolov5sparse": "Yolov5SparseDetectionModel",
    "yolov8onnx": "Yolov8OnnxDetectionModel",
}

ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"]

SAHI 的安装:

可以使用 pip 安装 SAHI:

pip install opencv-python matplotlib torch ultralytics sahi -i https://mirrors.cloud.tencent.com/pypi/simple

代码如下

import cv2
import matplotlib.pyplot as plt
import torch
from ultralytics import YOLO
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import read_image

def load_yolo_model(model_path, confidence_threshold=0.5):
    """
    加载YOLO模型
    Args:
        model_path: 模型路径
        confidence_threshold: 置信度阈值,默认0.5
    Returns:
        加载好的YOLO模型
    """
    yolo_model = YOLO(model_path)
    yolo_model.conf = confidence_threshold
    return yolo_model

def load_sahi_model(model_path, confidence_threshold=0.5, device="cpu"):
    """
    加载SAHI模型
    Args:
        model_path: 模型路径
        confidence_threshold: 置信度阈值,默认0.5
        device: 运行设备,默认CPU
    Returns:
        加载好的SAHI模型
    """
    return AutoDetectionModel.from_pretrained(
        model_type="ultralytics",
        model_path=model_path,
        confidence_threshold=confidence_threshold,
        device=device,
    )

def read_and_prepare_image(image_path):
    """
    读取并准备图像
    Args:
        image_path: 图像路径
    Returns:
        RGB格式的图像数组
    """
    try:
        image = cv2.imread(image_path)
        return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转换为RGB格式
    except FileNotFoundError:
        print(f"错误:找不到图像文件 {image_path}")
        exit()

def yolo_inference(yolo_model, image, device="cuda"):
    """
    使用YOLO模型进行推理
    Args:
        yolo_model: YOLO模型
        image: 输入图像
        device: 运行设备,默认CUDA
    Returns:
        检测结果数组
    """
    yolo_results = yolo_model(image, device=device)
    return yolo_results[0].boxes.data.cpu().numpy()

def sahi_inference(detection_model, image, slice_height=512, slice_width=512, 
                  overlap_height_ratio=0.2, overlap_width_ratio=0.2):
    """
    使用SAHI进行切片推理
    Args:
        detection_model: 检测模型
        image: 输入图像
        slice_height: 切片高度
        slice_width: 切片宽度
        overlap_height_ratio: 高度方向重叠比例
        overlap_width_ratio: 宽度方向重叠比例
    Returns:
        SAHI检测结果列表
    """
    sahi_result = get_sliced_prediction(
        image,
        detection_model,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
    )
    return sahi_result.object_prediction_list

def visualize_results(image, yolo_detections, sahi_detections):
    """
    可视化检测结果
    Args:
        image: 原始图像
        yolo_detections: YOLO检测结果
        sahi_detections: SAHI检测结果
    """
    # 创建1x2的图像网格
    fig, axes = plt.subplots(1, 2, figsize=(20, 10))

    # 绘制不使用SAHI的结果
    image_no_sahi = image.copy()
    for *xyxy, conf, cls in yolo_detections:
        x1, y1, x2, y2 = map(int, xyxy)
        # 绘制边界框
        cv2.rectangle(image_no_sahi, (x1, y1), (x2, y2), (0, 255, 0), 2)
        # 添加类别和置信度文本
        cv2.putText(image_no_sahi, f"{int(cls)}:{conf:.2f}", (x1, y1 - 10), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

    axes[0].imshow(image_no_sahi)
    axes[0].set_title("Without SAHI")

    # 绘制使用SAHI的结果
    image_sahi = image.copy()
    for prediction in sahi_detections:
        bbox = prediction.bbox.to_xyxy()
        x1, y1, x2, y2 = map(int, bbox)
        # 确保score是浮点数
        score = float(prediction.score) if hasattr(prediction.score, '__float__') else prediction.score.value
        # 绘制边界框
        cv2.rectangle(image_sahi, (x1, y1), (x2, y2), (0, 0, 255), 2)
        # 添加类别和置信度文本
        cv2.putText(image_sahi, f"{prediction.category.id}:{score:.2f}", (x1, y1 - 10), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)

    axes[1].imshow(image_sahi)
    axes[1].set_title("With SAHI")

    plt.show()
    fig.savefig("comparison.jpg")
    print("图像已保存为 comparison.jpg")

def main():
    """
    主函数
    """
    # 加载YOLO和SAHI模型
    yolo_model = load_yolo_model(r"D:\project\w\v11\yolo11n.engine")
    detection_model = load_sahi_model(r"D:\project\w\v11\yolo11n.engine", device="cuda")

    # 读取待检测图像
    image = read_and_prepare_image("1.jpg")

    # 分别使用YOLO和SAHI进行推理
    yolo_detections = yolo_inference(yolo_model, image, device="cuda")
    sahi_detections = sahi_inference(detection_model, image)

    # 可视化并保存结果
    visualize_results(image, yolo_detections, sahi_detections)

if __name__ == "__main__":
    main()

额外说明

默认只支持pt推理,如果想要直接使用onnx和engine推理,需要修改sahi库中代码
位置在 E:\miniconda3\envs\yolov9\Lib\site-packages\sahi\models\ultralytics.py中
修改后如下


class UltralyticsDetectionModel(DetectionModel):
    def check_dependencies(self) -> None:
        check_requirements(["ultralytics"])

    def load_model(self):
        """
        Detection model is initialized and set to self.model.
        """

        from ultralytics import YOLO

        try:
            model = YOLO(self.model_path)
            # model.to(self.device)
            self.set_model(model)
        except Exception as e:
            raise TypeError("model_path is not a valid Ultralytics model path: ", e)
03-30
### Sahi简介及其在IT工具和框架中的定位 Sahi是一种用于Web应用程序测试的开源自动化工具[^3]。它通过模拟真实用户的操作行为来验证网站的功能性和稳定性,从而帮助企业快速发现并修复潜在问题。与ChangeAgent不同的是,Sahi专注于功能测试而非文件链接管理[^1]。 #### 功能特点 Sahi的主要特性包括但不限于以下几点: - **跨浏览器支持**:能够兼容多种主流浏览器(如Chrome、Firefox等),确保应用在不同环境下的表现一致性。 - **易于编写脚本**:提供简单直观的API接口以及JavaScript作为主要编程语言,降低学习成本的同时提高了开发效率。 - **无插件运行模式**:无需安装额外扩展程序即可完成大部分任务,简化部署流程。 以下是利用Python调用selenium库实现类似功能的一个例子: ```python from selenium import webdriver driver = webdriver.Chrome() driver.get("http://example.com") element = driver.find_element_by_id("loginButton") if element.is_displayed(): print("Login button found.") else: print("Element not visible.") driver.quit() ``` 尽管上述代码片段展示的是基于 Selenium 的解决方案,但它可以用来对比说明 Sahi 如何更高效地处理某些特定场景下的交互逻辑[^4]。 #### 应用价值 对于关注于提升产品质量的企业而言,采用合适的测试手段至关重要[^2]。而像Sahi这样的专用框架则能有效缩短迭代周期时间,并减少因人为失误引发的风险概率。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值