论文地址:http://xueshu.baidu.com/usercenter/paper/show?paperid=a76088e42441c28b0fcfc5ae45e04a2d&site=xueshu_se
FCN和DenseNets两个网络结合,在分类网络DenseNets进行上采样,实现图像分割.
FCN网络
FCN-Densenet
程序:
with tf.variable_scope(scope, preset_model, [inputs]) as sc:
stack = slim.conv2d(inputs, n_filters_first_conv, [3, 3], scope='first_conv', activation_fn=None) ##第一层,输出是(?,?,?,48)
n_filters = n_filters_first_conv
#####################
# Downsampling path #
#####################
skip_connection_list = []
for i in range(n_pool):
# Dense Block
stack, _ = DenseBlock(stack, n_layers_per_block[i], growth_rate, dropout_p, scope='denseblock%d' % (i+1))
n_filters += growth_rate * n_layers_per_block[i]
# At the end of the dense block, the current stack is stored in the skip_connections list
skip_connection_list.append(stack)
# Transition Down
stack = TransitionDown(stack, n_filters, dropout_p, scope='transitiondown%d'%(i+1))
skip_connection_list = skip_connection_list[::-1]
#####################
# Bottleneck #
#####################
# Dense Block
# We will only upsample the new feature maps
stack, block_to_upsample = DenseBlock(stack, n_layers_per_block[n_pool], growth_rate, dropout_p, scope='denseblock%d' % (n_pool + 1))
#######################
# Upsampling path #
#######################
for i in range(n_pool):
# Transition Up ( Upsampling + concatenation with the skip connection)
n_filters_keep = growth_rate * n_layers_per_block[n_pool + i]
stack = TransitionUp(block_to_upsample, skip_connection_list[i], n_filters_keep,
scope='transitionup%d' % (n_pool + i + 1)) ###336,288,240,192,144
# Dense Block
# We will only upsample the new feature maps
stack, block_to_upsample = DenseBlock(stack, n_layers_per_block[n_pool + i + 1], growth_rate, dropout_p, scope='denseblock%d' % (n_pool + i + 2))
#####################
# Softmax #
#####################
net = slim.conv2d(stack, num_classes, [1, 1], activation_fn=None, scope='logits')
net_output = tf.placeholder(tf.float32,shape=[None,None,None,num_classes]) #标签是4维的(批次,图片高,宽,类别),就是每个像素点对应的类别
network, init_fn = model_builder.build_model(model_name=args.model, frontend=args.frontend, net_input=net_input,num_classes=num_classes,
crop_width=args.crop_width, crop_height=args.crop_height, is_training=True)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=network, labels=net_output))
opt = tf.train.RMSPropOptimizer(learning_rate=0.0001, decay=0.995).minimize(loss, var_list=[var for var in tf.trainable_variables()])
参考: https://github.com/GeorgeSeif/Semantic-Segmentation-Suite