segmentation语义分割数据增强方法及代码

# -*- coding:utf-8 -*-
"""数据增强
   1. 翻转变换 flip
   2. 随机修剪 random crop
   3. 色彩抖动 color jittering
   4. 平移变换 shift
   5. 尺度变换 scale
   6. 对比度变换 contrast
   7. 噪声扰动 noise
   8. 旋转变换/反射变换 Rotation/reflection
"""
 
from PIL import Image, ImageEnhance, ImageOps, ImageFile
import numpy as np
import random
import threading, os, time
import logging
 
logger = logging.getLogger(__name__)
ImageFile.LOAD_TRUNCATED_IMAGES = True
 
 
class DataAugmentation:
    """
    包含数据增强的八种方式
    """
 
 
    def __init__(self):
        pass
 
    @staticmethod
    def openImage(image):
        return Image.open(image, mode="r")
 
    @staticmethod
    def randomRotation(image, label, mode=Image.BICUBIC):
        """
         对图像进行随机任意角度(0~360度)旋转
        :param mode 邻近插值,双线性插值,双三次B样条插值(default)
        :param image PIL的图像image
        :return: 旋转转之后的图像
        """
        random_angle = np.random.randint(1, 360)
        return image.rotate(random_angle, mode) , label.rotate(random_angle, Image.NEAREST)
 
    #暂时未使用这个函数
    @staticmethod
    def randomCrop(image, label):
        """
        对图像随意剪切,考虑到图像大小范围(68,68),使用一个一个大于(36*36)的窗口进行截图
        :param image: PIL的图像image
        :return: 剪切之后的图像
        """
        image_width = image.size[0]
        image_height = image.size[1]
        crop_win_size = np.random.randint(40, 68)
        random_region = (
            (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1,
            (image_height + crop_win_size) >> 1)
        return image.crop(random_region), label
 
    @staticmethod
    def randomColor(image, label):
        """
        对图像进行颜色抖动
        :param image: PIL的图像image
        :return: 有颜色色差的图像image
        """
        random_factor = np.random.randint(0, 31) / 10.  # 随机因子
        color_image = ImageEnhance.Color(image).enhance(random_factor)  # 调整图像的饱和度
        random_factor = np.random.randint(10, 21) / 10.  # 随机因子
        brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 调整图像的亮度
        random_factor = np.random.randint(10, 21) / 10.  # 随机因1子
        contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 调整图像对比度
        random_factor = np.random.randint(0, 31) / 10.  # 随机因子
        return ImageEnhance.Sharpness(contrast_image).enhance(random_factor) ,label # 调整图像锐度
 
    @staticmethod
    def randomGaussian(image, label, mean=0.2, sigma=0.3):
        """
         对图像进行高斯噪声处理
        :param image:
        :return:
        """
 
        def gaussianNoisy(im, mean=0.2, sigma=0.3):
            """
            对图像做高斯噪音处理
            :param im: 单通道图像
            :param mean: 偏移量
            :param sigma: 标准差
            :return:
            """
            for _i in range(len(im)):
                im[_i] += random.gauss(mean, sigma)
            return im
 
        # 将图像转化成数组
        img = np.asarray(image)
        img.flags.writeable = True  # 将数组改为读写模式
        width, height = img.shape[:2]
        img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma)
        img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma)
        img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma)
        img[:, :, 0] = img_r.reshape([width, height])
        img[:, :, 1] = img_g.reshape([width, height])
        img[:, :, 2] = img_b.reshape([width, height])
        return Image.fromarray(np.uint8(img)), label
 
    @staticmethod
    def saveImage(image, path):
        image.save(path)
 
 
def makeDir(path):
    try:
        if not os.path.exists(path):
            if not os.path.isfile(path):
                # os.mkdir(path)
                os.makedirs(path)
            return 0
        else:
            return 1
    except Exception, e:
        print str(e)
        return -2
 
 
def imageOps(func_name, image, label, img_des_path, label_des_path , img_file_name, label_file_name, times=5):
    funcMap = {"randomRotation": DataAugmentation.randomRotation,
               "randomCrop": DataAugmentation.randomCrop,
               "randomColor": DataAugmentation.randomColor,
               "randomGaussian": DataAugmentation.randomGaussian
               }
    if funcMap.get(func_name) is None:
        logger.error("%s is not exist", func_name)
        return -1
 
    for _i in range(0, times, 1):
        new_image , new_label = funcMap[func_name](image,label)
        DataAugmentation.saveImage(new_image, os.path.join(img_des_path, func_name + str(_i) + img_file_name))
        DataAugmentation.saveImage(new_label, os.path.join(label_des_path, func_name + str(_i) + label_file_name))
 
 
