NLNet论文总结和代码实现

Non-local Neural Networks(非局部神经网络):使用自注意力机制捕获远程依赖。

论文: https://arxiv.org/pdf/1711.07971.pdf

源码:https://gitcode.net/mirrors/facebookresearch/video-nonlocal-net

目录

一、背景和出发点

二、创新点

三、Non-local Neural Networks(非局部神经网络)

3.1 非局部操作的一般定义

3.2 实例

3.3 Non-local Block(NL模块)

四、实验

五、结论

六、代码实现


长距离依赖关系,顾名思义,是要和远程建立关系,在long-range的情况下有关联。在cv领域,就是考虑一个像素的时候同时考虑其邻域,甚至是邻域的邻域....,能够建模两个较远像素之间的关系,例如是不是同属一个类等等。

引用博文:MedT

一、背景和出发点

卷积运算是一次处理一个局部邻域的构建块,在捕获远程依赖关系时有许多的局限,例如:它的计算效率很低;它会导致优化困难等。作者受经典的非局部均值方法(non-local means)的启发,提出了一种非局部模块作为一种高效、通用的组件,用于在深度神经网络中捕获远程依赖关系

(该非局部模块结合了多头注意力机制。)

二、创新点

1. 与RNN和CNN的渐进行为相反,非局部运算通过计算任意两个位置之间的交互来直接捕获远程依赖,而不管它们的位置距离如何。

2. 正如本文在实验中所展示的,非局部操作是高效的,即使只有几层(例如5层),也能达到最佳效果

3. 该非局部操作保持可变的输入大小,并且可以轻松地与其他操作(例如,我们将使用的卷积)组合

三、Non-local Neural Networks(非局部神经网络)

3.1 非局部操作的一般定义

其中,i 是输入特征中要被计算的位置,jx_i 所有可能关联到位置的索引。f 是位置间的类同标量函数,用于计算两位置间的相关性。g 是位置输入信号函数,对输入的特征进行赋权,C(x)是归一化因子。

上述的非局部操作看后可能会感到一头雾水,我们需要简单了解一下前继论文:non-local means,非局部均值滤波

算法思想:图像的长边缘,纹理结构都是相似的,因此,在同一幅图像中, 对相似区域进行分类并加权平均得到的去噪图像,也应该能提高去噪效果。

方法:NLM是非局部均值操作,对每个滤波点都利用了整张图像的信息。对像素点 i 做NLM操作,先遍历整张图像,求出 i 与其他像素点的相似度,再与不同位置的像素点的灰度化值相乘(根据邻域间的相似性均值大小确定像素的权值),累加求均值,就得到像素点 i 的滤波值。

公式

其中,NL(i)表示为 i 的滤波值,v(j) 表示像素点 j 的灰度值,j\in Ij需要遍历全图w ( i , j )v(j) 的权重,由 i , j 两像素点之间的相似度来得到。

如果仅仅比较 i,j 像素值来作为相似度的依据,不能正确反映 i,j 的关系,所以通常选其邻域N ( i ) , N ( j ),邻域大小可以自己选择。

参考:Non-Local Means(非局部均值滤波)

均值、中值、高斯、non-local means算法详解

Non-local Net理解_非局部模块

由上述的non-local means转到本文的Non-local操作的定义,可知,f(x_i,x_j) 用于计算像素域 x_ix_j相似度,作为像素域 x_j的权重,这样通过累加求均值,获得的y_i便捕获了整张图像的长距离依赖关系。(y_i可以看作整张图像与在y_i位置上的原像素的相似度相乘求均值得到的。)

3.2 实例

计算像素邻域间的相似性的四种方法:

1. 高斯函数

2. 嵌入式高斯函数

其中,\theta (\textup{x}_i)=W_\theta \textup{x}_i\phi (\textup{x}_j)=W_\phi \textup{x}_j​ 是两个嵌入件。W_\thetaW_\phi是要学习的权重矩阵,通过1x1x1卷积实现。(这里是对应视频分类任务,图像任务可采用1x1卷积)。

