【LUT技术专题】MuLUT: Cooperating Multiple Look-Up Tables for Efficient Image Super-Resolution

专题介绍

Look-Up Table(查找表,LUT)是一种数据结构(也可以理解为字典),通过输入的key来查找到对应的value。其优势在于无需计算过程,不依赖于GPU、NPU等特殊硬件,本质就是一种内存换算力的思想。LUT在图像处理中是比较常见的操作,如Gamma映射,3D CLUT等。

近些年,LUT技术已被用于深度学习领域,由SR-LUT启发性地提出了模型训练+LUT推理的新范式。
本专题旨在跟进和解读LUT技术的发展趋势,为读者分享最全最新的LUT方法,欢迎一起探讨交流,对该专题感兴趣的读者可以订阅本专栏第一时间看到更新。

系列文章如下:
【1】SR-LUT



本文将从头开始对MuLUT: Cooperating Multiple Look-Up Tables
for Efficient Image Super-Resolution,这篇轻量超分算法进行讲解,本篇文章是SRLUT文章的改进版本,建议对SRLUT不熟悉的读者可以先阅读 SRLUT讲解。参考资料如下:
[1]. MuLUT论文地址
[2]. MuLUT代码地址
[3]. MuLUT论文笔记

一、研究背景

MuLUT提出的原因在于:SRLUT的性能比较弱,感受野只有3 * 3,用于超分的效果不强,也无法应用于需要更大感受野的其他图像处理任务,例如Demosaic。作者希望在保持LUT轻量化、资源消耗少的前提下,去提升网络的效果和方法的通用性,扩展到更多的任务。

二、MuLUT方法

MuLUT提出了多LUT串联和并联的方法,来提升感受野,其单一LUT流程与SRLUT基本一致,只不过因为MuLUT使用了更多的LUT,所以MuLUT提出了一个LUT finetune的策略来减小LUT量化后的误差:

2.1 训练

作者设计了下图的网络结构,图中右下角可以看到,其横向(Hierarchical Indexing)表示是一个串联的索引结构,即堆叠多个可以转换为一个LUT的卷积块,纵向(Complementary Indexing)表示是一个并联的索引结构,因为是并联,所以需要将三个不同Block(MuLUT-S、D、Y)计算的结构进行相加合并,从图中还可以看到,S、D、Y只是选取查询的点不同而已,转换后都是一个4D的LUT,串联完最后使用一个pixelshuffle进行超分即可,因为串联的结构本身是一个三个block并联的,因此最后将他们相加并进行平均得到最终结果。
在这里插入图片描述

接下来针对上面提到的流程,分点进行细节的讲解:

