专题介绍
Look-Up Table(查找表,LUT)是一种数据结构(也可以理解为字典),通过输入的key来查找到对应的value。其优势在于无需计算过程,不依赖于GPU、NPU等特殊硬件,本质就是一种内存换算力的思想。LUT在图像处理中是比较常见的操作,如Gamma映射,3D CLUT等。
近些年,LUT技术已被用于深度学习领域,由SR-LUT启发性地提出了模型训练+LUT推理的新范式。
本专题旨在跟进和解读LUT技术的发展趋势,为读者分享最全最新的LUT方法,欢迎一起探讨交流,对该专题感兴趣的读者可以订阅本专栏第一时间看到更新。
系列文章如下:
【1】SR-LUT
本文将从头开始对MuLUT: Cooperating Multiple Look-Up Tables
for Efficient Image Super-Resolution,这篇轻量超分算法进行讲解,本篇文章是SRLUT文章的改进版本,建议对SRLUT不熟悉的读者可以先阅读 SRLUT讲解。参考资料如下:
[1]. MuLUT论文地址
[2]. MuLUT代码地址
[3]. MuLUT论文笔记
一、研究背景
MuLUT提出的原因在于:SRLUT的性能比较弱,感受野只有3 * 3,用于超分的效果不强,也无法应用于需要更大感受野的其他图像处理任务,例如Demosaic。作者希望在保持LUT轻量化、资源消耗少的前提下,去提升网络的效果和方法的通用性,扩展到更多的任务。
二、MuLUT方法
MuLUT提出了多LUT串联和并联的方法,来提升感受野,其单一LUT流程与SRLUT基本一致,只不过因为MuLUT使用了更多的LUT,所以MuLUT提出了一个LUT finetune的策略来减小LUT量化后的误差:
2.1 训练
作者设计了下图的网络结构,图中右下角可以看到,其横向(Hierarchical Indexing)表示是一个串联的索引结构,即堆叠多个可以转换为一个LUT的卷积块,纵向(Complementary Indexing)表示是一个并联的索引结构,因为是并联,所以需要将三个不同Block(MuLUT-S、D、Y)计算的结构进行相加合并,从图中还可以看到,S、D、Y只是选取查询的点不同而已,转换后都是一个4D的LUT,串联完最后使用一个pixelshuffle进行超分即可,因为串联的结构本身是一个三个block并联的,因此最后将他们相加并进行平均得到最终结果。
接下来针对上面提到的流程,分点进行细节的讲解:
1)并联的计算:并联用到的几个block结构如下图所示。
如果我们用卷积的视角去看的话,S是一个2x2的卷积,这与SRLUT一致,D是一个2x2的dilate卷积,跳着选点,Y的实现会比较麻烦一些,这里作者给出了一个实现方法,如下图所示。
可见,这个实现方法是通过Unfold和采样来实现,博主觉得这种方法会比较麻烦且难以理解,在后续的代码讲解部分会给出自己的实现方法。然后我们也可以用以下的公式来说明这个计算过程。
V
=
(
L
U
T
S
[
I
0
]
[
I
1
]
[
I
3
]
[
I
4
]
+
L
U
T
D
[
I
0
]
[
I
2
]
[
I
6
]
[
I
8
]
+
L
U
T
Y
[
(
I
0
]
[
I
4
]
[
I
5
]
[
I
7
]
)
/
3
,
\mathbf{V}=\left(L U T_{S}\left[I_{0}\right]\left[I_{1}\right]\left[I_{3}\right]\left[I_{4}\right]+L U T_{D}\left[I_{0}\right]\left[I_{2}\right]\left[I_{6}\right]\left[I_{8}\right]+L U T_{Y}\left[\left(I_{0}\right]\left[I_{4}\right]\left[I_{5}\right]\left[I_{7}\right]\right) / 3,\right.
V=(LUTS[I0][I1][I3][I4]+LUTD[I0][I2][I6][I8]+LUTY[(I0][I4][I5][I7])/3,
可以看到就是S、D、Y三者相加后求平均,这个平均过程作者也进行了一些修改,在下面会进行讲解。
2)相加平均的处理:作者这里使用了一个叫Re-index的策略,因为LUT需要我们的网络始终是8bit的状态,多个block的结果相加后平均才可以回到8bit,但是这里会存在浮点数的情况,为了解决这个问题,作者使用了Re-index,实际是比较简单的,如下图所示,作者会将计算后的浮点数,进行一个量化,但是保存浮点数的梯度,此被称为STE(Straight-through estimator),在量化训练中经常使用到,用这个策略让网络去适应这个量化损失。
3)串联的计算:串联的实现如下图所示。
配合上图和下式可以发现,作者是用LUT(1)的输出作为LUT(2)的索引,从而完成了串联。
V
=
L
U
T
(
2
)
[
L
U
T
(
1
)
[
I
∗
]
]
[
L
U
T
(
1
)
[
I
∗
]
]
[
L
U
T
(
1
)
[
I
∗
]
]
[
L
U
T
(
1
)
[
I
∗
]
]
.
\mathbf{V}=L U T^{(2)}\left[L U T^{(1)}\left[I_{*}\right]\right]\left[L U T^{(1)}\left[I_{*}\right]\right]\left[L U T^{(1)}\left[I_{*}\right]\right]\left[L U T^{(1)}\left[I_{*}\right]\right] .
V=LUT(2)[LUT(1)[I∗]][LUT(1)[I∗]][LUT(1)[I∗]][LUT(1)[I∗]].
串联LUT显然跟多个串联卷积是一样的,通过串联可以提升网络的感受野。
4)MuLUT感受野的计算:首先不考虑旋转,一个MuLUT-S、D、Y相加是RF=3x3,在stack一个MuLUT-S、D、Y Block,RF=5x5,然后加上3次旋转来提升感受野,就可以得到论文中给出的RF=9x9了。
2.2 转换
此步跟SRLUT一致,是将训练好的每个Block计算过程写入到一个固定的4D LUT中,前面我们知道步骤(1)中训练的网络感受野为2x2,因此针对于一个8bit的数据,4D LUT的4个输入存在256 * 256 * 256 * 256种可能性,将所有的可能性进行组合,可以得到它们相同大小的结果,同样以8bit去存储,在超分2倍的情况下,我们有2 * 2 * 256 * 256 * 256 * 256个数值,这些数值占用的存储大小可以计算得到256 * 256 * 256 * 256 * 4B=16GB,为了减小这部分大小,会对SRLUT的输入进行采样,假设我们均匀采样17个点,即只选择0、16、32…255,可以将整个LUT表的大小减小至17 * 17 * 17 * 17 * 4B=326.2KB,论文实际使用了17个点去进行存储。
而MuLUT有多个Block,我们需要针对每个Block进行转换为LUT表,然而在进行推理时会发现,LUT表是存在误差的,因为转换LUT时,我们使用了均匀采样,损失了部分精度,为了减小这部分精度的损失,作者提出了一个LUT Finetune的策略,此在下面的第三步微调中会讲解。
2.3 微调
SRLUT由于只有一个LUT,不进行微调去解决量化误差,训测模型性能差距不大,但是MuLUT存在多个LUT串并联的情况,不处理会有较大误差。作者提出的微调过程如下:
- 首先把训练好的网络进行第2步转换为LUT;
- 然后将这些LUT视为一个trainable权重,因为4-simplex插值过程是可微的,因此我们训练它,相当于第一步训练的卷积的权重;
- 第三步训练的是LUT表。
2.4 测试
利用步骤(3)微调后得到的LUT表可以进行测试,这里使用到的插值方式在SR-LUT里有讲解。
三、实验结果
定量的实验结果显示:MuLUT-SDY以及SDY-X2对比SRLUT有明显提升,当然在效果上会比传统的插值方法要更好,但是对比DNN的方法效果还是差的,当然它主打的是轻量级的超分方法。
定性的实验结果显示,该方法还是比传统插值要更好,逼近FSRCNN,初期的DNN的方法。
与多个方法进行对比性能消耗,性价比高,只用了很少的Energy Cost,不发热然后达到与DNN方法相近的效果,耗时的比较也是类似的结论,夹在传统方法和DNN之间。
接下来作者进行了消融实验:
1)并联结构:
从表中可以看到,对比了只有SRLUT只有S,MuLUT只有S和D没有Y,以及MuLUT完整体,SDY的效果是最好的。
2)re-index以及串联结构:
前两行对比了re-index策略,有re-index提升了精度,后三行对比了串联的效果,串联越多,感受野越大,效果越好。
3)LUT aware finetune策略:
对比了SRLUT4bit ft前后,有提升,3bit同理,因为量化损失大,提升更大,MULUT也是同样的结论,验证了有效性。
四、代码
同样的,实现代码分为四个部分,分别是训练、转表、微调、推理,在如下所示得sr文件夹中。
- 训练:这个步骤中,关注1_train_model.py文件,首先这个文件会从./common/option.py文件中导入模型SRNets。
class BaseOptions():
def __init__(self, debug=False):
self.initialized = False
self.debug = debug
def initialize(self, parser):
# experiment specifics
parser.add_argument('--model', type=str, default='SRNets')
parser.add_argument('--task', '-t', type=str, default='sr')
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--sigma', '-s', type=int, default=25, help="noise level")
parser.add_argument('--qf', '-q', type=int, default=20, help="deblocking quality factor")
parser.add_argument('--nf', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--stages', type=int, default=2, help="stages of MuLUT")
parser.add_argument('--modes', type=str, default='sdy', help="sampling modes to use in every stage")
parser.add_argument('--interval', type=int, default=4, help='N bit uniform sampling')
parser.add_argument('--modelRoot', type=str, default='../models')
parser.add_argument('--expDir', '-e', type=str, default='', help="experiment folder")
parser.add_argument('--load_from_opt_file', action='store_true', default=False)
parser.add_argument('--debug', default=False, action='store_true')
self.initialized = True
return parser
而SRNets其实就是stack了多个SRNet结构,在sr/model.py文件中,如下所示,可见这个就是论文前面提到得串联结构,这里作者会用stage去表示,可见默认设置是stage=2,与论文SDYx2也是一致的。
class SRNets(nn.Module):
""" A LUT-convertable SR network with configurable stages and patterns. """
def __init__(self, nf=64, scale=4, modes=['s', 'd', 'y'], stages=2):
super(SRNets, self).__init__()
for s in range(stages): # 2-stage
if (s + 1) == stages:
upscale = scale
flag = "N"
else:
upscale = None
flag = "1"
for mode in modes:
self.add_module("s{}_{}".format(str(s + 1), mode),
SRNet("{}x{}".format(mode.upper(), flag), nf=nf, upscale=upscale))
# print_network(self)
def forward(self, x, stage, mode):
key = "s{}_{}".format(str(stage), mode)
module = getattr(self, key)
return module(x)
之后就是SRNet的实现,在common/network.py中,如下所示。
class SRNet(nn.Module):
""" Wrapper of a generalized (spatial-wise) MuLUT block.
By specifying the unfolding patch size and pixel indices,
arbitrary sampling pattern can be implemented.
"""
def __init__(self, mode, nf=64, upscale=None, dense=True):
super(SRNet, self).__init__()
self.mode = mode
if 'x1' in mode:
assert upscale is None
if mode == 'Sx1':
self.model = MuLUTUnit('2x2', nf, upscale=1, dense=dense)
self.K = 2
self.S = 1
elif mode == 'SxN':
self.model = MuLUTUnit('2x2', nf, upscale=upscale, dense=dense)
self.K = 2
self.S = upscale
elif mode == 'Dx1':
self.model = MuLUTUnit('2x2d', nf, upscale=1, dense=dense)
self.K = 3
self.S = 1
elif mode == 'DxN':
self.model = MuLUTUnit('2x2d', nf, upscale=upscale, dense=dense)
self.K = 3
self.S = upscale
elif mode == 'Yx1':
self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense)
self.K = 3
self.S = 1
elif mode == 'YxN':
self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense)
self.K = 3
self.S = upscale
elif mode == 'Ex1':
self.model = MuLUTUnit('2x2d3', nf, upscale=1, dense=dense)
self.K = 4
self.S = 1
elif mode == 'ExN':
self.model = MuLUTUnit('2x2d3', nf, upscale=upscale, dense=dense)
self.K = 4
self.S = upscale
elif mode in ['Ox1', 'Hx1']:
self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense)
self.K = 4
self.S = 1
elif mode == ['OxN', 'HxN']:
self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense)
self.K = 4
self.S = upscale
else:
raise AttributeError
self.P = self.K - 1
def forward(self, x):
B, C, H, W = x.shape
x = F.unfold(x, self.K) # B,C*K*K,L
x = x.view(B, C, self.K * self.K, (H - self.P) * (W - self.P)) # B,C,K*K,L
x = x.permute((0, 1, 3, 2)) # B,C,L,K*K
x = x.reshape(B * C * (H - self.P) * (W - self.P),
self.K, self.K) # B*C*L,K,K
x = x.unsqueeze(1) # B*C*L,l,K,K
if 'Y' in self.mode:
x = torch.cat([x[:, :, 0, 0], x[:, :, 1, 1],
x[:, :, 1, 2], x[:, :, 2, 1]], dim=1)
x = x.unsqueeze(1).unsqueeze(1)
elif 'H' in self.mode:
x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
x[:, :, 2, 3], x[:, :, 3, 2]], dim=1)
x = x.unsqueeze(1).unsqueeze(1)
elif 'O' in self.mode:
x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
x[:, :, 1, 3], x[:, :, 3, 1]], dim=1)
x = x.unsqueeze(1).unsqueeze(1)
x = self.model(x) # B*C*L,K,K
x = x.squeeze(1)
x = x.reshape(B, C, (H - self.P) * (W - self.P), -1) # B,C,K*K,L
x = x.permute((0, 1, 3, 2)) # B,C,K*K,L
x = x.reshape(B, -1, (H - self.P) * (W - self.P)) # B,C*K*K,L
x = F.fold(x, ((H - self.P) * self.S, (W - self.P) * self.S),
self.S, stride=self.S)
return x
可以看到SRNet中最重要的模块是MuLUTUnit模块,这个也就是我们之前提到得S、D、Yblock得具体实现,如下所示。
############### MuLUT Blocks ###############
class MuLUTUnit(nn.Module):
""" Generalized (spatial-wise) MuLUT block. """
def __init__(self, mode, nf, upscale=1, out_c=1, dense=True):
super(MuLUTUnit, self).__init__()
self.act = nn.ReLU()
self.upscale = upscale
if mode == '2x2':
self.conv1 = Conv(1, nf, 2)
elif mode == '2x2d':
self.conv1 = Conv(1, nf, 2, dilation=2)
elif mode == '2x2d3':
self.conv1 = Conv(1, nf, 2, dilation=3)
elif mode == '1x4':
self.conv1 = Conv(1, nf, (1, 4))
else:
raise AttributeError
if dense:
self.conv2 = DenseConv(nf, nf)
self.conv3 = DenseConv(nf + nf * 1, nf)
self.conv4 = DenseConv(nf + nf * 2, nf)
self.conv5 = DenseConv(nf + nf * 3, nf)
self.conv6 = Conv(nf * 5, 1 * upscale * upscale, 1)
else:
self.conv2 = ActConv(nf, nf, 1)
self.conv3 = ActConv(nf, nf, 1)
self.conv4 = ActConv(nf, nf, 1)
self.conv5 = ActConv(nf, nf, 1)
self.conv6 = Conv(nf, upscale * upscale, 1)
if self.upscale > 1:
self.pixel_shuffle = nn.PixelShuffle(upscale)
def forward(self, x):
x = self.act(self.conv1(x))
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = torch.tanh(self.conv6(x))
if self.upscale > 1:
x = self.pixel_shuffle(x)
return x
可以看到,MuLUTUnit跟SRLUT得实现比较相近,只不过有两个点会有一些不同:
- 第一层conv得实现会更复杂,这个我们之前也提到了,SDY三种不同得block,取了不同得点,会带来不同得卷积核形状,或者叫计算方式。
- 最后会有一个tanh得激活函数,此是为了在stack得过程中有效控制输出得范围,试想如果没有激活函数将结果控制到[-1, 1]之间,没法保证映射得范围,也就很难做LUT得存储,常见得网络利用得是量化是不需要有这个约束。
这里,针对MuLUTUnit得第一个卷积,博主给出自己得实现代码,更加好理解和实现一些。
class SConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, mode):
super(SConv, self).__init()__
assert in_channels == out_channels
self.weight = nn.Parameter(torch.randn(out_channels, 1, kernel_size, kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channel))
self.weight_mask = nn.Parameter(torch.zeros_like(self.weight), requires_grad=False)
self.stride = stride
self.padding = 0
self.dilation = 1
self.groups = out_channels
self.mode = mode
if mode == 'S':
self.weight_mask[:, :, 0, 0] = 1
self.weight_mask[:, :, 0, 1] = 1
self.weight_mask[:, :, 1, 0] = 1
self.weight_mask[:, :, 1, 1] = 1
elif mode == 'D':
self.weight_mask[:, :, 0, 0] = 1
self.weight_mask[:, :, 0, 2] = 1
self.weight_mask[:, :, 2, 0] = 1
self.weight_mask[:, :, 2, 2] = 1
elif mode == 'Y':
self.weight_mask[:, :, 0, 0] = 1
self.weight_mask[:, :, 1, 1] = 1
self.weight_mask[:, :, 1, 2] = 1
self.weight_mask[:, :, 2, 1] = 1
else:
raise NotImplementedError('cannot find')
def forward(self, x):
x = F.pad(x, (1, 1, 1, 1), mode='reflect')
return F.Conv2d(x, self.weight*self.weight_mask.detach(), self.bias, self.stride, self.padding, self.dilation, self.groups)
博主得实现是利用一个weight_mask来直接选点,这样子是比较好理解得,跟论文比较好对应起来,当然用作者得方式也是可以实现得,考虑到这个任务所需要得计算量并不大,可以这样浪费一些显存去更简单得实现。
以上就把SRNets网络得定义给解释完了,网络得前向部分如下:
def mulut_predict(model_G, x, phase="train", opt=None):
modes, stages = opt.modes, opt.stages
# Stage 1
for s in range(stages):
pred = 0
for mode in modes:
pad = mode_pad_dict[mode]
for r in [0, 1, 2, 3]:
pred += round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
2, 3]), (0, pad, 0, pad), mode='replicate'), stage=s + 1, mode=mode), (4 - r) % 4, [2, 3]) * 127)
if s + 1 == stages:
avg_factor, bias, norm = len(modes), 0, 1
x = round_func((pred / avg_factor) + bias)
if phase == "train":
x = x / 255.0
else:
avg_factor, bias, norm = len(modes) * 4, 127, 255.0
x = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm
return x
以上得model_G就是SRNets,stages是cascade得SRNet个数,这里是2个,modes是我们得SDY,这里是3个,r则是旋转次数,可以看到,针对于不同的stage,作者会用不同得avg_factor和bias,这是因为前面提到中间得stage会被tanh归一化到[-1,1],而最后得是预测数据得输出,这里是归一化到[0,1]得,跟hq去对应,所以说这里得一些放缩系数不一。
- 转表:这里得过程跟SRLUT是一致得,只不过需要注意得是,我们此时构建模板时,需要根据S D Y不同得kernel去设计,如下:
def get_mode_input_tensor(input_tensor, mode):
if mode == "d":
input_tensor_dil = torch.zeros(
(input_tensor.shape[0], input_tensor.shape[1], 3, 3), dtype=input_tensor.dtype).to(input_tensor.device)
input_tensor_dil[:, :, 0, 0] = input_tensor[:, :, 0, 0]
input_tensor_dil[:, :, 0, 2] = input_tensor[:, :, 0, 1]
input_tensor_dil[:, :, 2, 0] = input_tensor[:, :, 1, 0]
input_tensor_dil[:, :, 2, 2] = input_tensor[:, :, 1, 1]
input_tensor = input_tensor_dil
elif mode == "y":
input_tensor_dil = torch.zeros(
(input_tensor.shape[0], input_tensor.shape[1], 3, 3), dtype=input_tensor.dtype).to(input_tensor.device)
input_tensor_dil[:, :, 0, 0] = input_tensor[:, :, 0, 0]
input_tensor_dil[:, :, 1, 1] = input_tensor[:, :, 0, 1]
input_tensor_dil[:, :, 1, 2] = input_tensor[:, :, 1, 0]
input_tensor_dil[:, :, 2, 1] = input_tensor[:, :, 1, 1]
input_tensor = input_tensor_dil
else:
# more sampling modes can be implemented similarly
raise ValueError("Mode {} not implemented.".format(mode))
return input_tensor
不同得input_tensor,我们设置得位置是不一样得,这影响着我们计算LUT得结果,一定要跟初期设计一 一对应。
- 微调:这里前面讲到了作者是先将LUT转换为可训练得参数,然后对其进行微调,代码如下:
class MuLUT(nn.Module):
""" PyTorch version of MuLUT for LUT-aware fine-tuning. """
def __init__(self, lut_folder, stages, modes, upscale=4, interval=4):
super(MuLUT, self).__init__()
self.interval = interval
self.upscale = upscale
self.modes = modes
self.stages = stages
for s in range(stages):
stage = s + 1
scale = upscale if stage == stages else 1
for mode in modes:
lut_path = os.path.join(lut_folder,
"LUT_x{}_{}bit_int8_s{}_{}.npy".format(upscale, interval, str(stage), mode))
key = "s{}_{}".format(str(stage), mode)
lut_arr = np.load(lut_path).reshape(-1, scale * scale).astype(np.float32) / 127.0
self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
可以看到作者先是load了一些lut,然后将他们注册(register_parameter)进了类中,这样方便后续进行训练。据此整体得推理流程如下:
def forward(self, x):
x = x * 255.0
modes, stages = self.modes, self.stages
for s in range(stages):
pred = 0
stage = s + 1
if stage == stages:
avg_factor, bias, norm = len(modes), 0, 1
scale = self.upscale
else:
avg_factor, bias, norm = len(modes) * 4, 127, 255.0
scale = 1
for mode in modes:
pad = mode_pad_dict[mode]
key = "s{}_{}".format(str(stage), mode)
weight = getattr(self, "weight_" + key)
for r in [0, 1, 2, 3]:
pred += torch.rot90(self.InterpTorchBatch(weight, scale, mode, F.pad(torch.rot90(x, r, [
2, 3]), (0, pad, 0, pad), mode='replicate'), pad), (4 - r) % 4, [2, 3])
pred = self.round_func(pred)
x = self.round_func(torch.clamp((pred / avg_factor) + bias, 0, 255))
x = x / 255.0
return x
可以看到先前是一个model_G得前向,现在只是将其修改为InterpTorchBatch查表而已,这个查表得方法在前面有讲解过,这里只需要更换我们查询得点得位置即可,本质还是一个4-simplex插值,如下所示。
if mode == "s":
# pytorch 1.5 dont support rounding_mode, use // equavilent
# https://pytorch.org/docs/1.5.0/torch.html#torch.div
img_a1 = torch.floor_divide(img_in[:, :, 0:0 + h, 0:0 + w], q).type(torch.int64)
img_b1 = torch.floor_divide(img_in[:, :, 0:0 + h, 1:1 + w], q).type(torch.int64)
img_c1 = torch.floor_divide(img_in[:, :, 1:1 + h, 0:0 + w], q).type(torch.int64)
img_d1 = torch.floor_divide(img_in[:, :, 1:1 + h, 1:1 + w], q).type(torch.int64)
# Extract LSBs
fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
fb = img_in[:, :, 0:0 + h, 1:1 + w] % q
fc = img_in[:, :, 1:1 + h, 0:0 + w] % q
fd = img_in[:, :, 1:1 + h, 1:1 + w] % q
elif mode == "d":
img_a1 = torch.floor_divide(img_in[:, :, 0:0 + h, 0:0 + w], q).type(torch.int64)
img_b1 = torch.floor_divide(img_in[:, :, 0:0 + h, 2:2 + w], q).type(torch.int64)
img_c1 = torch.floor_divide(img_in[:, :, 2:2 + h, 0:0 + w], q).type(torch.int64)
img_d1 = torch.floor_divide(img_in[:, :, 2:2 + h, 2:2 + w], q).type(torch.int64)
# Extract LSBs
fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
fb = img_in[:, :, 0:0 + h, 2:2 + w] % q
fc = img_in[:, :, 2:2 + h, 0:0 + w] % q
fd = img_in[:, :, 2:2 + h, 2:2 + w] % q
elif mode == "y":
img_a1 = torch.floor_divide(img_in[:, :, 0:0 + h, 0:0 + w], q).type(torch.int64)
img_b1 = torch.floor_divide(img_in[:, :, 1:1 + h, 1:1 + w], q).type(torch.int64)
img_c1 = torch.floor_divide(img_in[:, :, 1:1 + h, 2:2 + w], q).type(torch.int64)
img_d1 = torch.floor_divide(img_in[:, :, 2:2 + h, 1:1 + w], q).type(torch.int64)
# Extract LSBs
fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
fb = img_in[:, :, 1:1 + h, 1:1 + w] % q
fc = img_in[:, :, 1:1 + h, 2:2 + w] % q
fd = img_in[:, :, 2:2 + h, 1:1 + w] % q
else:
# more sampling modes can be implemented similarly
raise ValueError("Mode {} not implemented.".format(mode))
最后提一下前面讲到得re-index实现方法:
@staticmethod
def round_func(input):
# Backward Pass Differentiable Approximation (BPDA)
# This is equivalent to replacing round function (non-differentiable)
# with an identity function (differentiable) only when backward
forward_value = torch.round(input)
out = input.clone()
out.data = forward_value.data
return out
求导时,跳过了不可导的round算子,这样实现了一个re-index。
- 测试:测试的推理过程其实跟微调的过程是一样的了,本质上微调只是一个测试过程中的简单调整,这样减小误差。
五、总结
- MuLUT在SRLUT的基础上,利用串联和并联提升了感受野,是LUT方法扩展的关键。
- MuLUT提出了 LUT微调的方式,可以减小查表的误差,同样有效提升了LUT方法得实用性,减小了训测的误差。
- 该版本仍没有考虑通道之间的关系,同时感受野仍然太小,提升后也只有9x9,这对于追求更高质量的图像处理任务而言还不够。
感谢阅读,欢迎留言或私信,一起探讨和交流,如果对你有帮助的话,也希望可以给博主点一个关注,谢谢。