介绍
- 论文原地址: Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks
在U-Shape网络模型的基础上添加SE模块,具体加在每个卷积模块后,如图所示。

Spatial Squeeze and Channel Excitation Block


from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
import tensorflow as tf
def normal_conv2d(x, kernel_size, filters, strides, activation='relu', Separable=False):
"""
卷积模块
:param x: 输入数组[b, w, h, filter]
:param kernel_size: 卷积尺寸
:param filters: 卷积核数量
:param strides: 卷积步长
:param Separable: 是否使用可分离卷积
:param coord_conv2d: 是否使用坐标卷积
:param activation: 激活函数
:return: 输出数组[b, w, h, filters]
"""
if Separable:
x = layers.SeparableConv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)
else:
x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(activation)(x)
return x
def CAttention(x, channel):
x_origin = x
x = layers.GlobalAveragePooling2D()(x)
x = K.expand_dims(x, axis=1)
x = K.expand_dims(x, axis=1)
x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel//2, activation='relu')
x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel, activation='sigmoid')
x = layers.UpSampling2D(size=(x_origin.shape[1], x_origin.shape[2]), interpolation='nearest')(x)
x = tf.multiply(x, x_origin)
return x
Channel Squeeze and Spatial Excitation Block

def SAttention(x):
x_origin = x
x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=1, activation='sigmoid')
x = tf.multiply(x, x_origin)
return x
Spatial and Channel Squeeze & Excitation Block (scSE)