1)并联的计算:并联用到的几个block结构如下图所示。
在这里插入图片描述
如果我们用卷积的视角去看的话,S是一个2x2的卷积,这与SRLUT一致,D是一个2x2的dilate卷积,跳着选点,Y的实现会比较麻烦一些,这里作者给出了一个实现方法,如下图所示。
在这里插入图片描述
可见,这个实现方法是通过Unfold和采样来实现,博主觉得这种方法会比较麻烦且难以理解,在后续的代码讲解部分会给出自己的实现方法。然后我们也可以用以下的公式来说明这个计算过程。
V = ( L U T S [ I 0 ] [ I 1 ] [ I 3 ] [ I 4 ] + L U T D [ I 0 ] [ I 2 ] [ I 6 ] [ I 8 ] + L U T Y [ ( I 0 ] [ I 4 ] [ I 5 ] [ I 7 ] ) / 3 , \mathbf{V}=\left(L U T_{S}\left[I_{0}\right]\left[I_{1}\right]\left[I_{3}\right]\left[I_{4}\right]+L U T_{D}\left[I_{0}\right]\left[I_{2}\right]\left[I_{6}\right]\left[I_{8}\right]+L U T_{Y}\left[\left(I_{0}\right]\left[I_{4}\right]\left[I_{5}\right]\left[I_{7}\right]\right) / 3,\right. V=(LUTS[I0][I1][I3][I4]+LUTD[I0][I2][I6][I8]+LUTY[(I0][I4][I5][I7])/3,
可以看到就是S、D、Y三者相加后求平均,这个平均过程作者也进行了一些修改,在下面会进行讲解。

2)相加平均的处理:作者这里使用了一个叫Re-index的策略,因为LUT需要我们的网络始终是8bit的状态,多个block的结果相加后平均才可以回到8bit,但是这里会存在浮点数的情况,为了解决这个问题,作者使用了Re-index,实际是比较简单的,如下图所示,作者会将计算后的浮点数,进行一个量化,但是保存浮点数的梯度,此被称为STE(Straight-through estimator),在量化训练中经常使用到,用这个策略让网络去适应这个量化损失。
在这里插入图片描述

3)串联的计算:串联的实现如下图所示。
在这里插入图片描述
配合上图和下式可以发现,作者是用LUT(1)的输出作为LUT(2)的索引,从而完成了串联。
V = L U T ( 2 ) [ L U T ( 1 ) [ I ∗ ] ] [ L U T ( 1 ) [ I ∗ ] ] [ L U T ( 1 ) [ I ∗ ] ] [ L U T ( 1 ) [ I ∗ ] ] . \mathbf{V}=L U T^{(2)}\left[L U T^{(1)}\left[I_{*}\right]\right]\left[L U T^{(1)}\left[I_{*}\right]\right]\left[L U T^{(1)}\left[I_{*}\right]\right]\left[L U T^{(1)}\left[I_{*}\right]\right] . V=LUT(2)[LUT(1)[I]][LUT(1)[I]][LUT(1)[I]][LUT(1)[I]].
串联LUT显然跟多个串联卷积是一样的,通过串联可以提升网络的感受野。

4)MuLUT感受野的计算:首先不考虑旋转,一个MuLUT-S、D、Y相加是RF=3x3,在stack一个MuLUT-S、D、Y Block,RF=5x5,然后加上3次旋转来提升感受野,就可以得到论文中给出的RF=9x9了。

2.2 转换

此步跟SRLUT一致,是将训练好的每个Block计算过程写入到一个固定的4D LUT中,前面我们知道步骤(1)中训练的网络感受野为2x2,因此针对于一个8bit的数据,4D LUT的4个输入存在256 * 256 * 256 * 256种可能性,将所有的可能性进行组合,可以得到它们相同大小的结果,同样以8bit去存储,在超分2倍的情况下,我们有2 * 2 * 256 * 256 * 256 * 256个数值,这些数值占用的存储大小可以计算得到256 * 256 * 256 * 256 * 4B=16GB,为了减小这部分大小,会对SRLUT的输入进行采样,假设我们均匀采样17个点,即只选择0、16、32…255,可以将整个LUT表的大小减小至17 * 17 * 17 * 17 * 4B=326.2KB,论文实际使用了17个点去进行存储。

而MuLUT有多个Block,我们需要针对每个Block进行转换为LUT表,然而在进行推理时会发现,LUT表是存在误差的,因为转换LUT时,我们使用了均匀采样,损失了部分精度,为了减小这部分精度的损失,作者提出了一个LUT Finetune的策略,此在下面的第三步微调中会讲解。

2.3 微调

SRLUT由于只有一个LUT,不进行微调去解决量化误差,训测模型性能差距不大,但是MuLUT存在多个LUT串并联的情况,不处理会有较大误差。作者提出的微调过程如下:

  • 首先把训练好的网络进行第2步转换为LUT;
  • 然后将这些LUT视为一个trainable权重,因为4-simplex插值过程是可微的,因此我们训练它,相当于第一步训练的卷积的权重;
  • 第三步训练的是LUT表。

2.4 测试

利用步骤(3)微调后得到的LUT表可以进行测试,这里使用到的插值方式在SR-LUT里有讲解。

三、实验结果

在这里插入图片描述

定量的实验结果显示:MuLUT-SDY以及SDY-X2对比SRLUT有明显提升,当然在效果上会比传统的插值方法要更好,但是对比DNN的方法效果还是差的,当然它主打的是轻量级的超分方法。

在这里插入图片描述

