GCNet 论文详解
GCNet:Non-local Networks Meet Squeeze-Excitation Networks and Beyond,cvpr,2019.
链接:paper
代码:github
摘要
背景:Non-local Network 提供了一种捕获远程依赖关系的开创性方法,通过将特定于查询的全局上下文聚合得到每个查询位置。但是,本文发现,对于图中不同的查询位置,non-local network 网络建模的上下文信息几乎是相同的。
贡献:基于上述问题,文中建立了一个简化的基于查询无关的网络(snl),保持了NLNet的精度,大大减少了计算量。设计了一个轻量级GC block模块,和snet 都属于一种“三步骤”通用框架:1、全局上下文建模;2、channel-wise 依赖关系转换;3、特征融合。可以有效的建模全局上下文。可以将其应用于主干网络的多层中。
引言
无论是non-local还是本文所提出的gcnet都属于"long-range dependency “。它们的目的都是尽可能从"global"的角度来理解图像。使用传统的卷积效率低,且很难达到这一点,因为一个卷积操作是对一个“local neighborhood "的pixel 关系进行建模。虽然可以通过不断的加深网络,来达到网络后半部分的神经元的感受野变大,达到一种不太”local“的效果,但这种粗暴的加深网络的做法:有三个缺点:(1)不够精巧,参数量计、算量大;(2)网络越深,优化难度越大;(3)不太”local"并不代表全局,有最大距离的限制。
对上面第三点的理解可以为:在网络的后半部分,feature map右下角的cell,感受野可能扩大了,但可能还是始终无法获得原图片上左上角的信息(因为卷积操作最大距离的限制)。与之对应的,一个理想的神经网络应该是:不要那么深,让用户优化的太困难;网络涉及的精巧一点,最好较少的参数量达到原始网络结构的性能;真正的做到全局,没有最大距离的限制。
为了解决这个问题,我们提出了non-local network,通过自我注意机制来建模远程依赖关系。对于每一个query position,计算该query position与图中其他点之间的关系来提取该点的特征。
然而,作者对non-local 的attention map进行可视化后发现不同的query point 对应的attention map 几乎是一样的。如下图所示:
基于此,文中提出了简化版的NL(SNL)。该简化模块的计算量明显少于NL。其准确率几乎没有下降。此外,我们发现这个简化块与SENet 网络具有相似的结构。它们都是通过从各个位置聚合的相同特征来强化原有特征,但在聚合策略、转化和强化功能上的选择又相互区别。通过对这些函数的抽象,我们得到了统一简化的NL块和SE块的三步通用框架:(a)上下文建模模块,它将所有位置的特征聚合在一起,形成全局上下文特征;(b)特征转换模块,以捕获渠道间的相互依赖;和©融合模块,将全局上下文特征合并为所有位置特征。简化的NL块和SE块是这个通用框架的两个实例。文中还提出了GC 模块,在上下文建模(使用全局注意力池)和融合(使用添加)步骤上,GC 模块与简化的NL块共享相同的实现,而与SE块共享相同的转换步骤(使用两层瓶颈)。GC块在多个视觉识别任务中的表现优于简化的NL块和SE块。
Method
1、Simplifying the Non-local Block
既然attention map 与具体的点的位置无关,那就不要每一个点单独算一个attention map.用另外一种能过获取全局信息,但一个feature map 上的点可以共用attention 。基于此,我们简化了NL模块,提出了SNL模块,结构如下图(b):

Non-local 定义公式:
SNL定义公式:
一个直观的感觉是block结构确实简化了。在实际运算的变化中,可以看公式,公式中关键部分在于+号后面的变化。Non-local+号后面的内容与xi有关,因此每一个点i,需要计算一次。而snl的表达公式+号后面的内容已与xi无关,也即计算一次,然后大家公用,也映证了上文所讲的query-independent。
2、global context block
GCNet在上下文信息建模这个地方使用了Simplified NL block中的机制,可以充分利用全局上下文信息,同时在Transform阶段借鉴了SE block。
实验
Ablation
- a. 看出Simplified NL与NL几乎一直,但是参数量要小一些。且每个阶段都使用GC block性能更好。
- b. 在residual block中GC添加在add操作之后效果最好。
- c. 添加在residual 的所有阶段效果最优,能比baseline高1-3%。
- d. 使用简化版的NLNet并使用1×1卷积作为transform的时候效果最好,但是其计算量太大。
- e. ratio=1/4的时候效果最好。
- f. attention pooling+add的的组合用作池化和特征融合设计方法效果最好。
pytorch 代码
import torch
import torch.nn as nn
import torchvision
classs GlobalContextBlock(nn.Module):
def __init__(slef,inplanes,ratio,pooling_type='att',fusion_types=('channel_add',)):
super(GlobalContextBlock,self).__init__()
assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ['channel_add', 'channel_mul']
assert all([f in valid_fusion_types for f in fusion_types])
assert len(fusion_types) > 0, 'at least one fusion should be used'
self.inplanes = inplanes
self.ratio = ratio
self.planes = int(inplanes * ratio)
self.pooling_type = pooling_type
self.fusion_types = fusion_types
if pooling_type == 'att':
slef.conv_mask = nn.Conv2d(inplanes,1,kernel_size = 1)
slef.softmax = nn.Softmax(dim=2)
else:
slef.avg_pool = self.avg_pool = nn.AdaptiveAvgPool2d(1)
if 'channel_add' in fusion_types:
self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
else:
self.channel_add_conv = None
if 'channel_mul' in fusion_types:
self.channel_mul_conv = nn.Sequential(
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
else:
self.channel_mul_conv = None
def spatial_pool(self, x):
batch, channel, height, width = x.size()
if self.pooling_type == 'att':
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
context_mask = self.conv_mask(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = self.softmax(context_mask)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.channel_mul_conv is not None:
# [N, C, 1, 1]
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
if self.channel_add_conv is not None:
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
return out