def SCAttention(x, channel):
x1 = CAttention(x, channel)
x2 = SAttention(x)
x = layers.Add()([x1, x2])
return x
模型代码示例
"""
SCAttentionNet: with attention
Author: XG_hechao
Begin Date: 20201113
End Date:
"""
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
from CoordConv2 import Coord_Conv2d
import tensorflow as tf
def Channel_Split(x):
"""
通道划分
:param x: 输入数组[b, w, h, filters]
:return: 两个数组[b, w, h, 0:filters/2]和[b, w, h, filters/2:]
"""
channel_before = x.shape.as_list()[1:] # 取通道数
split_channel_num = channel_before[2] // 2 # 取一半通道
channel_one = x[:, :, :, 0:split_channel_num] # 取前一半通道
channel_two = x[:, :, :, split_channel_num:] # 取后一半通道
return channel_one, channel_two
def Channel_Shuffle(x):
"""
通道洗牌
:param x: 输入数组[b, w, h, filters]
:return: 输出数组[b, w, h, filters_new]
"""
height, width, channels = x.shape.as_list()[1:] # 取通道数
channels_per_split = channels // 2 # 取一半
x = K.reshape(x, [-1, height, width, 2, channels_per_split]) # 将n维打乱为[2, n/2]维
x = K.permute_dimensions(x, (0, 1, 2, 4, 3)) # 维度重排序
x = K.reshape(x, [-1, height, width, channels]) # 通道重组 [2, n/2]--->n
return x
def branch(x, filters, dilation_rate, unit_num, right=False):
"""
卷积分支
:param x: 输入数组[b, w, h, filter]
:param filters: 卷积核数量
:param dilation_rate: 扩张卷积率
:param unit_num: int
:param right: 分支卷积方式, False为使用右侧卷积方式
:return: 输出数组[b, w, h, filters]
"""
if right:
x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1)(x) # 卷积
x = layers.Activation('relu')(x) # 激活
x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1)(x)
x = layers.BatchNormalization()(x) # BN归一化
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1, dilation_rate=dilation_rate,
name='dilation_conv2d_right_{0}_1_{1}'.format(unit_num, dilation_rate))(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1, dilation_rate=dilation_rate,
name='dilation_conv2d_right_{0}_2_{1}'.format(unit_num, dilation_rate*2))(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
else:
x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1, dilation_rate=dilation_rate,
name='dilation_conv2d_left_{0}_1_{1}'.format(unit_num, dilation_rate))(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1, dilation_rate=dilation_rate,
name='dilation_conv2d_left_{0}_2_{1}'.format(unit_num, dilation_rate*2))(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
return x
def split_shuffle_module(x, filters, unit_num, dilation_rate_value):
"""
卷积通道分离、打乱模块
:param x: 输入数组[b, w, h, filter]
:param filters: 卷积核数量
:param unit_num: 循环单元数
:param dilation_rate_value: 扩张卷积率
:return: 输出数组[b, w, h, filters]
"""
for i in range(unit_num):
# if len(dilation_rate_value) is not 1:
dilation_rate = dilation_rate_value[i]
# else:
# dilation_rate = dilation_rate_value
add = x
x_one, x_two = Channel_Split(x) # 函数调用
x_one = branch(x_one, filters=filters//2, dilation_rate=dilation_rate, unit_num=i, right=True) # 函数调用
x_two = branch(x_two, filters=filters//2, dilation_rate=dilation_rate, unit_num=i, right=False)
x = layers.Concatenate()([x_one, x_two]) # 通道叠加
x = layers.Add()([add, x]) # 元素相加
x = layers.Activation('relu')(x)
x = Channel_Shuffle(x) # 函数调用
return x
def normal_conv2d(x, kernel_size, filters, strides, activation='relu', Separable=False, coord_conv2d=False):
"""
卷积模块
:param x: 输入数组[b, w, h, filter]
:param kernel_size: 卷积尺寸
:param filters: 卷积核数量
:param strides: 卷积步长
:param Separable: 是否使用可分离卷积
:param coord_conv2d: 是否使用坐标卷积
:param activation: 激活函数
:return: 输出数组[b, w, h, filters]
"""
if Separable:
x = layers.SeparableConv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)
else:
if coord_conv2d:
x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)
x = Coord_Conv2d(x)
x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=1, padding='same')(x)
else:
x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(activation)(x)
return x
def upsample(x, kernel_size, filters, strides, coord_conv2d=False):
"""
上采样模块
:param x: 输入数组[b, w, h, filter]
:param kernel_size: 卷积尺寸
:param filters: 卷积核数量
:param strides: 卷积步长
:param coord_conv2d: 是否使用坐标卷积
:return: 输出数组[b, w, h, filters]
"""
if coord_conv2d:
x = Coord_Conv2d(x)
x = layers.Conv2DTranspose(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)
else:
x = layers.Conv2DTranspose(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
return x
def CAttention(x, channel):
x_origin = x
x = layers.GlobalAveragePooling2D()(x)
x = K.expand_dims(x, axis=1)
x = K.expand_dims(x, axis=1)
x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel//2)
x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel, activation='sigmoid')
x = layers.UpSampling2D(size=(x_origin.shape[1], x_origin.shape[2]), interpolation='nearest')(x)
x = tf.multiply(x, x_origin)
return x
def SAttention(x):
x_origin = x
x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=1, activation='sigmoid')
x = tf.multiply(x, x_origin)
return x
def SCAttention(x, channel):
x1 = CAttention(x, channel)
x2 = SAttention(x)
x = layers.Add()([x1, x2])
return x
def Encoder(x):
"""
编码器
:param x: 输入数组[b, w, h, filter]
:return: 输出数组[b, w, h, filter_new]
"""
FF_layers = []
x = normal_conv2d(x, 3, 32, 2, coord_conv2d=False) # 函数调用
x = split_shuffle_module(x, 32, 1, [2]) # 函数调用
x = SCAttention(x, 32)
FF_layers.append(x)
x = normal_conv2d(x, 3, 64, 2, coord_conv2d=False)
x = split_shuffle_module(x, 64, 1, [5])
x = SCAttention(x, 64)
FF_layers.append(x)
x = normal_conv2d(x, 3, 128, 2, coord_conv2d=False)
x = split_shuffle_module(x, 128, 1, [8])
x = SCAttention(x, 128)
FF_layers.append(x)
x = normal_conv2d(x, 3, 256, 2, Separable=True)
return x, FF_layers
def Decoder(x, num_classes, FF_layers):
x = layers.UpSampling2D(size=2, interpolation='bilinear')(x)
x = layers.Concatenate()([x, FF_layers[2]])
x = upsample(x, 3, num_classes, 2, coord_conv2d=False)
x = layers.Concatenate()([x, FF_layers[1]])
x = upsample(x, 3, num_classes, 2, coord_conv2d=False)
x = layers.Concatenate()([x, FF_layers[0]])
return x
def SCAttention_Net(input_size, num_classes):
inputs = keras.Input(shape=input_size + (3,))
x, FF_layers = Encoder(inputs)
x = Decoder(x, num_classes, FF_layers)
x = upsample(x=x, strides=2, kernel_size=3, filters=num_classes)
outputs = layers.Conv2D(filters=num_classes, kernel_size=3, padding='same', activation='softmax')(x)
models = keras.Model(inputs, outputs)
return models
if __name__ == '__main__':
model = SCAttention_Net((512, 512), 5)
model.summary()
#keras.utils.plot_model(model, dpi=96, to_file='./SCAttention_Net.png', show_shapes=True)
import 的coord_conv2d.py
import tensorflow.keras.backend as K
def Coord_Conv2d(inputs, radius=False):
input_shape = K.shape(inputs)
input_shape = [input_shape[i] for i in range(4)]
batch_shape, dim1, dim2, channels = input_shape
xx_ones = K.ones(K.stack([batch_shape, dim2]), dtype='int32') # 创建[batch_size, dim2]大小的空数组
xx_ones = K.expand_dims(xx_ones, axis=-1) # 扩维至[batch_size, dim2, 1],例:[4,128,1]
xx_range = K.tile(K.expand_dims(K.arange(0, dim1), axis=0),
K.stack([batch_shape, 1])) # K.tile 复制数组,K.tile(shape([128,1],shape([4,1)) = shape[4,128]
xx_range = K.expand_dims(xx_range, axis=1) # 从[4,128]扩维至[4,1,128]
xx_channels = K.batch_dot(xx_ones, xx_range, axes=[2, 1]) # 矩阵乘[4,128,1]*[4,1,128]=[4,128,128]
xx_channels = K.expand_dims(xx_channels, axis=-1) # [4,128,128,1]
xx_channels = K.permute_dimensions(xx_channels, [0, 2, 1, 3]) # 交换维度 [4,128,128,1]
yy_ones = K.ones(K.stack([batch_shape, dim1]), dtype='int32')
yy_ones = K.expand_dims(yy_ones, axis=1)
yy_range = K.tile(K.expand_dims(K.arange(0, dim2), axis=0),
K.stack([batch_shape, 1]))
yy_range = K.expand_dims(yy_range, axis=-1)
yy_channels = K.batch_dot(yy_range, yy_ones, axes=[2, 1])
yy_channels = K.expand_dims(yy_channels, axis=-1)
yy_channels = K.permute_dimensions(yy_channels, [0, 2, 1, 3])
xx_channels = K.cast(xx_channels, K.floatx()) # int--->float
xx_channels = xx_channels / K.cast(dim1 - 1, K.floatx())
xx_channels = (xx_channels * 2) - 1.
yy_channels = K.cast(yy_channels, K.floatx())
yy_channels = yy_channels / K.cast(dim2 - 1, K.floatx())
yy_channels = (yy_channels * 2) - 1.
outputs = K.concatenate([inputs, xx_channels, yy_channels], axis=-1)
if radius:
radius_layer = K.sqrt(K.square(xx_channels-0.5) + K.square(yy_channels-0.5))
outputs = K.concatenate([outputs, radius_layer], axis=-1)
return outputs
if __name__ == '__main__':
x = K.ones([4, 32, 32, 3])
x = Coord_Conv2d(x)
print(x.shape)
1067





