【LLIE专题】 Retinexformer代码解读

本文是对Retinexformer技术的代码解读,原文解读请看Retinexformer

1、原文概要

首先基于通用Retinex理论提出了自己优化的Retinex理论,基于该理论提出了一个亮度引导的Unet类型的Transformer网络结构用于图像增强。下图为整个模型的结构示意图。
在这里插入图片描述

2、代码结构

代码整体结构如下
在这里插入图片描述
train.py是训练脚本,archs文件中是网络结构,losses文件中是损失函数。

3 、核心代码模块

archs 文件夹

archs 文件夹主要用于存放网络架构相关的代码,这些代码定义了模型的具体结构。在这个项目里,archs 文件夹包含了多个文件,下面为大家详细介绍每个文件的作用:

1. __init__.py

此文件的主要功能是自动扫描和导入 archs 文件夹下所有以 _arch.py 结尾的文件,并提供了动态实例化网络的功能。

2. RetinexFormer_arch.py

此文件定义了 RetinexFormer 网络架构,这是项目的核心网络,其具体结构如下:

  • 网络结构类:定义了网络结构的类。
class RetinexFormer(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_feat=31, stage=3, num_blocks=[1,1,1]):
        super(RetinexFormer, self).__init__()
        self.stage = stage
        modules_body = [RetinexFormer_Single_Stage(in_channels=in_channels, out_channels=out_channels, n_feat=n_feat, level=2, num_blocks=num_blocks)
                        for _ in range(stage)]
        self.body = nn.Sequential(*modules_body)
  • 前向传播函数:定义了数据在网络中的流动过程。
    def forward(self, x):
        out = self.body(x)
        return out
3. arch_util.py

该文件提供了一些网络架构的工具函数,例如初始化权重、构建层等。

4. layers.py

该文件定义了一些网络层,像多层感知机(MLP)等。

losses 文件夹

losses 文件夹主要用于存放损失函数相关的代码,这些代码定义了训练过程中使用的损失函数。该文件夹包含了多个文件,下面为你详细介绍每个文件的作用:

1. __init__.py

此文件导入了 losses.py 中定义的损失函数,并将它们添加到 __all__ 列表里,方便其他模块导入使用。

from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss)

__all__ = [
    'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss',
]
2. loss_util.py

该文件提供了一些损失函数的工具函数。

  • reduce_loss损失reduce_loss 函数用于根据指定的 reduction 方式减少损失。
def reduce_loss(loss, reduction):
    reduction_enum = F._Reduction.get_enum(reduction)
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    else:
        return loss.sum()
  • 加权损失weighted_loss 函数是一个装饰器,用于给损失函数添加权重和 reduction 参数。
def weighted_loss(loss_func):
    @functools.wraps(loss_func)
    def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
        loss = loss_func(pred, target, **kwargs)
        loss = weight_reduce_loss(loss, weight, reduction)
        return loss
    return wrapper
3. losses.py

此文件定义了多种损失函数,例如 L1 损失、MSE 损失、PSNR 损失、Charbonnier 损失等。

  • L1 损失L1Loss 类定义了 L1 损失函数。
class L1Loss(nn.Module):
    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(L1Loss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. '
                             f'Supported ones are: {_reduction_modes}')
        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        return self.loss_weight * l1_loss(
            pred, target, weight, reduction=self.reduction)
  • MSE 损失MSELoss 类定义了 MSE 损失函数。
class MSELoss(nn.Module):
    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(MSELoss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. '
                             f'Supported ones are: {_reduction_modes}')
        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        return self.loss_weight * mse_loss(
            pred, target, weight, reduction=self.reduction)

综上所述,archs 文件夹定义了网络架构,而 losses 文件夹定义了训练过程中使用的损失函数,这两个文件夹共同构成了项目的核心代码部分。

4、详细代码注释(网络结构)

RetinexFormer_arch.py 文件中的 RetinexFormer 类是本文模型,采用了基于Retinex理论的Transformer架构。以下是对 RetinexFormer 类的详细介绍,将按照模块进行分析:

1. 类定义和初始化

