torchvision中提供了通过Resnet + FPN的方式构建基础网络,这里以resnet50为例实现基础网络的构建。
目录
BackboneWithFPN
类BackboneWithFPN的在backbon_utils.py中,功能是给骨干网络加上FPN,定义如下:
class BackboneWithFPN(nn.Sequential):
def __init__(self, backbone, return_layers, in_channels_list, out_channels):
body = IntermediateLayerGetter(backbone, return_layers=return_layers)
fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=LastLevelMaxPool(),
)
super(BackboneWithFPN, self).__init__(OrderedDict(
[("body", body), ("fpn", fpn)]))
self.out_channels = out_channels
使用时的几个参数:
- backbone : 骨干网络
- return_layers : 需要做fpn的layer的dict
- in_channels_list : return_layer对应的channels
一个简单的使用例子:
import torch
import torchvision.models as models
import torchvision.models.detection.backbone_utils as backbone_utils
backbone = models.resnet50()
return_layers = {'layer1': 0, 'layer2': 1, 'layer3': 2, 'layer4': 3}
in_channels_list = [256,512,1024,2048]
out_channels = 256
resnet_with_fpn = backbone_utils.BackboneWithFPN(backbone,
return_layers,in_channels_list,out_channels)
resnet_with_fpn返回的是一个orderDict,包含了5个fpn后的feature,channel分别为160,80,40,20,10:
input = torch.Tensor(4,3,640,640)
output = resnet_with_fpn(input)
# output是一个OrderedDict
print(type(out