【LUT技术专题】LLFLUT代码讲解

本文是对LLFLUT技术的代码讲解,原文解读请看LLFLUT文章讲解

1、原文概要

LLFLUT的核心思路是利用拉普拉斯金字塔分解与重建,在低频分量进行全局色调调整,在高频分量逐步细化局部边缘细节。它的网络整体流程图如下所示:
在这里插入图片描述
整体流程比较复杂,分为2个阶段,第一个阶段是粗增强,会增强得到2个输出,第一个是大图的粗增强结果 I ˉ \bar{I} Iˉ以及下采样的增强结果 I ˉ LR \bar{I}_{\text{LR}} IˉLR。大图的粗增强结果会利用拉普拉斯分解出多个层,然后进行逐层的增强,底层还会使用canny算子来提取边缘,加入edge map的信息辅助,最终逐层向上融合得到增强结果 I ^ \hat{I} I^

2、代码结构

代码整体结构如下
在这里插入图片描述

主要关注模型部分的实现即可,即models子文件夹中的内容。

3 、核心代码模块

LLF_LUT.py 文件

整体网络结构如下所示:

class LLF_LUT(nn.Module):
    def __init__(self, config):
        super(LLF_LUT, self).__init__()
        self.transformer_config = config['transformer']
        self.filter_config = config['filter']
        self.LUT_config = config['LUT']
        self.pad_size = self.filter_config['low_freq_resolution']
        self.device = torch.device('cuda' if config['gpu_ids'] is not None else 'cpu')

        # define transformer model
        self.transformer = Spatial_Transformer(
            in_chans=self.transformer_config['input_channel'],
            embed_dim=self.transformer_config['embed_dim'],
            num_classes=self.transformer_config['num_classes'],
            out_chans=self.transformer_config['output_channel'],
            depths=self.transformer_config['depths'],
            num_heads=self.transformer_config['num_heads'],
            window_sizes=self.transformer_config['window_sizes'],
            back_RBs=self.transformer_config['back_RBs'],
            recon_type=self.transformer_config['recon_type']
        )

        # define Laplacian filter
        self.laplacian_filter = PPB(self.filter_config)
        self.pyramid = Lap_Pyramid_Conv(self.filter_config['num_lap'], self.filter_config['channels'], self.device)

        # define learnable LUTs
        self.LUT0 = Generator3DLUT_identity(dim=self.LUT_config['LUT_dim'])
        self.LUT1 = Generator3DLUT_zero(dim=self.LUT_config['LUT_dim'])
        self.LUT2 = Generator3DLUT_zero(dim=self.LUT_config['LUT_dim'])

        # Load TV_loss
        self.TV3 = TV_3D(dim=self.LUT_config['LUT_dim'])
        cuda = True if config['gpu_ids'] is not None else False
        Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
        self.TV3.weight_r = self.TV3.weight_r.type(Tensor)
        self.TV3.weight_g = self.TV3.weight_g.type(Tensor)
        self.TV3.weight_b = self.TV3.weight_b.type(Tensor)


    def forward(self, input_image):
        B, C, H, W = input_image.size()

        pyr_input = self.pyramid.pyramid_decom(input_image)
        pyr_input_low = pyr_input[-1]

        pred_weight, pred_weight_point = self.transformer(pyr_input_low)

        enhanced_low0 = self.LUT0(pyr_input_low)
        enhanced_low1 = self.LUT1(pyr_input_low)
        enhanced_low2 = self.LUT2(pyr_input_low)
        enhanced_full0 = self.LUT0(input_image)
        enhanced_full1 = self.LUT1(input_image)
        enhanced_full2 = self.LUT2(input_image)

        enhanced_low = pred_weight[:, :3] * enhanced_low0 + pred_weight[:, 3:6] * enhanced_low1 + \
                       pred_weight[:, 6:9] * enhanced_low2
        enhanced_full = pred_weight_point[:, 0] * enhanced_full0 + pred_weight_point[:, 1] * enhanced_full1 + \
                        pred_weight_point[:, 2] * enhanced_full2

        # remapping function  
        gauss_enhanced_full = self.pyramid.gauss_decom(enhanced_full)
        pyr_enhanced_full = self.pyramid.pyramid_decom(enhanced_full)
        pyr_reconstruct_results = self.laplacian_filter(gauss_enhanced_full, pyr_enhanced_full, enhanced_low)
        enhanced_image = self.pyramid.pyramid_recons(pyr_reconstruct_results)

        # define smooth loss and tv loss
        weights_norm = torch.mean(pred_weight ** 2)
        tv0, mn0 = self.TV3(self.LUT0)
        tv1, mn1 = self.TV3(self.LUT1)
        tv2, mn2 = self.TV3(self.LUT2)
        tv_cons = tv0 + tv1 + tv2
        mn_cons = mn0 + mn1 + mn2

        loss_smooth = weights_norm + tv_cons if self.LUT_config['lambda_smooth'] > 0 else 0
        loss_mono = mn_cons if self.LUT_config['lambda_mono'] > 0 else 0
        loss_LUT = self.LUT_config['lambda_smooth'] * loss_smooth + self.LUT_config['lambda_mono'] * loss_mono

        return enhanced_image, pyr_reconstruct_results, loss_LUT

