根据IOU对DOTA标注数据进行旋转框合并

背景

在目标检测领域,处理标注文件中的旋转框(rotated bounding boxes)是一个常见的任务,尤其是当对象之间存在重叠时,如何将重叠的旋转框合并为一个更紧凑的表示非常重要。本文将介绍如何编写Python脚本,通过并行化方式处理DOTA标注文件中的重叠旋转框,并将其合并为最小的旋转矩形。

我们将逐步讲解代码中的每个功能模块,包括如何解析DOTA格式标注文件、计算旋转框的重叠、并将它们合并为新的标注文件。

在前一篇博客中介绍了如何计算dota数据的iou,感兴趣的朋友请移步至

计算DOTA文件的IOU-优快云博客

完整代码

import os
import logging
from shapely.geometry import Polygon
import numpy as np
from itertools import combinations
from concurrent.futures import ProcessPoolExecutor, as_completed
import argparse
import pandas as pd

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)

def parse_args():
    """
    解析命令行参数。
    """
    parser = argparse.ArgumentParser(description='对标注文件进行旋转框合并处理,检测重叠并生成新的标注文件。')
    parser.add_argument('--input_folder', type=str, required=True, help='输入标注文件夹路径')
    parser.add_argument('--output_folder', type=str, required=True, help='输出标注文件夹路径')
    parser.add_argument('--by_class', action='store_true', default=True, help='是否按类别统计和合并重叠框')
    parser.add_argument('--num_workers', type=int, default=None, help='并行处理的进程数,默认为CPU核心数')
    return parser.parse_args()

def parse_annotation_file(file_path, by_class=False):
    """
    解析标注文件,提取对象的类别和多边形坐标。

    参数:
        file_path (str): 标注文件的路径。
        by_class (bool): 是否按类别解析。

    返回:
        dict or list: 按类别分组的字典或所有对象的列表。
    """
    objects = []
    try:
        with open(file_path, 'r') as file:
            for line_num, line in enumerate(file, 1):
                parts = line.strip().split()
                if len(parts) < 9:
                    logging.warning(f"{file_path} 第{line_num}行格式不正确,跳过。")
                    continue
                try:
                    # 假设坐标为前8个元素,类别为第9个元素
                    coords = list(map(float, parts[:8]))
                    obj_class = parts[8]
                    # 将坐标转换为Shapely多边形
                    polygon = Polygon(np.array(coords).reshape(-1, 2))
                    if not polygon.is_valid:
                        logging.warning(f"{file_path} 第{line_num}行的多边形无效,跳过。")
                        continue
                    objects.append((obj_class, polygon))
                except ValueError as ve:
                    logging.error(f"{file_path} 第{line_num}行坐标转换错误: {ve}")
    except Exception as e:
        logging.error(f"读取文件 {file_path} 时发生错误: {e}")
        return {} if by_class else []

    if by_class:
        class_dict = {}
        for obj_class, polygon in objects:
            class_dict.setdefault(obj_class, []).append(polygon)
        return class_dict
    else:
        return [polygon for _, polygon in objects]

def compute_iou(poly1, poly2):
    """
    计算两个多边形的IOU。

    参数:
        poly1 (Polygon): 第一个多边形。
        poly2 (Polygon): 第二个多边形。

    返回:
        float: 两个多边形的IOU值。
    """
    intersection = poly1.intersection(poly2).area
    union = poly1.union(poly2).area
    if union == 0:
        return 0
    return intersection / union

def find_connected_components(polygons):
    """
    使用并查集算法找到多边形的连通组件。

    参数:
        polygons (list of Polygon): 多边形列表。

    返回:
        list of lists: 每个子列表是一个连通组件中的多边形。
    """
    parent = {i: i for i in range(len(polygons))}

    def find(i):
        while parent[i] != i:
            parent[i] = parent[parent[i]]
            i = parent[i]
        return i

    def union(i, j):
        pi, pj = find(i), find(j)
        if pi != pj:
            parent[pi] = pj

    # 遍历所有组合,检查是否有重叠
    for i, j in combinations(range(len(polygons)), 2):
        if polygons[i].intersects(polygons[j]):
            union(i, j)

    # 聚集连通组件
    components = {}
    for i in range(len(polygons)):
        root = find(i)
        components.setdefault(root, []).append(polygons[i])

    return list(components.values())

def merge_polygons(polygons):
    """
    合并一组多边形为一个最小旋转矩形。

    参数:
        polygons (list of Polygon): 需要合并的多边形列表。

    返回:
        Polygon: 合并后的多边形。
    """
    # 计算所有多边形的联合
    merged = polygons[0]
    for poly in polygons[1:]:
        merged = merged.union(poly)
    # 获取最小旋转矩形
    min_rot_rect = merged.minimum_rotated_rectangle
    return min_rot_rect

