(即插即用模块-特征处理部分) 二十九、(2024) CAFM 卷积注意力融合模块

在这里插入图片描述

paper:Hybrid Convolutional and Attention Network for Hyperspectral Image Denoising

Code:https://github.com/summitgao/HCANet


1、Convolution and Attention Fusion Module

目前图像去雾的研究面临以下几种困境:HSI特征复杂性: HSI包含丰富的空间和光谱信息,仅依靠卷积或注意力机制难以充分建模其复杂的特征。卷积局限性: 卷积操作擅长捕捉局部特征,但感受野有限,难以建模长距离依赖关系,例如空间上相隔较远的像素之间的关联性。注意力局限性: 注意力机制擅长提取全局特征,但忽略了局部特征,例如像素之间的空间相邻关系。论文考虑到卷积和注意力机制在特征建模方面具有互补性,结合两者可以更全面地捕捉HSI的局部和全局特征,所以提出一种 卷积注意力融合模块(Convolution and Attention Fusion Module)

CAFM 的基本思想是将卷积和注意力机制结合,卷积操作擅长捕捉局部特征,但难以建模全局特征和长距离依赖关系。Transformer通过注意力机制擅长提取全局特征,但忽略了局部特征。从而可以更好地建模HSI的局部和全局特征,从而提升去噪效果。

对于输入X,CAFM 的实现过程:

  1. 输入特征经过1x1卷积和通道混洗操作,进入局部分支进行特征提取。首先使用1x1卷积调整通道维度。然后使用通道混洗操作增强跨通道信息交互和信息整合。最后使用3x3x3卷积提取特征。
  2. 输入特征经过1x1卷积和3x3深度卷积,生成Q、K和V,进入全局分支进行自注意力机制计算。首先使用1x1卷积和3x3深度卷积生成查询(Q)、键(K)和值(V)。然后使用自注意力机制计算注意力图,捕捉全局特征和长距离依赖关系。最后使用1x1卷积和注意力图进行特征加权,得到全局分支的输出。
  3. 将局部分支和全局分支的输出相加,得到CAFM模块的最终输出。

Convolution and Attention Fusion Module 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
from einops import rearrange

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, bias=False):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv3d(dim, dim * 3, kernel_size=(1, 1, 1), bias=bias)
        self.qkv_dwconv = nn.Conv3d(dim * 3, dim * 3, kernel_size=(3, 3, 3), stride=1, padding=1, groups=dim * 3,
                                    bias=bias)
        self.project_out = nn.Conv3d(dim, dim, kernel_size=(1, 1, 1), bias=bias)
        self.fc = nn.Conv3d(3 * self.num_heads, 9, kernel_size=(1, 1, 1), bias=True)

        self.dep_conv = nn.Conv3d(9 * dim // self.num_heads, dim, kernel_size=(3, 3, 3), bias=True,
                                  groups=dim // self.num_heads, padding=1)

    def forward(self, x):
        b, c, h, w = x.shape
        x = x.unsqueeze(2)
        qkv = self.qkv_dwconv(self.qkv(x))
        qkv = qkv.squeeze(2)
        f_conv = qkv.permute(0, 2, 3, 1)
        f_all = qkv.reshape(f_conv.shape[0], h * w, 3 * self.num_heads, -1).permute(0, 2, 1, 3)
        f_all = self.fc(f_all.unsqueeze(2))
        f_all = f_all.squeeze(2)

        # local conv
        f_conv = f_all.permute(0, 3, 1, 2).reshape(x.shape[0], 9 * x.shape[1] // self.num_heads, h, w)
        f_conv = f_conv.unsqueeze(2)
        out_conv = self.dep_conv(f_conv)  # B, C, H, W
        out_conv = out_conv.squeeze(2)

        # global SA
        q, k, v = qkv.chunk(3, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        out = out.unsqueeze(2)
        out = self.project_out(out)
        out = out.squeeze(2)
        output = out + out_conv

        return output


if __name__ == '__main__':
    x = torch.randn(4, 64, 128, 128).cuda()
    model = Attention(64).cuda()
    out = model(x)
    print(out.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值