下面是模型训练时使用到的数据增强策略,包括几何变换,可以模拟低光照、恶劣天气场景,可以将每个数据增强方法的p都设置为1,逐步观察该方法的增强效果。照一张想要数据增强的图片,放到dataaug文件夹下即可。
import warnings warnings.filterwarnings('ignore') import os, shutil, cv2, tqdm import numpy as np np.random.seed(0) import albumentations as A from PIL import Image IMAGE_PATH = 'dataaug' AUG_IMAGE_PATH = 'results' SHOW_SAVE_PATH = 'results' ENHANCEMENT_LOOP = 1 #数据增强策略 ENHANCEMENT_STRATEGY = A.Compose([ #几何翻转 A.Compose([ #仿射变换包括缩放、平移、旋转、剪切等,同时可以保持图像中的平行线和平行关系 A.Affine(scale=[0.5, 1.5], translate_percent=[0.0, 0.3], rotate=[-360, 360], shear=[-45, 45], keep_ratio=True, cval_mask=0, p=1), #-scale在x,y轴上缩放 #-translate_percent平移的百分比(相对于图像的宽度和高度) #-rotate旋转角度范围,单位是度 #-shear剪切变换会使图像在水平或垂直方向上发生倾斜(类似斜切效果) #keep_ratio如果设置为`True`,则在缩放时保持图像的宽高比不变(即x和y轴使用相同的缩放因子) #cval_mask=0当进行变换时,图像边界外区域的填充值(对于掩模mask A.BBoxSafeRandomCrop(erosion_rate=0.2, p=0.1), # 随机裁剪图像的同时确保边界框(Bounding Boxes)不会被过度裁剪或丢失 A.D4(p=0.1), #在数学中,**D4群**(二面体群)是指正方形对称的变换群,包含以下8种操作: # 1. 恒等变换(0°旋转)2. 90°旋转3. 180°旋转4. 270°旋转5. 水平翻转6. 垂直翻7. 主对角线翻转8. 反对角线翻转 #`A.D4`操作会从这8种变换中**随机选择一种**应用于图像。 A.ElasticTransform(p=1), #随机地扭曲图像像素,使其在图像上产生弹性变形的效果 A.HorizontalFlip(p=0.05), # 水平翻转 A.VerticalFlip(p=0.05), # 垂直翻转 A.GridDistortion(p=1),#以网格状的方式扭曲图像 A.Perspective(p=1),#该变换会使图像产生透视失真,就像从不同的角度观察图像一样,直线可能会弯曲 ], p=1.0), A.Compose([ A.GaussNoise(p=1), # 添加高斯噪声后,图像会出现颗粒状的随机点,类似于老式相机在暗光条件下拍摄的照片或电视静态噪声。 # 高斯噪声应用场景: # - 低光照条件下的图像处理(如监控、自动驾驶夜视系统)。 # - 医学影像(如X光、MRI)中模拟噪声以提高模型鲁棒性。 # - 任何需要模型对噪声不敏感的任务。 A.ISONoise(p=1), # 模拟相机高感光度(ISO)设置产生的图像噪声的数据增强操作 #-ISO100-400光线充足,800-1600轻度 彩色噪点室内环境,3200-6400黄昏弱光,明显颗粒噪点;12800+极暗环境,严重噪点 +色偏 A.ImageCompression(quality_lower=0, quality_upper=50, p=1),#用于模拟JPEG图像压缩效果,降低图片质量 A.RandomBrightnessContrast(p=1), #用来随机调整图像的亮度和对比度的 A.RandomFog(p=1), # 专门用于在图像上模拟雾天效果 A.RandomRain(p=1), # 模拟雨 A.RandomSnow(p=1), # 模拟雪 A.RandomShadow(p=1), # 模拟自然光线产生的阴影 A.RandomSunFlare(p=1), # 图像中模拟太阳光晕(镜头光晕)效果,通常出现在逆光拍摄时太阳直射镜头的情况。 A.ToGray(p=1), # 使用加权平均法将RGB图像转换为灰度图像 gray = 0.299 * R + 0.587 * G + 0.114 * B, 这个权重基于人眼对不同颜色的敏感度(绿色最高,红色次之,蓝色最低) ], p=1.0) ], is_check_shapes=False) def data_aug_single(images_name): file_heads, postfix = os.path.splitext(images_name) images_path = os.path.join(IMAGE_PATH, images_name) if os.path.exists(images_path): # 读取原始图片并转换为OpenCV格式 original_pil = Image.open(images_path) original_cv = cv2.cvtColor(np.array(original_pil), cv2.COLOR_RGB2BGR) for i in range(ENHANCEMENT_LOOP): # 增强后图片保存路径 aug_image_path = os.path.join(AUG_IMAGE_PATH, f'{file_heads}_{i + 1:03d}{postfix}') # 拼接后图片保存路径 combined_path = os.path.join(AUG_IMAGE_PATH, f'{file_heads}_combined_{i + 1:03d}{postfix}') try: # 应用数据增强 transformed = ENHANCEMENT_STRATEGY(image=np.array(original_pil)) transformed_image = transformed['image'] # 将增强图片转换为OpenCV格式 aug_cv = cv2.cvtColor(transformed_image, cv2.COLOR_RGB2BGR) # 保存增强图片 cv2.imwrite(aug_image_path, aug_cv) # 创建拼接图片(水平拼接) combined_image = np.hstack((original_cv, aug_cv)) # 添加分隔线 separator = np.zeros((original_cv.shape[0], 5, 3), dtype=np.uint8) separator.fill(255) # 白色分隔线 combined_image = np.hstack((original_cv, separator, aug_cv)) # 添加文字标签 font = cv2.FONT_HERSHEY_SIMPLEX cv2.putText(combined_image, "Original", (10, 30), font, 1, (0, 255, 0), 2) cv2.putText(combined_image, "Augmented", (original_cv.shape[1] + 15, 30), font, 1, (0, 255, 0), 2) # 保存拼接图片 cv2.imwrite(combined_path, combined_image) print(f'拼接图片保存成功: {combined_path}') except Exception as e: print(f"处理图片 {images_name} 时出错: {str(e)}") continue def data_aug(): if os.path.exists(AUG_IMAGE_PATH): shutil.rmtree(AUG_IMAGE_PATH) os.makedirs(AUG_IMAGE_PATH, exist_ok=True) for images_name in tqdm.tqdm(os.listdir(IMAGE_PATH)): data_aug_single(images_name) if __name__ == '__main__': #show_labels(IMAGE_PATH) #show_labels(AUG_IMAGE_PATH, AUG_LABEL_PATH) data_aug()