定性的实验结果显示,该方法还是比传统插值要更好,逼近FSRCNN,初期的DNN的方法。

在这里插入图片描述

在这里插入图片描述

与多个方法进行对比性能消耗,性价比高,只用了很少的Energy Cost,不发热然后达到与DNN方法相近的效果,耗时的比较也是类似的结论,夹在传统方法和DNN之间。

接下来作者进行了消融实验
1)并联结构:
在这里插入图片描述

从表中可以看到,对比了只有SRLUT只有S,MuLUT只有S和D没有Y,以及MuLUT完整体,SDY的效果是最好的。

2)re-index以及串联结构:
在这里插入图片描述

前两行对比了re-index策略,有re-index提升了精度,后三行对比了串联的效果,串联越多,感受野越大,效果越好。

3)LUT aware finetune策略:
在这里插入图片描述

对比了SRLUT4bit ft前后,有提升,3bit同理,因为量化损失大,提升更大,MULUT也是同样的结论,验证了有效性。

四、代码

同样的,实现代码分为四个部分,分别是训练、转表、微调、推理,在如下所示得sr文件夹中。
在这里插入图片描述

  1. 训练:这个步骤中,关注1_train_model.py文件,首先这个文件会从./common/option.py文件中导入模型SRNets。
class BaseOptions():
    def __init__(self, debug=False):
        self.initialized = False
        self.debug = debug

    def initialize(self, parser):
        # experiment specifics
        parser.add_argument('--model', type=str, default='SRNets')
        parser.add_argument('--task', '-t', type=str, default='sr')
        parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
        parser.add_argument('--sigma', '-s', type=int, default=25, help="noise level")
        parser.add_argument('--qf', '-q', type=int, default=20, help="deblocking quality factor")
        parser.add_argument('--nf', type=int, default=64, help="number of filters of convolutional layers")
        parser.add_argument('--stages', type=int, default=2, help="stages of MuLUT")
        parser.add_argument('--modes', type=str, default='sdy', help="sampling modes to use in every stage")
        parser.add_argument('--interval', type=int, default=4, help='N bit uniform sampling')
        parser.add_argument('--modelRoot', type=str, default='../models')

        parser.add_argument('--expDir', '-e', type=str, default='', help="experiment folder")
        parser.add_argument('--load_from_opt_file', action='store_true', default=False)

        parser.add_argument('--debug', default=False, action='store_true')

        self.initialized = True
        return parser

而SRNets其实就是stack了多个SRNet结构,在sr/model.py文件中,如下所示,可见这个就是论文前面提到得串联结构,这里作者会用stage去表示,可见默认设置是stage=2,与论文SDYx2也是一致的。

class SRNets(nn.Module):
    """ A LUT-convertable SR network with configurable stages and patterns. """

    def __init__(self, nf=64, scale=4, modes=['s', 'd', 'y'], stages=2):
        super(SRNets, self).__init__()

        for s in range(stages):  # 2-stage
            if (s + 1) == stages:
                upscale = scale
                flag = "N"
            else:
                upscale = None
                flag = "1"
            for mode in modes:
                self.add_module("s{}_{}".format(str(s + 1), mode),
                                SRNet("{}x{}".format(mode.upper(), flag), nf=nf, upscale=upscale))
        # print_network(self)

    def forward(self, x, stage, mode):
        key = "s{}_{}".format(str(stage), mode)
        module = getattr(self, key)
        return module(x)

之后就是SRNet的实现,在common/network.py中,如下所示。

