1 前言
在进行语义分割的时候,我们的数据集有时候不够用,常常需要进行数据增广。
比较常用的数据增广方法(包括旋转,上下翻转,左右翻转,裁剪,调整对比度,调整饱和度,调整亮度,中心裁剪等)
2 代码实现
import random
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as tf
class Augmentation:
def __init__(self):
pass
def rotate(self, image, mask, angle=None):
if angle is None:
angle = transforms.RandomRotation.get_params([-180, 180])
image = tf.rotate(image, angle)
mask = tf.rotate(mask, angle)
return image, mask
def flip(self, image, mask):
if random.random() > 0.5:
image = tf.hflip(image)
mask = tf.hflip(mask)
if random.random() > 0.5:
image = tf.vflip(image)
mask = tf.vflip(mask)
return image, mask
def randomResizeCrop(self, image, mask, scale=(0.3, 1.0), ratio=(1, 1)):
img = np.array(image)
h_image, w_image = img.shape[:2]
resize_size = h_image
i, j, h, w = transforms.RandomResizedCrop.get_params(image, scale=scale, ratio=ratio)
image = tf.resized_crop(image, i, j, h, w, resize_size)
mask = tf.resized_crop(mask, i, j, h, w, resize_size)
return image, mask
def adjustContrast(self, image, mask):
factor = random.uniform(0.5, 1.5)
image = tf.adjust_contrast(image, factor)
return image, mask
def adjustBrightness(self, image, mask):
factor = random.uniform(0.5, 1.5)
image = tf.adjust_brightness(image, factor)
return image, mask
def centerCrop(self, image, mask, size=None):
if size is None:
size = image.size
image = tf.center_crop(image, size)
mask = tf.center_crop(mask, size)
return image, mask
def adjustSaturation(self, image, mask):
factor = random.uniform(0.5, 1.5)
image = tf.adjust_saturation(image, factor)
return image, mask
def augmentationData(image_path, mask_path, option=[1, 2, 3, 4, 5, 6, 7], save_dir=None):
aug_image_savedDir = os.path.join(save_dir, 'img')
aug_mask_savedDir = os.path.join(save_dir, 'mask')
if not os.path.exists(aug_image_savedDir):
os.makedirs(aug_image_savedDir)
print('create aug image dir.....')
if not os.path.exists(aug_mask_savedDir):
os.makedirs(aug_mask_savedDir)
print('create aug mask dir.....')
aug = Augmentation()
images = [os.path.join(image_path, f) for f in os.listdir(image_path)]
masks = [os.path.join(mask_path, f) for f in os.listdir(mask_path)]
datas = list(zip(images, masks))
num = len(datas)
for (image_path, mask_path) in datas:
image = Image.open(image_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
for opt in option:
num += 1
if opt == 1:
image_tensor, mask_tensor = aug.rotate(image, mask)
aug_type = 'rotate'
elif opt == 2:
image_tensor, mask_tensor = aug.flip(image, mask)
aug_type = 'flip'
elif opt == 3:
image_tensor, mask_tensor = aug.randomResizeCrop(image, mask)
aug_type = 'ResizeCrop'
elif opt == 4:
image_tensor, mask_tensor = aug.adjustContrast(image, mask)
aug_type = 'Contrast'
elif opt == 5:
image_tensor, mask_tensor = aug.centerCrop(image, mask)
aug_type = 'centerCrop'
elif opt == 6:
image_tensor, mask_tensor = aug.adjustBrightness(image, mask)
aug_type = 'Brightness'
elif opt == 7:
image_tensor, mask_tensor = aug.adjustSaturation(image, mask)
aug_type = 'Saturation'
else:
continue
image_tensor = tf.to_tensor(image_tensor)
mask_tensor = tf.to_tensor(mask_tensor)
transforms.ToPILImage()(image_tensor).save(os.path.join(save_dir, 'img', f'{num}_{aug_type}.jpg'))
mask_pil = transforms.ToPILImage()(mask_tensor).convert("L")
mask_pil.save(os.path.join(save_dir, 'mask', f'{num}_{aug_type}.png'))
augmentationData(r'D:\wheat\project\tips\jpg', r'D:\wheat\project\tips\PNG',
save_dir=r'D:\wheat\project\tips\finished')
3 效果
4 Wheat_Seg_Mask数据集划分
(1)划分前
(2)划分后
(3)代码实现
import os
import shutil
import random
# 设置随机种子以保证结果可复现
random.seed(42)
# 原文件目录
image_dir = 'D:\\wheat\\project\\tips\\finished\\img'
mask_dir = 'D:\\wheat\\project\\tips\\finished\\mask'
# 目标目录
train_img_dir = 'D:\\wheat\\project\\tips\\Wheat_Seg_Mask\\img_dir\\train'
val_img_dir = 'D:\\wheat\\project\\tips\\Wheat_Seg_Mask\\img_dir\\val'
train_mask_dir = 'D:\\wheat\\project\\tips\\Wheat_Seg_Mask\\ann_dir\\train'
val_mask_dir = 'D:\\wheat\\project\\tips\\Wheat_Seg_Mask\\ann_dir\\val'
# 创建目标目录(如果不存在)
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)
os.makedirs(train_mask_dir, exist_ok=True)
os.makedirs(val_mask_dir, exist_ok=True)
# 获取所有文件名
images = os.listdir(image_dir)
masks = os.listdir(mask_dir)
# 确保图片和标签一一对应
images.sort()
masks.sort()
# 确保图片数量和标签数量相同
assert len(images) == len(masks), "图片和标签的数量不匹配!"
# 打乱数据集
data = list(zip(images, masks))
random.shuffle(data)
images, masks = zip(*data)
# 按比例划分训练集和验证集
split_idx = int(len(images) * 0.7)
train_images, val_images = images[:split_idx], images[split_idx:]
train_masks, val_masks = masks[:split_idx], masks[split_idx:]
# 复制文件到目标目录
for img, mask in zip(train_images, train_masks):
shutil.copy(os.path.join(image_dir, img), os.path.join(train_img_dir, img))
shutil.copy(os.path.join(mask_dir, mask), os.path.join(train_mask_dir, mask))
for img, mask in zip(val_images, val_masks):
shutil.copy(os.path.join(image_dir, img), os.path.join(val_img_dir, img))
shutil.copy(os.path.join(mask_dir, mask), os.path.join(val_mask_dir, mask))
print("数据集划分完成。")