目标 分割:UNet案例

1. 数据集获取

import os
from IPython.display import Image,display
from tensorflow.keras.preprocessing.image import load_img
from PIL import ImageOps

1.1 设置相关信息

# 图像位置
input_dir = 'segdata/images/'
# 图像路径
input_img_path = sorted([os.path.join(input_dir, fname)
                         for fname in os.listdir(input_dir) if fname.endswith('.jpg')])
# 标注信息
target_dir = 'segdata/annotations/trimaps/'

# 目标值
target_img_path = sorted([os.path.join(target_dir, fname) for fname in os.listdir(
    target_dir) if fname.endswith('.png') and not fname.startswith('.')])

img_size = (160,160)
batch_size = 32
num_classes = 4

1.2 图像展示

# 显示⼀个图像 
display(Image(filename=input_img_paths[10])) 

标注信息中只有3个值,我们使⽤PIL.ImageOps.autocontrast进⾏展示, 该⽅法计算输⼊图像的直⽅图,然后重新映射图像,最暗像素变为⿊⾊, 即0,最亮的变为⽩⾊,即255,其他的值以其他的灰度值进⾏显示,在这 ⾥前景,背景和不确定分别标注为:1,2,3,所以前景最⼩显示为⿊ ⾊,不确定的区域最⼤显示为⽩⾊。

# 显示标注图像 
img = ImageOps.autocontrast(load_img(target_img_path[10]))
display(img) 

1.3 数据集生成器

from tensorflow import keras
import numpy as np
from tensorflow.keras.preprocessing.image import load_img

class OxfordPets(keras.utils.Sequence):
    # 初始化
    def __init__(self,batch_size,img_size,input_img_paths,target_img_paths):
        # 批次大小
        self.batch_size = batch_size
        # 图像大小
        self.img_size = img_size
        # 图像的路径
        self.input_img_paths = input_img_paths
        # 目标值路经
        self.target_img_paths = target_img_paths        

    # 迭代次数
    def __len__(self):
        return len(self.target_img_paths)//self.batch_size

    # 获取batch数据
    def __getitem__(self,idx):
        # 获取该批次对应的样本的索引
        i = idx * self.batch_size
        # 获取该批次数据
        batch_input_img_paths = self.input_img_paths[i:i+self.batch_size]
        batch_target_img_paths = self.target_img_paths[i:i+self.batch_size]
        # 构建特征值
        x = np.zeros((batch_size,)+self.img_size+(3,),dtype="float32")
        for j,path in enumerate(batch_input_img_paths):
            img = load_img(path,target_size=self.img_size)
            x[j] = img
        # 构建目标值
        y = np.zeros((batch_size,)+self.img_size+(1,),dtype='uint8')
        for j,path in enumerate(batch_target_img_paths):
            img = load_img(path,target_size=self.img_size,color_mode='grayscale')
            y[j] = np.expand_dims(img,2)
        return x,y

2.模型构建 

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, Cropping2D, Concatenate
from tensorflow.keras.layers import Lambda, Activation, BatchNormalization, Dropout
from tensorflow.keras.models import Model

2.1 编码部分 

def downsampling_block(input_tensor,filters):
    # 输入:input_tensor,通道数:filters
    # 卷积
    x = Conv2D(filters,kernel_size=(3,3),padding='same')(input_tensor)
    # BN 
    x = BatchNormalization()(x)
    # 激活
    x = Activation('relu')(x)
    # 卷积
    x = Conv2D(filters,kernel_size=(3,3),padding='same')(x)
    # BN 
    x = BatchNormalization()(x)
    # 激活
    x = Activation('relu')(x)
    # 返回
    return MaxPooling2D(pool_size=(2,2))(x),x

2.2 解码部分 