opsList = {"randomRotation",  "randomColor", "randomGaussian"}
 
 
def threadOPS(img_path, new_img_path, label_path, new_label_path):
    """
    多线程处理事务
    :param src_path: 资源文件
    :param des_path: 目的地文件
    :return:
    """
    #img path 
    if os.path.isdir(img_path):
        img_names = os.listdir(img_path)
    else:
        img_names = [img_path]
 
    #label path 
    if os.path.isdir(label_path):
        label_names = os.listdir(label_path)
    else:
        label_names = [label_path]
 
    img_num = 0
    label_num = 0
 
    #img num
    for img_name in img_names:
        tmp_img_name = os.path.join(img_path, img_name)
        if os.path.isdir(tmp_img_name):
            print('contain file folder')
            exit()
        else:
            img_num = img_num + 1;
    #label num
    for label_name in label_names:
        tmp_label_name = os.path.join(label_path, label_name)
        if os.path.isdir(tmp_label_name):
            print('contain file folder')
            exit()
        else:
            label_num = label_num + 1
 
    if img_num != label_num:
        print('the num of img and label is not equl')
        exit()
    else: 
        num = img_num
 
 
    for i in range(num):
        img_name = img_names[i]
        print img_name
        label_name = label_names[i]
        print label_name
 
        tmp_img_name = os.path.join(img_path, img_name)
        tmp_label_name = os.path.join(label_path, label_name)
 
        # 读取文件并进行操作
        image = DataAugmentation.openImage(tmp_img_name)
        label = DataAugmentation.openImage(tmp_label_name)
 
        threadImage = [0] * 5
        _index = 0
        for ops_name in opsList:
            threadImage[_index] = threading.Thread(target=imageOps,
                                                    args=(ops_name, image, label, new_img_path, new_label_path, img_name, label_name))
            threadImage[_index].start()
            _index += 1
            time.sleep(0.2)
 
 
if __name__ == '__main__':
    threadOPS("/data1/qixinyuan/data/datasets/little/img",
              "/data1/qixinyuan/data/datasets/little/new_img",
              "/data1/qixinyuan/data/datasets/little/label",
              "/data1/qixinyuan/data/datasets/little/new_label")

还有一种

安装Augmentor

pip install Augmentor

对图片进行随机旋转

import Augmentor
p = Augmentor.Pipeline("/path/to/images")
p.rotate(probability=1, max_left_rotation=5, max_right_rotation=5) #probability表示以一定概率随机处理图片
p.sample(500) #产生500张图片

image and ground truth data can be identically augmented

p = Augmentor.Pipeline("/path/to/images")
# Point to a directory containing ground truth data.
# Images with the same file names will be added as ground truth data
# and augmented in parallel to the original data.
p.ground_truth("/path/to/ground_truth_images")
# Add operations to the pipeline as normal:
p.rotate(probability=1, max_left_rotation=5, max_right_rotation=5)
p.flip_left_right(probability=0.5)
p.zoom_random(probability=0.5, percentage_area=0.8)
p.flip_top_bottom(probability=0.5)
p.sample(50)

这里写图片描述

在旋转图片时,常常会在图片周围产生空白填充,如图 

这里写图片描述

这里写图片描述

遇到这种情况,Augmentor会在旋转的时候同时缩放图片,不致在四周出现黑色填充 

这里写图片描述

问题

在使用 
ground_truth()函数时,如果路径中有多张图片,将会导致augment之后的mask和image不对应,因此只能在路径中存放一张图片,如果有很多组数据需要augment则需要将他们单个存放在文件夹中

下面是我的代码:

 # -*- coding: utf-8 -*-
import Augmentor
import glob
import os
import random
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

train_path = 'train'
groud_truth_path = 'mask'
img_type = 'jpg'
train_tmp_path = 'tmp/train'
mask_tmp_path = 'tmp/mask'

def start(train_path,groud_truth_path):
    train_img = glob.glob(train_path+'/*.'+img_type)
    masks = glob.glob(groud_truth_path+'/*.'+img_type)

    if len(train_img) != len(masks):
        print ("trains can't match masks")
        return 0
    for i in range(len(train_img)):
        train_img_tmp_path = train_tmp_path + '/'+str(i)
        if not os.path.lexists(train_img_tmp_path):
            os.mkdir(train_img_tmp_path)
        img = load_img(train_path+'/'+str(i)+'.'+img_type)
        x_t = img_to_array(img)
        img_tmp = array_to_img(x_t)
        img_tmp.save(train_img_tmp_path+'/'+str(i)+'.'+img_type)

        mask_img_tmp_path =mask_tmp_path +'/'+str(i)
        if not os.path.lexists(mask_img_tmp_path):
            os.mkdir(mask_img_tmp_path)
        mask = load_img(groud_truth_path+'/'+str(i)+'.'+img_type)
        x_l = img_to_array(mask)
        mask_tmp = array_to_img(x_l)
        mask_tmp.save(mask_img_tmp_path+'/'+str(i)+'.'+img_type)
        print ("%s folder has been created!"%str(i))
    return i+1


def doAugment(num):
    sum = 0
    for i in range(num):
        p = Augmentor.Pipeline(train_tmp_path+'/'+str(i))
        p.ground_truth(mask_tmp_path+'/'+str(i))
        p.rotate(probability=0.5, max_left_rotation=5, max_right_rotation=5)#旋转
        p.flip_left_right(probability=0.5)#按概率左右翻转
        p.zoom_random(probability=0.6, percentage_area=0.99)#随即将一定比例面积的图形放大至全图
        p.flip_top_bottom(probability=0.6)#按概率随即上下翻转
        p.random_distortion(probability=0.8,grid_width=10,grid_height=10, magnitude=20)#小块变形
        count = random.randint(40, 60)
        print("\nNo.%s data is being augmented and %s data will be created"%(i,count))
        sum = sum + count
        p.sample(count)
        print("Done")
    print("%s pairs of data has been created totally"%sum)