def polygon_to_coords(polygon):
    """
    将Shapely多边形转换为坐标列表。

    参数:
        polygon (Polygon): Shapely多边形。

    返回:
        list of float: 坐标列表,格式为[x1, y1, x2, y2, x3, y3, x4, y4]。
    """
    coords = list(polygon.exterior.coords)
    if len(coords) < 4:
        # 无效的多边形
        return []
    # 最小旋转矩形应该有4个点
    coords = coords[:4]
    # 确保是4个点
    if len(coords) != 4:
        # 可能是线性或点状
        return []
    # 展开为平面坐标
    flattened = [coord for point in coords for coord in point]
    return flattened

def analyze_and_merge_file(file_path, output_folder, by_class=False):
    """
    分析并合并单个标注文件中的重叠旋转框。

    参数:
        file_path (str): 输入标注文件的路径。
        output_folder (str): 输出标注文件夹的路径。
        by_class (bool): 是否按类别进行统计和合并。

    返回:
        None
    """
    filename = os.path.basename(file_path)
    merged_objects = []

    if by_class:
        class_dict = parse_annotation_file(file_path, by_class=True)
        for obj_class, polygons in class_dict.items():
            if len(polygons) < 2:
                # 只有一个或零个对象,不需要合并
                for poly in polygons:
                    coords = polygon_to_coords(poly)
                    if coords:
                        merged_objects.append((obj_class, coords))
                continue

            # 找到连通组件
            components = find_connected_components(polygons)
            for component in components:
                merged_polygon = merge_polygons(component)
                coords = polygon_to_coords(merged_polygon)
                if coords:
                    merged_objects.append((obj_class, coords))
    else:
        polygons = parse_annotation_file(file_path, by_class=False)
        if len(polygons) < 2:
            # 只有一个或零个对象,不需要合并
            for poly in polygons:
                coords = polygon_to_coords(poly)
                if coords:
                    merged_objects.append(("merged", coords))
        else:
            # 找到连通组件
            components = find_connected_components(polygons)
            for component in components:
                merged_polygon = merge_polygons(component)
                coords = polygon_to_coords(merged_polygon)
                if coords:
                    merged_objects.append(("merged", coords))

    # 写入新的标注文件
    output_file_path = os.path.join(output_folder, filename)
    try:
        with open(output_file_path, 'w') as f:
            for obj_class, coords in merged_objects:
                coords_str = ' '.join(map(str, coords))
                f.write(f"{coords_str} {obj_class} 0\n")
        logging.info(f"已生成合并后的文件: {output_file_path}")
    except Exception as e:
        logging.error(f"写入文件 {output_file_path} 时发生错误: {e}")

def main():
    args = parse_args()
    input_folder = args.input_folder
    output_folder = args.output_folder
    by_class = args.by_class
    num_workers = args.num_workers

    # 创建输出文件夹,如果不存在
    os.makedirs(output_folder, exist_ok=True)

    # 获取所有标注文件
    all_files = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if os.path.isfile(os.path.join(input_folder, f))]
    logging.info(f"找到 {len(all_files)} 个标注文件。")

    # 准备并行处理
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        future_to_file = {executor.submit(analyze_and_merge_file, file_path, output_folder, by_class): file_path for file_path in all_files}
        for future in as_completed(future_to_file):
            file_path = future_to_file[future]
            try:
                future.result()
            except Exception as exc:
                logging.error(f"处理文件 {os.path.basename(file_path)} 时发生异常: {exc}")

    logging.info("所有文件处理完毕。")

if __name__ == '__main__':
    main()

代码详解

1. 配置日志和命令行参数解析

首先,我们配置了日志系统,用于记录脚本运行时的状态。日志系统可以帮助我们监控处理过程中的进展以及调试潜在问题。

import logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[logging.StreamHandler()]
)


def parse_args():
    parser = argparse.ArgumentParser(description='对标注文件进行旋转框合并处理,检测重叠并生成新的标注文件。')
    parser.add_argument('--input_folder', type=str, required=True, help='输入标注文件夹路径')
    parser.add_argument('--output_folder', type=str, required=True, help='输出标注文件夹路径')
    parser.add_argument('--by_class', action='store_true', default=True, help='是否按类别统计和合并重叠框')
    parser.add_argument('--num_workers', type=int, default=None, help='并行处理的进程数,默认为CPU核心数')
    return parser.parse_args()

为了使脚本更加通用和易于使用,我们定义了 parse_args() 函数,支持用户通过命令行指定输入文件夹、输出文件夹、是否按类别进行处理等参数。

2. 解析DOTA标注文件

