【每天一篇深度学习论文】即插即用半小波注意力模块HWAB

论文介绍

题目:

HALF WAVELET ATTENTION ON M-NET+ FOR LOW-LIGHT IMAGE ENHANCEMENT

论文地址:

链接: https://arxiv.org/abs/2203.01296

创新点

  • 改进的分层架构 M-Net+:提出了一个改良的分层模型 M-Net+,专为低光图像增强设计。该架构旨在缓解采样过程中的空间信息损失问题。通过采用像素去卷积(Pixel Unshuffle)和双线性下采样,提升了多尺度特征的多样性和丰富性。
  • 半小波注意力块(Half Wavelet Attention Block, HWAB):新引入了一种高效的特征提取模块 HWAB,利用小波域信息提取更丰富的特征。这种方法结合了小波变换和注意力机制,可以同时减少计算复杂度并增强特征语义信息。
  • 改进的特征融合方法:在解码过程中,使用选择性核特征融合(Selective Kernel Feature Fusion, SKFF)方法替代传统的特征拼接方式,有效地融合了不同分辨率的特征,同时降低了网络的参数量和计算复杂度。
  • 性能表现:在 LOL 和 MIT-Adobe FiveK 两个数据集上,提出的 HWMNet 模型在图像质量(PSNR、SSIM 和 LPIPS)以及计算复杂度方面均达到了竞争性甚至领先的效果。

方法

模型总体架构

HWMNet 继承了 U-Net 和 M-Net 的分层结构,包含以下关键模块:

  • 编码器(Encoder):从输入低光图像中提取多层次特征。
  • 解码器(Decoder):将不同分辨率的特征融合,并逐步恢复到原始图像分辨率。
  • 跳跃连接(Skip Connections):连接编码器和解码器的对应层,用于保持高分辨率的特征信息。
    在这里插入图片描述

关键改进模块

M-Net+ 架构

M-Net+ 是基于 M-Net 的改进架构,解决了原始 M-Net 的两个主要问题:

  • 避免空间信息损失
    • 在 U-Net 路径中使用像素去卷积(Pixel Unshuffle)进行下采样。
    • 在门柱路径(Gatepost Path)中使用双线性插值下采样。
  • 高效特征融合
  • 在解码阶段,使用选择性核特征融合(SKFF)方法取代简单的特征拼接,减轻高维特征融合的计算复杂度。

半小波注意力块(HWAB)

HWAB 是模型的核心创新模块,用于增强特征提取的多样性:
在这里插入图片描述

  • 输入特征被分为两部分:
    • 保留部分:直接保留原始域的特征信息。
    • 变换部分:通过离散小波变换(DWT)进入小波域,从中提取更丰富的上下文信息。
  • 在小波域中,通过通道注意力(Channel Attention)和空间注意力(Spatial Attention)对特征加权,随后通过逆小波变换(IWT)回到原始域。
  • 最后,合并保留特征和加权特征,再通过卷积层生成输出特征。

特征处理流程

1. 输入处理

  • 输入图像经过一个初始 3×3 卷积层,提取初始特征。
  • 每一层都通过 HWAB 处理,分为多分辨率特征。

2.多层次特征提取

  • U-Net 路径通过像素去卷积进行下采样,逐步降低特征图分辨率。
  • 门柱路径使用双线性下采样,并保持特征与 U-Net 路径的连接。

3. 特征融合

  • 在解码阶段,通过 SKFF 将多分辨率特征高效融合,减轻计算负担并提升重建质量。

4. 输出生成

  • 经过多层次特征融合后,模型最终通过卷积层生成增强后的图像。

模型的主要优势

  • 分层结构提升了模型对多尺度信息的处理能力。
  • HWAB 模块显著提高了特征提取的多样性和语义丰富度。
  • 通过高效特征融合和轻量化设计,实现了更低的计算复杂度。

即插即用模块作用

HWAB 作为一个即插即用模块:

  • 图像增强任务
    特别适用于低光图像增强任务,如论文中提到的 LOL 和 MIT-Adobe FiveK 数据集。在需要同时提升图像亮度、对比度和细节的场景中效果显著。
  • 图像修复任务
    可用于其他图像修复任务,如图像去噪、去模糊等,因为其设计本质上有助于提取和恢复细节特征。
  • 需要低计算复杂度的场景
    HWAB 通过小波变换对特征分解并仅处理一半的特征,显著降低了计算复杂度,非常适合嵌入式设备或实时处理的应用场景。
  • 多尺度特征处理的场景
    在需要多分辨率特征提取和整合的视觉任务中,HWAB 可高效提取不同尺度下的丰富特征信息。。