class RetinexFormer(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_feat=31, stage=3, num_blocks=[1,1,1]):
        super(RetinexFormer, self).__init__()
        self.stage = stage

        modules_body = [RetinexFormer_Single_Stage(in_channels=in_channels, out_channels=out_channels, n_feat=n_feat, level=2, num_blocks=num_blocks)
                        for _ in range(stage)]
        
        self.body = nn.Sequential(*modules_body)
  • 参数说明

    • in_channels:输入图像的通道数,默认为3(RGB图像)。
    • out_channels:输出图像的通道数,默认为3(RGB图像)。
    • n_feat:特征通道数,默认为31。
    • stage:网络的阶段数,即 RetinexFormer_Single_Stage 模块的堆叠数量,默认为3。
    • num_blocks:每个 IGAB 模块中的块数,默认为 [1, 1, 1]
  • 网络结构

    • RetinexFormer 类由多个 RetinexFormer_Single_Stage 模块堆叠而成,这些模块通过 nn.Sequential 组合在一起。

2. 前向传播

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        out = self.body(x)

        return out
  • 功能:前向传播函数将输入图像 x 依次通过堆叠的 RetinexFormer_Single_Stage 模块,最终输出增强后的图像。

3. 子模块:RetinexFormer_Single_Stage

class RetinexFormer_Single_Stage(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_feat=31, level=2, num_blocks=[1, 1, 1]):
        super(RetinexFormer_Single_Stage, self).__init__()
        self.estimator = Illumination_Estimator(n_feat)
        self.denoiser = Denoiser(in_dim=in_channels,out_dim=out_channels,dim=n_feat,level=level,num_blocks=num_blocks)  #### 将 Denoiser 改为 img2img
    
    def forward(self, img):
        # img:        b,c=3,h,w
        
        # illu_fea:   b,c,h,w
        # illu_map:   b,c=3,h,w

        illu_fea, illu_map = self.estimator(img)
        input_img = img * illu_map + img
        output_img = self.denoiser(input_img,illu_fea)

        return output_img
  • 参数说明

    • in_channels:输入图像的通道数,默认为3。
    • out_channels:输出图像的通道数,默认为3。
    • n_feat:特征通道数,默认为31。
    • level:编码器和解码器的层数,默认为2。
    • num_blocks:每个 IGAB 模块中的块数,默认为 [1, 1, 1]
  • 网络结构

    • RetinexFormer_Single_Stage 模块包含两个子模块:
      • Illumination_Estimator:用于估计图像的光照图。
      • Denoiser:用于对增强后的图像进行去噪处理。
  • 前向传播

    1. 通过 Illumination_Estimator 模块估计光照特征 illu_fea 和光照图 illu_map
    2. 将输入图像 img 与光照图 illu_map 相乘并加上原图像,得到增强后的输入图像 input_img
    3. 将增强后的输入图像 input_img 和光照特征 illu_fea 输入到 Denoiser 模块中进行去噪处理,得到最终的输出图像 output_img

4. 子模块:Illumination_Estimator

class Illumination_Estimator(nn.Module):
    def __init__(
            self, n_fea_middle, n_fea_in=4, n_fea_out=3):  #__init__部分是内部属性,而forward的输入才是外部输入
        super(Illumination_Estimator, self).__init__()

        self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)

        self.depth_conv = nn.Conv2d(
            n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)

        self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)

    def forward(self, img):
        # img:        b,c=3,h,w
        # mean_c:     b,c=1,h,w
        
        # illu_fea:   b,c,h,w
        # illu_map:   b,c=3,h,w
        
        mean_c = img.mean(dim=1).unsqueeze(1)
        # stx()
        input = torch.cat([img,mean_c], dim=1)

        x_1 = self.conv1(input)
        illu_fea = self.depth_conv(x_1)
        illu_map = self.conv2(illu_fea)
        return illu_fea, illu_map
  • 参数说明

    • n_fea_middle:中间特征通道数。
    • n_fea_in:输入特征通道数,默认为4。
    • n_fea_out:输出特征通道数,默认为3。
  • 网络结构

    • 该模块包含三个卷积层:
      • conv1:1x1卷积层,用于将输入特征映射到中间特征。
      • depth_conv:深度可分离卷积层,用于提取光照特征。
      • conv2:1x1卷积层,用于将中间特征映射到输出特征。
  • 前向传播

    1. 计算输入图像 img 的通道均值 mean_c
    2. 将输入图像 img 和通道均值 mean_c 在通道维度上拼接,得到输入特征 input
    3. 通过 conv1 卷积层得到中间特征 x_1
    4. 通过 depth_conv 卷积层得到光照特征 illu_fea
    5. 通过 conv2 卷积层得到光照图 illu_map

