【CVPR 2025】低光增强SG-LLIE(ranks second in NTIRE 2025)--part2代码详解

【CVPR 2025】本文参考论文Towards Scale-Aware Low-Light Enhancement via Structure-Guided
Transformer Design
论文地址:arxiv
代码地址:github
the solution ranks second in the NTIRE 2025 Low-Light Enhancement Challenge.


细看不难发现本文章采用了 ESDNet (Towards Efficient and Scale-Robust Ultra-High-Definition Image Demoir´eing的)高效且尺度鲁棒的 U-Net 架构、DRDB 和 SAM 模块作为网络的“骨骼”和“肌肉” 。
并且借鉴了 Retinexformer 使用 Transformer 模块 (IGAB) 进行特征增强和引导的思想,并将 IGAB 作为其核心的 SGTB。
它的核心创新在于,它不使用 Retinexformer 的光照引导,而是设计了一套全新的结构先验提取机制,并将这个结构先验作为引导信号输入到复用的 IGAB (SGTB) 中,通过其内部的交叉注意力部分(即实现了 SGCA)来指导低光图像的增强过程,特别关注结构和细节的保持 。


RetinexFormer_arch.py (模块化原理解析) ⚙️

在这里插入图片描述

权重初始化函数

_no_grad_trunc_normal_, trunc_normal_, variance_scaling_, lecun_normal_

原理讲解: 这些是标准的神经网络权重初始化函数。良好的权重初始化对于神经网络的稳定训练和快速收敛至关重要。例如,trunc_normal_ 用于生成截断正态分布的权重,避免极端值;variance_scaling_ 则根据层的输入/输出单元数量(fan-in/fan-out)来调整权重的方差,以保持信号在前向和后向传播过程中的稳定性。


PreNorm(nn.Module)

class PreNorm(nn.Module):
    # 在主要操作 (fn) 前应用层归一化 (LayerNorm)
    def __init__(self, dim, fn):
        # dim: 特征维度
        # fn: 主要操作模块 (例如一个注意力层或前馈网络)
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim) ## original norm layer

    def forward(self, x, *args, **kwargs):
        # x: 输入张量
        x = self.norm(x) # 先进行层归一化
        return self.fn(x, *args, **kwargs) # 然后执行主要操作

原理讲解: PreNorm 是一种在 Transformer 结构中广泛采用的技术。它将层归一化 (LayerNorm) 操作置于多头注意力 (MSA) 或前馈网络 (FFN) 这类主要计算模块之前。相较于传统的 Post-Norm (归一化在主要模块之后),Pre-Norm 被证明能够更有效地稳定训练过程,缓解梯度消失或爆炸的问题,从而使得训练更深、更复杂的 Transformer 模型成为可能。


GELU(nn.Module)

class GELU(nn.Module):
    # GELU 激活函数
    def forward(self, x):
        return F.gelu(x)

原理讲解: GELU (Gaussian Error Linear Unit) 是一种高性能的激活函数,尤其在 Transformer 模型中表现出色。它的数学形式为 x ⋅ Φ ( x ) x \cdot \Phi(x) xΦ(x),其中 Φ ( x ) \Phi(x) Φ(x) 是高斯累积分布函数。GELU 的特点在于其平滑的非线性特性,它在输入值为负时不像 ReLU 那样直接输出零,而是允许一定的负值通过,这被认为有助于模型学习更复杂的表示。


conv(...)

def conv(in_channels, out_channels, kernel_size, bias=False, padding=1, stride=1):
    # 封装 nn.Conv2d,默认 padding 使卷积后尺寸不变
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2), bias=bias, stride=stride)

