voc2yolo.py
import os
import xml.etree.ElementTree as ET
from typing import Dict, Tuple, List
def convert_voc_to_yolo(xml_file: str, class_mapping: Dict[str, int]) -> List[str]:
"""
将单个VOC格式的XML文件转换为YOLO格式的标注
参数:
xml_file (str): VOC格式的XML文件路径
class_mapping (Dict[str, int]): 类别名称到ID的映射
返回:
List[str]: YOLO格式的标注行列表
"""
tree = ET.parse(xml_file)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
yolo_lines = []
for obj in root.findall('object'):
class_name = obj.find('name').text
if class_name not in class_mapping:
print(f"警告: 未定义的类别 '{class_name}' 在文件 {xml_file} 中,已跳过")
continue
class_id = class_mapping[class_name]
bbox = obj.find('bndbox')
xmin = int(bbox.find('xmin').text)
ymin = int(bbox.find('ymin').text)
xmax = int(bbox.find('xmax').text)
ymax = int(bbox.find('ymax').text)
x_center = (xmin + xmax) / 2 / width
y_center = (ymin + ymax) / 2 / height
w = (xmax - xmin) / width
h = (ymax - ymin) / height
yolo_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
yolo_lines.append(yolo_line)
return yolo_lines
def create_classes_file(output_dir: str, class_mapping: Dict[str, int]) -> None:
"""
创建YOLO所需的classes.txt文件
参数:
output_dir (str): 输出目录路径
class_mapping (Dict[str, int]): 类别名称到ID的映射
"""
classes_path = os.path.join(output_dir, "classes.txt")
with open(classes_path, 'w') as f:
sorted_classes = sorted(class_mapping.items(), key=lambda x: x[1])
for class_name, _ in sorted_classes:
f.write(f"{class_name}\n")
print(f"已创建类别文件: {classes_path}")
def process_folder(input_folder: str, output_folder: str, class_mapping: Dict[str, int]) -> None:
"""
处理整个文件夹的VOC格式XML文件,转换为YOLO格式
参数:
input_folder (str): 输入文件夹路径(VOC XML文件所在目录)
output_folder (str): 输出文件夹路径(YOLO TXT文件保存目录)
class_mapping (Dict[str, int]): 类别名称到ID的映射
"""
os.makedirs(output_folder, exist_ok=True)
processed_count = 0
for filename in os.listdir(input_folder):
if filename.endswith('.xml'):
xml_path = os.path.join(input_folder, filename)
txt_filename = os.path.splitext(filename)[0] + '.txt'
txt_path = os.path.join(output_folder, txt_filename)
try:
yolo_lines = convert_voc_to_yolo(xml_path, class_mapping)
if yolo_lines:
with open(txt_path, 'w') as f:
f.write('\n'.join(yolo_lines) + '\n')
processed_count += 1
print(f"已转换: {xml_path} -> {txt_path}")
else:
print(f"警告: 文件 {xml_path} 不包含有效标注,已跳过")
except Exception as e:
print(f"错误: 处理文件 {xml_path} 时出错: {e}")
create_classes_file(output_folder, class_mapping)
print(f"转换完成。共处理 {processed_count} 个XML文件。")
if __name__ == "__main__":
input_folder = r"C:\Users\Virgil\Desktop\DetFly_train\lables_010"
output_folder = r"C:\Users\Virgil\Desktop\DetFly_train\lables_yolo_010"
class_mapping = {
"UAV": 0,
}
process_folder(input_folder, output_folder, class_mapping)