pytorch代码实现注意力机制之EMA

新提出的EMA注意力机制利用跨空间学习和多尺度策略优化通道和空间注意力,实验显示在图像分类和目标检测任务中效果优于ECA、CBAM和CA,尤其在处理小目标时有显著提升。
部署运行你感兴趣的模型镜像

EMA注意力机制

EMA注意力机制,基于跨空间学习的高效多尺度注意力模块,ICASSP2023推出,效果优于ECA、CBAM、CA ,小目标涨点明显。
论文地址:Efficient Multi-Scale Attention Module with Cross-Spatial Learning
EMA结构原理图
结构对比

在各种计算机视觉任务中,通道或空间注意力机制在产生更清晰的特征表示方面的显著有效性得到了证明。然而,通过通道降维来建模跨通道关系可能会给提取深度视觉表示带来副作用。提出了一种新的高效的多尺度注意力(EMA)模块。以保留每个通道上的信息和降低计算开销为目标,将部分通道重塑为批量维度,并将通道维度分组为多个子特征,使空间语义特征在每个特征组中均匀分布。具体来说,除了对全局信息进行编码以重新校准每个并行分支中的通道权重外,还通过跨维度交互进一步聚合两个并行分支的输出特征,以捕获像素级成对关系。对图像分类和目标检测任务进行了广泛的消融研究和实验,使用流行的基准(如CIFAR-100、ImageNet-1k、MS COCO和VisDrone2019)来评估其性能。

代码实现:

import torch
from torch import nn

class EMA(nn.Module):
    def __init__(self, channels, factor=8):
        super(EMA, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

您可能感兴趣的与本文相关的镜像

Yolo-v8.3

Yolo-v8.3

Yolo

YOLO(You Only Look Once)是一种流行的物体检测和图像分割模型,由华盛顿大学的Joseph Redmon 和Ali Farhadi 开发。 YOLO 于2015 年推出,因其高速和高精度而广受欢迎

基于resnet改进EMA注意力模块+项目说明书+代码+番茄叶片病害4分类项目实战、一键训练 【项目说明书】一千字的word,包含代码训练流程、代码简单介绍,原理等等 本项目是一个基于PyTorch框架的深度学习图像分类系统,采用卷积神经网络(CNN)实现完整的训练与评估流程。系统核心功能包括数据预处理、模型训练、性能评估和可视化分析,适用于多样化的图像分类任务。项目文件结构清晰,主要由train.py(主训练脚本)、data_utils.py(数据处理模块)和train_utils.py(训练评估工具)组成,支持命令行参数配置如数据路径、批次大小和学习率等。 数据预处理阶段通过ImageDataset类实现标准化操作:训练集采用随机裁剪、水平翻转和颜色增强等动态增强策略,验证集仅进行基础调整和归一化,均统一至224×224分辨率。训练流程支持GPU加速,自动记录损失值、准确率、精确率、召回率、特异度和F1分数六类指标,并在每轮训练后生成验证集评估报告。系统会动态保存最佳模型权重(.pth文件)至checkpoints目录,同时输出训练曲线图(含6项指标对比)和详细日志文件,便于监控过拟合/欠拟合现象。 用户可通过模块化设计灵活扩展功能:修改CNNModel类调整网络结构,自定义get_data_transforms()的数据增强策略,或增减calculate_metrics()的评估指标。项目要求数据集按类别分目录存放,依赖PyTorch、NumPy等基础库,建议合理设置batch_size以避免内存溢出。该系统整合了从数据加载到模型部署的全流程工具,兼具标准化流程与高度可定制性,为图像分类任务提供高效解决方案。
### 多尺度空间金字塔注意力机制的有效方法 #### 关键概念解析 多尺度空间金字塔注意力机制通过引入不同尺度的感受野来捕捉图像中的特征,从而提高模型对复杂场景的理解能力。这种机制通常结合卷积神经网络(CNN),利用多个分支处理同一输入的不同分辨率版本。 #### 实现细节 为了有效地应用多尺度空间金字塔注意力机制,在每个块级别计算来自各个分支生成的特征图加权平均值[^1]。权重由学习到的标量参数决定,这些参数指示相对于其他映射而言应关注的程度。此过程允许模型自适应地聚焦于重要的局部区域,而忽略不那么相关的部分。 对于具体实现方式之一——CBAM模块,则进一步区分了通道维度上的注意力建模以及空间位置上的建模,并且两者之间采用串联结构依次执行[^2]。值得注意的是,在实际操作过程中,可以同时运用最大池化与均值池化的策略获取全局上下文信息。 此外,当涉及到语义分割任务时,有研究指出使用硬标签而非软标签能够提升存储效率及训练速度[^4]。这意味着尽管某些情况下可能损失了一定程度的概率分布表达力,但在特定应用场景下依然可以获得性能优势。 ```python import torch.nn as nn class MultiScaleSpatialPyramidAttention(nn.Module): def __init__(self, channels): super(MultiScaleSpatialPyramidAttention, self).__init__() # 定义不同的感受野大小 self.branches = nn.ModuleList([ nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation) for dilation in [1, 2, 4] ]) def forward(self, x): attentions = [] for branch in self.branches: atten_map = branch(x).sigmoid() attentions.append(atten_map * x) out = sum(attentions) / len(attentions) return out ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

「已注销」

你的激励是我肝下去的动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值