SAHI(Slicing Aided Hyper Inference)是一个用于对大型图像/小物体执行切片推理的计算机视觉库。它主要用于解决目标检测任务中,特别是小目标检测的问题。以下是关于 Python 的 SAHI 软件包的详细介绍:
SAHI 的主要功能和特点:
2. SAHI的主要特点
- 支持多种目标检测模型(YOLO、MMDetection等)
- 自动切片和合并结果
- 可配置的切片大小和重叠率
- 支持批量处理
- 结果可视化
默认支持的权重格式
MODEL_TYPE_TO_MODEL_CLASS_NAME = {
"ultralytics": "UltralyticsDetectionModel",
"rtdetr": "RTDetrDetectionModel",
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
"detectron2": "Detectron2DetectionModel",
"huggingface": "HuggingfaceDetectionModel",
"torchvision": "TorchVisionDetectionModel",
"yolov5sparse": "Yolov5SparseDetectionModel",
"yolov8onnx": "Yolov8OnnxDetectionModel",
}
ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"]
SAHI 的安装:
可以使用 pip 安装 SAHI:
pip install opencv-python matplotlib torch ultralytics sahi -i https://mirrors.cloud.tencent.com/pypi/simple
代码如下
import cv2
import matplotlib.pyplot as plt
import torch
from ultralytics import YOLO
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import read_image
def load_yolo_model(model_path, confidence_threshold=0.5):
"""
加载YOLO模型
Args:
model_path: 模型路径
confidence_threshold: 置信度阈值,默认0.5
Returns:
加载好的YOLO模型
"""
yolo_model = YOLO(model_path)
yolo_model.conf = confidence_threshold
return yolo_model
def load_sahi_model(model_path, confidence_threshold=0.5, device="cpu"):
"""
加载SAHI模型
Args:
model_path: 模型路径
confidence_threshold: 置信度阈值,默认0.5
device: 运行设备,默认CPU
Returns:
加载好的SAHI模型
"""
return AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=model_path,
confidence_threshold=confidence_threshold,
device=device,
)
def read_and_prepare_image(image_path):
"""
读取并准备图像
Args:
image_path: 图像路径
Returns:
RGB格式的图像数组
"""
try:
image = cv2.imread(image_path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转换为RGB格式
except FileNotFoundError:
print(f"错误:找不到图像文件 {image_path}")
exit()
def yolo_inference(yolo_model, image, device="cuda"):
"""
使用YOLO模型进行推理
Args:
yolo_model: YOLO模型
image: 输入图像
device: 运行设备,默认CUDA
Returns:
检测结果数组
"""
yolo_results = yolo_model(image, device=device)
return yolo_results[0].boxes.data.cpu().numpy()
def sahi_inference(detection_model, image, slice_height=512, slice_width=512,
overlap_height_ratio=0.2, overlap_width_ratio=0.2):
"""
使用SAHI进行切片推理
Args:
detection_model: 检测模型
image: 输入图像
slice_height: 切片高度
slice_width: 切片宽度
overlap_height_ratio: 高度方向重叠比例
overlap_width_ratio: 宽度方向重叠比例
Returns:
SAHI检测结果列表
"""
sahi_result = get_sliced_prediction(
image,
detection_model,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
)
return sahi_result.object_prediction_list
def visualize_results(image, yolo_detections, sahi_detections):
"""
可视化检测结果
Args:
image: 原始图像
yolo_detections: YOLO检测结果
sahi_detections: SAHI检测结果
"""
# 创建1x2的图像网格
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
# 绘制不使用SAHI的结果
image_no_sahi = image.copy()
for *xyxy, conf, cls in yolo_detections:
x1, y1, x2, y2 = map(int, xyxy)
# 绘制边界框
cv2.rectangle(image_no_sahi, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 添加类别和置信度文本
cv2.putText(image_no_sahi, f"{int(cls)}:{conf:.2f}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
axes[0].imshow(image_no_sahi)
axes[0].set_title("Without SAHI")
# 绘制使用SAHI的结果
image_sahi = image.copy()
for prediction in sahi_detections:
bbox = prediction.bbox.to_xyxy()
x1, y1, x2, y2 = map(int, bbox)
# 确保score是浮点数
score = float(prediction.score) if hasattr(prediction.score, '__float__') else prediction.score.value
# 绘制边界框
cv2.rectangle(image_sahi, (x1, y1), (x2, y2), (0, 0, 255), 2)
# 添加类别和置信度文本
cv2.putText(image_sahi, f"{prediction.category.id}:{score:.2f}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
axes[1].imshow(image_sahi)
axes[1].set_title("With SAHI")
plt.show()
fig.savefig("comparison.jpg")
print("图像已保存为 comparison.jpg")
def main():
"""
主函数
"""
# 加载YOLO和SAHI模型
yolo_model = load_yolo_model(r"D:\project\w\v11\yolo11n.engine")
detection_model = load_sahi_model(r"D:\project\w\v11\yolo11n.engine", device="cuda")
# 读取待检测图像
image = read_and_prepare_image("1.jpg")
# 分别使用YOLO和SAHI进行推理
yolo_detections = yolo_inference(yolo_model, image, device="cuda")
sahi_detections = sahi_inference(detection_model, image)
# 可视化并保存结果
visualize_results(image, yolo_detections, sahi_detections)
if __name__ == "__main__":
main()
额外说明
默认只支持pt推理,如果想要直接使用onnx和engine推理,需要修改sahi库中代码
位置在 E:\miniconda3\envs\yolov9\Lib\site-packages\sahi\models\ultralytics.py中
修改后如下
class UltralyticsDetectionModel(DetectionModel):
def check_dependencies(self) -> None:
check_requirements(["ultralytics"])
def load_model(self):
"""
Detection model is initialized and set to self.model.
"""
from ultralytics import YOLO
try:
model = YOLO(self.model_path)
# model.to(self.device)
self.set_model(model)
except Exception as e:
raise TypeError("model_path is not a valid Ultralytics model path: ", e)