在上一篇博客中介绍了Faster RCNN算法的整体结构:Faster RCNN代码详解(一):算法整体结构,在该结构中最主要的两部分是网络结构的构建和数据的读取,因此这篇博客就来介绍下Faster RCNN算法的网络结构构建细节。因为其中特征提取网络选择多样,所以这里以常用的ResNet为例来介绍。
这部分内容非常重要,因为这里我将详细介绍RPN网络的结构、RPN网络的损失函数、proposal的生成、Roi Pooling层、检测网络的结构、检测网络的损失函数等细节,也就是论文中的Figure2。从最近两年follow的论文来看,大部分都是在网络结构上做改进,比如R-FCN是在Roi Pooling层做了调整,FPN是基于融合后的多层特征做检测等。因此了解清楚Faster RCNN算法的网络结构对后续理解基于Faster RCNN算法的改进非常有帮助,而了解网络结构的最好办法就是阅读代码。
网络结构的构造通过get_resnet_train函数进行,该函数所在脚本:~mx-rcnn/rcnn/symbol/symbol_resnet.py。接下来就来看看该函数的代码细节:
def get_resnet_train(num_classes=config.NUM_CLASSES,num_anchors=config.NUM_ANCHORS):
# 这里设置了几个变量,其中参数name的命名和数据读取类AnchorLoader的provide_data和
# provide_label属性是对应相同的,如果两处命名冲突则会报错。
data = mx.symbol.Variable(name="data")
im_info = mx.symbol.Variable(name="im_info")
gt_boxes = mx.symbol.Variable(name="gt_boxes")
rpn_label = mx.symbol.Variable(name='label')
rpn_bbox_target = mx.symbol.Variable(name='bbox_target')
rpn_bbox_weight = mx.symbol.Variable(name='bbox_weight')
# shared convolutional layers
# get_resnet_conv方法返回的是ResNet网络从conv1到conv4_x的部分,该部分就是用来做特征提取的,
# 之所以这些层叫shared convolutional layers,是因为输出conv_feat不仅仅作为接下来RPN网络的输入,
# 还作为后续的ROI Pooling的输入。另外关于conv_feat的维度大小,假设输入图像大小为600*800,
# 则conv_feat的feature map大小是38*50。
# 接下来关于输出feature map的维度分析时都基于输入大小是600*800的假设,这样理解起来比较清晰。
conv_feat = get_resnet_conv(data)
# RPN layers
# 接下来这一部分是RPN网络的层,输入就是ResNet的conv4_x层的输出,也就是conv_feat。
# 首先说明一下anchor和region proposal(后面就用proposal代替)的区别:
# anchor是固定不变的(可以参考~/mx-rcnn/rcnn/io/rpn.py的assign_anchor函数,
# 系列四博客也会介绍),只是一个初始值或者参考值,其标签通过和ground truth计算
# IOU后就可以得到,也是固定的,而proposal是模型预测的框(RPN网络的目的
# 就是输出proposal,这个输出的proposal要尽可能和ground truth接近),
# 当然anchor和proposal并不是相互独立的,其实每个proposal都和一个anchor对应,
# 所以anchor和proposal的数量是一样的。RPN的网络结构非常简单,
# 先是一个3*3的卷积层,这个3*3的卷积层就对应论文中的sliding window操作。
# 然后基于该卷积层的输出引出两条支路(这两条支路都用1*1卷积层实现),
# 一条支路用来预测proposal的标签(卷积核数量是2*anchor数量,表示每个anchor属于背景类和目标类的概率);
# 另一条支路用来预测proposal的坐标偏置(卷积核数量是4*anchor数量,表示每个anchor的四个坐标信息的偏置)。
rpn_conv = mx.symbol.Convolution(
data=conv_feat, kernel=(3, 3), pad=(1, 1), num_filter=512, name="rpn_conv_3x3")
rpn_relu = mx.symbol.Activation(data=rpn_conv, act_type="relu", name="rpn_relu")
# num_anchors默认是9