原理讲解: 这是一个简单的二维卷积层 (nn.Conv2d) 的辅助封装函数。它将常用的参数如输入/输出通道数、核大小等作为输入,并自动计算 padding 的值(通过 kernel_size // 2),以确保当 stride=1 时,卷积操作不改变输入特征图的空间维度。这种封装可以简化模型定义时的代码。


LayerNorm(nn.Module)

# (BiasFree_LayerNorm 和 WithBias_LayerNorm 的定义已在上一回复中提供,此处假设它们已定义)
# from einops import rearrange # 确保导入

# def to_3d(x):
#     return rearrange(x, 'b c h w -> b (h w) c')

# def to_4d(x,h,w):
#     return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type='WithBias'):
        super(LayerNorm, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim) # 自定义的无偏置LN
        else:
            self.body = WithBias_LayerNorm(dim) # 自定义的带偏置LN

    def forward(self, x):
        # x: 输入特征, 形状 [B, C, H, W] (channels-first)
        h, w = x.shape[-2:]
        # 转为序列形式 [B, N, C] (N=H*W), 应用自定义LayerNorm, 再转回 [B, C, H, W]
        # to_3d 将 (B, C, H, W) 转换为 (B, H*W, C)
        # self.body 在最后一个维度 C 上进行归一化
        # to_4d 将 (B, H*W, C) 转换回 (B, C, H, W)
        return to_4d(self.body(to_3d(x)), h, w)

原理讲解: 层归一化 (LayerNorm) 的作用是针对每个样本,在特征维度上(即通道维度 C)计算均值和方差,并进行归一化,然后通过可学习的缩放因子(gamma)和偏置因子(beta)进行仿射变换。这有助于稳定隐藏层激活值的分布,减少内部协变量偏移,从而加速模型训练并提高泛化能力。此处的 LayerNorm 模块特别设计为处理图像特征图(通常为4D张量 [B, C, H, W]):它首先使用 to_3d 将输入转换为序列形式 [B, N, C] (其中 N = H*W),然后应用自定义的1D LayerNorm (BiasFree_LayerNormWithBias_LayerNorm) 在最后一个维度 C 上进行归一化,最后通过 to_4d 将结果转换回原始的4D图像特征图格式。


Illumination_Estimator(nn.Module)

class Illumination_Estimator(nn.Module):
    # 光照估计器模块:用于从输入图像中估计光照分量
    def __init__(
            self, n_fea_middle, n_fea_in=4, n_fea_out=3):
        # n_fea_middle: 中间层特征通道数
        # n_fea_in: 输入特征通道数 (默认为4,因为输入是图像3通道+1个均值通道)
        # n_fea_out: 输出光照图通道数 (默认为3,彩色光照图)
        super(Illumination_Estimator, self).__init__()

        # 1x1卷积,将输入(图像+均值通道)映射到中间特征维度
        self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)

        # 深度卷积,在每个中间特征通道上独立进行空间卷积
        # groups=n_fea_middle 使其成为一个标准的深度卷积
        self.depth_conv = nn.Conv2d(
            n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_middle) 

        # 1x1卷积,将深度卷积提取的特征映射回所需的光照图通道数
        self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)

    def forward(self, img):
        # img: 输入图像, 形状 [B, C_img=3, H, W]
        
        # 计算图像在通道维度上的均值,并扩展维度以进行拼接
        mean_c = img.mean(dim=1).unsqueeze(1) # 形状 [B, 1, H, W]
        
        # 将原始图像和其均值通道在通道维度上拼接
        input_tensor = torch.cat([img, mean_c], dim=1) # 形状 [B, C_img+1, H, W]

        x_1 = self.conv1(input_tensor) # 形状 [B, n_fea_middle, H, W]
        illu_fea = self.depth_conv(x_1) # 光照特征, 形状 [B, n_fea_middle, H, W]
        illu_map = self.conv2(illu_fea) # 估计的光照图, 形状 [B, n_fea_out, H, W]
        
        return illu_fea, illu_map

原理讲解: Illumination_Estimator 模块的设计目标是根据输入的图像估计其光照分量,这是Retinex理论的核心步骤之一(即图像可以分解为光照和反射)。它首先将输入图像(3通道)与其单通道的均值图(代表整体亮度信息)在通道维度上拼接,形成一个4通道的输入。这个拼接后的张量首先通过一个1x1卷积 (self.conv1) 进行特征提取和通道数调整,得到中间特征。随后,一个5x5的深度卷积 (self.depth_conv) 对这些中间特征进行空间上的进一步处理,深度卷积的特性是每个输入通道有自己独立的卷积核,这有助于在参数量较少的情况下捕捉空间模式。最后,另一个1x1卷积 (self.conv2) 将处理后的特征映射为最终的光照图 (illu_map,通常是3通道) 和供后续模块使用的中间光照特征 (illu_fea)。


IG_MSA(nn.Module) (Illumination-Guided Multi-Head Self-Attention)

