2021-Lite-HRNet: A Lightweight High-Resolution Network

本文探讨了如何构建一个轻量级的高分辨率网络 Lite-HRNet,该网络基于HRNet并引入ShuffleBlock以平衡性能和计算复杂度。作者分析了ShuffleBlock中1*1卷积的计算瓶颈,并提出了ConditionalChannelWeighting和Cross-ResolutionWeightComputation策略来替代1*1卷积,通过特征聚合和权重计算增强多尺度信息交互。实验表明,Lite-HRNet在人体姿态估计和语义分割任务中表现出色,同时降低了计算复杂度。

1. Title

Lite-HRNet: A Lightweight High-Resolution Network

2. Summary

本文是想制作一个高性能的轻量化HRNet网络,我个人实际使用中会发现,Small HRNet的性能一般会比同量级的UNet要差一些,个人理解是多尺度信息交互不够充分的原因,毕竟原来的交互方式很简单(带步长的卷积进行下采样,双线性插值进行上采样),因此简单对HRNet进行放缩是不能取得较好的trade-off的。
作者首先是在HRNet中引入Shuffle Block,得到了Naive Lite-HRNet,并且在性能和复杂度上取得了不错的tradeoff。通过进一步分析,作者认为Shuffle Block中的1*1 Conv成为了性能瓶颈,因此想解决这个问题。
在HRNet中多个branch独立使用1*1 Conv计算复杂度会比较高,因此,作者想到了首先把多个branch的特征聚合起来,增强后,然后再作为权重分发回原branch,聚合过程中通过Pooling的方法降低feature map的大小,以此来降低整体计算复杂度,分发过程中再重新上采样回原始分辨率。这样一来一方面可以降低计算复杂度,另一方面还能将独立的各个分支的信息聚合起来,引入多尺度交互,以弥补spatial信息的损失。
个人认为,采用类似的思路,在HRNet多尺度特征交互方面再做些文章是可以进一步提升精度。

3. Problem Statement

Human pose estimation一般比较依赖于高分辨率的特征表示以获得较好的性能,但是目前的网络计算量较大,不能称之为一个高效的网络结构,因此,本文想解决的问题就是如何在计算资源受到约束的情况下部署一个高效的高分辨率模型。
通过简单地将ShuffleNet中的Shuffle Block应用于HRNet,即可得到一个轻量级的HRNet,并且可以获得超越MobileNet、ShuffleNet以及Small HRNet的性能,但是Shuffle Blocks中大量使用的1*1 Conv成为了计算瓶颈,因此,如何能替换掉成本较高的1*1 Conv并且保持甚至取得超越其性能是本文要解决的核心问题。

4. Method(s)

4.1 Naive Lite-HRNet

(1)Shuffle Blocks

shuffle block
Shuffle Block会将通道首先分为两个部分,其中的一部分会送入一个1*1 Conv 3*3 DepthWise Conv和1*1 Conv中进行增强,处理完后会和另一部分拼接起来,最终会把通道重新shuffle。

(2)HRNet

Small HRNet Architecture
HRNet有两大优点:

  1. 通过全程保持高分辨率的特征,有利于位置信息的保留,对于位置敏感的任务例如语义分割、目标检测、人体姿态估计等都具有良好的作用。
  2. 另外通过充分地多尺度特征融合,HRNet有利于多尺度信息的挖掘,对于目标的尺度变化不敏感。

(3)Simple Combination

通过简单将Stem中的第2个3*3 Conv以及所有的Residual Block替换为Shuffle Block,并且将所有multi-resolution fusion中的Conv替换为Separable Conv,即可得到 Naive Lite-HRNet。
下面是官方代码中Stem的部分的实现,部分需要说明或者注意的地方,已经加上了中文注释:

class Stem(nn.Module):
    def __init__(self,
                 in_channels,
                 stem_channels,
                 out_channels,
                 expand_ratio,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 # 是否使用torch.utils.checkpoint用于降低显存使用,与模型实现没有关系,可以忽略
                 # 可参考博客:https://blog.youkuaiyun.com/ONE_SIX_MIX/article/details/93937091
                 with_cp=False):  
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.with_cp = with_cp
		
		# Stem中的第一个卷积不使用shuffle block
		# ConvModule是MMCV中的一个基本卷积模块:conv/norm/activation
        self.conv1 = ConvModule(
            in_channels=in_channels,
            out_channels=stem_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=dict(type='ReLU'))
		
        mid_channels = int(round(stem_channels * expand_ratio))
        branch_channels = stem_channels // 2
        if stem_channels == self.out_channels:
            inc_channels = self.out_channels - branch_channels
        else:
            inc_channels = self.out_channels - stem_channels
		
		# Shuffle Block中左侧不做增强的分支
        self.branch1 = nn.Sequential(
            ConvModule(
                branch_channels,
                branch_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                groups=branch_channels,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=None),
            ConvModule(
                branch_channels,
                inc_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU')),
        )
		
		# Shuffle Block中右侧增强分支
        self.expand_conv = ConvModule(
            branch_channels,
            mid_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'))
        self.depthwise_conv = ConvModule(
            mid_channels,
            mid_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            groups=mid_channels,  # groups=in_channels 深度可分离卷积
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None)
        self.linear_conv = ConvModule(
            mid_channels,
            branch_channels
            if stem_channels == self.out_channels else stem_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'))

    def forward(self, x):

        def _inner_forward(x):
            x = self.conv1(x)
            x1, x2 = x.chunk(2, dim=1)

            x2 = self.expand_conv(x2)
            x2 = self.depthwise_conv(x2)
            x2 = self.linear_conv(x2)

            out = torch.cat((self.branch1(x1), x2), dim=1)

            out = channel_shuffle(out, 2)  # shuffle channel

            return out

        if self.with_cp and x.
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值