本文是对Retinexformer技术的代码解读,原文解读请看Retinexformer。
1、原文概要
首先基于通用Retinex理论提出了自己优化的Retinex理论,基于该理论提出了一个亮度引导的Unet类型的Transformer网络结构用于图像增强。下图为整个模型的结构示意图。
2、代码结构
代码整体结构如下:
train.py是训练脚本,archs文件中是网络结构,losses文件中是损失函数。
3 、核心代码模块
archs
文件夹
archs
文件夹主要用于存放网络架构相关的代码,这些代码定义了模型的具体结构。在这个项目里,archs
文件夹包含了多个文件,下面为大家详细介绍每个文件的作用:
1. __init__.py
此文件的主要功能是自动扫描和导入 archs
文件夹下所有以 _arch.py
结尾的文件,并提供了动态实例化网络的功能。
2. RetinexFormer_arch.py
此文件定义了 RetinexFormer
网络架构,这是项目的核心网络,其具体结构如下:
- 网络结构类:定义了网络结构的类。
class RetinexFormer(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_feat=31, stage=3, num_blocks=[1,1,1]):
super(RetinexFormer, self).__init__()
self.stage = stage
modules_body = [RetinexFormer_Single_Stage(in_channels=in_channels, out_channels=out_channels, n_feat=n_feat, level=2, num_blocks=num_blocks)
for _ in range(stage)]
self.body = nn.Sequential(*modules_body)
- 前向传播函数:定义了数据在网络中的流动过程。
def forward(self, x):
out = self.body(x)
return out
3. arch_util.py
该文件提供了一些网络架构的工具函数,例如初始化权重、构建层等。
4. layers.py
该文件定义了一些网络层,像多层感知机(MLP)等。
losses
文件夹
losses
文件夹主要用于存放损失函数相关的代码,这些代码定义了训练过程中使用的损失函数。该文件夹包含了多个文件,下面为你详细介绍每个文件的作用:
1. __init__.py
此文件导入了 losses.py
中定义的损失函数,并将它们添加到 __all__
列表里,方便其他模块导入使用。
from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss)
__all__ = [
'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss',
]
2. loss_util.py
该文件提供了一些损失函数的工具函数。
- reduce_loss损失:
reduce_loss
函数用于根据指定的reduction
方式减少损失。
def reduce_loss(loss, reduction):
reduction_enum = F._Reduction.get_enum(reduction)
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
else:
return loss.sum()
- 加权损失:
weighted_loss
函数是一个装饰器,用于给损失函数添加权重和reduction
参数。
def weighted_loss(loss_func):
@functools.wraps(loss_func)
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction)
return loss
return wrapper
3. losses.py
此文件定义了多种损失函数,例如 L1 损失、MSE 损失、PSNR 损失、Charbonnier 损失等。
- L1 损失:
L1Loss
类定义了 L1 损失函数。
class L1Loss(nn.Module):
def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
return self.loss_weight * l1_loss(
pred, target, weight, reduction=self.reduction)
- MSE 损失:
MSELoss
类定义了 MSE 损失函数。
class MSELoss(nn.Module):
def __init__(self, loss_weight=1.0, reduction='mean'):
super(MSELoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
return self.loss_weight * mse_loss(
pred, target, weight, reduction=self.reduction)
综上所述,archs
文件夹定义了网络架构,而 losses
文件夹定义了训练过程中使用的损失函数,这两个文件夹共同构成了项目的核心代码部分。
4、详细代码注释(网络结构)
RetinexFormer_arch.py
文件中的 RetinexFormer
类是本文模型,采用了基于Retinex理论的Transformer架构。以下是对 RetinexFormer
类的详细介绍,将按照模块进行分析:
1. 类定义和初始化
class RetinexFormer(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_feat=31, stage=3, num_blocks=[1,1,1]):
super(RetinexFormer, self).__init__()
self.stage = stage
modules_body = [RetinexFormer_Single_Stage(in_channels=in_channels, out_channels=out_channels, n_feat=n_feat, level=2, num_blocks=num_blocks)
for _ in range(stage)]
self.body = nn.Sequential(*modules_body)
-
参数说明:
in_channels
:输入图像的通道数,默认为3(RGB图像)。out_channels
:输出图像的通道数,默认为3(RGB图像)。n_feat
:特征通道数,默认为31。stage
:网络的阶段数,即RetinexFormer_Single_Stage
模块的堆叠数量,默认为3。num_blocks
:每个IGAB
模块中的块数,默认为[1, 1, 1]
。
-
网络结构:
RetinexFormer
类由多个RetinexFormer_Single_Stage
模块堆叠而成,这些模块通过nn.Sequential
组合在一起。
2. 前向传播
def forward(self, x):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
out = self.body(x)
return out
- 功能:前向传播函数将输入图像
x
依次通过堆叠的RetinexFormer_Single_Stage
模块,最终输出增强后的图像。
3. 子模块:RetinexFormer_Single_Stage
class RetinexFormer_Single_Stage(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_feat=31, level=2, num_blocks=[1, 1, 1]):
super(RetinexFormer_Single_Stage, self).__init__()
self.estimator = Illumination_Estimator(n_feat)
self.denoiser = Denoiser(in_dim=in_channels,out_dim=out_channels,dim=n_feat,level=level,num_blocks=num_blocks) #### 将 Denoiser 改为 img2img
def forward(self, img):
# img: b,c=3,h,w
# illu_fea: b,c,h,w
# illu_map: b,c=3,h,w
illu_fea, illu_map = self.estimator(img)
input_img = img * illu_map + img
output_img = self.denoiser(input_img,illu_fea)
return output_img
-
参数说明:
in_channels
:输入图像的通道数,默认为3。out_channels
:输出图像的通道数,默认为3。n_feat
:特征通道数,默认为31。level
:编码器和解码器的层数,默认为2。num_blocks
:每个IGAB
模块中的块数,默认为[1, 1, 1]
。
-
网络结构:
RetinexFormer_Single_Stage
模块包含两个子模块:Illumination_Estimator
:用于估计图像的光照图。Denoiser
:用于对增强后的图像进行去噪处理。
-
前向传播:
- 通过
Illumination_Estimator
模块估计光照特征illu_fea
和光照图illu_map
。 - 将输入图像
img
与光照图illu_map
相乘并加上原图像,得到增强后的输入图像input_img
。 - 将增强后的输入图像
input_img
和光照特征illu_fea
输入到Denoiser
模块中进行去噪处理,得到最终的输出图像output_img
。
- 通过
4. 子模块:Illumination_Estimator
class Illumination_Estimator(nn.Module):
def __init__(
self, n_fea_middle, n_fea_in=4, n_fea_out=3): #__init__部分是内部属性,而forward的输入才是外部输入
super(Illumination_Estimator, self).__init__()
self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)
self.depth_conv = nn.Conv2d(
n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)
self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)
def forward(self, img):
# img: b,c=3,h,w
# mean_c: b,c=1,h,w
# illu_fea: b,c,h,w
# illu_map: b,c=3,h,w
mean_c = img.mean(dim=1).unsqueeze(1)
# stx()
input = torch.cat([img,mean_c], dim=1)
x_1 = self.conv1(input)
illu_fea = self.depth_conv(x_1)
illu_map = self.conv2(illu_fea)
return illu_fea, illu_map
-
参数说明:
n_fea_middle
:中间特征通道数。n_fea_in
:输入特征通道数,默认为4。n_fea_out
:输出特征通道数,默认为3。
-
网络结构:
- 该模块包含三个卷积层:
conv1
:1x1卷积层,用于将输入特征映射到中间特征。depth_conv
:深度可分离卷积层,用于提取光照特征。conv2
:1x1卷积层,用于将中间特征映射到输出特征。
- 该模块包含三个卷积层:
-
前向传播:
- 计算输入图像
img
的通道均值mean_c
。 - 将输入图像
img
和通道均值mean_c
在通道维度上拼接,得到输入特征input
。 - 通过
conv1
卷积层得到中间特征x_1
。 - 通过
depth_conv
卷积层得到光照特征illu_fea
。 - 通过
conv2
卷积层得到光照图illu_map
。
- 计算输入图像
5. 子模块:Denoiser
class Denoiser(nn.Module):
def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
super(Denoiser, self).__init__()
self.dim = dim
self.level = level
# Input projection
self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
# Encoder
self.encoder_layers = nn.ModuleList([])
dim_level = dim
for i in range(level):
self.encoder_layers.append(nn.ModuleList([
IGAB(
dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
]))
dim_level *= 2
# Bottleneck
self.bottleneck = IGAB(
dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])
# Decoder
self.decoder_layers = nn.ModuleList([])
for i in range(level):
self.decoder_layers.append(nn.ModuleList([
nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
kernel_size=2, padding=0, output_padding=0),
nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
IGAB(
dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
heads=(dim_level // 2) // dim),
]))
dim_level //= 2
# Output projection
self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, illu_fea):
"""
x: [b,c,h,w] x是feature, 不是image
illu_fea: [b,c,h,w]
return out: [b,c,h,w]
"""
# Embedding
fea = self.embedding(x)
# Encoder
fea_encoder = []
illu_fea_list = []
for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
fea = IGAB(fea,illu_fea) # bchw
illu_fea_list.append(illu_fea)
fea_encoder.append(fea)
fea = FeaDownSample(fea)
illu_fea = IlluFeaDownsample(illu_fea)
# Bottleneck
fea = self.bottleneck(fea,illu_fea)
# Decoder
for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
fea = FeaUpSample(fea)
fea = Fution(
torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
illu_fea = illu_fea_list[self.level-1-i]
fea = LeWinBlcok(fea,illu_fea)
# Mapping
out = self.mapping(fea) + x
return out
-
参数说明:
in_dim
:输入特征通道数,默认为3。out_dim
:输出特征通道数,默认为3。dim
:特征通道数,默认为31。level
:编码器和解码器的层数,默认为2。num_blocks
:每个IGAB
模块中的块数,默认为[2, 4, 4]
。
-
网络结构:
- 该模块采用了编码器 - 解码器架构,包含以下部分:
- 输入投影:
embedding
卷积层,用于将输入特征映射到指定维度。 - 编码器:由多个
IGAB
模块和下采样卷积层组成,用于提取特征。 - 瓶颈层:
bottleneck
IGAB
模块,用于进一步提取特征。 - 解码器:由多个上采样卷积层、融合层和
IGAB
模块组成,用于恢复特征。 - 输出投影:
mapping
卷积层,用于将特征映射到输出维度。
- 输入投影:
- 该模块采用了编码器 - 解码器架构,包含以下部分:
-
前向传播:
- 通过
embedding
卷积层将输入特征x
映射到指定维度。 - 依次通过编码器的
IGAB
模块和下采样卷积层,提取特征并下采样。 - 通过瓶颈层的
IGAB
模块进一步提取特征。 - 依次通过解码器的上采样卷积层、融合层和
IGAB
模块,恢复特征并上采样。 - 通过
mapping
卷积层将特征映射到输出维度,并加上输入特征x
,得到最终的输出特征。
- 通过
6. 子模块:IGAB模块
IGAB
模块主要由多个 IG_MSA
(光照引导多头自注意力)和前馈网络(Feed Forward)子模块交替组成。该模块的独特之处在于,它利用光照估计器生成的光照特征来指导注意力机制的计算,让模型在处理低光照图像时能够更加关注图像的光照信息。
代码实现
class IGAB(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
num_blocks=2,
):
super().__init__()
self.blocks = nn.ModuleList([])
for _ in range(num_blocks):
self.blocks.append(nn.ModuleList([
IG_MSA(dim=dim, dim_head=dim_head, heads=heads),
PreNorm(dim, FeedForward(dim=dim))
]))
def forward(self, x, illu_fea):
"""
x: [b,c,h,w]
illu_fea: [b,c,h,w]
return out: [b,c,h,w]
"""
x = x.permute(0, 2, 3, 1)
for (attn, ff) in self.blocks:
x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
x = ff(x) + x
out = x.permute(0, 3, 1, 2)
return out
1. 初始化参数
dim
:输入特征的通道数。dim_head
:每个注意力头的维度,默认为64。heads
:注意力头的数量,默认为8。num_blocks
:堆叠的IG_MSA
和前馈网络对的数量,默认为2。
2. 子模块
IG_MSA
:光照引导的多头自注意力模块,它将光照特征融入到注意力计算过程中。PreNorm
和FeedForward
:前馈网络模块,其中PreNorm
是在进行前馈计算之前对输入进行层归一化操作。
3. 前向传播过程
- 维度调整:把输入特征
x
和光照特征illu_fea
从[b,c,h,w]
调整为[b,h,w,c]
,这样便于后续的注意力计算。 - 堆叠块处理:依次通过多个
IG_MSA
和前馈网络对,并且都使用了残差连接:- 先将特征输入到
IG_MSA
模块中,利用光照特征来引导注意力机制。 - 接着将输出输入到前馈网络中进行非线性变换。
- 先将特征输入到
- 维度恢复:将处理后的特征从
[b,h,w,c]
重新调整回[b,c,h,w]
。
核心组件:IG_MSA(光照引导多头自注意力)
class IG_MSA(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
):
super().__init__()
self.num_heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
self.proj = nn.Linear(dim_head * heads, dim, bias=True)
self.pos_emb = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
GELU(),
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
)
self.dim = dim
def forward(self, x_in, illu_fea_trans):
"""
x_in: [b,h,w,c] # input_feature
illu_fea: [b,h,w,c] # mask shift? 为什么是 b, h, w, c?
return out: [b,h,w,c]
"""
b, h, w, c = x_in.shape
x = x_in.reshape(b, h * w, c)
q_inp = self.to_q(x)
k_inp = self.to_k(x)
v_inp = self.to_v(x)
illu_attn = illu_fea_trans
q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
(q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))
v = v * illu_attn
q = q.transpose(-2, -1)
k = k.transpose(-2, -1)
v = v.transpose(-2, -1)
q = F.normalize(q, dim=-1, p=2)
k = F.normalize(k, dim=-1, p=2)
attn = (k @ q.transpose(-2, -1))
attn = attn * self.rescale
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.permute(0, 3, 1, 2)
x = x.reshape(b, h * w, self.num_heads * self.dim_head)
out_c = self.proj(x).view(b, h, w, c)
out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
0, 3, 1, 2)).permute(0, 2, 3, 1)
out = out_c + out_p
return out
IG_MSA工作原理
- 查询(Q)、键(K)、值(V)生成:利用线性层将输入特征映射为Q、K、V。
- 光照引导:把光照特征与值(V)相乘,以此来调整注意力机制对不同区域的关注程度。
- 注意力计算:
- 先对Q和K进行L2归一化处理,然后计算它们的点积。
- 引入可学习的缩放因子
rescale
,对注意力得分进行调整。 - 最后通过softmax函数得到最终的注意力权重。
- 特征聚合:用注意力权重对V进行加权求和。
- 位置编码:通过深度可分离卷积为特征添加位置信息。
- 输出整合:将注意力输出和位置编码输出相加,得到最终结果。
设计亮点
- 光照引导机制:借助光照特征来调制注意力机制,使模型在处理低光照图像时能够更好地利用光照信息。
- 归一化注意力:对Q和K进行归一化处理,让注意力计算更加稳定。
- 双重残差连接:在
IGAB
模块中使用了残差连接,有助于梯度的传播,使模型更容易训练。 - 位置感知:通过位置编码模块保留了图像的空间信息。
IGAB
模块通过将光照引导机制与自注意力机制相结合,能够有效处理低光照图像中的光照不均问题,同时保留图像的细节信息。这种设计让RetinexFormer网络在低光照图像增强任务中表现出色。
小结
RetinexFormer
类通过堆叠多个 RetinexFormer_Single_Stage
模块,结合光照估计和去噪处理,实现了低光照图像的增强。每个RetinexFormer_Single_Stage
模块包含光照估计器和去噪器,光照估计器用于估计图像的光照图,去噪器采用编码器 - 解码器架构对增强后的图像进行去噪处理。
5、总结
本文是Retinexformer这篇暗光增强文章的代码解读。该文章结合了 Retinex 理论与 Transformer 架构,采用单阶段 Retinex 架构和 IG_MSA 模块,能更好捕捉光照信息,且使用编码器 - 解码器架构保留图像细节。并且在 NTIRE 2024 低光照增强挑战赛中获得了第二名。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。