3. 点积

4. Concatenation(标注[40]提出的方法)

其中,[·,·] 表示维度拼接操作。

3.3 Non-local Block(NL模块)

NL模块由Non-local操作多头注意力机制结合而来。可以采用以上的四种方法中的任意一种计算相似度。

思想: 由上述的Non-local操作概述,我们可以得知要在一张图像上建立起长距离依赖关系只需要计算整张图像上每个像素值其他像素值之间的相似度,以此作为其他像素值的权重,累加()求均值,得到的新值替代原来的像素值,最终得到的新的图像的像素值是按类别划分的,相似的像素值被整合为了一类,这样就完成了对长距离依赖关系的捕获。

过程:以点积计算相似度的方法为例。首先,输入图像 \textup{X}\in R^{1024 \times W \times H } 通过三个1x1的卷积得到三个特征图\theta , \phi ,g\theta\phi 通过点积操作计算每个像素值其他像素值之间的相似度,定义如下:

再通过softmax函数求均值(归一化的同时也进行均值操作),这里得到的相似度矩阵大小为WHxWH,也就是说每个像素都有对应的一个相似度矩阵(大小为W x H),每个大小为W x H的相似度矩阵作为权重与原特征图(g)按位相乘再相加(点积计算)得到的新的特征图(大小为WxHx512),即捕获到了图像的长距离依赖关系。定义如下:

再经过一个1x1的卷积将通道数恢复到1024,与原输入图像 \textup{X}进行残差连接,最终得到一个融合了图像长距离依赖关系的新特征图Z,定义如下:

这里简单的说一下残差连接(参考深度残差网络ResNet),引入残差连接的好处在于允许将NL模块插入任何预先训练的模型中,而不会破坏其初始化表现,这样就保证了NL模块在其他网络中的通用性,即插即用。

四、实验

我们仅关注在2D图像上的实验。

数据集:CoCo数据集 。

主干网络:ResNet50、ResNet101或ResNeXt-152。

1. 验证四种相似度计算方法效果差异

2. 比较单个NL模块添加到ResNet不同位置的情况

3. 比较添加不同数量的NL模块效果

 分别添加1,5,10个NL模块在ResNet50和ResNet100。

4. 单个NL模块添加在res3位置​的进行对象检测和实例分割

五、结论

NLNet通过非局部操作捕获远程依赖关系,NL模块通过残差连接保证了可以与任何现有架构相结合。NL模块巧妙地将非局部操作与多头注意力机制相结合的方法至今依然是热门的参考架构。正如本文在实验中所展示的,非局部操作是十分高效的,即使只有几层(例如5层),也能达到最佳效果。

六、代码实现

论文中给出的源码针对的是视频分类任务,在此贴出针对2D图像的源码:

# -*- coding: utf-8 -*-
"""
@Author  :   zhwzhong
@License :   (C) Copyright 2013-2018, hit
@Contact :   zhwzhong.hit@gmail.com
@Software:   PyCharm
@File    :   nlrn.py
@Time    :   2019/5/5 16:27
@Desc    :
"""
import torch
from torch import nn
from torch.nn import functional as F


class _NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, gamma=0, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.gamma = nn.Parameter(torch.ones(1)) * gamma

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            # 参数数量与 channel 有关
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)
        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        """
        :param x: (b, c, t, h, w)
        :return:
        """

        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)

        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = self.gamma * W_y + x

        return z


class NONLocalBlock1D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock1D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock2D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock2D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=2, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock3D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock3D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=3, sub_sample=sub_sample,
                                              bn_layer=bn_layer)

if __name__ == '__main__':
    import torch

    for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]:
        img = torch.zeros(2, 3, 20)
        net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
        out = net(img)
        print(out.size())

        img = torch.zeros(2, 3, 20, 20)
        net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
        out = net(img)
        print(out.size())

        img = torch.randn(2, 3, 8, 20, 20)
        net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
        out = net(img)
        print(out.size())

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

向岸看

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值