class IG_MSA(nn.Module):
    def __init__(
            self,
            dim,          # 输入特征维度 (通道数 C)
            dim_head=64,  # 每个注意力头的维度
            heads=8,      # 注意力头的数量
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        inner_dim = dim_head * heads # Q,K,V 的总维度

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_k = nn.Linear(dim, inner_dim, bias=False)
        self.to_v = nn.Linear(dim, inner_dim, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
        self.proj = nn.Linear(inner_dim, dim, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim), # 深度卷积
            GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim), # 深度卷积
        )
        self.dim = dim

    def forward(self, x_in):
        # x_in: 输入特征, 形状 [B, H, W, C_in], C_in == dim (channels-last format)
        
        b, h, w, c = x_in.shape # B, H, W, C
        x = x_in.reshape(b, h * w, c) # 重塑为序列形式 [B, N, C], N = H*W

        q_inp = self.to_q(x) 
        k_inp = self.to_k(x) 
        v_inp = self.to_v(x) 
        
        # 将 Q, K, V 分割到多个头: 'b n (h d) -> b h n d' 
        # q,k,v 形状变为: [B, num_heads, N, dim_head]
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
                                 (q_inp, k_inp, v_inp))

        # L2 归一化 Q, K (沿最后一个维度,即dim_head)
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)

        # 计算注意力得分: (Q @ K_transpose) / sqrt(d_k)
        # q: [B, num_heads, N, dim_head]
        # k.transpose(-2, -1): [B, num_heads, dim_head, N]
        # attn: [B, num_heads, N, N]
        attn = (q @ k.transpose(-2, -1)) # 实际计算 Q K^T (token-wise attention)
        
        # 乘以可学习的 rescale 参数 (这里 rescale 是 (heads,1,1),会广播)
        # 论文中通常是乘以 1/sqrt(dim_head)
        attn = attn * self.rescale.unsqueeze(0) # 添加 batch 维度 (1,heads,1,1)
        attn = attn.softmax(dim=-1) # Softmax 归一化 (在最后一个维度 N 上)

        # 应用注意力到 V
        # x_attn: [B, num_heads, N, N] @ [B, num_heads, N, dim_head] -> [B, num_heads, N, dim_head]
        x_attn = attn @ v
        
        # 合并多头输出
        # 'b h n d -> b n (h d)'
        x_attn_merged = rearrange(x_attn, 'b h n d -> b n (h d)') # [B, N, inner_dim]
        
        out_c = self.proj(x_attn_merged).view(b, h, w, c) # 投影回原始维度并reshape [B, H, W, C]
        
        # 计算卷积位置编码 (Convolutional Positional Embedding)
        # v_inp: [B, N, inner_dim], reshape -> [B, H, W, C(inner_dim)] -> permute -> [B, C(inner_dim), H, W]
        # 注意:pos_emb的输入通道数是dim, 而v_inp是inner_dim。这里假设dim==inner_dim,或者pos_emb应该作用于x_in
        # 如果 pos_emb 的 nn.Conv2d 输入是 dim, 那么 v_inp 需要先投影回 dim
        # 按照代码,v_inp的通道数是inner_dim, pos_emb的第一个卷积输入通道是dim。维度不匹配除非inner_dim==dim
        # 假设这里的dim指的是原始输入特征的dim,v_inp在reshape时也应对应这个dim
        # 或者 pos_emb 的输入应该是 x_in.permute(0,3,1,2)
        # 如果按原代码,v_inp.reshape(b,h,w,c) 中 c 应该等于 self.dim
        pos_emb_out = self.pos_emb(v_inp.reshape(b, h, w, self.dim).permute(0, 3, 1, 2))
        out_p = pos_emb_out.permute(0, 2, 3, 1) # 转回 BHWC [B, H, W, C]
        
        out = out_c + out_p # 自注意力输出 + 位置编码

        return out

原理讲解: IG_MSA (Multi-Head Self-Attention) 是 Transformer 的核心。它允许模型在处理序列中的一个元素(在这里是图像展平后的一个“像素块”或“patch”)时,同时权衡序列中所有其他元素的重要性。

  1. 输入投影: 输入特征 x_in (形状 [B, H, W, C]) 首先被展平为 [B, N, C] (N=H*W),然后通过三个独立的线性层 (to_q, to_k, to_v) 分别投影成查询 (Query, Q)、键 (Key, K) 和值 (Value, V) 向量。
  2. 多头机制: Q, K, V 被分割成多个“头”(num_heads)。每个头处理特征的一个子空间,这使得模型能从不同角度捕捉信息。
  3. 缩放点积注意力: 在每个头内部,注意力权重通过计算Q和K的点积得到,然后除以一个缩放因子(这里是可学习的 rescale,经典Transformer是 d k \sqrt{d_k} dk d k d_k dk是键向量的维度)并应用Softmax函数得到归一化的权重。这些权重随后用于加权求和V向量,得到该头的输出。
  4. 合并与输出: 所有头的输出被拼接起来,并通过一个最终的线性层 (proj) 投影回原始的特征维度。
  5. 卷积位置编码: pos_emb 是一个由两个深度卷积组成的简单网络,它作用于原始的 v_inp (经过reshape和permute) 来生成位置信息,并加到自注意力机制的输出上。这与ViT中常见的可学习位置嵌入或固定正弦位置编码不同,它通过卷积隐式地学习相对位置关系。
    尽管名为 “Illumination-Guided”,但此模块的实现是一个标准的自注意力机制,并没有显式地从外部接收光照特征进行引导。

FeedForward(nn.Module)

class FeedForward(nn.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值