前言
detector 在maskrcnn benchmark中是网络模型的入口类,它把各个网络结构组合为统一的模型。用户训练中用到的网络模型由detectors.py实例化:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from .generalized_rcnn import GeneralizedRCNN
# 目前maskrcnn benchmark只定义了一种常用的rcnn
_DETECTION_META_ARCHITECTURES = {"GeneralizedRCNN": GeneralizedRCNN}
def build_detection_model(cfg):
meta_arch = _DETECTION_META_ARCHITECTURES[cfg.MODEL.META_ARCHITECTURE]
return meta_arch(cfg)
代码中涉及到的GeneralizedRCNN为实际构造整个模型的类,其功能是使得整个网络能够组合在一起,并且调的通。其代码解析如下:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Implements the Generalized R-CNN framework
"""
import torch
from torch import nn
from maskrcnn_benchmark.structures.image_list import to_image_list
from ..backbone import build_backbone
from ..rpn.rpn import build_rpn
from ..roi_heads.roi_heads import build_roi_heads