mmrotate实现图片的批量推理检测

在mmrotate中新建一个detect.py文件,写入以下代码即可实现对一个文件夹内的多张图片进行推理检测。

import os
from argparse import ArgumentParser
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
import mmrotate  # noqa: F401

def parse_args():
    parser = ArgumentParser()
    parser.add_argument('img_dir', help='Directory of images')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument('--out-dir', default='./output', help='Path to output directory')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='dota',
        choices=['dota', 'sar', 'hrsc', 'hrsc_classwise', 'random'],
        help='Color palette used for visualization')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='bbox score threshold')
    args = parser.parse_args()
    return args

def main(args):
    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)

    # Create output directory if it doesn't exist
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    # Iterate through all images in the directory
    for img_file in os.listdir(args.img_dir):
        if img_file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
            img_path = os.path.join(args.img_dir, img_file)
            # test the image
            result = inference_detector(model, img_path)
            # show the results
            out_file = os.path.join(args.out_dir, img_file)
            show_result_pyplot(
                model,
                img_path,
                result,
                palette=args.palette,
                score_thr=args.score_thr,
                out_file=out_file)

if __name__ == '__main__':
    args = parse_args()
    main(args)

在终端中调用:
python demo/image_demo.py <存储图片的文件夹路径> <模型配置文件.py> <模型权重文件.pth> --out-dir <输出结果的存储文件夹路径>
示例如下:
python demo/image_demo.py data/temp/images my_demo/output/work_dir/two_stage/oriented_rcnn/train_0.01/my_oriented_rcnn_r50_fpn_1x_dota_le90.py my_demo/output/work_dir/two_stage/oriented_rcnn/train_0.01/best_mAP_epoch_14.pth --out-dir data/temp/detect

### 关于 MMROTATE 的可视化方法 MMRotate 是一个用于旋转框目标检测的开源库,提供了多种工具来进行模型训练、推理以及结果可视化。对于可视化的实现主要依赖 `mmdet` 提供的基础功能,并在此基础上进行了扩展以支持旋转矩形。 #### 使用预训练模型进行图像预测并保存结果 可以通过命令行调用官方提供的脚本完成单张图片的目标检测与可视化操作[^2]: ```bash # 下载指定配置文件对应的权重 mim download mmrotate --config oriented-rcnn-le90_r50_fpn_1x_dota --dest . # 执行检测并将结果显示在新创建的结果图上 python demo/image_demo.py demo/demo.jpg oriented-rcnn-le90_r50_fpn_1x_dota.py oriented_rcnn_r50_fpn_1x_dota_le90-6d2b2ce0.pth --out-file result.jpg ``` 上述过程会读取给定路径下的输入图片,在终端打印出识别到的对象信息的同时还会生成一张带有标注框的新图片存放在当前工作目录下名为 `result.jpg` 文件中。 #### 自定义绘图函数展示更多细节 如果希望进一步自定义视觉效果,则可以在 Python 中编写额外代码片段来处理输出数据结构中的边界框坐标等属性值。下面给出一段简单的例子说明如何利用 Matplotlib 库绘制带角度的方向边框: ```python import cv2 from matplotlib import pyplot as plt import numpy as np def draw_oriented_bbox(img, bboxes, labels=None): """ 绘制具有方向性的边界框 参数: img (str or ndarray): 输入图像路径或数组形式的数据. bboxes (list of list[float]): 形状为[N, 5]的一维列表, 每个元素表示[x_center,y_center,w,h,angle]. labels (optional[list[str]]): 对应类别标签,默认为空即不显示文字描述. 返回: fig: 包含所画图形对象的Figure实例. """ if isinstance(img, str): img = cv2.imread(img) h, w = img.shape[:2] fig, ax = plt.subplots(figsize=(w / 100., h / 100.), dpi=100) for i, bbox in enumerate(bboxes): center_x, center_y, width, height, angle = map(float, bbox) rect = ((center_x, center_y), (width, height), -np.degrees(angle)) box = cv2.boxPoints(rect).astype(np.int32) color = tuple(map(int, np.random.randint(0, 255, size=[3]))) # Draw rotated rectangle with random colors cv2.polylines(img,[box],True,color=color,thickness=2,lineType=cv2.LINE_AA) if labels is not None and len(labels)>i: label_text = f'{labels[i]}' text_size = cv2.getTextSize(label_text,cv2.FONT_HERSHEY_SIMPLEX,.7,(0,))[0] point = max(center_x-text_size[0]/2,0),max(center_y-text_size[1]-8,0) cv2.putText(img,label_text,tuple(point.astype('int')),cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7,color=color,thickness=2) ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) plt.axis('off') return fig ``` 此段程序接受任意数量的位置参数作为待渲染区域的信息源,并能够依据传入的具体数值调整最终呈现出来的样式特征。此外还可以接收可选的关键字参数用来附加分类名称至各个实体旁边以便更直观地区分不同种类的事物。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值