数据增强
import os
import xml.etree.ElementTree as ET
import random
# 定义原始文件夹和目标文件夹路径
source_xml_folder = r"C:\Users\915324472\Desktop\2\xml"
target_xml_folder = r"C:\Users\915324472\Desktop\2\aug_xml"
# 检查目标文件夹是否存在,如果不存在则创建
if not os.path.exists(target_xml_folder):
os.makedirs(target_xml_folder)
# 定义增强操作的次数
num_augmentations_per_image = 5 # 每张图片生成5个增强版本
# 遍历原始文件夹中的所有XML文件
for filename in os.listdir(source_xml_folder):
if filename.endswith(".xml"):
# 获取文件名(不带扩展名)
file_name_without_extension = os.path.splitext(filename)[0]
# 构建完整的文件路径
source_xml_path = os.path.join(source_xml_folder, filename)
# 加载XML文件
tree = ET.parse(source_xml_path)
root = tree.getroot()
# 获取原始图像的宽度和高度
width = int(root.find("size/width").text)
height = int(root.find("size/height").text)
# 为每个原始XML文件生成多个增强版本
for aug_idx in range(num_augmentations_per_image):
# 深拷贝原始XML树,以便对每个增强版本进行独立修改
augmented_tree = ET.ElementTree(ET.fromstring(ET.tostring(root)))
augmented_root = augmented_tree.getroot()
# 遍历所有<object>标签,更新边界框
for obj in augmented_root.findall("object"):
bndbox = obj.find("bndbox")
xmin = int(bndbox.find("xmin").text)
ymin = int(bndbox.find("ymin").text)
xmax = int(bndbox.find("xmax").text)
ymax = int(bndbox.find("ymax").text)
# 水平翻转
if random.random() > 0.5:
xmin, xmax = width - xmax, width - xmin
# 垂直翻转
if random.random() > 0.5:
ymin, ymax = height - ymax, height - ymin
# 随机裁剪
crop_x = random.randint(0, width // 4)
crop_y = random.randint(0, height // 4)
crop_width = width - crop_x - random.randint(0, width // 4)
crop_height = height - crop_y - random.randint(0, height // 4)
xmin = max(xmin - crop_x, 0)
ymin = max(ymin - crop_y, 0)
xmax = min(xmax - crop_x, crop_width)
ymax = min(ymax - crop_y, crop_height)
# 更新图像尺寸
augmented_root.find("size/width").text = str(crop_width)
augmented_root.find("size/height").text = str(crop_height)
# 随机平移
translate_x = random.uniform(-0.1, 0.1) * crop_width
translate_y = random.uniform(-0.1, 0.1) * crop_height
xmin = max(int(xmin + translate_x), 0)
ymin = max(int(ymin + translate_y), 0)
xmax = min(int(xmax + translate_x), crop_width)
ymax = min(int(ymax + translate_y), crop_height)
# 随机擦除
erase_x = random.randint(0, crop_width)
erase_y = random.randint(0, crop_height)
erase_width = random.randint(0, crop_width // 4)
erase_height = random.randint(0, crop_height // 4)
# 检查边界框是否与擦除区域重叠
if (xmin < erase_x + erase_width and xmax > erase_x and
ymin < erase_y + erase_height and ymax > erase_y):
# 如果重叠,调整边界框
if xmin < erase_x:
xmin = erase_x
if xmax > erase_x + erase_width:
xmax = erase_x + erase_width
if ymin < erase_y:
ymin = erase_y
if ymax > erase_y + erase_height:
ymax = erase_y + erase_height
# 更新边界框
bndbox.find("xmin").text = str(xmin)
bndbox.find("ymin").text = str(ymin)
bndbox.find("xmax").text = str(xmax)
bndbox.find("ymax").text = str(ymax)
# 更新文件名和路径
augmented_filename = f"{file_name_without_extension}_aug_{aug_idx}.xml"
augmented_file_path = os.path.join(target_xml_folder, augmented_filename)
augmented_root.find("filename").text = f"{file_name_without_extension}_aug_{aug_idx}.jpg"
augmented_root.find("path").text = os.path.join(target_xml_folder, f"{file_name_without_extension}_aug_{aug_idx}.jpg")
# 保存增强后的XML文件到目标文件夹
augmented_tree.write(augmented_file_path, encoding="utf-8", xml_declaration=True)
print("数据增强完成。")