【CVPR 2025】本文参考论文Towards Scale-Aware Low-Light Enhancement via Structure-Guided
Transformer Design
论文地址:arxiv
代码地址:github
the solution ranks second in the NTIRE 2025 Low-Light Enhancement Challenge.
文章目录
-
- `RetinexFormer_arch.py` (模块化原理解析) ⚙️
-
- **权重初始化函数**
- **`PreNorm(nn.Module)`**
- **`GELU(nn.Module)`**
- **`conv(...)`**
- **`LayerNorm(nn.Module)`**
- **`Illumination_Estimator(nn.Module)`**
- **`IG_MSA(nn.Module)`** (Illumination-Guided Multi-Head Self-Attention)
- **`FeedForward(nn.Module)`**
- **`Cross_attention(nn.Module)`**
- **`IGAB(nn.Module)`** (Illumination-Guided Attention Block)
- **`Denoiser(nn.Module)`**
- **`RetinexFormer_Single_Stage(nn.Module)`**
- **`RetinexFormer(nn.Module)`**
- `UHDM_arch.py` (模块化原理解析) 🏗️
细看不难发现本文章采用了 ESDNet (Towards Efficient and Scale-Robust Ultra-High-Definition Image Demoir´eing的)高效且尺度鲁棒的 U-Net 架构、DRDB 和 SAM 模块作为网络的“骨骼”和“肌肉” 。
并且借鉴了 Retinexformer 使用 Transformer 模块 (IGAB) 进行特征增强和引导的思想,并将 IGAB 作为其核心的 SGTB。
它的核心创新在于,它不使用 Retinexformer 的光照引导,而是设计了一套全新的结构先验提取机制,并将这个结构先验作为引导信号输入到复用的 IGAB (SGTB) 中,通过其内部的交叉注意力部分(即实现了 SGCA)来指导低光图像的增强过程,特别关注结构和细节的保持 。
RetinexFormer_arch.py (模块化原理解析) ⚙️