消融实验结果

在这里插入图片描述

在这里插入图片描述

  • 表 1 是在 LOL 数据集上的结果对比,表明 HWAB 和 M-Net+ 架构结合后在 PSNR、SSIM 和 LPIPS 三个指标上表现优异。
  • 表 2 是在 MIT-Adobe FiveK 数据集上的结果对比,展示了 HWMNet 在多个任务下的稳健性和高效性。
  • HWAB 的引入使模型在保持较低计算复杂度的情况下,实现了比大多数方法更好的性能(如 PSNR 和 LPIPS 指标)。

即插即用模块代码


import torch
import torch.nn as nn
#论文:HALF WAVELET ATTENTION ON M-NET+ FOR LOW-LIGHT IMAGE ENHANCEMENT
#论文地址:https://arxiv.org/abs/2203.01296

def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2), bias=bias, stride=stride)

def dwt_init(x):
    x01 = x[:, :, 0::2, :] / 2
    x02 = x[:, :, 1::2, :] / 2
    x1 = x01[:, :, :, 0::2]
    x2 = x02[:, :, :, 0::2]
    x3 = x01[:, :, :, 1::2]
    x4 = x02[:, :, :, 1::2]
    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4
    # print(x_HH[:, 0, :, :])
    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)

def iwt_init(x):
    r = 2
    in_batch, in_channel, in_height, in_width = x.size()
    out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r ** 2)), r * in_height, r * in_width
    x1 = x[:, 0:out_channel, :, :] / 2
    x2 = x[:, out_channel:out_channel * 2, :, :] / 2
    x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
    x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
    h = torch.zeros([out_batch, out_channel, out_height, out_width])

    h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
    h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
    h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
    h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4

    return h


class DWT(nn.Module):
    def __init__(self):
        super(DWT, self).__init__()
        self.requires_grad = True

    def forward(self, x):
        return dwt_init(x)


class IWT(nn.Module):
    def __init__(self):
        super(IWT, self).__init__()
        self.requires_grad = True

    def forward(self, x):
        return iwt_init(x)


# Spatial Attention Layer
class SALayer(nn.Module):
    def __init__(self, kernel_size=5, bias=False):
        super(SALayer, self).__init__()
        self.conv_du = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias),
            nn.Sigmoid()
        )

    def forward(self, x):
        # torch.max will output 2 things, and we want the 1st one
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        avg_pool = torch.mean(x, 1, keepdim=True)
        channel_pool = torch.cat([max_pool, avg_pool], dim=1) # [N,2,H,W] could add 1x1 conv -> [N,3,H,W]
        y = self.conv_du(channel_pool)

        return x * y

# Channel Attention Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16, bias=False):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

# Half Wavelet Attention Block (HWAB)
class HWAB(nn.Module):
    def __init__(self, n_feat, o_feat, kernel_size=3, reduction=16, bias=False, act=nn.PReLU()):
        super(HWAB, self).__init__()
        self.dwt = DWT()
        self.iwt = IWT()

        modules_body = \
            [
                conv(n_feat*2, n_feat, kernel_size, bias=bias),
                act,
                conv(n_feat, n_feat*2, kernel_size, bias=bias)
            ]
        self.body = nn.Sequential(*modules_body)

        self.WSA = SALayer()
        self.WCA = CALayer(n_feat*2, reduction, bias=bias)

        self.conv1x1 = nn.Conv2d(n_feat*4, n_feat*2, kernel_size=1, bias=bias)
        self.conv3x3 = nn.Conv2d(n_feat, o_feat, kernel_size=3, padding=1, bias=bias)
        self.activate = act
        self.conv1x1_final = nn.Conv2d(n_feat, o_feat, kernel_size=1, bias=bias)

    def forward(self, x):
        residual = x

        # Split 2 part
        wavelet_path_in, identity_path = torch.chunk(x, 2, dim=1)

        # Wavelet domain (Dual attention)
        x_dwt = self.dwt(wavelet_path_in)
        res = self.body(x_dwt)
        branch_sa = self.WSA(res)
        branch_ca = self.WCA(res)
        res = torch.cat([branch_sa, branch_ca], dim=1)
        res = self.conv1x1(res) + x_dwt
        wavelet_path = self.iwt(res)

        out = torch.cat([wavelet_path, identity_path], dim=1)
        out = self.activate(self.conv3x3(out))
        out += self.conv1x1_final(residual)

        return out


if __name__ == '__main__':


    block = HWAB(n_feat=64, o_feat=64)

    input = torch.randn(1, 64, 128, 128) # B C H W

    output = block(input)

    print(input.size())    print(output.size())
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值