ResNet
from ./fcos_core/modeling/backbone/resnet.py
"""
下述为ResNet及大量变体(采用cfg参数)
Example usage. Strings may be specified in the config file.
model = ResNet(
"StemWithFixedBatchNorm",
"BottleneckWithFixedBatchNorm",
"ResNet50StagesTo4",
)
OR:
model = ResNet(
"StemWithGN",
"BottleneckWithGN",
"ResNet50StagesTo4",
)
Custom implementations may be written in user code and hooked in via the
`register_*` functions.
"""
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch import nn
from fcos_core.layers import FrozenBatchNorm2d
from fcos_core.layers import Conv2d
from fcos_core.layers import DFConv2d
from fcos_core.modeling.make_layers import group_norm
from fcos_core.utils.registry import Registry
StageSpec = namedtuple(
"StageSpec",
[
"index",
"block_count",
"return_features",
],
)
ResNet50StagesTo5 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, False), (4, 3, True))
)
ResNet50StagesTo4 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True))
)
ResNet101StagesTo5 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, False), (4, 3, True))
)
ResNet101StagesTo4 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, True))
)
ResNet50FPNStagesTo5 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 6, True), (4, 3, True))
)
ResNet101FPNStagesTo5 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 23, True), (4, 3, True))
)
ResNet152FPNStagesTo5 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, True), (2, 8, True), (3, 36, True), (4, 3, True))
)
class ResNet(nn.Module):
def __init__(self, cfg):
super(ResNet, self).__init__()
stem_module = _STEM_MODULES[cfg.MODEL.RESNETS.STEM_FUNC]
stage_specs = _STAGE_SPECS[cfg.MODEL.BACKBONE.CONV_BODY]
transformation_module = _TRANSFORMATION_MODULES[cfg.MODEL.RESNETS.TRANS_FUNC]
self.stem = stem_module(cfg)
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
stage2_bottleneck_channels = num_groups * width_per_group
stage2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
self.stages = []
self.return_features = {}
for stage_spec in stage_specs:
name = "layer" + str(stage_spec.index)
stage2_relative_factor = 2 ** (stage_spec.index - 1)
bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor
out_channels = stage2_out_channels * stage2_relative_factor
stage_with_dcn = cfg.MODEL.RESNETS.STAGE_WITH_DCN[stage_spec.index - 1]
module = _make_stage(
transformation_module,
in_channels,
bottleneck_channels,
out_channels,
stage_spec.block_count,
num_groups,
cfg.MODEL.RESNETS.STRIDE_IN_1X1,
first_stride=int(stage_spec.index > 1) + 1,
dcn_config={
"stage_with_dcn": stage_with_dcn,
"with_modulated_dcn": cfg.MODEL.RESNETS.WITH_MODULATED_DCN,
"deformable_groups": cfg.MODEL.RESNETS.DEFORMABLE_GROUPS,
}
)
in_channels = out_channels
self.add_module(name, module)
self.stages.append(name)
self.return_features[name] = stage_spec.return_features
self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT)
def _freeze_backbone(self, freeze_at):
if freeze_at < 0:
return
for stage_index in range(freeze_at):
if stage_index == 0:
m = self.stem
else:
m = getattr(self, "layer" + str(stage_index))
for p in m.parameters():
p.requires_grad = False
def forward(self, x):
outputs = []
x = self.stem(x)
for stage_name in self.stages:
x = getattr(self, stage_name)(x)
if self.return_features[stage_name]:
outputs.append(x)
return outputs
class ResNetHead(nn.Module):
def __init__(
self,
block_module,
stages,
num_groups=1,
width_per_group=64,
stride_in_1x1=True,
stride_init=None,
res2_out_channels=256,
dilation=1,
dcn_config=None
):
super(ResNetHead, self).__init__()
stage2_relative_factor = 2 ** (stages[0].index - 1)
stage2_bottleneck_channels = num_groups * width_per_group
out_channels = res2_out_channels * stage2_relative_factor
in_channels = out_channels // 2
bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor
block_module = _TRANSFORMATION_MODULES[block_module]
self.stages = []
stride = stride_init
for stage in stages:
name = "layer" + str(stage.index)
if not stride:
stride = int(stage.index > 1) + 1
module = _make_stage(
block_module,
in_channels,
bottleneck_channels,
out_channels,
stage.block_count,
num_groups,
stride_in_1x1,
first_stride=stride,
dilation=dilation,
dcn_config=dcn_config
)
stride = None
self.add_module(name, module)
self.stages.append(name)
self.out_channels = out_channels
def forward(self, x):
for stage in self.stages:
x = getattr(self, stage)(x)
return x
def _make_stage(
transformation_module,
in_channels,
bottleneck_channels,
out_channels,
block_count,
num_groups,
stride_in_1x1,
first_stride,
dilation=1,
dcn_config=None
):
blocks = []
stride = first_stride
for _ in range(block_count):
blocks.append(
transformation_module(
in_channels,
bottleneck_channels,
out_channels,
num_groups,
stride_in_1x1,
stride,
dilation=dilation,
dcn_config=dcn_config
)
)
stride = 1
in_channels = out_channels
return nn.Sequential(*blocks)
class Bottleneck(nn.Module):
def __init__(
self,
in_channels,
bottleneck_channels,
out_channels,
num_groups,
stride_in_1x1,
stride,
dilation,
norm_func,
dcn_config
):
super(Bottleneck, self).__init__()
self.downsample = None
if in_channels != out_channels:
down_stride = stride if dilation == 1 else 1
self.downsample = nn.Sequential(
Conv2d(
in_channels, out_channels,
kernel_size=1, stride=down_stride, bias=False
),
norm_func(out_channels),
)
for modules in [self.downsample,]:
for l in modules.modules():
if isinstance(l, Conv2d):
nn.init.kaiming_uniform_(l.weight, a=1)
if dilation > 1:
stride = 1
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
)
self.bn1 = norm_func(bottleneck_channels)
with_dcn = dcn_config.get("stage_with_dcn", False)
if with_dcn:
deformable_groups = dcn_config.get("deformable_groups", 1)
with_modulated_dcn = dcn_config.get("with_modulated_dcn", False)
self.conv2 = DFConv2d(
bottleneck_channels,
bottleneck_channels,
with_modulated_dcn=with_modulated_dcn,
kernel_size=3,
stride=stride_3x3,
groups=num_groups,
dilation=dilation,
deformable_groups=deformable_groups,
bias=False
)
else:
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=dilation,
bias=False,
groups=num_groups,
dilation=dilation
)
nn.init.kaiming_uniform_(self.conv2.weight, a=1)
self.bn2 = norm_func(bottleneck_channels)
self.conv3 = Conv2d(
bottleneck_channels, out_channels, kernel_size=1, bias=False
)
self.bn3 = norm_func(out_channels)
for l in [self.conv1, self.conv3,]:
nn.init.kaiming_uniform_(l.weight, a=1)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = F.relu_(out)
out = self.conv2(out)
out = self.bn2(out)
out = F.relu_(out)
out0 = self.conv3(out)
out = self.bn3(out0)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = F.relu_(out)
return out
class BaseStem(nn.Module):
def __init__(self, cfg, norm_func):
super(BaseStem, self).__init__()
out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
self.conv1 = Conv2d(
3, out_channels, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = norm_func(out_channels)
for l in [self.conv1,]:
nn.init.kaiming_uniform_(l.weight, a=1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu_(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
return x
class BottleneckWithFixedBatchNorm(Bottleneck):
def __init__(
self,
in_channels,
bottleneck_channels,
out_channels,
num_groups=1,
stride_in_1x1=True,
stride=1,
dilation=1,
dcn_config=None
):
super(BottleneckWithFixedBatchNorm, self).__init__(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
num_groups=num_groups,
stride_in_1x1=stride_in_1x1,
stride=stride,
dilation=dilation,
norm_func=FrozenBatchNorm2d,
dcn_config=dcn_config
)
class StemWithFixedBatchNorm(BaseStem):
def __init__(self, cfg):
super(StemWithFixedBatchNorm, self).__init__(
cfg, norm_func=FrozenBatchNorm2d
)
class BottleneckWithGN(Bottleneck):
def __init__(
self,
in_channels,
bottleneck_channels,
out_channels,
num_groups=1,
stride_in_1x1=True,
stride=1,
dilation=1,
dcn_config=None
):
super(BottleneckWithGN, self).__init__(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
num_groups=num_groups,
stride_in_1x1=stride_in_1x1,
stride=stride,
dilation=dilation,
norm_func=group_norm,
dcn_config=dcn_config
)
class StemWithGN(BaseStem):
def __init__(self, cfg):
super(StemWithGN, self).__init__(cfg, norm_func=group_norm)
_TRANSFORMATION_MODULES = Registry({
"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm,
"BottleneckWithGN": BottleneckWithGN,
})
_STEM_MODULES = Registry({
"StemWithFixedBatchNorm": StemWithFixedBatchNorm,
"StemWithGN": StemWithGN,
})
_STAGE_SPECS = Registry({
"R-50-C4": ResNet50StagesTo4,
"R-50-C5": ResNet50StagesTo5,
"R-101-C4": ResNet101StagesTo4,
"R-101-C5": ResNet101StagesTo5,
"R-50-FPN": ResNet50FPNStagesTo5,
"R-50-FPN-RETINANET": ResNet50FPNStagesTo5,
"R-101-FPN": ResNet101FPNStagesTo5,
"R-101-FPN-RETINANET": ResNet101FPNStagesTo5,
"R-152-FPN": ResNet152FPNStagesTo5,
})