5. 子模块:Denoiser

class Denoiser(nn.Module):
    def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
        super(Denoiser, self).__init__()
        self.dim = dim
        self.level = level

        # Input projection
        self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)

        # Encoder
        self.encoder_layers = nn.ModuleList([])
        dim_level = dim
        for i in range(level):
            self.encoder_layers.append(nn.ModuleList([
                IGAB(
                    dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
            ]))
            dim_level *= 2

        # Bottleneck
        self.bottleneck = IGAB(
            dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])

        # Decoder
        self.decoder_layers = nn.ModuleList([])
        for i in range(level):
            self.decoder_layers.append(nn.ModuleList([
                nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
                                   kernel_size=2, padding=0, output_padding=0),
                nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
                IGAB(
                    dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
                    heads=(dim_level // 2) // dim),
            ]))
            dim_level //= 2

        # Output projection
        self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, illu_fea):
        """
        x:          [b,c,h,w]         x是feature, 不是image
        illu_fea:   [b,c,h,w]
        return out: [b,c,h,w]
        """

        # Embedding
        fea = self.embedding(x)

        # Encoder
        fea_encoder = []
        illu_fea_list = []
        for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
            fea = IGAB(fea,illu_fea)  # bchw
            illu_fea_list.append(illu_fea)
            fea_encoder.append(fea)
            fea = FeaDownSample(fea)
            illu_fea = IlluFeaDownsample(illu_fea)

        # Bottleneck
        fea = self.bottleneck(fea,illu_fea)

        # Decoder
        for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
            fea = FeaUpSample(fea)
            fea = Fution(
                torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
            illu_fea = illu_fea_list[self.level-1-i]
            fea = LeWinBlcok(fea,illu_fea)

        # Mapping
        out = self.mapping(fea) + x

        return out
  • 参数说明

    • in_dim:输入特征通道数,默认为3。
    • out_dim:输出特征通道数,默认为3。
    • dim:特征通道数,默认为31。
    • level:编码器和解码器的层数,默认为2。
    • num_blocks:每个 IGAB 模块中的块数,默认为 [2, 4, 4]
  • 网络结构

    • 该模块采用了编码器 - 解码器架构,包含以下部分:
      • 输入投影embedding 卷积层,用于将输入特征映射到指定维度。
      • 编码器:由多个 IGAB 模块和下采样卷积层组成,用于提取特征。
      • 瓶颈层bottleneck IGAB 模块,用于进一步提取特征。
      • 解码器:由多个上采样卷积层、融合层和 IGAB 模块组成,用于恢复特征。
      • 输出投影mapping 卷积层,用于将特征映射到输出维度。
  • 前向传播

    1. 通过 embedding 卷积层将输入特征 x 映射到指定维度。
    2. 依次通过编码器的 IGAB 模块和下采样卷积层,提取特征并下采样。
    3. 通过瓶颈层的 IGAB 模块进一步提取特征。
    4. 依次通过解码器的上采样卷积层、融合层和 IGAB 模块,恢复特征并上采样。
    5. 通过 mapping 卷积层将特征映射到输出维度,并加上输入特征 x,得到最终的输出特征。

6. 子模块:IGAB模块

IGAB 模块主要由多个 IG_MSA(光照引导多头自注意力)和前馈网络(Feed Forward)子模块交替组成。该模块的独特之处在于,它利用光照估计器生成的光照特征来指导注意力机制的计算,让模型在处理低光照图像时能够更加关注图像的光照信息。

代码实现
class IGAB(nn.Module):
    def __init__(
            self,
            dim,
            dim_head=64,
            heads=8,
            num_blocks=2,
    ):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(nn.ModuleList([
                IG_MSA(dim=dim, dim_head=dim_head, heads=heads),
                PreNorm(dim, FeedForward(dim=dim))
            ]))

    def forward(self, x, illu_fea):
        """
        x: [b,c,h,w]
        illu_fea: [b,c,h,w]
        return out: [b,c,h,w]
        """
        x = x.permute(0, 2, 3, 1)
        for (attn, ff) in self.blocks:
            x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
            x = ff(x) + x
        out = x.permute(0, 3, 1, 2)
        return out
1. 初始化参数
  • dim:输入特征的通道数。
  • dim_head:每个注意力头的维度,默认为64。
  • heads:注意力头的数量,默认为8。
  • num_blocks:堆叠的 IG_MSA 和前馈网络对的数量,默认为2。
2. 子模块
  • IG_MSA:光照引导的多头自注意力模块,它将光照特征融入到注意力计算过程中。
  • PreNormFeedForward:前馈网络模块,其中 PreNorm 是在进行前馈计算之前对输入进行层归一化操作。
3. 前向传播过程
  1. 维度调整:把输入特征 x 和光照特征 illu_fea[b,c,h,w] 调整为 [b,h,w,c],这样便于后续的注意力计算。
  2. 堆叠块处理:依次通过多个 IG_MSA 和前馈网络对,并且都使用了残差连接:
    • 先将特征输入到 IG_MSA 模块中,利用光照特征来引导注意力机制。
    • 接着将输出输入到前馈网络中进行非线性变换。
  3. 维度恢复:将处理后的特征从 [b,h,w,c] 重新调整回 [b,c,h,w]
核心组件:IG_MSA(光照引导多头自注意力)
class IG_MSA(nn.Module):
    def __init__(
            self,
            dim,
            dim_head=64,
            heads=8,
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
        self.proj = nn.Linear(dim_head * heads, dim, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
            GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
        )
        self.dim = dim

    def forward(self, x_in, illu_fea_trans):
        """
        x_in: [b,h,w,c]         # input_feature
        illu_fea: [b,h,w,c]         # mask shift? 为什么是 b, h, w, c?
        return out: [b,h,w,c]
        """
        b, h, w, c = x_in.shape
        x = x_in.reshape(b, h * w, c)
        q_inp = self.to_q(x)
        k_inp = self.to_k(x)
        v_inp = self.to_v(x)
        illu_attn = illu_fea_trans
        q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
                                 (q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))
        v = v * illu_attn
        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)
        attn = (k @ q.transpose(-2, -1))
        attn = attn * self.rescale
        attn = attn.softmax(dim=-1)
        x = attn @ v
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(b, h * w, self.num_heads * self.dim_head)
        out_c = self.proj(x).view(b, h, w, c)
        out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
            0, 3, 1, 2)).permute(0, 2, 3, 1)
        out = out_c + out_p

        return out