class SRNet(nn.Module):
    """ Wrapper of a generalized (spatial-wise) MuLUT block. 
        By specifying the unfolding patch size and pixel indices,
        arbitrary sampling pattern can be implemented.
    """

    def __init__(self, mode, nf=64, upscale=None, dense=True):
        super(SRNet, self).__init__()
        self.mode = mode

        if 'x1' in mode:
            assert upscale is None
        if mode == 'Sx1':
            self.model = MuLUTUnit('2x2', nf, upscale=1, dense=dense)
            self.K = 2
            self.S = 1
        elif mode == 'SxN':
            self.model = MuLUTUnit('2x2', nf, upscale=upscale, dense=dense)
            self.K = 2
            self.S = upscale
        elif mode == 'Dx1':
            self.model = MuLUTUnit('2x2d', nf, upscale=1, dense=dense)
            self.K = 3
            self.S = 1
        elif mode == 'DxN':
            self.model = MuLUTUnit('2x2d', nf, upscale=upscale, dense=dense)
            self.K = 3
            self.S = upscale
        elif mode == 'Yx1':
            self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense)
            self.K = 3
            self.S = 1
        elif mode == 'YxN':
            self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense)
            self.K = 3
            self.S = upscale
        elif mode == 'Ex1':
            self.model = MuLUTUnit('2x2d3', nf, upscale=1, dense=dense)
            self.K = 4
            self.S = 1
        elif mode == 'ExN':
            self.model = MuLUTUnit('2x2d3', nf, upscale=upscale, dense=dense)
            self.K = 4
            self.S = upscale
        elif mode in ['Ox1', 'Hx1']:
            self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense)
            self.K = 4
            self.S = 1
        elif mode == ['OxN', 'HxN']:
            self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense)
            self.K = 4
            self.S = upscale
        else:
            raise AttributeError
        self.P = self.K - 1

    def forward(self, x):
        B, C, H, W = x.shape
        x = F.unfold(x, self.K)  # B,C*K*K,L
        x = x.view(B, C, self.K * self.K, (H - self.P) * (W - self.P))  # B,C,K*K,L
        x = x.permute((0, 1, 3, 2))  # B,C,L,K*K
        x = x.reshape(B * C * (H - self.P) * (W - self.P),
                      self.K, self.K)  # B*C*L,K,K
        x = x.unsqueeze(1)  # B*C*L,l,K,K

        if 'Y' in self.mode:
            x = torch.cat([x[:, :, 0, 0], x[:, :, 1, 1],
                           x[:, :, 1, 2], x[:, :, 2, 1]], dim=1)

            x = x.unsqueeze(1).unsqueeze(1)
        elif 'H' in self.mode:
            x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
                           x[:, :, 2, 3], x[:, :, 3, 2]], dim=1)

            x = x.unsqueeze(1).unsqueeze(1)
        elif 'O' in self.mode:
            x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
                           x[:, :, 1, 3], x[:, :, 3, 1]], dim=1)

            x = x.unsqueeze(1).unsqueeze(1)

        x = self.model(x)   # B*C*L,K,K
        x = x.squeeze(1)
        x = x.reshape(B, C, (H - self.P) * (W - self.P), -1)  # B,C,K*K,L
        x = x.permute((0, 1, 3, 2))  # B,C,K*K,L
        x = x.reshape(B, -1, (H - self.P) * (W - self.P))  # B,C*K*K,L
        x = F.fold(x, ((H - self.P) * self.S, (W - self.P) * self.S),
                   self.S, stride=self.S)
        return x

可以看到SRNet中最重要的模块是MuLUTUnit模块,这个也就是我们之前提到得S、D、Yblock得具体实现,如下所示。

############### MuLUT Blocks ###############
class MuLUTUnit(nn.Module):
    """ Generalized (spatial-wise)  MuLUT block. """

    def __init__(self, mode, nf, upscale=1, out_c=1, dense=True):
        super(MuLUTUnit, self).__init__()
        self.act = nn.ReLU()
        self.upscale = upscale

        if mode == '2x2':
            self.conv1 = Conv(1, nf, 2)
        elif mode == '2x2d':
            self.conv1 = Conv(1, nf, 2, dilation=2)
        elif mode == '2x2d3':
            self.conv1 = Conv(1, nf, 2, dilation=3)
        elif mode == '1x4':
            self.conv1 = Conv(1, nf, (1, 4))
        else:
            raise AttributeError

        if dense:
            self.conv2 = DenseConv(nf, nf)
            self.conv3 = DenseConv(nf + nf * 1, nf)
            self.conv4 = DenseConv(nf + nf * 2, nf)
            self.conv5 = DenseConv(nf + nf * 3, nf)
            self.conv6 = Conv(nf * 5, 1 * upscale * upscale, 1)
        else:
            self.conv2 = ActConv(nf, nf, 1)
            self.conv3 = ActConv(nf, nf, 1)
            self.conv4 = ActConv(nf, nf, 1)
            self.conv5 = ActConv(nf, nf, 1)
            self.conv6 = Conv(nf, upscale * upscale, 1)
        if self.upscale > 1:
            self.pixel_shuffle = nn.PixelShuffle(upscale)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = torch.tanh(self.conv6(x))
        if self.upscale > 1:
            x = self.pixel_shuffle(x)
        return x

