注意力机制SimAM代码(Pytorch)

介绍SimAM,一种用于卷积神经网络的简单且无需额外参数的注意力模块。该模块通过计算输入特征图的均值并利用sigmoid激活函数来增强重要特征,同时抑制不重要的特征,以此提高模型的性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks

论文链接:

http://proceedings.mlr.press/v139/yang21o.html

code: https://github.com/ZjjConan/SimAM

import torch
import torch.nn as nn


class simam_module(torch.nn.Module):
    def __init__(self, channels = None, e_lambda = 1e-4):
        super(simam_module, self).__init__()

        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s

    @staticmethod
    def get_module_name():
        return "simam"

    def forward(self, x):

        b, c, h, w = x.size()
        
        n = w * h - 1

        x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5

        return x * self.activaton(y)

### SimAM注意力机制概述 SimAM是一种轻量级的注意力机制,在深度学习领域被广泛应用于图像分类、目标检测以及语义分割等任务中。它通过简单而高效的计算方式,增强了模型对重要特征的关注能力,同时减少了额外参数的数量和计算开销。 SimAM的核心思想在于利用局部区域内的相似性来衡量像素的重要性。具体而言,对于输入特征图中的每一个位置 \(i\),其对应的响应值可以通过比较该位置与其他所有位置之间的差异得到。这种设计使得SimAM能够捕捉到全局上下文信息并增强显著特征的表现力[^1]。 以下是SimAM的具体实现方法: ### 实现细节 #### 数学表达 给定一个大小为\(H \times W \times C\) 的特征图 \(X\)SimAM定义如下公式用于计算每个通道上的权重矩阵: \[ E_{ij} = 1 - \exp(-\lVert X_i - M_j\rVert_2^2), \] 其中, - \(M_j=\frac{1}{N}\sum_k^{N}(X_k)\),表示第j个窗口内所有元素均值; - \(E_{ij}\in[0,1]\), 表达的是当前点相对于邻域平均值偏离程度;当某一点越接近于周围邻居,则它的能量值会更趋近于零,反之则越大。 最终输出可以写成: \[ Y=X\times (1-E). \] #### PyTorch代码示例 下面提供了一个基于PyTorch框架下的SimAM模块实现版本: ```python import torch import torch.nn as nn class SimAM(nn.Module): def __init__(self, lambda_param=0.01): super(SimAM, self).__init__() self.lambda_param = lambda_param def forward(self, x): b, c, h, w = x.size() n = w * h - 1 # 计算空间维度上每一点与其余点的距离平方和 distance = (x.view(b,c,-1).unsqueeze(3) - x.view(b,c,-1).unsqueeze(2)).pow(2) # 排除自身距离项 mask = (torch.ones((h*w,h*w)) - torch.eye(h*w)).cuda().bool() distance = distance[:,:,:,mask].view(b,c,n) exp_distance = (-distance / (n ** .5 + 1e-8)).exp() norm_exp_dis = exp_distance.sum(dim=-1).clamp(min=self.lambda_param)**-.5 output = x * norm_exp_dis.unsqueeze(-1).unsqueeze(-1) return output ``` 此段代码实现了上述提到的能量函数及其后续操作过程,并将其封装成了可调用形式以便集成至其他神经网络结构当中去使用。 ### 性能优势 相比传统复杂的SENet或者CBAM等多分支架构的设计方案来说,SimAM仅需极少新增运算成本即可达到相近甚至超越的效果表现。因此非常适合资源受限环境下的移动端部署场景应用需求.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值