self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation(‘relu’)
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding=‘same’)
self.bn2 = layers.BatchNormalization()
se-block
self.se_globalpool = keras.layers.GlobalAveragePooling2D()
self.se_resize = keras.layers.Reshape((1, 1, filter_num))
self.se_fc1 = keras.layers.Dense(units=filter_num // 16, activation=‘relu’,
use_bias=False)
self.se_fc2 = keras.layers.Dense(units=filter_num, activation=‘sigmoid’,
use_bias=False)
if stride != 1:
self.downsample = Sequential()
self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))
else:
self.downsample = lambda x: x
def call(self, input, training=None):
out = self.conv1(input)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
se_block
b = out
out = self.se_globalpool(out)
out = self.se_resize(out)
out = self.se_fc1(out)
out = self.se_fc2(out)
out = keras.layers.Multiply()([b, out])
identity = self.downsample(input)
output = layers.add([out, identity])
output = tf.nn.relu(output)
return output
第二个残差模块
第二个残差模块用于实现ResNet50、ResNet101、ResNet152模型,SENet模块嵌入到第三个卷积后面。
第二个残差模块
class Block(layers.Layer):
def init(self, filters, downsample=False, stride=1):
super(Block, self).init()
self.downsample = downsample
self.conv1 = layers.Conv2D(filters, (1, 1), strides=stride, padding=‘same’)
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation(‘relu’)
self.conv2 = layers.Conv2D(filters, (3, 3), strides=1, padding=‘same’)
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2D(4 * filters, (1, 1), strides=1, padding=‘same’)
self.bn3 = layers.BatchNormalization()
se-block
self.se_globalpool = keras.layers.GlobalAveragePooling2D()
self.se_resize = keras.layers.Reshape((1, 1, 4 * filters))
self.se_fc1 = keras.layers.Dense(units=4 * filters // 16, activation=‘re