基于pytorch从模型中取出指定特征层的输出

重构网络模型

问题:若网络模型为ResNet-50,想取出layer4和layer3作为网络的最终输出。在语义分割中是常需要用到的操作。将layer4作为主输出,layer3用于辅助输出。分类网络如GoogLeNet。
解决:首先需要重构原本ResNet-50,将layer3、layer4与其输出值定义为一个字典:
{‘layer4’: ‘main_out’, ‘layer3’: ‘aux_out’}, 使用IntermediateLayerGetter将网络最终的输出调整为一个有序字典的形式,layer3、layer4的输出分别对应’aux_out和’‘main_out’。
return_layer中以字典的方式传入指定层的名称以及输出后的名称。如:{‘layer4’: ‘main_out’, ‘layer3’: ‘aux_out’}。最终网络输出aux_out和main_out,即layer3、layer4的特征值。

class IntermediateLayerGetter(nn.ModuleDict):
    def __init__(self, model: nn.Module, return_layer: Dict[str, str]):
        # 首先判断 return_layer中的key 是否在model中
        if not set(return_layer).issubset([name for name, _ in model.named_children()]):
            raise ValueError('return_layers are not present in model')
        orig_return_layers = return_layer
        return_layer = {str(k): str(v) for k, v in return_layer.items()}

        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layer:
                del return_layer[name]
            if not return_layer:
                break
        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layer = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layer:
                out_name = self.return_layer[name]
                out[out_name] = x
        return out

整体程序

import torch
from ResNet_dilation import ResNet50
import torch.nn as nn
from typing import Dict
from collections import OrderedDict
import torch.nn.functional as F


class IntermediateLayerGetter(nn.ModuleDict):
    def __init__(self, model: nn.Module, return_layer: Dict[str, str]):
        # 首先判断 return_layer中的key 是否在model中
        if not set(return_layer).issubset([name for name, _ in model.named_children()]):
            raise ValueError('return_layers are not present in model')
        orig_return_layers = return_layer
        return_layer = {str(k): str(v) for k, v in return_layer.items()}

        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layer:
                del return_layer[name]
            if not return_layer:
                break
        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layer = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layer:
                out_name = self.return_layer[name]
                out[out_name] = x
        return out


class FCNHead(nn.Sequential):
    def __init__(self, in_channels, channels):
        # 降维 1024 -> 256
        inter_channels = in_channels // 4

        layers = [
            nn.Conv2d(in_channels, inter_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, channels, kernel_size=1)
        ]

        super(FCNHead, self).__init__(*layers)


class FCN(nn.Module):
    def __init__(self, backbone, main_classifier, aux_classifier):
        super(FCN, self).__init__()

        self.backbone = backbone
        self.main_classifier = main_classifier
        self.aux_classifier = aux_classifier

    def forward(self, x):
        input_shape = x.shape[2:]
        features = self.backbone(x)

        result = OrderedDict()
        x = features['main_out']
        x = self.main_classifier(x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result['main_out'] = x

        if self.aux_classifier is not None:
            x = features['aux_out']
            x = self.aux_classifier(x)
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result['aux_out'] = x

        return result


def fcn_resnet50(aux, num_classes=21):

    backbone = ResNet50(num_classes=num_classes, dilation_replace_stride=[False, True, True])

    out_layer = 'layer4'
    out_inplanes = 2048
    aux_layer3 = 'layer3'
    aux_inplanes = 1024

    return_layers = {out_layer: 'main_out'}
    if aux:
        return_layers[aux_layer3] = 'aux_out'
        # 返回main_out and aux_out的特征图 OrderDict
    backbone = IntermediateLayerGetter(backbone, return_layer=return_layers) 
    
    aux_classifier = None
    if aux:
        aux_classifier = FCNHead(aux_inplanes, num_classes)

    main_classifier = FCNHead(out_inplanes, num_classes)

    model = FCN(backbone, main_classifier, aux_classifier)
    # OrderDict main_out and aux_out
    return model

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

卡子爹

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值