Asymmetric Non-local Neural Networks

Asymmetric Non-local Neural Networks

论文:Asymmetric Non-local Neural Networks for Semantic Segmentation,ICCV,2019.

链接:paper

代码:github

本文是在Non-local Neural Networks(cvpr,2018) 这篇论文的基础上改进得来的,non-local block Non-local模块作为语义分割任务中很有用的技术,但因为计算量较大,需要较大的显存开销,这阻碍了non-local network在实际应用中的使用。本文提出了APNB来减少non-local block的计算量和显存开销,AFNB通过提升分割性能增强non-local block的学习能力。我们着重分析这篇论文的改进点。

Revisiting Non-local Block

img

输入X \in \mathbb{R}^{C*H*W}

1.经过3个1x1的卷积Query,Key和Value变换之后,分别得到 : \varnothing \in \mathbb{R}{\hat{C}HW}==\mathbb{R}{\hat{C}*N}\theta \in \mathbb{R}{\hat{C}*H*W}==\mathbb{R}{\hat{C}*N}\gamma \in \mathbb{R}{\hat{C}*H*W}==\mathbb{R}{\hat{C}*N}

2.计算所有空间位置特征向量的相似度:V=\varnothing ^{T}*\theta

3.归一化:\overrightarrow{V}=softmax(V)

4.根据所有空间位置特征向量的相似度加权求和:O=\overrightarrow{V}*\gamma^ {T}

5.新特征经过1*1卷积变换后与原输入特征拼接:Y = cat\left(W_{o}\left(O^{T} \right ),X\right)

计算复杂度:O\left ( \hat{C}N^{2} \right )=O\left ( \hat{C}H{2}W{2} \right )

Asymmetric Pyramid Non-local Block(APNB)

Non-local 的计算复杂度主要来自第2步和第4步:

