BAM: Bottleneck Attention Module__实现

BAM: Bottleneck Attention Module

引言

  在此论文中,我们把重心放在了Attention对于一般深度神经网络的影响上, 我们提出了一个简单但是有效的Attention 模型—BAM,它可以结合到任何前向传播卷积神经网络中,我们的模型通过两个分离的路径 channel和spatial, 得到一个Attention Map.
整体结构图如下:
在这里插入图片描述

这里作者将BAM放在了Resnet网络中每个stage之间。有趣的是,通过可视化我们可以看到多层BAMs形成了一个分层的注意力机制,这有点像人类的感知机制。BAM在每个stage之间消除了像背景语义特征这样的低层次特征,然后逐渐聚焦于高级的语义–明确的目标(比如图中的狗).

主要思想:

channel attention branch

对于给定的feature map F ∈ R C ∗ H ∗ W \mathrm{F} \in R^{C * H * W} FRCHW,BAM可以得到一个3D的Attention map M ( F ) ∈ R C ∗ H ∗ W M(F)∈R^{C*H*W} M(F)RCHW.加强的feature map F′;
F ′ = F + F ⊗ M ( F ) \mathrm{F}^{\prime}=\mathrm{F}+\mathrm{F} \otimes \mathrm{M}(\mathrm{F}) F=F+FM(F)

为了设计一个有效且强大的模型,我们首先计算channel attention,然后计算spatial attention.这时M(F)就变成:
M ( F ) = σ ( M c ( F ) + M s ( F ) ) \mathrm{M}(\mathrm{F})=\sigma\left(\mathrm{M}_{c}(\mathrm{F})+\mathrm{M}_{s}(\mathrm{F})\right) M(F)=σ(Mc(F)+Ms(F))

这里σ 代表sigmoid函数,为了聚合feature map在每个通道维度,我们采用全局平均池化得到 F C F_{C} FC这个向量然后对全局信息在每个通道进行软编码。为了评估Attention在每个通道的效果?我们使用了一个多层感知(MLP)用一层隐藏层。在MLP之后,我们增加了BN去调整规模和空间分支一样的输出,channel attention可以被计算为:

M c ( F ) = B N ( M L P ( AvgPool ( F ) ) ) \mathbf{M}_{\mathbf{c}}(\mathbf{F})=B N(M L P(\text {AvgPool}(\mathbf{F}))) Mc(F)=BN(MLP(AvgPool(F)))

