RPN是two-stage的标志性结构,并且其本身也是一个二分类的目标检测网络,因此在faster-rcnn的整个网络结构中能看到anchor的使用,回归和分类等操作,这里讲具体介绍一下。
整个rpn部分代码在torchvison/models/detection/rpn.py中,其中定义了RPNHead,AnchorGenerator,RegionProposalNetwork三个模块。
目录
AnchorGenerator
AnchorGenerator的定义:
Module that generates anchors for a set of feature maps and image sizes.
顾名思义,AnchorGenerator的主要作用就是生成与feature相对应的anchors。
输入参数:
- sizes : 用于每层feature的anchor基础尺寸
- aspect_ratios : 宽高比例
根据sizes的个数和aspect_ratios的个数,将在feature map的每个位置上生成固定数量的anchor。
并且注意AnchorGenerator继承自nn.Module,也是有forward()函数的,并且注意forward的输入包含了一个ImageList类型。ImageList类型的定义在torchvison/models/detection/image_list.py中。
写个简单的例子测试一下:
import torchvision.models.detection.rpn as rpn
import torchvision.models.detection.image_list as image_list
import torch
# 创建AnchorGenerator实例
anchor_generator = rpn.AnchorGenerator()
# 构建ImageList
batched_images = torch.Tensor(8,3,640,640)
image_sizes = [(640,640)] * 8
image_list_ = image_list.ImageList(batched_images,image_sizes)
# 构建feature_maps
feature_maps = [torch.Tensor([8,256,80,80]),torch.Tensor(8,256,160,160), torch.Tensor(8,256,320,320)]
# 生成anchors
anchors = anchor_generator(image_list_,feature_maps)
验证一下生成的anchors:
>>> print(type(anchors))
<class 'list'>
>>> for anchor in anchors:
... print(anchor.shape)
...
torch.Size([403200, 4])
torch.Size([403200, 4])
torch.Size([403200, 4])
torch.Size([403200, 4])
torch.Size([403200, 4])
torch.Size([403200, 4])
torch.Size([403200, 4])
torch.Size([403200, 4])
从结果上我们可以看到:
- 返回的结果是一个Tensor的list,list中的元素个数和batch_size相同
- anchors中的每个Tensor大小相同,均为torch.Size([403200, 4])
这里来分析下为什么anchors中输出的tensor大小都是403200×4:
首先AnchorGenerator()的默认参数aspect_scale=(0.5, 1.0, 2.0),sizes=(128, 256, 512),因为输入的aspect_scale大小为3,因此会在feature的每个位置上生成3个anchor,共生成80×80×3+160×160×3+320×320×3=403200个anchors。
注意sizes中的每个值是用于每层feature的anchor的基数大小,比如在例子中80×80的feature map每个grid设置的anchor的大小为128,并根据aspect_scale的值生成(128/sqrt(2), 128*sqrt(2)),(128*128),(128*sqrt(2),128/sqrt(2))三种尺寸的anchor
RegionProposalNetwork
RegionProposalNetwork是整个rpn的主体,其中集成了AnchorGenerator和RPNHead,功能包含生成anchors,anchor与groundtruth的匹配,nms,回归与分类损失的计算等等。
参数定义如下:
anchor_generator :传入AnchorGenerator
head:通过feature生成regression deltas和objectness的模块
fg_iou_thresh:iou大于该阈值被认为是前景
bg_iou_thresh:iou小于该阈值被认为是背景