CVPR2021—— Coordinate Attention for Efficient Mobile Network Design

本文介绍了一种名为坐标注意力(Coordinate Attention, CA)的新型注意力机制,该机制在SE和CBAM的基础上进行了改进,通过分别捕获水平和垂直方向的注意力特征来提升网络性能。与SE仅关注通道不同,CA模块利用两个一维池化操作融合信息,从而有效利用位置信息。

在这里插入图片描述

Coordinate Attention for Efficient Mobile Network Design

这是一篇基于SE和CBAM的改进注意力机制。

性能比SE和CBAM要好一些。

SE模块只是在通道上施加了权重,而忽略了位置信息。

本文中,提出一种novel的注意力机制,使用两个1D的pooling捕捉水平和垂直方向的注意力特征编码。

话不多说了。直接上图。

在这里插入图片描述

懂的人都懂:

(a)是SE模块,
(b)是CBAM模块
©是本文提出的CA模块。

很明显看出,CA模块和前二者的区别在于将一个2D的pooling转换成了2个1D的pooling然后进行融合最后得到注意力。

代码实现起来也非常简单。

官方提供的源码:
https://github.com/Andrew-Qibin/CoordAttention/blob/main/coordatt.py

CVPR2021论文:https://arxiv.org/abs/2103.02907

下面我的PyTorch实现:

import  torch
import torch.nn as nn
import  math
import  torch.nn.functional as F
# 对应论文中的non-linear
class h_swish(nn.Module):
    def __init__(self, inplace = True):
        super(h_swish,self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        sigmoid = self.relu(x + 3) / 6
        x = x * sigmoid
        return x
class CoorAttention(nn.Module):
    def __init__(self,in_channels, out_channels, reduction = 32):
        super(CoorAttention, self).__init__()
        self.poolh = nn.AdaptiveAvgPool2d((None, 1))
        self.poolw = nn.AdaptiveAvgPool2d((1,None))
        middle = max(8, in_channels//reduction)
        self.conv1 = nn.Conv2d(in_channels,middle,kernel_size=1,stride=1,padding=0)
        self.bn1 = nn.BatchNorm2d(middle)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(middle,out_channels,kernel_size=1,stride=1,padding=0)
        self.conv_w = nn.Conv2d(middle,out_channels,kernel_size=1,stride=1,padding=0)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x): # [batch_size, c, h, w]
        identity = x
        batch_size, c, h, w = x.size()  # [batch_size, c, h, w]
        # X Avg Pool
        x_h = self.poolh(x)    # [batch_size, c, h, 1]

        #Y Avg Pool
        x_w = self.poolw(x)    # [batch_size, c, 1, w]
        x_w = x_w.permute(0,1,3,2) # [batch_size, c, w, 1]

        #following the paper, cat x_h and x_w in dim = 2,W+H
        # Concat + Conv2d + BatchNorm + Non-linear
        y = torch.cat((x_h, x_w), dim=2)   # [batch_size, c, h+w, 1]
        y = self.act(self.bn1(self.conv1(y)))  # [batch_size, c, h+w, 1]
        # split
        x_h, x_w = torch.split(y, [h,w], dim=2)  # [batch_size, c, h, 1]  and [batch_size, c, w, 1]
        x_w = x_w.permute(0,1,3,2) # 把dim=2和dim=3交换一下,也即是[batch_size,c,w,1] -> [batch_size, c, 1, w]
        # Conv2d + Sigmoid
        attention_h = self.sigmoid(self.conv_h(x_h))
        attention_w = self.sigmoid(self.conv_w(x_w))
        # re-weight
        return identity * attention_h * attention_w
"""
x = torch.ones(1,16,2,2)
a = CoorAttention(16,16)
print(a(x).size())
"""

分析代码:import torch import torch.nn as nn import torch.nn.functional as F # --- Attention Mechanisms --- class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) mid_ch = max(1, in_planes // ratio) self.fc1 = nn.Conv2d(in_planes, mid_ch, 1, bias=False) self.relu1 = nn.ReLU() self.fc2 = nn.Conv2d(mid_ch, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) out = avg_out + max_out return self.sigmoid(out) class SpatialAttention(nn.Module): def __init__(self, kernel_size=3): super(SpatialAttention, self).__init__() padding = kernel_size // 2 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x_cat = torch.cat([avg_out, max_out], dim=1) x_out = self.conv1(x_cat) return self.sigmoid(x_out) class CBAM(nn.Module): def __init__(self, in_planes, ratio=16, kernel_size=3): super(CBAM, self).__init__() self.ca = ChannelAttention(in_planes, ratio) self.sa = SpatialAttention(kernel_size) def forward(self, x): x = x * self.ca(x) x = x * self.sa(x) return x # NEW: Coordinate Attention (CVPR 2021) class CoordinateAttention(nn.Module): """ Coordinate Attention for Efficient Mobile Network Design (CVPR 2021) More stable and effective than complex cross-attention for small feature maps """ def __init__(self, in_channels, reduction=8): super(CoordinateAttention, self).__init__() self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) mid_channels = max(8, in_channels // reduction) self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(mid_channels) self.act = nn.ReLU(inplace=True) self.conv_h = nn.Conv2d(mid_channels, in_channels, kernel_size=1, stride=1, padding=0) self.conv_w = nn.Conv2d(mid_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): identity = x n, c, h, w = x.size() # Coordinate information embedding x_h = self.pool_h(x) # (n, c, h, 1) x_w = self.pool_w(x).permute(0, 1, 3, 2) # (n, c, w, 1) -> (n, c, 1, w) y = torch.cat([x_h, x_w], dim=2) # (n, c, h+w, 1) y = self.conv1(y) y = self.bn1(y) y = self.act(y) x_h, x_w = torch.split(y, [h, w], dim=2) x_w = x_w.permute(0, 1, 3, 2) a_h = self.conv_h(x_h).sigmoid() a_w = self.conv_w(x_w).sigmoid() out = identity * a_h * a_w return out # NEW: Efficient Channel Attention (ECA-Net, CVPR 2020) class EfficientChannelAttention(nn.Module): """ ECA-Net: Efficient Channel Attention (CVPR 2020) Parameter-free, very efficient for small networks """ def __init__(self, channels, gamma=2, b=1): super(EfficientChannelAttention, self).__init__() kernel_size = int(abs((torch.log2(torch.tensor(channels, dtype=torch.float32)) + b) / gamma)) kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): y = self.avg_pool(x) y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) y = self.sigmoid(y) return x * y.expand_as(x) class SelfAttentionBlock(nn.Module): def __init__(self, in_channels, heads=4, dim_head=32): super(SelfAttentionBlock, self).__init__() inner_dim = heads * dim_head self.heads = heads self.scale = dim_head ** -0.5 self.norm = nn.LayerNorm(in_channels) self.to_qkv = nn.Conv2d(in_channels, inner_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(inner_dim, in_channels, 1) self.ffn = nn.Sequential( nn.Conv2d(in_channels, in_channels * 2, 1), nn.GELU(), nn.Conv2d(in_channels * 2, in_channels, 1) ) def forward(self, x): B, C, H, W = x.shape x_norm = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) qkv = self.to_qkv(x_norm).chunk(3, dim=1) q, k, v = map(lambda t: t.reshape(B, self.heads, -1, H * W).permute(0, 1, 3, 2), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = dots.softmax(dim=-1) out = torch.matmul(attn, v) out = out.permute(0, 1, 3, 2).reshape(B, -1, H, W) out = self.to_out(out) x = x + out x = x + self.ffn(x) return x # --- Convolutional Blocks --- class conv_block1(nn.Module): def __init__(self, in_ch, out_ch, padding=0): super(conv_block1, self).__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): return self.net(x) class conv_block2(nn.Module): def __init__(self, in_ch, out_ch, padding=0, use_cbam=True): super(conv_block2, self).__init__() self.use_cbam = use_cbam self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding) self.bn1 = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=1, stride=1, padding=0) self.bn2 = nn.BatchNorm2d(out_ch) if self.use_cbam: self.attention = CBAM(out_ch) else: self.attention = ChannelAttention(out_ch) if in_ch != out_ch or padding != 0: self.shortcut = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=padding), nn.BatchNorm2d(out_ch) ) else: self.shortcut = nn.Identity() def forward(self, x): residual = self.shortcut(x) y = self.relu(self.bn1(self.conv1(x))) y = self.bn2(self.conv2(y)) if self.use_cbam: y = self.attention(y) else: y = y * self.attention(y) y += residual y = self.relu(y) return y class conv_block3(nn.Module): def __init__(self, in_ch, out_ch, padding=0): super(conv_block3, self).__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=padding), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): return self.net(x) class conv11_block(nn.Module): def __init__(self, in_ch): super(conv11_block, self).__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, 2 * in_ch, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(2 * in_ch), nn.ReLU(inplace=True), nn.Conv2d(2 * in_ch, in_ch, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(in_ch), nn.ReLU(inplace=True) ) def forward(self, x): return self.net(x) # NEW: Coordinate Attention Fusion Block (Replaces Cross-Attention) class CoordinateAttentionFusionBlock(nn.Module): """ 使用Coordinate Attention替代复杂的Cross-Attention 更稳定,参数更少,适合小特征图 """ def __init__(self, in_ch, out_ch, padding=0): super(CoordinateAttentionFusionBlock, self).__init__() # Main feature extraction self.net_main = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) # Coordinate Attention for spatial-channel modeling self.coord_attn = CoordinateAttention(out_ch, reduction=8) # Feature fusion with guidance self.fusion_conv = nn.Sequential( nn.Conv2d(out_ch * 2, out_ch, kernel_size=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) # Shortcut self.shortcut = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding), nn.BatchNorm2d(out_ch) ) self.relu = nn.ReLU(inplace=True) # Side output for alignment side_ch_out = max(16, out_ch // 16) self.net_side_out = nn.Sequential( nn.Conv2d(out_ch, side_ch_out, kernel_size=3, stride=1, padding=0), nn.BatchNorm2d(side_ch_out) ) def forward(self, x_main, x_side): residual = self.shortcut(x_main) # Extract main features y_main = self.net_main(x_main) # Apply coordinate attention y_attn = self.coord_attn(y_main) # Adaptive fusion with guidance from x_side # Resize x_side to match y_attn if needed if y_attn.shape[2:] != x_side.shape[2:]: x_side_resized = F.interpolate(x_side, size=y_attn.shape[2:], mode='bilinear', align_corners=False) else: x_side_resized = x_side # Concatenate and fuse fused = torch.cat([y_attn, x_side_resized], dim=1) y = self.fusion_conv(fused) # Residual connection y = self.relu(y + residual) # Side output y_side = self.net_side_out(y) y_side = torch.softmax(y_side, dim=1) return y, y_side # NEW: Dynamic Feature Alignment Module (替代GRL) class DynamicFeatureAlignment(nn.Module): """ Dynamic Feature Alignment using learnable weights 替代GRL,使用可学习的特征对齐,更稳定 基于FDA (Fourier Domain Adaptation) 的思想简化版 """ def __init__(self, feature_dim): super(DynamicFeatureAlignment, self).__init__() # Style transfer parameters (learnable) self.alpha = nn.Parameter(torch.tensor(0.1)) # 初始值较小 # Feature statistics alignment self.align_conv = nn.Sequential( nn.Conv2d(feature_dim, feature_dim, kernel_size=1), nn.BatchNorm2d(feature_dim), nn.ReLU(inplace=True) ) def forward(self, x): """ 使用特征统计信息进行软对齐,而不是对抗训练 """ # Calculate feature statistics mean = x.mean(dim=[2, 3], keepdim=True) var = x.var(dim=[2, 3], keepdim=True) # Normalize x_norm = (x - mean) / (var + 1e-5).sqrt() # Apply learnable alignment x_aligned = self.align_conv(x_norm) # Soft blending with original features x_out = self.alpha * x_aligned + (1 - self.alpha) * x return x_out # NEW: Multi-Scale Feature Aggregation (增强多尺度特征) class MultiScaleFusionBlock(nn.Module): """ Multi-scale feature aggregation with efficient attention 用于增强特征表达能力 """ def __init__(self, in_channels): super(MultiScaleFusionBlock, self).__init__() # Multi-scale convolutions self.branch1 = nn.Sequential( nn.Conv2d(in_channels, in_channels // 4, kernel_size=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True) ) self.branch2 = nn.Sequential( nn.Conv2d(in_channels, in_channels // 4, kernel_size=3, padding=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True) ) self.branch3 = nn.Sequential( nn.Conv2d(in_channels, in_channels // 4, kernel_size=5, padding=2), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True) ) self.branch4 = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // 4, kernel_size=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True) ) # Fusion self.fusion = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True) ) # ECA attention for refinement self.eca = EfficientChannelAttention(in_channels) def forward(self, x): b1 = self.branch1(x) b2 = self.branch2(x) b3 = self.branch3(x) b4 = F.interpolate(self.branch4(x), size=x.shape[2:], mode='bilinear', align_corners=False) concat = torch.cat([b1, b2, b3, b4], dim=1) out = self.fusion(concat) out = self.eca(out) return out + x # Residual # Keep original conv_block4 for compatibility class conv_block4(nn.Module): def __init__(self, in_ch, out_ch, padding=0): super(conv_block4, self).__init__() self.net_main = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, 2 * out_ch, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(2 * out_ch), nn.ReLU(inplace=True), nn.Conv2d(2 * out_ch, out_ch, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(out_ch) ) self.net_side = nn.Sequential( nn.Conv2d(out_ch, max(1, int(out_ch / 15)), kernel_size=3, stride=1, padding=0), nn.BatchNorm2d(max(1, int(out_ch / 15))) ) self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding), nn.BatchNorm2d(out_ch) ) self.relu = nn.ReLU(inplace=True) def forward(self, x_main, x_side): y_main = self.net_main(x_main) y = y_main + self.conv(x_main) - x_side y_side = self.net_side(y_main - x_side) y_side = torch.softmax(y_side, dim=1) return y, y_side
最新发布
01-06
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值