import cv2
import random
import os
import numpy as np
from tqdm import tqdm
import glob
#################数据增强代码#############
img_w = 256
img_h = 256
#image_sets = ['1.png', '2.png', '3.png', '4.png', '5.png']
def load_data(image_sets):
for i in range(160):
a = str(i)
image_sets.append(a+'.bmp')
return image_sets
def gamma_transform(img, gamma):
gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]
gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)
return cv2.LUT(img, gamma_table)
def random_gamma_transform(img, gamma_vari):
log_gamma_vari = np.log(gamma_vari)
alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)
gamma = np.exp(alpha)
return gamma_transform(img, gamma)
def rotate(xb, yb, angle):
M_rotate = cv2.getRotationMatrix2D((img_w / 2, img_h / 2), angle, 1)
xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))
yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))
return xb, yb
def blur(img):
img = cv2.blur(img, (3, 3));
return img
def add_noise(img):
for i in range(200): # 添加点噪声
temp_x = np.random.randint(0, img.shape[0])
temp_y = np.random.randint(0, img.shape[1])
img[temp_x][temp_y] = 255
return img
def data_augment(xb, yb):
if np.random.random() < 0.25:
xb, yb = rotate(xb, yb, 90)
if np.random.random() < 0.25:
xb, yb = rotate(xb, yb, 180)
if np.random.random() < 0.25:
xb, yb = rotate(xb, yb, 270)
if np.random.random() < 0.25:
xb = cv2.flip(xb, 1) # flipcode > 0:沿y轴翻转
yb = cv2.flip(yb, 1)
if np.random.random() < 0.25:
xb = random_gamma_transform(xb, 1.0)
if np.random.random() < 0.25:
xb = blur(xb)
if np.random.random() < 0.2:
xb = add_noise(xb)
return xb, yb
def creat_dataset(image_num=1600, mode='original'):
print('creating dataset...')
image_sets = []
image_sets = load_data(image_sets)
image_each = image_num / len(image_sets)
g_count = 0
for i in tqdm(range(len(image_sets))):
count = 0
src_img = cv2.imread('D:/image/medical/train/Image/' + image_sets[i]) # 3 channels
label_img = cv2.imread('D:/image/medical/train/Mask/' + image_sets[i],
cv2.IMREAD_GRAYSCALE) # single channel
X_height, X_width, _ = src_img.shape
while count < image_each:
random_width = random.randint(0, X_width - img_w - 1)
random_height = random.randint(0, X_height - img_h - 1)
src_roi = src_img[random_height: random_height + img_h,
random_width: random_width + img_w, :]
label_roi = label_img[random_height: random_height + img_h,
random_width: random_width + img_w]
if mode == 'augment':
src_roi, label_roi = data_augment(src_roi, label_roi)
##可视
visualize = np.zeros((256, 256)).astype(np.uint8)
visualize = label_roi * 50
cv2.imwrite(('D:/image/train/visual/%d.png' % g_count), visualize)
cv2.imwrite(('D:/image/train/image/%d.png' % g_count), src_roi)
cv2.imwrite(('D:/image/train/mask/%d.png' % g_count), label_roi)
count += 1
g_count += 1
if __name__ == '__main__':
creat_dataset(mode='augment')
图像增强
最新推荐文章于 2020-11-01 21:22:36 发布
该博客提供了一段Python代码用于图像数据增强。代码中定义了多种数据增强方法,如伽马变换、旋转、模糊、添加噪声等。通过`creat_dataset`函数可生成指定数量的增强后图像数据,并将其保存到指定目录,适用于医学图像等数据的扩充。
1209