<think> 嗯,用户这次聚焦在Bottleneck Attention ModuleBAM)的详细解析上,结合对话历史里那些系统级指令(比如数学公式的LaTeX格式要求)和丰富的站内引用素材,我需要整合出一个专业且易理解的回答。 用户可能正在研究注意力机制在CNN中的应用,特别是这个轻量级模块的设计思路。从引用[1][4]能看出BAM的核心卖点是双路径结构和即插即用特性,而引用[3][5]提供了关键参数和数学实现细节。有趣的是,用户特意提到"不是真正需求"的提示,说明ta更期待原创性整合而非单纯堆砌引用。 准备从三个维度展开:原理部分重点讲双分支的协同机制(空间分支用膨胀卷积的数学表达要按指令用$$格式);实现部分解析代码片段时保留GitHub链接但强调通用性;应用部分则关联引用[4]的多数据集效果。特别注意所有公式都要严格遵循LaTeX规范,比如通道分支的压缩公式得写成 $$ M_c(F) = BN(W_1 \delta(W_0 AvgPool(F) + b_0) + b_1) $$ 这种完整形式。 用户没有明确场景需求,但引用里提到目标检测(VOC/COCO),所以应用案例会侧重CV任务。最后生成的问题要引导深度探索,比如结合姊妹篇CBAM的对比——这在引用[1]提到但未展开,正好作为延伸点。 </think> 以下是对**Bottleneck Attention Module (BAM)** 的详细介绍,涵盖定义、核心原理、实现方法及应用场景: --- ### **一、BAM 的定义与目标** BAM 是一种轻量级**注意力机制模块**,由 Jongchan Park 等人在 2018 年提出。其核心目标是通过**自适应增强特征图中的重要信息**(通道与空间维度),提升卷积神经网络(CNN)的特征表达能力。BAM 可直接集成到现有 CNN 的瓶颈层(Bottleneck),无需修改主干网络结构[^1][^4]。 --- ### **二、核心结构与工作原理** #### 1. **双分支注意力机制** BAM 包含两个并行分支,分别生成**通道注意力图**(Channel Attention Map)和**空间注意力图**(Spatial Attention Map),最终融合为统一的注意力权重: $$ M(\mathbf{F}) = \sigma(M_c(\mathbf{F}) + M_s(\mathbf{F})) $$ 其中 $\mathbf{F}$ 为输入特征图,$\sigma$ 为 Sigmoid 函数[^1][^4]。 #### 2. **通道分支(Channel Branch)** - **步骤**:全局平均池化 → 全连接层压缩 → ReLU → 全连接层恢复维度 → BN层。 - **数学表达**: $$ M_c(\mathbf{F}) = \mathrm{BN}\left( \mathbf{W}_1 \left( \delta(\mathbf{W}_0 \mathrm{AvgPool}(\mathbf{F}) + \mathbf{b}_0) \right) + \mathbf{b}_1 \right) $$ - $\mathbf{W}_0 \in \mathbb{R}^{C/r \times C}$:压缩层($r$ 为压缩比,默认为 16) - $\mathbf{W}_1 \in \mathbb{R}^{C \times C/r}$:恢复层[^3][^5]。 #### 3. **空间分支(Spatial Branch)** - **结构**:两层膨胀卷积(Dilated Convolution)扩大感受野。 - **流程**: $$ M_s(\mathbf{F}) = f_{3\times3}^{d=4} \left( f_{3\times3}^{d=2} (\mathbf{F}) \right) $$ - $d$ 为膨胀率(默认 $d=4$),用于捕捉上下文信息[^3][^4]。 #### 4. **特征图增强** 最终输出特征图通过注意力权重调整: $$ \mathbf{F}' = \mathbf{F} + \mathbf{F} \otimes M(\mathbf{F}) $$ $\otimes$ 表示逐元素乘法,强化重要区域并抑制噪声[^1][^4]。 --- ### **三、使用方法(集成到 CNN)** BAM 以**即插即用**方式嵌入 CNN 的瓶颈层(如 ResNet 的残差块之间): ```python import torch import torch.nn as nn class BAM(nn.Module): def __init__(self, channel, reduction=16, dilation=4): super().__init__() # 通道分支 self.channel_att = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Linear(channel, channel // reduction), nn.ReLU(), nn.Linear(channel // reduction, channel), nn.BatchNorm1d(channel) ) # 空间分支 self.spatial_att = nn.Sequential( nn.Conv2d(channel, channel // reduction, 3, padding=dilation, dilation=dilation), nn.ReLU(), nn.Conv2d(channel // reduction, 1, 3, padding=dilation, dilation=dilation), nn.BatchNorm2d(1) ) def forward(self, x): channel_att = self.channel_att(x).sigmoid() spatial_att = self.spatial_att(x).sigmoid() att = channel_att * spatial_att # 融合双分支 return x + x * att.expand_as(x) # 示例:嵌入 ResNet 的残差块 class ResBlockWithBAM(nn.Module): def __init__(self, in_channels): super().__init__() self.conv_block = ... # 标准卷积层 self.bam = BAM(in_channels) def forward(self, x): residual = x x = self.conv_block(x) x = self.bam(x) # 在残差连接前加入BAM return x + residual ``` [^2][^5] --- ### **四、关键优势与实验效果** 1. **轻量化**:参数增量仅 0.1%~0.5%,几乎不增加计算负担[^4]。 2. **通用性**:在 ResNet、MobileNet 等架构中均能提升性能[^1][^4]。 3. **显著效果**: - **ImageNet**:ResNet-50 + BAM → Top-1 错误率降低 1.3%[^1]。 - **目标检测**(VOC/COCO):mAP 提升 1.5~2.0%[^4]。 - **语义分割**:IoU 平均提高 1.8%[^3]。 --- ### **五、应用场景** 1. **图像分类**:增强关键特征判别性。 2. **目标检测**:提升小目标和遮挡目标的识别能力。 3. **语义分割**:细化边界与上下文理解[^1][^4]。 --- ### **六、参数配置建议** | **参数** | **推荐值** | **作用** | |------------|------------|------------------------------| | 压缩比 $r$ | 16 | 平衡计算量与效果 | | 膨胀率 $d$ | 4 | 控制空间分支的感受野大小[^3] | ---
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cpp编程小茶馆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值