DOTA格式的标注文件包含多个目标对象的类别和坐标。我们通过 parse_annotation_file() 函数读取这些信息,并将每个对象的类别与其对应的多边形坐标提取出来。

def parse_annotation_file(file_path, by_class=False):
    objects = []
    try:
        with open(file_path, 'r') as file:
            for line_num, line in enumerate(file, 1):
                parts = line.strip().split()
                if len(parts) < 9:
                    logging.warning(f"{file_path} 第{line_num}行格式不正确,跳过。")
                    continue
                coords = list(map(float, parts[:8]))  # 前8个为坐标
                obj_class = parts[8]  # 类别
                polygon = Polygon(np.array(coords).reshape(-1, 2))  # 转换为Shapely多边形
                if polygon.is_valid:
                    objects.append((obj_class, polygon))
    except Exception as e:
        logging.error(f"读取文件 {file_path} 时发生错误: {e}")
    if by_class:
        class_dict = {}
        for obj_class, polygon in objects:
            class_dict.setdefault(obj_class, []).append(polygon)
        return class_dict
    else:
        return [polygon for _, polygon in objects]

通过Shapely库,我们将坐标转换为支持几何计算的多边形对象,为后续的重叠检测和合并操作打下基础。

3. 计算IOU与查找连通组件

为了检测旋转框之间的重叠,我们需要计算两个多边形的 IOU(Intersection Over Union)。这是通过 compute_iou() 函数完成的,使用Shapely提供的 intersection()union() 方法来计算交集和并集面积。

def compute_iou(poly1, poly2):
    intersection = poly1.intersection(poly2).area
    union = poly1.union(poly2).area
    if union == 0:
        return 0
    return intersection / union

接下来,我们使用并查集(Union-Find)算法来查找连通组件,也就是说,找到那些相互重叠的多边形,并将它们聚集在一起,作为一个组件处理。

def find_connected_components(polygons):
    parent = {i: i for i in range(len(polygons))}

    def find(i):
        while parent[i] != i:
            parent[i] = parent[parent[i]]
            i = parent[i]
        return i

    def union(i, j):
        pi, pj = find(i), find(j)
        if pi != pj:
            parent[pi] = pj

    for i, j in combinations(range(len(polygons)), 2):
        if polygons[i].intersects(polygons[j]):
            union(i, j)

    components = {}
    for i in range(len(polygons)):
        root = find(i)
        components.setdefault(root, []).append(polygons[i])

    return list(components.values())

通过检查多边形是否相交,我们将所有连通的旋转框归为一组。

4. 合并多边形为最小旋转矩形

在找到重叠的连通组件后,我们需要将这些重叠的多边形合并为一个紧凑的最小旋转矩形。这是通过 merge_polygons() 函数实现的:

def merge_polygons(polygons):
    merged = polygons[0]
    for poly in polygons[1:]:
        merged = merged.union(poly)
    min_rot_rect = merged.minimum_rotated_rectangle
    return min_rot_rect

minimum_rotated_rectangle 方法会生成包含所有输入多边形的最小旋转矩形。

5. 将合并后的结果写入新文件

对于每个处理后的文件,我们将合并后的多边形转换为坐标列表,并将其写入到新的标注文件中。这里就不再进行多余的代码展示了。

6. 并行化处理与主函数

为了加快处理速度,特别是在有大量标注文件时,我们使用 ProcessPoolExecutor 来并行化处理。用户可以通过命令行参数控制并行进程数。

推荐工具

在本文代码中,我们使用了以下Python库,它们在处理几何计算、多进程处理、文件解析等方面发挥了重要作用。如果你对这些库不太熟悉,可以通过以下链接获取更多信息和文档。

  1. Shapely - 进行几何对象的构造和操作,比如多边形的交集、并集等计算。

  2. NumPy - 科学计算库,用于处理数值数组。在这里,我们用它来将多边形的坐标转换为二维数组。

  3. itertools - Python标准库中的组合工具,用于生成多边形配对,计算它们之间的IOU。

  4. concurrent.futures - Python标准库中的并发工具,用于多进程并行处理标注文件。

结论

在本文中,我们介绍了如何使用Python脚本对DOTA格式的标注文件进行旋转框的重叠检测和合并。通过并行化处理,我们能够高效地处理大量标注文件,并将重叠的对象合并为紧凑的最小旋转矩形。

这种方法可以应用于需要处理大量标注数据的任务中,有助于提高数据标注的准确性和质量。希望这篇博客能为你在处理标注文件和重叠问题时提供有效的参考。

---

希望这篇博客对您有所帮助,如果您喜欢这篇文章,请点赞或关注,我会持续分享更多实用的 Python 技术内容!

---

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值