可以看到,MuLUTUnit跟SRLUT得实现比较相近,只不过有两个点会有一些不同:

  • 第一层conv得实现会更复杂,这个我们之前也提到了,SDY三种不同得block,取了不同得点,会带来不同得卷积核形状,或者叫计算方式。
  • 最后会有一个tanh得激活函数,此是为了在stack得过程中有效控制输出得范围,试想如果没有激活函数将结果控制到[-1, 1]之间,没法保证映射得范围,也就很难做LUT得存储,常见得网络利用得是量化是不需要有这个约束。
    这里,针对MuLUTUnit得第一个卷积,博主给出自己得实现代码,更加好理解和实现一些。
class SConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, mode):
        super(SConv, self).__init()__
        assert in_channels == out_channels
        self.weight = nn.Parameter(torch.randn(out_channels, 1, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channel))
        self.weight_mask = nn.Parameter(torch.zeros_like(self.weight), requires_grad=False)

        self.stride = stride
        self.padding = 0
        self.dilation = 1
        self.groups = out_channels
        self.mode = mode

        if mode == 'S':
            self.weight_mask[:, :, 0, 0] = 1
            self.weight_mask[:, :, 0, 1] = 1
            self.weight_mask[:, :, 1, 0] = 1
            self.weight_mask[:, :, 1, 1] = 1
        elif mode == 'D':
            self.weight_mask[:, :, 0, 0] = 1
            self.weight_mask[:, :, 0, 2] = 1
            self.weight_mask[:, :, 2, 0] = 1
            self.weight_mask[:, :, 2, 2] = 1
        elif mode == 'Y':
            self.weight_mask[:, :, 0, 0] = 1
            self.weight_mask[:, :, 1, 1] = 1
            self.weight_mask[:, :, 1, 2] = 1
            self.weight_mask[:, :, 2, 1] = 1
        else:
            raise NotImplementedError('cannot find')

    def forward(self, x):
        x = F.pad(x, (1, 1, 1, 1), mode='reflect')
        return F.Conv2d(x, self.weight*self.weight_mask.detach(), self.bias, self.stride, self.padding, self.dilation, self.groups)

博主得实现是利用一个weight_mask来直接选点,这样子是比较好理解得,跟论文比较好对应起来,当然用作者得方式也是可以实现得,考虑到这个任务所需要得计算量并不大,可以这样浪费一些显存去更简单得实现。
以上就把SRNets网络得定义给解释完了,网络得前向部分如下:

def mulut_predict(model_G, x, phase="train", opt=None):
    modes, stages = opt.modes, opt.stages
    # Stage 1
    for s in range(stages):
        pred = 0
        for mode in modes:
            pad = mode_pad_dict[mode]
            for r in [0, 1, 2, 3]:
                pred += round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
                    2, 3]), (0, pad, 0, pad), mode='replicate'), stage=s + 1, mode=mode), (4 - r) % 4, [2, 3]) * 127)
        if s + 1 == stages:
            avg_factor, bias, norm = len(modes), 0, 1
            x = round_func((pred / avg_factor) + bias)
            if phase == "train":
                x = x / 255.0
        else:
            avg_factor, bias, norm = len(modes) * 4, 127, 255.0
            x = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm

    return x

