1. 数据集获取
import os
from IPython.display import Image,display
from tensorflow.keras.preprocessing.image import load_img
from PIL import ImageOps
1.1 设置相关信息
# 图像位置
input_dir = 'segdata/images/'
# 图像路径
input_img_path = sorted([os.path.join(input_dir, fname)
for fname in os.listdir(input_dir) if fname.endswith('.jpg')])
# 标注信息
target_dir = 'segdata/annotations/trimaps/'
# 目标值
target_img_path = sorted([os.path.join(target_dir, fname) for fname in os.listdir(
target_dir) if fname.endswith('.png') and not fname.startswith('.')])
img_size = (160,160)
batch_size = 32
num_classes = 4
1.2 图像展示
# 显示⼀个图像
display(Image(filename=input_img_paths[10]))
标注信息中只有3个值,我们使⽤PIL.ImageOps.autocontrast进⾏展示, 该⽅法计算输⼊图像的直⽅图,然后重新映射图像,最暗像素变为⿊⾊, 即0,最亮的变为⽩⾊,即255,其他的值以其他的灰度值进⾏显示,在这 ⾥前景,背景和不确定分别标注为:1,2,3,所以前景最⼩显示为⿊ ⾊,不确定的区域最⼤显示为⽩⾊。
# 显示标注图像
img = ImageOps.autocontrast(load_img(target_img_path[10]))
display(img)
1.3 数据集生成器
from tensorflow import keras
import numpy as np
from tensorflow.keras.preprocessing.image import load_img
class OxfordPets(keras.utils.Sequence):
# 初始化
def __init__(self,batch_size,img_size,input_img_paths,target_img_paths):
# 批次大小
self.batch_size = batch_size
# 图像大小
self.img_size = img_size
# 图像的路径
self.input_img_paths = input_img_paths
# 目标值路经
self.target_img_paths = target_img_paths
# 迭代次数
def __len__(self):
return len(self.target_img_paths)//self.batch_size
# 获取batch数据
def __getitem__(self,idx):
# 获取该批次对应的样本的索引
i = idx * self.batch_size
# 获取该批次数据
batch_input_img_paths = self.input_img_paths[i:i+self.batch_size]
batch_target_img_paths = self.target_img_paths[i:i+self.batch_size]
# 构建特征值
x = np.zeros((batch_size,)+self.img_size+(3,),dtype="float32")
for j,path in enumerate(batch_input_img_paths):
img = load_img(path,target_size=self.img_size)
x[j] = img
# 构建目标值
y = np.zeros((batch_size,)+self.img_size+(1,),dtype='uint8')
for j,path in enumerate(batch_target_img_paths):
img = load_img(path,target_size=self.img_size,color_mode='grayscale')
y[j] = np.expand_dims(img,2)
return x,y
2.模型构建
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, Cropping2D, Concatenate
from tensorflow.keras.layers import Lambda, Activation, BatchNormalization, Dropout
from tensorflow.keras.models import Model
2.1 编码部分
def downsampling_block(input_tensor,filters):
# 输入:input_tensor,通道数:filters
# 卷积
x = Conv2D(filters,kernel_size=(3,3),padding='same')(input_tensor)
# BN
x = BatchNormalization()(x)
# 激活
x = Activation('relu')(x)
# 卷积
x = Conv2D(filters,kernel_size=(3,3),padding='same')(x)
# BN
x = BatchNormalization()(x)
# 激活
x = Activation('relu')(x)
# 返回
return MaxPooling2D(pool_size=(2,2))(x),x
2.2 解码部分
def upsampling_block(input_tensor,skip_tensor,filters):
# input——tensor:输入特征图,skip_tensor:编码部分的特征图,filters:通道数
# 反卷积
x = Conv2DTranspose(filters,kernel_size=(2,2),strides=(2,2),padding='same')(input_tensor)
# 尺寸
_,x_height,x_width,_ = x.shape
_,s_height,s_width,_ = skip_tensor.shape
# 计算差异
h_crop = s_height-x_height
w_crop = s_width-x_width
# 判断是否进行裁剪
if h_crop==0 and w_crop ==0:
y = skip_tensor
else:
# 获取裁剪的大小
cropping = ((h_crop//2,h_crop-h_crop//2),(w_crop//2,w_crop-w_crop//2))
y = Cropping2D(cropping=cropping)(skip_tensor)
# 特征融合
x = Concatenate()([x,y])
# 卷积
x = Conv2D(filters,kernel_size=(3,3),padding='same')(x)
# BN
x = BatchNormalization()(x)
# 激活层
x = Activation('relu')(x)
# 卷积
x = Conv2D(filters,kernel_size=(3,3),padding='same')(x)
# BN
x = BatchNormalization()(x)
# 激活层
x = Activation('relu')(x)
return x
2.3 unet网络
将编码部分和解码部分组合⼀起,就可构建unet⽹络,在这⾥unet⽹络的 深度通过depth进⾏设置,并设置第⼀个编码模块的卷积核个数通过filter 进⾏设置,通过以下模块将编码和解码部分进⾏组合:
def unet(imagesize,classes,fetures=64,depth=3):
# 定义输入
inputs = keras.Input(shape=(imagesize+(3,)))
x = inputs
# 构建编码部分
skips = []
for i in range(depth):
x,x0 = downsampling_block(x,fetures)
skips.append(x0)
fetures *=2
# 卷积
x = Conv2D(filters=fetures,kernel_size=(3,3),padding='same')(x)
# BN
x = BatchNormalization()(x)
# 激活
x = Activation('relu')(x)
# 卷积
x = Conv2D(filters=fetures,kernel_size=(3,3),padding='same')(x)
# BN
x = BatchNormalization()(x)
# 激活
x = Activation('relu')(x)
# 解码部分
for i in reversed(range(depth)):
fetures //=2
x = upsampling_block(x,skips[i],fetures)
# 1x1卷积
x = Conv2D(filters= classes,kernel_size=(1,1),padding='same')(x)
# 激活
outputs = Activation('softmax')(x)
return keras.Model(inputs=inputs,outputs = outputs)
model = unet(img_size,num_classes)
model = unet(img_size,num_classes)
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 160, 160, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 160, 160, 64) 1792 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 160, 160, 64) 256 conv2d[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 160, 160, 64) 0 batch_normalization[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 160, 160, 64) 36928 activation[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 160, 160, 64) 256 conv2d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 160, 160, 64) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 80, 80, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 80, 80, 128) 73856 max_pooling2d[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 80, 80, 128) 512 conv2d_2[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 80, 80, 128) 0 batch_normalization_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 80, 80, 128) 147584 activation_2[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 80, 80, 128) 512 conv2d_3[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 80, 80, 128) 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 40, 40, 128) 0 activation_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 40, 40, 256) 295168 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 40, 40, 256) 1024 conv2d_4[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 40, 40, 256) 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 40, 40, 256) 590080 activation_4[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 40, 40, 256) 1024 conv2d_5[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 40, 40, 256) 0 batch_normalization_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 20, 20, 256) 0 activation_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 20, 20, 512) 1180160 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 20, 20, 512) 2048 conv2d_6[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 20, 20, 512) 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 20, 20, 512) 2359808 activation_6[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 20, 20, 512) 2048 conv2d_7[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 20, 20, 512) 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 40, 40, 256) 524544 activation_7[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 40, 40, 512) 0 conv2d_transpose[0][0]
activation_5[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 40, 40, 256) 1179904 concatenate[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 40, 40, 256) 1024 conv2d_8[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 40, 40, 256) 0 batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 40, 40, 256) 590080 activation_8[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 40, 40, 256) 1024 conv2d_9[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 40, 40, 256) 0 batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 80, 80, 128) 131200 activation_9[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 80, 80, 256) 0 conv2d_transpose_1[0][0]
activation_3[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 80, 80, 128) 295040 concatenate_1[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 80, 80, 128) 512 conv2d_10[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 80, 80, 128) 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 80, 80, 128) 147584 activation_10[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 80, 80, 128) 512 conv2d_11[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 80, 80, 128) 0 batch_normalization_11[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 160, 160, 64) 32832 activation_11[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 160, 160, 128 0 conv2d_transpose_2[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 160, 160, 64) 73792 concatenate_2[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 160, 160, 64) 256 conv2d_12[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 160, 160, 64) 0 batch_normalization_12[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 160, 160, 64) 36928 activation_12[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 160, 160, 64) 256 conv2d_13[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 160, 160, 64) 0 batch_normalization_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 160, 160, 4) 260 activation_13[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 160, 160, 4) 0 conv2d_14[0][0]
==================================================================================================
Total params: 7,708,804
Trainable params: 7,703,172
Non-trainable params: 5,632
keras.utils.plot_model(model)
3. 模型训练
3.1 数据集划分
import random
# 验证集数量
val_samples = 1000
# 打乱
random.Random(1337).shuffle(input_img_path)
random.Random(1337).shuffle(target_img_path)
# 划分数据集
# 训练集
train_input_img_paths = input_img_path[:-val_samples]
train_target_img_paths = target_img_path[:-val_samples]
# 验证集
val_input_img_paths = input_img_path[-val_samples:]
val_target_img_paths = target_img_path[-val_samples:]
3.2 数据集获取
train_gen = OxfordPets(batch_size,img_size,train_input_img_paths,train_target_img_paths)
val_gen = OxfordPets(batch_size,img_size,val_input_img_paths,val_target_img_paths)
3.3 模型编译
model.compile(optimizer='rmsprop',loss="sparse_categorical_crossentropy")
3.4 模型训练
model.fit(train_gen,epochs=2,validation_data=val_gen,steps_per_epoch=1,validation_steps=1)
4.模型预测
# 获取验证集数据,并进⾏预测
val_gen = OxfordPets(batch_size, img_size, val_target_img_paths)
val_preds = model.predict(val_gen)
# 图像显示
def display_mask(i):
# 获取到第i个样本的预测结果
mask = np.argmax(val_preds[i], axis=-1)
# 维度调整
mask = np.expand_dims(mask, axis=-1)
# 转换为图像,并进⾏显示
img = PIL.ImageOps.autocontrast(keras.preprocessing.image.array_to_img(mask))
display(img)
# 选中验证集的第10个图像
i = 10
# 输⼊图像显示
display(Image(filename=val_input_img_paths[i]))
# 真实值显示
img = PIL.ImageOps.autocontrast(load_img(val_target_img_pat
display(img)
# 显示预测结果
display_mask(i)