a = start(train_path, groud_truth_path)
doAugment(a)

参考:

https://blog.youkuaiyun.com/qq_20852429/article/details/79137777

https://blog.youkuaiyun.com/jiuliang1916/article/details/79498885

https://github.com/mdbloice/Augmentor

### 关于语义分割数据增强代码实现 在语义分割任务中,数据增强是一种有效的方法来提升模型性能。以下分别提供基于 PyTorch TensorFlow 的语义分割数据增强代码示例。 --- #### 基于 PyTorch 的语义分割数据增强代码 在 PyTorch 中,可以使用 `torchvision.transforms` 对图像及其对应的标签进行同步变换。为了确保图像其标注的一致性,通常需要自定义一个数据加载器类。以下是具体实现: ```python import torch from torchvision import transforms from PIL import Image import numpy as np from torch.utils.data import Dataset, DataLoader class SegmentationDataset(Dataset): def __init__(self, image_paths, mask_paths, transform=None): self.image_paths = image_paths self.mask_paths = mask_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] mask_path = self.mask_paths[idx] # 读取图像掩码 image = Image.open(img_path).convert('RGB') mask = Image.open(mask_path) if self.transform is not None: image = self.transform(image) mask = self.transform(mask) return image, mask # 自定义数据增强转换 data_transforms = transforms.Compose([ transforms.RandomResizedCrop(size=(256, 256)), # 随机裁剪并调整大小 transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 色彩抖动 transforms.ToTensor(), # 转换为张量 ]) # 初始化数据集数据加载器 image_paths = [...] # 图像文件路径列表 mask_paths = [...] # 掩码文件路径列表 dataset = SegmentationDataset(image_paths=image_paths, mask_paths=mask_paths, transform=data_transforms) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # 测试数据加载器 for images, masks in dataloader: print(f"Images shape: {images.shape}, Masks shape: {masks.shape}") break ``` 上述代码展示了一个简单的语义分割数据集类,并实现了常见的数据增强方法,如随机裁剪、水平翻转以及色彩抖动[^1]。 --- #### 基于 TensorFlow/Keras 的语义分割数据增强代码 在 TensorFlow 中,可以通过 `tf.keras.layers.experimental.preprocessing` 提供的功能轻松实现数据增强。下面是一个完整的代码示例: ```python import tensorflow as tf from tensorflow.keras.layers.experimental.preprocessing import RandomFlip, RandomRotation, Resizing, Normalization from tensorflow.keras.models import Model from tensorflow.keras.applications import EfficientNetB0 from tensorflow.keras.layers import Input, Conv2DTranspose, Concatenate, Dropout def create_segmentation_model(input_shape=(256, 256, 3)): inputs = Input(shape=input_shape) # 数据增强层 augmented_inputs = RandomFlip("horizontal")(inputs) augmented_inputs = RandomRotation(0.2)(augmented_inputs) # 主干网络 (EfficientNet作为编码器) backbone = EfficientNetB0(include_top=False, weights='imagenet', input_tensor=augmented_inputs) backbone.trainable = True skips = [backbone.get_layer(name).output for name in [ "block2a_expand_activation", "block3a_expand_activation", "block4a_expand_activation", "block6a_expand_activation" ]] x = backbone.output # 解码器部分 upsample_layers = [] for i, skip in enumerate(reversed(skips)): x = Conv2DTranspose(256 // (2 ** i), kernel_size=3, strides=2, padding="same", activation="relu")(x) x = Concatenate()([x, skip]) x = Dropout(0.3)(x) upsample_layers.append(x) outputs = Conv2DTranspose(1, kernel_size=3, strides=2, padding="same", activation="sigmoid")(upsample_layers[-1]) model = Model(inputs, outputs) return model # 创建模型实例 segmentation_model = create_segmentation_model() # 编译模型 segmentation_model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]) # 准备数据集 train_dataset = ( tf.data.Dataset.list_files("./path/to/images/*.png") # 替换为实际路径 .map(lambda x: load_image_and_mask(x)) .batch(4) .prefetch(tf.data.AUTOTUNE) ) # 训练模型 segmentation_model.fit(train_dataset, epochs=10) ``` 在这个例子中,数据增强操作被嵌入到模型架构中,使得每次前向传播都会自动应用这些增强[^2]。 --- #### 注意事项 1. **一致性问题** 在语义分割任务中,必须保证原始图像与其对应掩码之间的几何变换一致(例如翻转、旋转等),因此建议自行设计数据增强逻辑以满足这一需求。 2. **增强强度的选择** 数据增强的程度应当适中;过度增强可能导致分布偏移,而不足则可能无法充分挖掘潜在特征。 ---
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值