<think>好的,我现在需要帮助用户解决如何在模型中添加non-local Block注意力机制的问题,使用PyTorch或TensorFlow。用户还提到参考了两篇论文,分别是Asymmetric Non-local Neural Networks和GCNet,所以我要结合这些资料来回答。 首先,我得回忆一下non-local block的基本结构。根据论文,non-local操作通过计算所有位置的特征之间的关联来捕获长程依赖。核心步骤包括计算查询(query)、键(key)、值(value),然后通过softmax得到注意力权重,最后加权求和得到输出。可能还需要一个可学习的权重矩阵来调整输出的规模。 然后,用户提到的GCNet对non-local block进行了优化,发现attention maps相似,因此简化了结构。GCNet的global context block可能更高效,可能结合了squeeze-excitation的思想,比如通道注意力。不过用户的问题是关于non-local block的实现,所以需要先给出标准的non-local block代码,再提到GCNet的改进作为优化选项。 接下来,我需要考虑如何在PyTorch中实现non-local block。通常,这涉及到定义卷积层或线性层来处理输入特征,生成query、key、value。然后进行矩阵乘法计算相似度,应用softmax,再与value相乘。最后可能加上残差连接。 可能需要分步骤说明:首先导入必要的库,然后定义NonLocalBlock类。在初始化中定义各个卷积层,比如将输入通道数减少到一半的卷积,然后前向传播中进行reshape和矩阵运算。注意处理张量的维度变换,比如将空间维度(H*W)展平,以便进行矩阵乘法。 在TensorFlow中的实现类似,但需要使用tf.keras.layers中的层,比如Conv2D,Reshape,Dot等。同样需要生成query、key、value,计算注意力权重,然后应用。 然后是如何将non-local block添加到模型中。以ResNet为例,可以在残差块之后插入non-local模块。需要调整输入输出的通道数,确保尺寸匹配,可能通过1x1卷积调整通道数,或使用池化调整空间尺寸。 还需要注意内存消耗,因为non-local的计算复杂度较高,特别是处理大尺寸特征图时。可能需要在较低分辨率的特征图上使用,或者使用下采样来减少计算量,如GCNet中的优化方法。 最后,生成相关问题,可能涉及non-local的应用场景、与其他注意力机制的比较、优化方法等,参考用户提供的引用内容中的信息。 现在需要确保代码示例正确,比如PyTorch中的实现是否正确处理了张量的维度,是否有残差连接,以及是否正确应用了softmax。同时,要引用用户提供的参考文献,比如在说明non-local结构时引用[^1],在提到GCNet优化时引用[^2]。 可能遇到的错误点:矩阵乘法的维度不匹配,忘记添加残差连接导致训练困难,或者没有正确应用归一化(如LayerNorm)。此外,输入是否需要下采样也需要考虑,以节省计算资源。 总结步骤:1. 解释non-local block原理;2. 提供PyTorch和TensorFlow的实现代码;3. 说明如何集成到现有模型;4. 注意事项和优化建议;5. 相关问题。</think>### 如何在模型中添加 Non-Local Block 注意力机制 Non-Local Block 是一种通过全局特征建模来增强模型感知能力的注意力机制,常用于计算机视觉任务(如语义分割、目标检测)。其核心思想是通过计算特征图中任意两个位置的相关性,捕捉长距离依赖关系。以下是具体实现方法及代码示例: --- #### 一、Non-Local Block 原理 1. **计算特征相关性**:对输入特征图 $X \in \mathbb{R}^{C \times H \times W}$,通过卷积生成 $Query$、$Key$、$Value$ 三个分支。 2. **注意力权重矩阵**:通过矩阵乘法计算所有位置间的相似度,公式为: $$ \text{Attention}(Q,K) = \text{Softmax}\left(\frac{QK^T}{\sqrt{C}}\right) $$ 3. **加权聚合**:将注意力权重与 $Value$ 加权求和,得到增强后的特征。 --- #### 二、PyTorch 实现代码 ```python import torch import torch.nn as nn class NonLocalBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.mid_channels = in_channels // 2 # 减少计算量 # 定义Q/K/V生成层 self.query = nn.Conv2d(in_channels, self.mid_channels, 1) self.key = nn.Conv2d(in_channels, self.mid_channels, 1) self.value = nn.Conv2d(in_channels, in_channels, 1) # 输出层 + 残差连接 self.gamma = nn.Parameter(torch.zeros(1)) # 可学习权重 self.final_conv = nn.Conv2d(in_channels, in_channels, 1) def forward(self, x): batch_size = x.size(0) # 生成Q/K/V [B, C, H, W] q = self.query(x) # [B, C_mid, H, W] k = self.key(x) # [B, C_mid, H, W] v = self.value(x) # [B, C, H, W] # 展平空间维度 [B, C, H*W] q = q.view(batch_size, self.mid_channels, -1) k = k.view(batch_size, self.mid_channels, -1) v = v.view(batch_size, self.in_channels, -1) # 计算注意力矩阵 [B, H*W, H*W] attn = torch.bmm(q.permute(0,2,1), k) # Q^T * K attn = torch.softmax(attn / (self.mid_channels**0.5), dim=-1) # 加权聚合 [B, C, H*W] out = torch.bmm(v, attn.permute(0,2,1)) # V * Attn out = out.view(batch_size, self.in_channels, *x.shape[2:]) # 残差连接 return self.gamma * self.final_conv(out) + x ``` --- #### 三、TensorFlow 实现代码 ```python import tensorflow as tf from tensorflow.keras.layers import Conv2D, Reshape, Layer class NonLocalBlock(Layer): def __init__(self, in_channels): super().__init__() self.mid_channels = in_channels // 2 # 定义Q/K/V生成层 self.query_conv = Conv2D(self.mid_channels, 1) self.key_conv = Conv2D(self.mid_channels, 1) self.value_conv = Conv2D(in_channels, 1) # 输出层 self.gamma = self.add_weight(shape=(1,), initializer='zeros') self.final_conv = Conv2D(in_channels, 1) def call(self, x): batch_size = tf.shape(x)[0] # 生成Q/K/V [B, H, W, C] q = self.query_conv(x) # [B, H, W, C_mid] k = self.key_conv(x) v = self.value_conv(x) # 展平空间维度 [B, H*W, C] q_flat = Reshape((-1, self.mid_channels))(q) k_flat = Reshape((-1, self.mid_channels))(k) v_flat = Reshape((-1, tf.shape(v)[-1]))(v) # 计算注意力矩阵 [B, H*W, H*W] attn = tf.matmul(q_flat, k_flat, transpose_b=True) attn = tf.nn.softmax(attn / tf.sqrt(tf.cast(self.mid_channels, tf.float32))) # 加权聚合 [B, H*W, C] out = tf.matmul(attn, v_flat) out = Reshape(tf.shape(x)[1:])(out) # 残差连接 return self.gamma * self.final_conv(out) + x ``` --- #### 四、集成到现有模型 以 ResNet 为例,在残差块后插入 Non-Local Block: ```python class ResNetWithNonLocal(nn.Module): def __init__(self): super().__init__() self.backbone = resnet50(pretrained=True) self.non_local = NonLocalBlock(2048) # 在最后一层特征添加 def forward(self, x): x = self.backbone.conv1(x) x = self.backbone.layer1(x) x = self.backbone.layer2(x) x = self.backbone.layer3(x) x = self.backbone.layer4(x) # [B, 2048, 7, 7] x = self.non_local(x) # 添加 Non-Local return x ``` --- #### 五、注意事项 1. **计算复杂度**:Non-Local 的复杂度为 $O((H \times W)^2)$,建议在低分辨率特征图上使用(如 $7 \times 7$)。 2. **优化方法**:GCNet 提出简化版本(Global Context Block),通过共享注意力权重减少计算量。 3. **归一化**:可添加 LayerNorm 或 BatchNorm 提升稳定性。 ---
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值