背景
在目标检测领域,处理标注文件中的旋转框(rotated bounding boxes)是一个常见的任务,尤其是当对象之间存在重叠时,如何将重叠的旋转框合并为一个更紧凑的表示非常重要。本文将介绍如何编写Python脚本,通过并行化方式处理DOTA标注文件中的重叠旋转框,并将其合并为最小的旋转矩形。
我们将逐步讲解代码中的每个功能模块,包括如何解析DOTA格式标注文件、计算旋转框的重叠、并将它们合并为新的标注文件。
在前一篇博客中介绍了如何计算dota数据的iou,感兴趣的朋友请移步至
完整代码
import os
import logging
from shapely.geometry import Polygon
import numpy as np
from itertools import combinations
from concurrent.futures import ProcessPoolExecutor, as_completed
import argparse
import pandas as pd
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.StreamHandler()
]
)
def parse_args():
"""
解析命令行参数。
"""
parser = argparse.ArgumentParser(description='对标注文件进行旋转框合并处理,检测重叠并生成新的标注文件。')
parser.add_argument('--input_folder', type=str, required=True, help='输入标注文件夹路径')
parser.add_argument('--output_folder', type=str, required=True, help='输出标注文件夹路径')
parser.add_argument('--by_class', action='store_true', default=True, help='是否按类别统计和合并重叠框')
parser.add_argument('--num_workers', type=int, default=None, help='并行处理的进程数,默认为CPU核心数')
return parser.parse_args()
def parse_annotation_file(file_path, by_class=False):
"""
解析标注文件,提取对象的类别和多边形坐标。
参数:
file_path (str): 标注文件的路径。
by_class (bool): 是否按类别解析。
返回:
dict or list: 按类别分组的字典或所有对象的列表。
"""
objects = []
try:
with open(file_path, 'r') as file:
for line_num, line in enumerate(file, 1):
parts = line.strip().split()
if len(parts) < 9:
logging.warning(f"{file_path} 第{line_num}行格式不正确,跳过。")
continue
try:
# 假设坐标为前8个元素,类别为第9个元素
coords = list(map(float, parts[:8]))
obj_class = parts[8]
# 将坐标转换为Shapely多边形
polygon = Polygon(np.array(coords).reshape(-1, 2))
if not polygon.is_valid:
logging.warning(f"{file_path} 第{line_num}行的多边形无效,跳过。")
continue
objects.append((obj_class, polygon))
except ValueError as ve:
logging.error(f"{file_path} 第{line_num}行坐标转换错误: {ve}")
except Exception as e:
logging.error(f"读取文件 {file_path} 时发生错误: {e}")
return {} if by_class else []
if by_class:
class_dict = {}
for obj_class, polygon in objects:
class_dict.setdefault(obj_class, []).append(polygon)
return class_dict
else:
return [polygon for _, polygon in objects]
def compute_iou(poly1, poly2):
"""
计算两个多边形的IOU。
参数:
poly1 (Polygon): 第一个多边形。
poly2 (Polygon): 第二个多边形。
返回:
float: 两个多边形的IOU值。
"""
intersection = poly1.intersection(poly2).area
union = poly1.union(poly2).area
if union == 0:
return 0
return intersection / union
def find_connected_components(polygons):
"""
使用并查集算法找到多边形的连通组件。
参数:
polygons (list of Polygon): 多边形列表。
返回:
list of lists: 每个子列表是一个连通组件中的多边形。
"""
parent = {i: i for i in range(len(polygons))}
def find(i):
while parent[i] != i:
parent[i] = parent[parent[i]]
i = parent[i]
return i
def union(i, j):
pi, pj = find(i), find(j)
if pi != pj:
parent[pi] = pj
# 遍历所有组合,检查是否有重叠
for i, j in combinations(range(len(polygons)), 2):
if polygons[i].intersects(polygons[j]):
union(i, j)
# 聚集连通组件
components = {}
for i in range(len(polygons)):
root = find(i)
components.setdefault(root, []).append(polygons[i])
return list(components.values())
def merge_polygons(polygons):
"""
合并一组多边形为一个最小旋转矩形。
参数:
polygons (list of Polygon): 需要合并的多边形列表。
返回:
Polygon: 合并后的多边形。
"""
# 计算所有多边形的联合
merged = polygons[0]
for poly in polygons[1:]:
merged = merged.union(poly)
# 获取最小旋转矩形
min_rot_rect = merged.minimum_rotated_rectangle
return min_rot_rect
def polygon_to_coords(polygon):
"""
将Shapely多边形转换为坐标列表。
参数:
polygon (Polygon): Shapely多边形。
返回:
list of float: 坐标列表,格式为[x1, y1, x2, y2, x3, y3, x4, y4]。
"""
coords = list(polygon.exterior.coords)
if len(coords) < 4:
# 无效的多边形
return []
# 最小旋转矩形应该有4个点
coords = coords[:4]
# 确保是4个点
if len(coords) != 4:
# 可能是线性或点状
return []
# 展开为平面坐标
flattened = [coord for point in coords for coord in point]
return flattened
def analyze_and_merge_file(file_path, output_folder, by_class=False):
"""
分析并合并单个标注文件中的重叠旋转框。
参数:
file_path (str): 输入标注文件的路径。
output_folder (str): 输出标注文件夹的路径。
by_class (bool): 是否按类别进行统计和合并。
返回:
None
"""
filename = os.path.basename(file_path)
merged_objects = []
if by_class:
class_dict = parse_annotation_file(file_path, by_class=True)
for obj_class, polygons in class_dict.items():
if len(polygons) < 2:
# 只有一个或零个对象,不需要合并
for poly in polygons:
coords = polygon_to_coords(poly)
if coords:
merged_objects.append((obj_class, coords))
continue
# 找到连通组件
components = find_connected_components(polygons)
for component in components:
merged_polygon = merge_polygons(component)
coords = polygon_to_coords(merged_polygon)
if coords:
merged_objects.append((obj_class, coords))
else:
polygons = parse_annotation_file(file_path, by_class=False)
if len(polygons) < 2:
# 只有一个或零个对象,不需要合并
for poly in polygons:
coords = polygon_to_coords(poly)
if coords:
merged_objects.append(("merged", coords))
else:
# 找到连通组件
components = find_connected_components(polygons)
for component in components:
merged_polygon = merge_polygons(component)
coords = polygon_to_coords(merged_polygon)
if coords:
merged_objects.append(("merged", coords))
# 写入新的标注文件
output_file_path = os.path.join(output_folder, filename)
try:
with open(output_file_path, 'w') as f:
for obj_class, coords in merged_objects:
coords_str = ' '.join(map(str, coords))
f.write(f"{coords_str} {obj_class} 0\n")
logging.info(f"已生成合并后的文件: {output_file_path}")
except Exception as e:
logging.error(f"写入文件 {output_file_path} 时发生错误: {e}")
def main():
args = parse_args()
input_folder = args.input_folder
output_folder = args.output_folder
by_class = args.by_class
num_workers = args.num_workers
# 创建输出文件夹,如果不存在
os.makedirs(output_folder, exist_ok=True)
# 获取所有标注文件
all_files = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if os.path.isfile(os.path.join(input_folder, f))]
logging.info(f"找到 {len(all_files)} 个标注文件。")
# 准备并行处理
with ProcessPoolExecutor(max_workers=num_workers) as executor:
future_to_file = {executor.submit(analyze_and_merge_file, file_path, output_folder, by_class): file_path for file_path in all_files}
for future in as_completed(future_to_file):
file_path = future_to_file[future]
try:
future.result()
except Exception as exc:
logging.error(f"处理文件 {os.path.basename(file_path)} 时发生异常: {exc}")
logging.info("所有文件处理完毕。")
if __name__ == '__main__':
main()
代码详解
1. 配置日志和命令行参数解析
首先,我们配置了日志系统,用于记录脚本运行时的状态。日志系统可以帮助我们监控处理过程中的进展以及调试潜在问题。
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[logging.StreamHandler()]
)
def parse_args():
parser = argparse.ArgumentParser(description='对标注文件进行旋转框合并处理,检测重叠并生成新的标注文件。')
parser.add_argument('--input_folder', type=str, required=True, help='输入标注文件夹路径')
parser.add_argument('--output_folder', type=str, required=True, help='输出标注文件夹路径')
parser.add_argument('--by_class', action='store_true', default=True, help='是否按类别统计和合并重叠框')
parser.add_argument('--num_workers', type=int, default=None, help='并行处理的进程数,默认为CPU核心数')
return parser.parse_args()
为了使脚本更加通用和易于使用,我们定义了 parse_args()
函数,支持用户通过命令行指定输入文件夹、输出文件夹、是否按类别进行处理等参数。
2. 解析DOTA标注文件
DOTA格式的标注文件包含多个目标对象的类别和坐标。我们通过 parse_annotation_file()
函数读取这些信息,并将每个对象的类别与其对应的多边形坐标提取出来。
def parse_annotation_file(file_path, by_class=False):
objects = []
try:
with open(file_path, 'r') as file:
for line_num, line in enumerate(file, 1):
parts = line.strip().split()
if len(parts) < 9:
logging.warning(f"{file_path} 第{line_num}行格式不正确,跳过。")
continue
coords = list(map(float, parts[:8])) # 前8个为坐标
obj_class = parts[8] # 类别
polygon = Polygon(np.array(coords).reshape(-1, 2)) # 转换为Shapely多边形
if polygon.is_valid:
objects.append((obj_class, polygon))
except Exception as e:
logging.error(f"读取文件 {file_path} 时发生错误: {e}")
if by_class:
class_dict = {}
for obj_class, polygon in objects:
class_dict.setdefault(obj_class, []).append(polygon)
return class_dict
else:
return [polygon for _, polygon in objects]
通过Shapely库,我们将坐标转换为支持几何计算的多边形对象,为后续的重叠检测和合并操作打下基础。
3. 计算IOU与查找连通组件
为了检测旋转框之间的重叠,我们需要计算两个多边形的 IOU(Intersection Over Union)。这是通过 compute_iou()
函数完成的,使用Shapely提供的 intersection()
和 union()
方法来计算交集和并集面积。
def compute_iou(poly1, poly2):
intersection = poly1.intersection(poly2).area
union = poly1.union(poly2).area
if union == 0:
return 0
return intersection / union
接下来,我们使用并查集(Union-Find)算法来查找连通组件,也就是说,找到那些相互重叠的多边形,并将它们聚集在一起,作为一个组件处理。
def find_connected_components(polygons):
parent = {i: i for i in range(len(polygons))}
def find(i):
while parent[i] != i:
parent[i] = parent[parent[i]]
i = parent[i]
return i
def union(i, j):
pi, pj = find(i), find(j)
if pi != pj:
parent[pi] = pj
for i, j in combinations(range(len(polygons)), 2):
if polygons[i].intersects(polygons[j]):
union(i, j)
components = {}
for i in range(len(polygons)):
root = find(i)
components.setdefault(root, []).append(polygons[i])
return list(components.values())
通过检查多边形是否相交,我们将所有连通的旋转框归为一组。
4. 合并多边形为最小旋转矩形
在找到重叠的连通组件后,我们需要将这些重叠的多边形合并为一个紧凑的最小旋转矩形。这是通过 merge_polygons()
函数实现的:
def merge_polygons(polygons):
merged = polygons[0]
for poly in polygons[1:]:
merged = merged.union(poly)
min_rot_rect = merged.minimum_rotated_rectangle
return min_rot_rect
minimum_rotated_rectangle
方法会生成包含所有输入多边形的最小旋转矩形。
5. 将合并后的结果写入新文件
对于每个处理后的文件,我们将合并后的多边形转换为坐标列表,并将其写入到新的标注文件中。这里就不再进行多余的代码展示了。
6. 并行化处理与主函数
为了加快处理速度,特别是在有大量标注文件时,我们使用 ProcessPoolExecutor
来并行化处理。用户可以通过命令行参数控制并行进程数。
推荐工具
在本文代码中,我们使用了以下Python库,它们在处理几何计算、多进程处理、文件解析等方面发挥了重要作用。如果你对这些库不太熟悉,可以通过以下链接获取更多信息和文档。
Shapely - 进行几何对象的构造和操作,比如多边形的交集、并集等计算。
- 官方文档:Shapely Documentation
- 安装方法:
pip install shapely
NumPy - 科学计算库,用于处理数值数组。在这里,我们用它来将多边形的坐标转换为二维数组。
- 官方文档:NumPy Documentation
- 安装方法:
pip install numpy
itertools - Python标准库中的组合工具,用于生成多边形配对,计算它们之间的IOU。
concurrent.futures - Python标准库中的并发工具,用于多进程并行处理标注文件。
结论
在本文中,我们介绍了如何使用Python脚本对DOTA格式的标注文件进行旋转框的重叠检测和合并。通过并行化处理,我们能够高效地处理大量标注文件,并将重叠的对象合并为紧凑的最小旋转矩形。
这种方法可以应用于需要处理大量标注数据的任务中,有助于提高数据标注的准确性和质量。希望这篇博客能为你在处理标注文件和重叠问题时提供有效的参考。
---
希望这篇博客对您有所帮助,如果您喜欢这篇文章,请点赞或关注,我会持续分享更多实用的 Python 技术内容!
---