def upsampling_block(input_tensor,skip_tensor,filters):
    # input——tensor:输入特征图,skip_tensor:编码部分的特征图,filters:通道数
    # 反卷积
    x = Conv2DTranspose(filters,kernel_size=(2,2),strides=(2,2),padding='same')(input_tensor)
    # 尺寸
    _,x_height,x_width,_ = x.shape
    _,s_height,s_width,_ = skip_tensor.shape
    # 计算差异
    h_crop = s_height-x_height
    w_crop = s_width-x_width
    # 判断是否进行裁剪
    if h_crop==0 and w_crop ==0:
        y = skip_tensor
    else:
        # 获取裁剪的大小
        cropping = ((h_crop//2,h_crop-h_crop//2),(w_crop//2,w_crop-w_crop//2))
        y = Cropping2D(cropping=cropping)(skip_tensor)
    # 特征融合
    x = Concatenate()([x,y])
    # 卷积
    x = Conv2D(filters,kernel_size=(3,3),padding='same')(x)
    # BN 
    x = BatchNormalization()(x)
    # 激活层
    x = Activation('relu')(x)
    # 卷积
    x = Conv2D(filters,kernel_size=(3,3),padding='same')(x)
    # BN 
    x = BatchNormalization()(x)
    # 激活层
    x = Activation('relu')(x)
    return x

2.3 unet网络 

将编码部分和解码部分组合⼀起,就可构建unet⽹络,在这⾥unet⽹络的 深度通过depth进⾏设置,并设置第⼀个编码模块的卷积核个数通过filter 进⾏设置,通过以下模块将编码和解码部分进⾏组合:

def unet(imagesize,classes,fetures=64,depth=3):
    # 定义输入
    inputs = keras.Input(shape=(imagesize+(3,)))
    x = inputs
    # 构建编码部分
    skips = []
    for i in range(depth):
        x,x0 = downsampling_block(x,fetures)
        skips.append(x0)
        fetures *=2
    # 卷积
    x = Conv2D(filters=fetures,kernel_size=(3,3),padding='same')(x)
    # BN
    x = BatchNormalization()(x)
    # 激活
    x = Activation('relu')(x)
    # 卷积
    x = Conv2D(filters=fetures,kernel_size=(3,3),padding='same')(x)
    # BN
    x = BatchNormalization()(x)
    # 激活
    x = Activation('relu')(x)
    # 解码部分
    for i in reversed(range(depth)):
        fetures //=2
        x = upsampling_block(x,skips[i],fetures)
    # 1x1卷积
    x = Conv2D(filters= classes,kernel_size=(1,1),padding='same')(x)
    # 激活
    outputs = Activation('softmax')(x)
    return keras.Model(inputs=inputs,outputs = outputs)
model = unet(img_size,num_classes)

model = unet(img_size,num_classes)

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 160, 160, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 160, 160, 64) 256         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 160, 160, 64) 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 160, 160, 64) 36928       activation[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 160, 160, 64) 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 160, 160, 64) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 80, 80, 64)   0           activation_1[0][0]               
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 80, 80, 128)  73856       max_pooling2d[0][0]              
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 80, 80, 128)  512         conv2d_2[0][0]                   
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 80, 80, 128)  0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 80, 80, 128)  147584      activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 80, 80, 128)  512         conv2d_3[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 80, 80, 128)  0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 40, 40, 128)  0           activation_3[0][0]               
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 40, 40, 256)  295168      max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 40, 40, 256)  1024        conv2d_4[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 40, 40, 256)  0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 40, 40, 256)  590080      activation_4[0][0]               
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 40, 40, 256)  1024        conv2d_5[0][0]                   
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 40, 40, 256)  0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 20, 20, 256)  0           activation_5[0][0]               
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 20, 20, 512)  1180160     max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 20, 20, 512)  2048        conv2d_6[0][0]                   
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 20, 20, 512)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 20, 20, 512)  2359808     activation_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 20, 20, 512)  2048        conv2d_7[0][0]                   
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 20, 20, 512)  0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 40, 40, 256)  524544      activation_7[0][0]               
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 40, 40, 512)  0           conv2d_transpose[0][0]           
                                                                 activation_5[0][0]               
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 40, 40, 256)  1179904     concatenate[0][0]                
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 40, 40, 256)  1024        conv2d_8[0][0]                   
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 40, 40, 256)  0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 40, 40, 256)  590080      activation_8[0][0]               
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 40, 40, 256)  1024        conv2d_9[0][0]                   
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 40, 40, 256)  0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 80, 80, 128)  131200      activation_9[0][0]               
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 80, 80, 256)  0           conv2d_transpose_1[0][0]         
                                                                 activation_3[0][0]               
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 80, 80, 128)  295040      concatenate_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 80, 80, 128)  512         conv2d_10[0][0]                  
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 80, 80, 128)  0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 80, 80, 128)  147584      activation_10[0][0]              
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 80, 80, 128)  512         conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 80, 80, 128)  0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 160, 160, 64) 32832       activation_11[0][0]              
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 160, 160, 128 0           conv2d_transpose_2[0][0]         
                                                                 activation_1[0][0]               
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 160, 160, 64) 73792       concatenate_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 160, 160, 64) 256         conv2d_12[0][0]                  
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 160, 160, 64) 0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 160, 160, 64) 36928       activation_12[0][0]              
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 160, 160, 64) 256         conv2d_13[0][0]                  
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 160, 160, 64) 0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 160, 160, 4)  260         activation_13[0][0]              
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 160, 160, 4)  0           conv2d_14[0][0]                  
==================================================================================================
Total params: 7,708,804
Trainable params: 7,703,172
Non-trainable params: 5,632 

keras.utils.plot_model(model)

3. 模型训练

3.1 数据集划分

import random
# 验证集数量
val_samples = 1000
# 打乱
random.Random(1337).shuffle(input_img_path)
random.Random(1337).shuffle(target_img_path)
# 划分数据集
# 训练集
train_input_img_paths = input_img_path[:-val_samples]
train_target_img_paths = target_img_path[:-val_samples]
# 验证集
val_input_img_paths = input_img_path[-val_samples:]
val_target_img_paths = target_img_path[-val_samples:]

3.2 数据集获取

train_gen = OxfordPets(batch_size,img_size,train_input_img_paths,train_target_img_paths)
val_gen = OxfordPets(batch_size,img_size,val_input_img_paths,val_target_img_paths)

3.3 模型编译

model.compile(optimizer='rmsprop',loss="sparse_categorical_crossentropy")

3.4 模型训练

model.fit(train_gen,epochs=2,validation_data=val_gen,steps_per_epoch=1,validation_steps=1)

4.模型预测 

# 获取验证集数据,并进⾏预测 
val_gen = OxfordPets(batch_size, img_size, val_target_img_paths)
val_preds = model.predict(val_gen) 

# 图像显示
def display_mask(i): 
     # 获取到第i个样本的预测结果 
     mask = np.argmax(val_preds[i], axis=-1) 
     # 维度调整 
     mask = np.expand_dims(mask, axis=-1) 
     # 转换为图像,并进⾏显示 
     img = PIL.ImageOps.autocontrast(keras.preprocessing.image.array_to_img(mask))
     display(img) 

# 选中验证集的第10个图像 
i = 10

# 输⼊图像显示 
display(Image(filename=val_input_img_paths[i])) 

# 真实值显示 
img = PIL.ImageOps.autocontrast(load_img(val_target_img_pat
display(img) 

# 显示预测结果 
display_mask(i) 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值