本文是对LLFLUT技术的代码讲解,原文解读请看LLFLUT文章讲解。
1、原文概要
LLFLUT的核心思路是利用拉普拉斯金字塔分解与重建,在低频分量进行全局色调调整,在高频分量逐步细化局部边缘细节。它的网络整体流程图如下所示:

整体流程比较复杂,分为2个阶段,第一个阶段是粗增强,会增强得到2个输出,第一个是大图的粗增强结果
I
ˉ
\bar{I}
Iˉ以及下采样的增强结果
I
ˉ
LR
\bar{I}_{\text{LR}}
IˉLR。大图的粗增强结果会利用拉普拉斯分解出多个层,然后进行逐层的增强,底层还会使用canny算子来提取边缘,加入edge map的信息辅助,最终逐层向上融合得到增强结果
I
^
\hat{I}
I^。
2、代码结构
代码整体结构如下:

主要关注模型部分的实现即可,即models子文件夹中的内容。
3 、核心代码模块
LLF_LUT.py 文件
整体网络结构如下所示:
class LLF_LUT(nn.Module):
def __init__(self, config):
super(LLF_LUT, self).__init__()
self.transformer_config = config['transformer']
self.filter_config = config['filter']
self.LUT_config = config['LUT']
self.pad_size = self.filter_config['low_freq_resolution']
self.device = torch.device('cuda' if config['gpu_ids'] is not None else 'cpu')
# define transformer model
self.transformer = Spatial_Transformer(
in_chans=self.transformer_config['input_channel'],
embed_dim=self.transformer_config['embed_dim'],
num_classes=self.transformer_config['num_classes'],
out_chans=self.transformer_config['output_channel'],
depths=self.transformer_config['depths'],
num_heads=self.transformer_config['num_heads'],
window_sizes=self.transformer_config['window_sizes'],
back_RBs=self.transformer_config['back_RBs'],
recon_type=self.transformer_config['recon_type']
)
# define Laplacian filter
self.laplacian_filter = PPB(self.filter_config)
self.pyramid = Lap_Pyramid_Conv(self.filter_config['num_lap'], self.filter_config['channels'], self.device)
# define learnable LUTs
self.LUT0 = Generator3DLUT_identity(dim=self.LUT_config['LUT_dim'])
self.LUT1 = Generator3DLUT_zero(dim=self.LUT_config['LUT_dim'])
self.LUT2 = Generator3DLUT_zero(dim=self.LUT_config['LUT_dim'])
# Load TV_loss
self.TV3 = TV_3D(dim=self.LUT_config['LUT_dim'])
cuda = True if config['gpu_ids'] is not None else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
self.TV3.weight_r = self.TV3.weight_r.type(Tensor)
self.TV3.weight_g = self.TV3.weight_g.type(Tensor)
self.TV3.weight_b = self.TV3.weight_b.type(Tensor)
def forward(self, input_image):
B, C, H, W = input_image.size()
pyr_input = self.pyramid.pyramid_decom(input_image)
pyr_input_low = pyr_input[-1]
pred_weight, pred_weight_point = self.transformer(pyr_input_low)
enhanced_low0 = self.LUT0(pyr_input_low)
enhanced_low1 = self.LUT1(pyr_input_low)
enhanced_low2 = self.LUT2(pyr_input_low)
enhanced_full0 = self.LUT0(input_image)
enhanced_full1 = self.LUT1(input_image)
enhanced_full2 = self.LUT2(input_image)
enhanced_low = pred_weight[:, :3] * enhanced_low0 + pred_weight[:, 3:6] * enhanced_low1 + \
pred_weight[:, 6:9] * enhanced_low2
enhanced_full = pred_weight_point[:, 0] * enhanced_full0 + pred_weight_point[:, 1] * enhanced_full1 + \
pred_weight_point[:, 2] * enhanced_full2
# remapping function
gauss_enhanced_full = self.pyramid.gauss_decom(enhanced_full)
pyr_enhanced_full = self.pyramid.pyramid_decom(enhanced_full)
pyr_reconstruct_results = self.laplacian_filter(gauss_enhanced_full, pyr_enhanced_full, enhanced_low)
enhanced_image = self.pyramid.pyramid_recons(pyr_reconstruct_results)
# define smooth loss and tv loss
weights_norm = torch.mean(pred_weight ** 2)
tv0, mn0 = self.TV3(self.LUT0)
tv1, mn1 = self.TV3(self.LUT1)
tv2, mn2 = self.TV3(self.LUT2)
tv_cons = tv0 + tv1 + tv2
mn_cons = mn0 + mn1 + mn2
loss_smooth = weights_norm + tv_cons if self.LUT_config['lambda_smooth'] > 0 else 0
loss_mono = mn_cons if self.LUT_config['lambda_mono'] > 0 else 0
loss_LUT = self.LUT_config['lambda_smooth'] * loss_smooth + self.LUT_config['lambda_mono'] * loss_mono
return enhanced_image, pyr_reconstruct_results, loss_LUT
根据forward的推理顺序来讲,首先输入图像经过pyramid_decom操作得到金字塔,对应于
pyr_input = self.pyramid.pyramid_decom(input_image)
具体操作如下所示:
def pyramid_decom(self, img):
current = img
pyr = []
for _ in range(self.num_high):
filtered = self.conv_gauss(current, self.kernel)
down = self.downsample(filtered)
up = self.upsample(down)
if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
diff = current - up
pyr.append(diff)
current = down
pyr.append(current)
return pyr
可见经过这个函数可生成最后一层的低频和每一层的高频信息,conv_gauss是高斯滤波过程,downsample的方法直接选的第一个点,相当于nearest。
后续计算为:
pred_weight, pred_weight_point = self.transformer(pyr_input_low)
将底层分辨率拿过来送入transformer,这与讲解过程匹配,得到pixel-level和image-level的weight。
后续使用LUT进行增强并利用weight进行加权。
enhanced_low0 = self.LUT0(pyr_input_low)
enhanced_low1 = self.LUT1(pyr_input_low)
enhanced_low2 = self.LUT2(pyr_input_low)
enhanced_full0 = self.LUT0(input_image)
enhanced_full1 = self.LUT1(input_image)
enhanced_full2 = self.LUT2(input_image)
enhanced_low = pred_weight[:, :3] * enhanced_low0 + pred_weight[:, 3:6] * enhanced_low1 + \
pred_weight[:, 6:9] * enhanced_low2
enhanced_full = pred_weight_point[:, 0] * enhanced_full0 + pred_weight_point[:, 1] * enhanced_full1 + \
pred_weight_point[:, 2] * enhanced_full2
自此完成了一阶段的粗增强,后续是二阶段的拉普拉斯增强。
# remapping function
gauss_enhanced_full = self.pyramid.gauss_decom(enhanced_full)
pyr_enhanced_full = self.pyramid.pyramid_decom(enhanced_full)
pyr_reconstruct_results = self.laplacian_filter(gauss_enhanced_full, pyr_enhanced_full, enhanced_low)
# 金字塔重建 通过之前的一些细节
enhanced_image = self.pyramid.pyramid_recons(pyr_reconstruct_results)
laplacian_filter的输入有通过3DLUT增强的图的分解 + 通过3DLUT增强的图的特征金字塔 + 底层被增强的图。
PPB.py 文件
该文件中放着二阶段增强中使用到的关键模块。
class PPB(nn.Module):
def __init__(self, config):
super(PPB, self).__init__()
num_residual_blocks = config['num_residual_blocks']
num_lap = config['num_lap']
self.block = HFBlock(num_residual_blocks, num_lap)
def forward(self, gauss_input, pyr_input, enhanced_low):
low_freq_gray = kornia.color.rgb_to_grayscale(enhanced_low)
edge_map = kornia.filters.canny(low_freq_gray)[1]
low_freq_up = nn.functional.interpolate(enhanced_low, size=(pyr_input[-2].shape[2], pyr_input[-2].shape[3]))
# gauss_input[-1] or pyr_input[-1]
gauss_input_up = nn.functional.interpolate(pyr_input[-1],
size=(pyr_input[-2].shape[2], pyr_input[-2].shape[3]))
edge_map_up = nn.functional.interpolate(edge_map, size=(pyr_input[-2].shape[2], pyr_input[-2].shape[3]))
concat_imgs = torch.cat([pyr_input[-2], edge_map_up, low_freq_up, gauss_input_up], 1)
pyr_reconstruct_results = self.block(concat_imgs, gauss_input, pyr_input, enhanced_low)
return pyr_reconstruct_results
可见首先使用到了增强图的倒数第二层细节+底层预测的边缘resize到第二层的+底层增强图resize到第二层+倒数第一层resize下来的层。HFBlock是增强的具体过程:
class HFBlock(nn.Module):
def __init__(self, num_residual_blocks, lap_layer=3):
super(HFBlock, self).__init__()
self.high_freq_block = None
self.lap_layer = lap_layer
model = [nn.Conv2d(in_channels=10, out_channels=64, kernel_size=1, padding=0, stride=1,
groups=1,
bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True)]
for _ in range(num_residual_blocks):
# model += [ResidualBlock_1(64)]
model += [ResidualBlock(64)]
model += [nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, padding=0, stride=1,
groups=1,
bias=True)]
self.model = nn.Sequential(*model)
self.high_freq_blocks = nn.ModuleList()
for i in range(lap_layer - 1):
high_freq_block = nn.Sequential(
nn.Conv2d(in_channels=9, out_channels=16, kernel_size=1, padding=0, stride=1,
groups=1,
bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1,
groups=16,
bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=1, padding=0, stride=1,
groups=1,
bias=True))
self.high_freq_blocks.append(high_freq_block)
# setattr(self, 'high_freq_block_{}'.format(str(i)), high_freq_block)
# self.add_module('high_freq_block_{}'.format(str(i)), high_freq_block)
def forward(self, high_with_low, gauss_lut, pyr_lut, enhanced_low):
pyr_reconstruct_list = []
fact_sigma = self.model(high_with_low)
fact = fact_sigma[:, 0, :, :]
fact = fact.unsqueeze(1)
sigma = fact_sigma[:, 1, :, :]
sigma = sigma.unsqueeze(1)
pyr_reconstruct_ori = remapping(gauss_lut[-2], pyr_lut[-2], sigma, fact, 10)
pyr_reconstruct = pyr_reconstruct_ori
up_enhanced = enhanced_low
for i in range(self.lap_layer - 1):
up = nn.functional.interpolate(up_enhanced, size=(pyr_lut[-2 - i].shape[2], pyr_lut[-2 - i].shape[3]))
up_enhanced = up + pyr_reconstruct
up_enhanced = nn.functional.interpolate(up_enhanced,
size=(pyr_lut[-3 - i].shape[2], pyr_lut[-3 - i].shape[3]))
pyr_reconstruct = nn.functional.interpolate(pyr_reconstruct,
size=(pyr_lut[-3 - i].shape[2], pyr_lut[-3 - i].shape[3]))
# self.high_freq_block = getattr(self, 'high_freq_block_{}'.format(str(i)))
concat_high = torch.cat([up_enhanced, pyr_lut[-3 - i], pyr_reconstruct], 1)
fact_sigma = self.high_freq_blocks[i](concat_high)
fact = fact_sigma[:, 0, :, :]
fact = fact.unsqueeze(1)
sigma = fact_sigma[:, 1, :, :]
sigma = sigma.unsqueeze(1)
pyr_reconstruct = remapping(gauss_lut[-3 - i], pyr_lut[-3 - i], sigma, fact, 10)
setattr(self, 'pyr_reconstruct_{}'.format(str(i)), pyr_reconstruct)
for i in reversed(range(self.lap_layer - 1)):
pyr_reconstruct = getattr(self, 'pyr_reconstruct_{}'.format(str(i)))
pyr_reconstruct_list.append(pyr_reconstruct)
pyr_reconstruct_list.append(pyr_reconstruct_ori)
pyr_reconstruct_list.append(enhanced_low)
return pyr_reconstruct_list
通过推理出的各个fact和sigma自适应完成remapping操作,得到增强后的拉普拉斯高频。remapping实现如下:
def remapping(img_gauss, img_lpr, sigma, fact, N):
# 低频 高频
discretisation = torch.linspace(0, 1, N)
discretisation_step = discretisation[1]
for ref in discretisation:
img_remap = fact * (img_lpr - ref) * torch.exp(
-(img_lpr - ref) * (img_lpr - ref) * (2 * sigma * sigma))
img_lpr = img_lpr + (torch.abs(img_gauss - ref) < discretisation_step) * img_remap * (
1 - torch.abs(img_gauss - ref) / discretisation_step)
return img_lpr
这个remapping操作分为了discretisation这些桶,在桶里完成全像素区间的一个增强,这里博主认为可以分为3种类型去形象分析。
- 平坦区:平坦区无高频信息(img_lpr - ref ≈ 0),函数通过 “增益项归零 + 权重不影响” 实现零增强,保证平坦区干净无伪影。同时fact和sigma也会在这个过程中去控制img_lpr的增强。
- 边缘区:边缘区的高频偏差是 “有用信号”(小偏差),函数通过 “高斯加权保留增益 + 掩码覆盖 + 线性权重平滑” 实现精准、自然的增强。边缘的 delta 小,高斯加权因子≈1,增益项完整保留;fact会比较高,这样通过增强去提升它。多个相邻桶的增强叠加,边缘增强效果更显著且平滑。
- 噪声区:噪声区的高频偏差是 “无效信号”(大偏差),函数通过 “高斯强衰减 + 掩码过滤 + fact 抑制” 三重机制,实现噪声零增强。通过设计的fact、sigma第一步将remap进行缩小,然后还存在一个线性衰减来控制权重(1 - torch.abs(img_gauss - ref) / discretisation_step)。
3、总结
LLF-LUT++ 实现4K 图像 13.50ms / 张实时处理,参数 717K(比 LLF-LUT 少 14K);在两大基准数据集上,480p 和 4K 分辨率的定量(PSNR 最高提升 2.64dB)与定性表现均优于现有 SOTA 方法;它融合全局与局部增强的设计,有效解决细节丢失和局部效果差的问题。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。
1258

被折叠的 条评论
为什么被折叠?