以上得model_G就是SRNets,stages是cascade得SRNet个数,这里是2个,modes是我们得SDY,这里是3个,r则是旋转次数,可以看到,针对于不同的stage,作者会用不同得avg_factor和bias,这是因为前面提到中间得stage会被tanh归一化到[-1,1],而最后得是预测数据得输出,这里是归一化到[0,1]得,跟hq去对应,所以说这里得一些放缩系数不一。

  1. 转表:这里得过程跟SRLUT是一致得,只不过需要注意得是,我们此时构建模板时,需要根据S D Y不同得kernel去设计,如下:
def get_mode_input_tensor(input_tensor, mode):
    if mode == "d":
        input_tensor_dil = torch.zeros(
            (input_tensor.shape[0], input_tensor.shape[1], 3, 3), dtype=input_tensor.dtype).to(input_tensor.device)
        input_tensor_dil[:, :, 0, 0] = input_tensor[:, :, 0, 0]
        input_tensor_dil[:, :, 0, 2] = input_tensor[:, :, 0, 1]
        input_tensor_dil[:, :, 2, 0] = input_tensor[:, :, 1, 0]
        input_tensor_dil[:, :, 2, 2] = input_tensor[:, :, 1, 1]
        input_tensor = input_tensor_dil
    elif mode == "y":
        input_tensor_dil = torch.zeros(
            (input_tensor.shape[0], input_tensor.shape[1], 3, 3), dtype=input_tensor.dtype).to(input_tensor.device)
        input_tensor_dil[:, :, 0, 0] = input_tensor[:, :, 0, 0]
        input_tensor_dil[:, :, 1, 1] = input_tensor[:, :, 0, 1]
        input_tensor_dil[:, :, 1, 2] = input_tensor[:, :, 1, 0]
        input_tensor_dil[:, :, 2, 1] = input_tensor[:, :, 1, 1]
        input_tensor = input_tensor_dil
    else:
        # more sampling modes can be implemented similarly
        raise ValueError("Mode {} not implemented.".format(mode))
    return input_tensor

不同得input_tensor,我们设置得位置是不一样得,这影响着我们计算LUT得结果,一定要跟初期设计一 一对应。

  1. 微调:这里前面讲到了作者是先将LUT转换为可训练得参数,然后对其进行微调,代码如下:
class MuLUT(nn.Module):
    """ PyTorch version of MuLUT for LUT-aware fine-tuning. """

    def __init__(self, lut_folder, stages, modes, upscale=4, interval=4):
        super(MuLUT, self).__init__()
        self.interval = interval
        self.upscale = upscale
        self.modes = modes
        self.stages = stages

        for s in range(stages):
            stage = s + 1
            scale = upscale if stage == stages else 1
            for mode in modes:
                lut_path = os.path.join(lut_folder,
                                        "LUT_x{}_{}bit_int8_s{}_{}.npy".format(upscale, interval, str(stage), mode))
                key = "s{}_{}".format(str(stage), mode)
                lut_arr = np.load(lut_path).reshape(-1, scale * scale).astype(np.float32) / 127.0
                self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))

可以看到作者先是load了一些lut,然后将他们注册(register_parameter)进了类中,这样方便后续进行训练。据此整体得推理流程如下:

    def forward(self, x):
        x = x * 255.0
        modes, stages = self.modes, self.stages
        for s in range(stages):
            pred = 0
            stage = s + 1
            if stage == stages:
                avg_factor, bias, norm = len(modes), 0, 1
                scale = self.upscale
            else:
                avg_factor, bias, norm = len(modes) * 4, 127, 255.0
                scale = 1
            for mode in modes:
                pad = mode_pad_dict[mode]
                key = "s{}_{}".format(str(stage), mode)
                weight = getattr(self, "weight_" + key)
                for r in [0, 1, 2, 3]:
                    pred += torch.rot90(self.InterpTorchBatch(weight, scale, mode, F.pad(torch.rot90(x, r, [
                        2, 3]), (0, pad, 0, pad), mode='replicate'), pad), (4 - r) % 4, [2, 3])
                    pred = self.round_func(pred)
            x = self.round_func(torch.clamp((pred / avg_factor) + bias, 0, 255))

        x = x / 255.0
        return x