权重初始化函数
(_no_grad_trunc_normal_, trunc_normal_, variance_scaling_, lecun_normal_)
原理讲解: 这些是标准的神经网络权重初始化函数。良好的权重初始化对于神经网络的稳定训练和快速收敛至关重要。例如,trunc_normal_ 用于生成截断正态分布的权重,避免极端值;variance_scaling_ 则根据层的输入/输出单元数量(fan-in/fan-out)来调整权重的方差,以保持信号在前向和后向传播过程中的稳定性。
PreNorm(nn.Module)
class PreNorm(nn.Module):
# 在主要操作 (fn) 前应用层归一化 (LayerNorm)
def __init__(self, dim, fn):
# dim: 特征维度
# fn: 主要操作模块 (例如一个注意力层或前馈网络)
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim) ## original norm layer
def forward(self, x, *args, **kwargs):
# x: 输入张量
x = self.norm(x) # 先进行层归一化
return self.fn(x, *args, **kwargs) # 然后执行主要操作
原理讲解: PreNorm 是一种在 Transformer 结构中广泛采用的技术。它将层归一化 (LayerNorm) 操作置于多头注意力 (MSA) 或前馈网络 (FFN) 这类主要计算模块之前。相较于传统的 Post-Norm (归一化在主要模块之后),Pre-Norm 被证明能够更有效地稳定训练过程,缓解梯度消失或爆炸的问题,从而使得训练更深、更复杂的 Transformer 模型成为可能。
GELU(nn.Module)
class GELU(nn.Module):
# GELU 激活函数
def forward(self, x):
return F.gelu(x)
原理讲解: GELU (Gaussian Error Linear Unit) 是一种高性能的激活函数,尤其在 Transformer 模型中表现出色。它的数学形式为 x ⋅ Φ ( x ) x \cdot \Phi(x) x⋅Φ(x),其中 Φ ( x ) \Phi(x) Φ(x) 是高斯累积分布函数。GELU 的特点在于其平滑的非线性特性,它在输入值为负时不像 ReLU 那样直接输出零,而是允许一定的负值通过,这被认为有助于模型学习更复杂的表示。
conv(...)
def conv(in_channels, out_channels, kernel_size, bias=False, padding=1, stride=1):
# 封装 nn.Conv2d,默认 padding 使卷积后尺寸不变
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), bias=bias, stride=stride)
原理讲解: 这是一个简单的二维卷积层 (nn.Conv2d) 的辅助封装函数。它将常用的参数如输入/输出通道数、核大小等作为输入,并自动计算 padding 的值(通过 kernel_size // 2),以确保当 stride=1 时,卷积操作不改变输入特征图的空间维度。这种封装可以简化模型定义时的代码。
LayerNorm(nn.Module)
# (BiasFree_LayerNorm 和 WithBias_LayerNorm 的定义已在上一回复中提供,此处假设它们已定义)
# from einops import rearrange # 确保导入
# def to_3d(x):
# return rearrange(x, 'b c h w -> b (h w) c')
# def to_4d(x,h,w):
# return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type='WithBias'):
super(LayerNorm, self).__init__()
if LayerNorm_type =='BiasFree':
self.body = BiasFree_LayerNorm(dim) # 自定义的无偏置LN
else:
self.body = WithBias_LayerNorm(dim) # 自定义的带偏置LN
def forward(self, x):
# x: 输入特征, 形状 [B, C, H, W] (channels-first)
h, w = x.shape[-2:]
# 转为序列形式 [B, N, C] (N=H*W), 应用自定义LayerNorm, 再转回 [B, C, H, W]
# to_3d 将 (B, C, H, W) 转换为 (B, H*W, C)
# self.body 在最后一个维度 C 上进行归一化
# to_4d 将 (B, H*W, C) 转换回 (B, C, H, W)
return to_4d(self.body(to_3d(x)), h, w)
原理讲解: 层归一化 (LayerNorm) 的作用是针对每个样本,在特征维度上(即通道维度 C)计算均值和方差,并进行归一化,然后通过可学习的缩放因子(gamma)和偏置因子(beta)进行仿射变换。这有助于稳定隐藏层激活值的分布,减少内部协变量偏移,从而加速模型训练并提高泛化能力。此处的 LayerNorm 模块特别设计为处理图像特征图(通常为4D张量 [B, C, H, W]):它首先使用 to_3d 将输入转换为序列形式 [B, N, C] (其中 N = H*W),然后应用自定义的1D LayerNorm (BiasFree_LayerNorm 或 WithBias_LayerNorm) 在最后一个维度 C 上进行归一化,最后通过 to_4d 将结果转换回原始的4D图像特征图格式。
Illumination_Estimator(nn.Module)
class Illumination_Estimator(nn.Module):
# 光照估计器模块:用于从输入图像中估计光照分量
def __init__(
self, n_fea_middle, n_fea_in=4, n_fea_out=3):
# n_fea_middle: 中间层特征通道数
# n_fea_in: 输入特征通道数 (默认为4,因为输入是图像3通道+1个均值通道)
# n_fea_out: 输出光照图通道数 (默认为3,彩色光照图)
super(Illumination_Estimator, self).__init__()
# 1x1卷积,将输入(图像+均值通道)映射到中间特征维度
self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)
# 深度卷积,在每个中间特征通道上独立进行空间卷积
# groups=n_fea_middle 使其成为一个标准的深度卷积
self.depth_conv = nn.Conv2d(
n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_middle)
# 1x1卷积,将深度卷积提取的特征映射回所需的光照图通道数
self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)
def forward(self, img):
# img: 输入图像, 形状 [B, C_img=3, H, W]
# 计算图像在通道维度上的均值,并扩展维度以进行拼接
mean_c = img.mean(dim=1).unsqueeze(1) # 形状 [B, 1, H, W]
# 将原始图像和其均值通道在通道维度上拼接
input_tensor = torch.cat([img, mean_c], dim=1) # 形状 [B, C_img+1, H, W]
x_1 = self.conv1(input_tensor) # 形状 [B, n_fea_middle, H, W]
illu_fea = self.depth_conv(x_1) # 光照特征, 形状 [B, n_fea_middle, H, W]
illu_map = self.conv2(illu_fea) # 估计的光照图, 形状 [B, n_fea_out, H, W]
return illu_fea, illu_map
原理讲解: Illumination_Estimator 模块的设计目标是根据输入的图像估计其光照分量,这是Retinex理论的核心步骤之一(即图像可以分解为光照和反射)。它首先将输入图像(3通道)与其单通道的均值图(代表整体亮度信息)在通道维度上拼接,形成一个4通道的输入。这个拼接后的张量首先通过一个1x1卷积 (self.conv1) 进行特征提取和通道数调整,得到中间特征。随后,一个5x5的深度卷积 (self.depth_conv) 对这些中间特征进行空间上的进一步处理,深度卷积的特性是每个输入通道有自己独立的卷积核,这有助于在参数量较少的情况下捕捉空间模式。最后,另一个1x1卷积 (self.conv2) 将处理后的特征映射为最终的光照图 (illu_map,通常是3通道) 和供后续模块使用的中间光照特征 (illu_fea)。
IG_MSA(nn.Module) (Illumination-Guided Multi-Head Self-Attention)
class IG_MSA(nn.Module):
def __init__(
self,
dim, # 输入特征维度 (通道数 C)
dim_head=64, # 每个注意力头的维度
heads=8, # 注意力头的数量
):
super().__init__()
self.num_heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads # Q,K,V 的总维度
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
self.proj = nn.Linear(inner_dim, dim, bias=True)
self.pos_emb = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim), # 深度卷积
GELU(),
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim), # 深度卷积
)
self.dim = dim
def forward(self, x_in):
# x_in: 输入特征, 形状 [B, H, W, C_in], C_in == dim (channels-last format)
b, h, w, c = x_in.shape # B, H, W, C
x = x_in.reshape(b, h * w, c) # 重塑为序列形式 [B, N, C], N = H*W
q_inp = self.to_q(x)
k_inp = self.to_k(x)
v_inp = self.to_v(x)
# 将 Q, K, V 分割到多个头: 'b n (h d) -> b h n d'
# q,k,v 形状变为: [B, num_heads, N, dim_head]
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
(q_inp, k_inp, v_inp))
# L2 归一化 Q, K (沿最后一个维度,即dim_head)
q = F.normalize(q, dim=-1, p=2)
k = F.normalize(k, dim=-1, p=2)
# 计算注意力得分: (Q @ K_transpose) / sqrt(d_k)
# q: [B, num_heads, N, dim_head]
# k.transpose(-2, -1): [B, num_heads, dim_head, N]
# attn: [B, num_heads, N, N]
attn = (q @ k.transpose(-2, -1)) # 实际计算 Q K^T (token-wise attention)
# 乘以可学习的 rescale 参数 (这里 rescale 是 (heads,1,1),会广播)
# 论文中通常是乘以 1/sqrt(dim_head)
attn = attn * self.rescale.unsqueeze(0) # 添加 batch 维度 (1,heads,1,1)
attn = attn.softmax(dim=-1) # Softmax 归一化 (在最后一个维度 N 上)
# 应用注意力到 V
# x_attn: [B, num_heads, N, N] @ [B, num_heads, N, dim_head] -> [B, num_heads, N, dim_head]
x_attn = attn @ v
# 合并多头输出
# 'b h n d -> b n (h d)'
x_attn_merged = rearrange(x_attn, 'b h n d -> b n (h d)') # [B, N, inner_dim]
out_c = self.proj(x_attn_merged).view(b, h, w, c) # 投影回原始维度并reshape [B, H, W, C]
# 计算卷积位置编码 (Convolutional Positional Embedding)
# v_inp: [B, N, inner_dim], reshape -> [B, H, W, C(inner_dim)] -> permute -> [B, C(inner_dim), H, W]
# 注意:pos_emb的输入通道数是dim, 而v_inp是inner_dim。这里假设dim==inner_dim,或者pos_emb应该作用于x_in
# 如果 pos_emb 的 nn.Conv2d 输入是 dim, 那么 v_inp 需要先投影回 dim
# 按照代码,v_inp的通道数是inner_dim, pos_emb的第一个卷积输入通道是dim。维度不匹配除非inner_dim==dim
# 假设这里的dim指的是原始输入特征的dim,v_inp在reshape时也应对应这个dim
# 或者 pos_emb 的输入应该是 x_in.permute(0,3,1,2)
# 如果按原代码,v_inp.reshape(b,h,w,c) 中 c 应该等于 self.dim
pos_emb_out = self.pos_emb(v_inp.reshape(b, h, w, self.dim).permute(0, 3, 1, 2))
out_p = pos_emb_out.permute(0, 2, 3, 1) # 转回 BHWC [B, H, W, C]
out = out_c + out_p # 自注意力输出 + 位置编码
return out
原理讲解: IG_MSA (Multi-Head Self-Attention) 是 Transformer 的核心。它允许模型在处理序列中的一个元素(在这里是图像展平后的一个“像素块”或“patch”)时,同时权衡序列中所有其他元素的重要性。
- 输入投影: 输入特征
x_in(形状[B, H, W, C]) 首先被展平为[B, N, C](N=H*W),然后通过三个独立的线性层 (to_q,to_k,to_v) 分别投影成查询 (Query, Q)、键 (Key, K) 和值 (Value, V) 向量。 - 多头机制: Q, K, V 被分割成多个“头”(
num_heads)。每个头处理特征的一个子空间,这使得模型能从不同角度捕捉信息。 - 缩放点积注意力: 在每个头内部,注意力权重通过计算Q和K的点积得到,然后除以一个缩放因子(这里是可学习的
rescale,经典Transformer是 d k \sqrt{d_k} dk, d k d_k dk是键向量的维度)并应用Softmax函数得到归一化的权重。这些权重随后用于加权求和V向量,得到该头的输出。 - 合并与输出: 所有头的输出被拼接起来,并通过一个最终的线性层 (
proj) 投影回原始的特征维度。 - 卷积位置编码:
pos_emb是一个由两个深度卷积组成的简单网络,它作用于原始的v_inp(经过reshape和permute) 来生成位置信息,并加到自注意力机制的输出上。这与ViT中常见的可学习位置嵌入或固定正弦位置编码不同,它通过卷积隐式地学习相对位置关系。
尽管名为 “Illumination-Guided”,但此模块的实现是一个标准的自注意力机制,并没有显式地从外部接收光照特征进行引导。
FeedForward(nn.Module)
class FeedForward(nn.

最低0.47元/天 解锁文章
1344

被折叠的 条评论
为什么被折叠?



