paper:Hybrid Convolutional and Attention Network for Hyperspectral Image Denoising
1、Convolution and Attention Fusion Module
目前图像去雾的研究面临以下几种困境:HSI特征复杂性: HSI包含丰富的空间和光谱信息,仅依靠卷积或注意力机制难以充分建模其复杂的特征。卷积局限性: 卷积操作擅长捕捉局部特征,但感受野有限,难以建模长距离依赖关系,例如空间上相隔较远的像素之间的关联性。注意力局限性: 注意力机制擅长提取全局特征,但忽略了局部特征,例如像素之间的空间相邻关系。论文考虑到卷积和注意力机制在特征建模方面具有互补性,结合两者可以更全面地捕捉HSI的局部和全局特征,所以提出一种 卷积注意力融合模块(Convolution and Attention Fusion Module)。
CAFM 的基本思想是将卷积和注意力机制结合,卷积操作擅长捕捉局部特征,但难以建模全局特征和长距离依赖关系。Transformer通过注意力机制擅长提取全局特征,但忽略了局部特征。从而可以更好地建模HSI的局部和全局特征,从而提升去噪效果。
对于输入X,CAFM 的实现过程:
- 输入特征经过1x1卷积和通道混洗操作,进入局部分支进行特征提取。首先使用1x1卷积调整通道维度。然后使用通道混洗操作增强跨通道信息交互和信息整合。最后使用3x3x3卷积提取特征。
- 输入特征经过1x1卷积和3x3深度卷积,生成Q、K和V,进入全局分支进行自注意力机制计算。首先使用1x1卷积和3x3深度卷积生成查询(Q)、键(K)和值(V)。然后使用自注意力机制计算注意力图,捕捉全局特征和长距离依赖关系。最后使用1x1卷积和注意力图进行特征加权,得到全局分支的输出。
- 将局部分支和全局分支的输出相加,得到CAFM模块的最终输出。
Convolution and Attention Fusion Module 结构图:
2、代码实现
import torch
import torch.nn as nn
from einops import rearrange
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, bias=False):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv3d(dim, dim * 3, kernel_size=(1, 1, 1), bias=bias)
self.qkv_dwconv = nn.Conv3d(dim * 3, dim * 3, kernel_size=(3, 3, 3), stride=1, padding=1, groups=dim * 3,
bias=bias)
self.project_out = nn.Conv3d(dim, dim, kernel_size=(1, 1, 1), bias=bias)
self.fc = nn.Conv3d(3 * self.num_heads, 9, kernel_size=(1, 1, 1), bias=True)
self.dep_conv = nn.Conv3d(9 * dim // self.num_heads, dim, kernel_size=(3, 3, 3), bias=True,
groups=dim // self.num_heads, padding=1)
def forward(self, x):
b, c, h, w = x.shape
x = x.unsqueeze(2)
qkv = self.qkv_dwconv(self.qkv(x))
qkv = qkv.squeeze(2)
f_conv = qkv.permute(0, 2, 3, 1)
f_all = qkv.reshape(f_conv.shape[0], h * w, 3 * self.num_heads, -1).permute(0, 2, 1, 3)
f_all = self.fc(f_all.unsqueeze(2))
f_all = f_all.squeeze(2)
# local conv
f_conv = f_all.permute(0, 3, 1, 2).reshape(x.shape[0], 9 * x.shape[1] // self.num_heads, h, w)
f_conv = f_conv.unsqueeze(2)
out_conv = self.dep_conv(f_conv) # B, C, H, W
out_conv = out_conv.squeeze(2)
# global SA
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = out.unsqueeze(2)
out = self.project_out(out)
out = out.squeeze(2)
output = out + out_conv
return output
if __name__ == '__main__':
x = torch.randn(4, 64, 128, 128).cuda()
model = Attention(64).cuda()
out = model(x)
print(out.shape)