IG_MSA工作原理
  1. 查询(Q)、键(K)、值(V)生成:利用线性层将输入特征映射为Q、K、V。
  2. 光照引导:把光照特征与值(V)相乘,以此来调整注意力机制对不同区域的关注程度。
  3. 注意力计算
    • 先对Q和K进行L2归一化处理,然后计算它们的点积。
    • 引入可学习的缩放因子 rescale,对注意力得分进行调整。
    • 最后通过softmax函数得到最终的注意力权重。
  4. 特征聚合:用注意力权重对V进行加权求和。
  5. 位置编码:通过深度可分离卷积为特征添加位置信息。
  6. 输出整合:将注意力输出和位置编码输出相加,得到最终结果。
设计亮点
  1. 光照引导机制:借助光照特征来调制注意力机制,使模型在处理低光照图像时能够更好地利用光照信息。
  2. 归一化注意力:对Q和K进行归一化处理,让注意力计算更加稳定。
  3. 双重残差连接:在 IGAB 模块中使用了残差连接,有助于梯度的传播,使模型更容易训练。
  4. 位置感知:通过位置编码模块保留了图像的空间信息。

IGAB 模块通过将光照引导机制与自注意力机制相结合,能够有效处理低光照图像中的光照不均问题,同时保留图像的细节信息。这种设计让RetinexFormer网络在低光照图像增强任务中表现出色。

小结

RetinexFormer 类通过堆叠多个 RetinexFormer_Single_Stage 模块,结合光照估计和去噪处理,实现了低光照图像的增强。每个RetinexFormer_Single_Stage 模块包含光照估计器和去噪器,光照估计器用于估计图像的光照图,去噪器采用编码器 - 解码器架构对增强后的图像进行去噪处理。

5、总结

本文是Retinexformer这篇暗光增强文章的代码解读。该文章结合了 Retinex 理论与 Transformer 架构,采用单阶段 Retinex 架构和 IG_MSA 模块,能更好捕捉光照信息,且使用编码器 - 解码器架构保留图像细节。并且在 NTIRE 2024 低光照增强挑战赛中获得了第二名


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值