可以看到先前是一个model_G得前向,现在只是将其修改为InterpTorchBatch查表而已,这个查表得方法在前面有讲解过,这里只需要更换我们查询得点得位置即可,本质还是一个4-simplex插值,如下所示。

        if mode == "s":
            # pytorch 1.5 dont support rounding_mode, use // equavilent
            # https://pytorch.org/docs/1.5.0/torch.html#torch.div
            img_a1 = torch.floor_divide(img_in[:, :, 0:0 + h, 0:0 + w], q).type(torch.int64)
            img_b1 = torch.floor_divide(img_in[:, :, 0:0 + h, 1:1 + w], q).type(torch.int64)
            img_c1 = torch.floor_divide(img_in[:, :, 1:1 + h, 0:0 + w], q).type(torch.int64)
            img_d1 = torch.floor_divide(img_in[:, :, 1:1 + h, 1:1 + w], q).type(torch.int64)

            # Extract LSBs
            fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
            fb = img_in[:, :, 0:0 + h, 1:1 + w] % q
            fc = img_in[:, :, 1:1 + h, 0:0 + w] % q
            fd = img_in[:, :, 1:1 + h, 1:1 + w] % q

        elif mode == "d":
            img_a1 = torch.floor_divide(img_in[:, :, 0:0 + h, 0:0 + w], q).type(torch.int64)
            img_b1 = torch.floor_divide(img_in[:, :, 0:0 + h, 2:2 + w], q).type(torch.int64)
            img_c1 = torch.floor_divide(img_in[:, :, 2:2 + h, 0:0 + w], q).type(torch.int64)
            img_d1 = torch.floor_divide(img_in[:, :, 2:2 + h, 2:2 + w], q).type(torch.int64)

            # Extract LSBs
            fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
            fb = img_in[:, :, 0:0 + h, 2:2 + w] % q
            fc = img_in[:, :, 2:2 + h, 0:0 + w] % q
            fd = img_in[:, :, 2:2 + h, 2:2 + w] % q

        elif mode == "y":
            img_a1 = torch.floor_divide(img_in[:, :, 0:0 + h, 0:0 + w], q).type(torch.int64)
            img_b1 = torch.floor_divide(img_in[:, :, 1:1 + h, 1:1 + w], q).type(torch.int64)
            img_c1 = torch.floor_divide(img_in[:, :, 1:1 + h, 2:2 + w], q).type(torch.int64)
            img_d1 = torch.floor_divide(img_in[:, :, 2:2 + h, 1:1 + w], q).type(torch.int64)

            # Extract LSBs
            fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
            fb = img_in[:, :, 1:1 + h, 1:1 + w] % q
            fc = img_in[:, :, 1:1 + h, 2:2 + w] % q
            fd = img_in[:, :, 2:2 + h, 1:1 + w] % q
        else:
            # more sampling modes can be implemented similarly
            raise ValueError("Mode {} not implemented.".format(mode))

最后提一下前面讲到得re-index实现方法:

    @staticmethod
    def round_func(input):
        # Backward Pass Differentiable Approximation (BPDA)
        # This is equivalent to replacing round function (non-differentiable)
        # with an identity function (differentiable) only when backward
        forward_value = torch.round(input)
        out = input.clone()
        out.data = forward_value.data
        return out

求导时,跳过了不可导的round算子,这样实现了一个re-index。

  1. 测试:测试的推理过程其实跟微调的过程是一样的了,本质上微调只是一个测试过程中的简单调整,这样减小误差。

五、总结

  1. MuLUT在SRLUT的基础上,利用串联和并联提升了感受野,是LUT方法扩展的关键。
  2. MuLUT提出了 LUT微调的方式,可以减小查表的误差,同样有效提升了LUT方法得实用性,减小了训测的误差。
  3. 该版本仍没有考虑通道之间的关系,同时感受野仍然太小,提升后也只有9x9,这对于追求更高质量的图像处理任务而言还不够。

感谢阅读,欢迎留言或私信,一起探讨和交流,如果对你有帮助的话,也希望可以给博主点一个关注,谢谢。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值