根据forward的推理顺序来讲,首先输入图像经过pyramid_decom操作得到金字塔,对应于

pyr_input = self.pyramid.pyramid_decom(input_image)

具体操作如下所示:

 def pyramid_decom(self, img):
     current = img
     pyr = []
     for _ in range(self.num_high):
         filtered = self.conv_gauss(current, self.kernel)
         down = self.downsample(filtered)
         up = self.upsample(down)
         if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
             up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
         diff = current - up
         pyr.append(diff)
         current = down
     pyr.append(current)
     return pyr

可见经过这个函数可生成最后一层的低频和每一层的高频信息,conv_gauss是高斯滤波过程,downsample的方法直接选的第一个点,相当于nearest。
后续计算为:

pred_weight, pred_weight_point = self.transformer(pyr_input_low)

将底层分辨率拿过来送入transformer,这与讲解过程匹配,得到pixel-level和image-level的weight。
后续使用LUT进行增强并利用weight进行加权。

enhanced_low0 = self.LUT0(pyr_input_low)
enhanced_low1 = self.LUT1(pyr_input_low)
enhanced_low2 = self.LUT2(pyr_input_low)
enhanced_full0 = self.LUT0(input_image)
enhanced_full1 = self.LUT1(input_image)
enhanced_full2 = self.LUT2(input_image)
enhanced_low = pred_weight[:, :3] * enhanced_low0 + pred_weight[:, 3:6] * enhanced_low1 + \
               pred_weight[:, 6:9] * enhanced_low2
enhanced_full = pred_weight_point[:, 0] * enhanced_full0 + pred_weight_point[:, 1] * enhanced_full1 + \
                pred_weight_point[:, 2] * enhanced_full2

自此完成了一阶段的粗增强,后续是二阶段的拉普拉斯增强。

        # remapping function  
        gauss_enhanced_full = self.pyramid.gauss_decom(enhanced_full)
        pyr_enhanced_full = self.pyramid.pyramid_decom(enhanced_full)

        pyr_reconstruct_results = self.laplacian_filter(gauss_enhanced_full, pyr_enhanced_full, enhanced_low)
        # 金字塔重建 通过之前的一些细节
        enhanced_image = self.pyramid.pyramid_recons(pyr_reconstruct_results)

laplacian_filter的输入有通过3DLUT增强的图的分解 + 通过3DLUT增强的图的特征金字塔 + 底层被增强的图。

PPB.py 文件

该文件中放着二阶段增强中使用到的关键模块。

class PPB(nn.Module):
    def __init__(self, config):
        super(PPB, self).__init__()
        num_residual_blocks = config['num_residual_blocks']
        num_lap = config['num_lap']
        self.block = HFBlock(num_residual_blocks, num_lap)

    def forward(self, gauss_input, pyr_input, enhanced_low):
        low_freq_gray = kornia.color.rgb_to_grayscale(enhanced_low)
        edge_map = kornia.filters.canny(low_freq_gray)[1]
        low_freq_up = nn.functional.interpolate(enhanced_low, size=(pyr_input[-2].shape[2], pyr_input[-2].shape[3]))
        # gauss_input[-1] or pyr_input[-1]
        gauss_input_up = nn.functional.interpolate(pyr_input[-1],
                                                   size=(pyr_input[-2].shape[2], pyr_input[-2].shape[3]))
        edge_map_up = nn.functional.interpolate(edge_map, size=(pyr_input[-2].shape[2], pyr_input[-2].shape[3]))

        concat_imgs = torch.cat([pyr_input[-2], edge_map_up, low_freq_up, gauss_input_up], 1)
        pyr_reconstruct_results = self.block(concat_imgs, gauss_input, pyr_input, enhanced_low)
        return pyr_reconstruct_results

可见首先使用到了增强图的倒数第二层细节+底层预测的边缘resize到第二层的+底层增强图resize到第二层+倒数第一层resize下来的层。HFBlock是增强的具体过程:

class HFBlock(nn.Module):
    def __init__(self, num_residual_blocks, lap_layer=3):
        super(HFBlock, self).__init__()
        self.high_freq_block = None
        self.lap_layer = lap_layer

        model = [nn.Conv2d(in_channels=10, out_channels=64, kernel_size=1, padding=0, stride=1,
                           groups=1,
                           bias=True),
                 nn.LeakyReLU(negative_slope=0.2, inplace=True)]

        for _ in range(num_residual_blocks):
            # model += [ResidualBlock_1(64)]
            model += [ResidualBlock(64)]

        model += [nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, padding=0, stride=1,
                            groups=1,
                            bias=True)]
        self.model = nn.Sequential(*model)

        self.high_freq_blocks = nn.ModuleList()
        for i in range(lap_layer - 1):
            high_freq_block = nn.Sequential(
                nn.Conv2d(in_channels=9, out_channels=16, kernel_size=1, padding=0, stride=1,
                          groups=1,
                          bias=True),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1,
                          groups=16,
                          bias=True),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(in_channels=16, out_channels=2, kernel_size=1, padding=0, stride=1,
                          groups=1,
                          bias=True))
            self.high_freq_blocks.append(high_freq_block)
            # setattr(self, 'high_freq_block_{}'.format(str(i)), high_freq_block)
            # self.add_module('high_freq_block_{}'.format(str(i)), high_freq_block)

    def forward(self, high_with_low, gauss_lut, pyr_lut, enhanced_low):
        pyr_reconstruct_list = []
        fact_sigma = self.model(high_with_low)
        fact = fact_sigma[:, 0, :, :]
        fact = fact.unsqueeze(1)
        sigma = fact_sigma[:, 1, :, :]
        sigma = sigma.unsqueeze(1)
        
        pyr_reconstruct_ori = remapping(gauss_lut[-2], pyr_lut[-2], sigma, fact, 10)

        pyr_reconstruct = pyr_reconstruct_ori
        up_enhanced = enhanced_low

        for i in range(self.lap_layer - 1):
            up = nn.functional.interpolate(up_enhanced, size=(pyr_lut[-2 - i].shape[2], pyr_lut[-2 - i].shape[3]))
            up_enhanced = up + pyr_reconstruct
            up_enhanced = nn.functional.interpolate(up_enhanced,
                                                    size=(pyr_lut[-3 - i].shape[2], pyr_lut[-3 - i].shape[3]))
            pyr_reconstruct = nn.functional.interpolate(pyr_reconstruct,
                                                        size=(pyr_lut[-3 - i].shape[2], pyr_lut[-3 - i].shape[3]))
            # self.high_freq_block = getattr(self, 'high_freq_block_{}'.format(str(i)))
            concat_high = torch.cat([up_enhanced, pyr_lut[-3 - i], pyr_reconstruct], 1)
            fact_sigma = self.high_freq_blocks[i](concat_high)
            fact = fact_sigma[:, 0, :, :]
            fact = fact.unsqueeze(1)
            sigma = fact_sigma[:, 1, :, :]
            sigma = sigma.unsqueeze(1)
            pyr_reconstruct = remapping(gauss_lut[-3 - i], pyr_lut[-3 - i], sigma, fact, 10)

            setattr(self, 'pyr_reconstruct_{}'.format(str(i)), pyr_reconstruct)

        for i in reversed(range(self.lap_layer - 1)):
            pyr_reconstruct = getattr(self, 'pyr_reconstruct_{}'.format(str(i)))
            pyr_reconstruct_list.append(pyr_reconstruct)

        pyr_reconstruct_list.append(pyr_reconstruct_ori)
        pyr_reconstruct_list.append(enhanced_low)
        return pyr_reconstruct_list

通过推理出的各个fact和sigma自适应完成remapping操作,得到增强后的拉普拉斯高频。remapping实现如下:

def remapping(img_gauss, img_lpr, sigma, fact, N):
    # 低频 高频  
    discretisation = torch.linspace(0, 1, N)
    discretisation_step = discretisation[1]
    for ref in discretisation:
        img_remap = fact * (img_lpr - ref) * torch.exp(
            -(img_lpr - ref) * (img_lpr - ref) * (2 * sigma * sigma))
        img_lpr = img_lpr + (torch.abs(img_gauss - ref) < discretisation_step) * img_remap * (
                1 - torch.abs(img_gauss - ref) / discretisation_step)
    return img_lpr

这个remapping操作分为了discretisation这些桶,在桶里完成全像素区间的一个增强,这里博主认为可以分为3种类型去形象分析。

  1. 平坦区:平坦区无高频信息(img_lpr - ref ≈ 0),函数通过 “增益项归零 + 权重不影响” 实现零增强,保证平坦区干净无伪影。同时fact和sigma也会在这个过程中去控制img_lpr的增强。
  2. 边缘区:边缘区的高频偏差是 “有用信号”(小偏差),函数通过 “高斯加权保留增益 + 掩码覆盖 + 线性权重平滑” 实现精准、自然的增强。边缘的 delta 小,高斯加权因子≈1,增益项完整保留;fact会比较高,这样通过增强去提升它。多个相邻桶的增强叠加,边缘增强效果更显著且平滑。
  3. 噪声区:噪声区的高频偏差是 “无效信号”(大偏差),函数通过 “高斯强衰减 + 掩码过滤 + fact 抑制” 三重机制,实现噪声零增强。通过设计的fact、sigma第一步将remap进行缩小,然后还存在一个线性衰减来控制权重(1 - torch.abs(img_gauss - ref) / discretisation_step)。

3、总结

LLF-LUT++ 实现4K 图像 13.50ms / 张实时处理,参数 717K(比 LLF-LUT 少 14K);在两大基准数据集上,480p 和 4K 分辨率的定量(PSNR 最高提升 2.64dB)与定性表现均优于现有 SOTA 方法;它融合全局与局部增强的设计,有效解决细节丢失和局部效果差的问题。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值