【LUT技术专题】SRLUT:Practical Single-Image Super-Resolution Using Look-Up Table

该文章已生成可运行项目,


本文将从头开始对SRLUT: Practical Single-Image Super-Resolution Using Look-Up Table,这篇极轻量超分算法进行讲解。参考资料如下:
[1]. SRLUT论文地址
[2]. SRLUT代码地址
[3]. SRLUT代码讲解


专题介绍

Look-Up Table(查找表,LUT)是一种数据结构(也可以理解为字典),通过输入的key来查找到对应的value。其优势在于无需计算过程,不依赖于GPU、NPU等特殊硬件,本质就是一种内存换算力的思想。LUT在图像处理中是比较常见的操作,如Gamma映射,3D CLUT等。

近些年,LUT技术已被用于深度学习领域,由SR-LUT启发性地提出了模型训练+LUT推理的新范式。
本专题旨在跟进和解读LUT技术的发展趋势,为读者分享最全最新的LUT方法,欢迎一起探讨交流,对该专题感兴趣的读者可以订阅本专栏第一时间看到更新。

系列文章如下:
【1】SR-LUT
【2】Mu-LUT
【3】SP-LUT
【4】RC-LUT
【5】EC-LUT
【6】SPF-LUT
【7】Dn-LUT


一、研究背景

超分网络的需求日益增长,但大多数方法都需要GPU来进行推理。本文的研究目标就是在不利用GPU的情况下,使用Look-UP table查找表来实现超分任务。
查找表:本质是一个字典,可以通过输入的key来查找到对应的value,例如1D LUT就是一个gamma曲线,可以用来调整图像亮度;3D LUT在我们常见的PS等图像软件中可以用来调整颜色,他的原理就是通过3D LUT将一个RGB元素映射到另一个RGB上。

二、SRLUT方法

方法的整体流程图如下:

分为三步:

  1. 训练:与常规CNN网络一致,通过构建多个卷积层搭配最后的Depth to space(pixelshuffle)来完成网络的超分,值得注意的是它只有第一层的卷积是一个2x2的卷积核,其他的均为1,因此可以知道网络的感受野为2x2,为了进一步扩大感受野,SRLUT作者设计将图像进行3次旋转,每次旋转90度,此操作在不引入更多非1x1卷积的情况下,提升了感受野,使得网络的感受野变成了3x3。
  2. 转换:此步是将我们训练好的网络计算过程写入到一个固定的4D LUT中,前面我们知道步骤(1)中训练的网络感受野为2x2,因此针对于一个8bit的数据,4D LUT的4个输入存在256 * 256 * 256 * 256种可能性,将所有的可能性进行组合,可以得到它们相同大小的结果,同样以8bit去存储,在超分2倍的情况下,我们有2 * 2 * 256 * 256 * 256 * 256个数值,这些数值占用的存储大小可以计算得到256 * 256 * 256 * 256 * 4B=16GB,为了减小这部分大小,会对SRLUT的输入进行采样,假设我们均匀采样17个点,即只选择0、16、32…255,可以将整个LUT表的大小减小至17 * 17 * 17 * 17 * 4B=326.2KB,论文实际使用了17个点去进行存储。
  3. 测试:利用步骤(2)转换得到的LUT表可以进行测试,当我们对一个输入范围内的4个点I0-I3进行测试时,可以通过查找SR-LUT来计算得到结果。这里存在一个问题即,如何查找一个存在间隔的表,因为实际输入是8bit,而我们的表是经过进一步量化的,因此需要对其进行插值才能查找。这里的插值方法存在两种,首先以2D LUT的插值过程为例进行讲解:
    (a)最常见的双线性插值:一个例子可见下图。请添加图片描述

可见当输入的两个查询点是I0=24以及I1=60时,通过4bit的量化,可以计算得到他们夹在第1维1-2和第2维3-4之间,可以计算他们的坐标为1.5和3.75,此时,我们可以利用面积来得到I0、I1对应的结果。
P I 0 I 1 = ( 1 − d x ) ∗ ( 1 − d y ) ∗ P 00 + ( 1 − d x ) ∗ d y ∗ P 01 + d x ∗ ( 1 − d y ) ∗ P 10 + d x ∗ d y ∗ P 11 P_{I_{0} I_{1}}=\left(1-d_{x}\right) *\left(1-d_{y}\right) * P_{00}+\left(1-d_{x}\right) * d_{y} * P_{01}+d_{x} *\left(1-d_{y}\right) * P_{10}+d_{x} * d_{y} * P_{11} PI0I1=(1dx)(1dy)P00+(1dx)dyP01+dx(1dy)P10+dxdyP11

由于所有系数的和正好是正方形的面积,因此天生是归一化的。
(b)双线性插值的浮点计算量会更大,作者选择了一种更少计算量的插值方法,2D LUT中它的名字叫三角形插值,同样利用一个例子来进行讲解。
请添加图片描述

同样的位置,作者针对这种情况会先进行选点,选择合适的点再进行插值。选点的逻辑如下,首先对角线上的点P00和P11一定会被选择,P01和P10则根据待插值点的位置来进行选择,利用dx和dy,因为dx<dy,在对角线的左上侧,则我们会选择P01作为最后一个使用到的采样点。
接下来就是得到他们的权重了,这里的权重是通过三角形的面积求得,P00就需要w0对应得三角形面积来进行加权,P11需要w2,P01需要w1,可以计算得到。
P I 0 I 1 = ( 1 − d y ) ∗ P 00 + ( d y − d x ) ∗ P 01 + d x ∗ P 11 P_{I_{0} I_{1}}=\left(1-d_{y}\right) * P_{00}+\left(d_{y}-d_{x}\right) * P_{01}+d_{x} * P_{11} PI0I1=(1dy)P00+(dydx)P01+dxP11

因为面积正好也是1/2,所以权重也是正好为1,天生是归一化得。这里作者用了LSB得思路来进行讲解,本质上是一样得逻辑,如下图:
在这里插入图片描述
LSB即每个数得低4有效位,对于I0和I1来说就分别是8和12,因为I0得有效位比I1小,因此选择P01,权重得计算是一致得,只不过作者使用LSB来进行计算,因此每个w都需要除以16。
这里提两点博主对于(b)为何作者会这样去使用这个插值方法的思考
(i):选点得方式:不固定选择对角线上得点,选择更近得3个点,这个点在实际使用时是不合理得,因为SRLUT不能够处理通道之间得关系,当我们不固定一个插值得方向时,容易出现颜色不合理得问题。
(ii):权重得计算:不使用三角形得面积,而是使用跟双线性插值一样得矩形面积来加权,这样使用得面积在三角形之外,容易发生分层问题,即插值不平滑问题,同样不合理。

上面讲完了2D LUT得情况,3D LUT得情况更加复杂一些,即我们现在是在立方体中完成插值,对于插值方式(a)来说,变成了三次线性插值,对于插值方式(b)来说,变成了四面体插值。三次线性插值此时加权系数变为了一个块得体积,插值方式(b)需要判断点位于可插值得6个四面体得不同情况,再进行同样情况得加权系数计算。

三、实验结果

在这里插入图片描述
定量的实验结果显示:SRLUT,Ours-V、Ours-F、Ours-S,分别代表感受野为2D、3D以及4D时候的SR-LUT,在效果上会比传统的插值方法要更好,但是对比DNN的方法效果还是比较差的,当然它主打的是轻量级的超分方法。
在这里插入图片描述
定性的实验结果显示,该方法还是比传统插值要更好。

四、代码

同样的,实现代码分为三个部分,分别是训练、转表、推理。

  1. 训练:网络结构如下所示,比较简单的一个网络结构,最后会使用一个pixelshuffle层来进行超分。
class SRNet(torch.nn.Module):
    def __init__(self, upscale=4):
        super(SRNet, self).__init__()

        self.upscale = upscale

        self.conv1 = nn.Conv2d(1, 64, [2,2], stride=1, padding=0, dilation=1)
        self.conv2 = nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1)
        self.conv3 = nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1)
        self.conv4 = nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1)
        self.conv5 = nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1)
        self.conv6 = nn.Conv2d(64, 1*upscale*upscale, 1, stride=1, padding=0, dilation=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale)

        # Init weights
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.lower().find('conv') != -1:
                nn.init.kaiming_normal(m.weight)
                nn.init.constant(m.bias, 0)
            elif classname.find('bn') != -1:
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.fill_(0)


    def forward(self, x_in):
        B, C, H, W = x_in.size()
        x_in = x_in.reshape(B*C, 1, H, W)

        x = self.conv1(x_in)
        x = self.conv2(F.relu(x))
        x = self.conv3(F.relu(x))
        x = self.conv4(F.relu(x))
        x = self.conv5(F.relu(x))
        x = self.conv6(F.relu(x))
        x = self.pixel_shuffle(x)
        x = x.reshape(B, C, self.upscale*(H-1), self.upscale*(W-1))

        return x

需要注意以下两点:
(a)输入之前需要做一次pad,因为卷积会损失分辨率。
(b)除了第一层是用2 × 2的卷积核提取的,其余几层都是使用了大小为1×1的滤波,不能在这个过程中继续提升感受野,否则转表时不能等效。
这个过程中tensor的shape大小变化是:(b, 3, h, w) ->(b * 3, 1, h, w)->(b * 3, 64, h-1, w-1)->(b * 3, 16, h-1, w-1)->(b, 3, (h-1) * 4,(w-1) * 4)
这里pad放在预处理中了,因此h和w会比原来大1,因为SRLUT只能对空间进行查找,因此第一步需要将通道放在batch上处理,第二步使用2x2卷积将通道进行扩充并进行一系列1x1卷积变换,第三步变化成超分前所需要得通道数16,最后使用pixelshuffle进行超分。

  1. 转表:思想就是将所有输入的可能性进行计算,我们之前已经分析了不可能将所有的可能性进行转表,因此这里使用了以下几种可能性:{0,16,32,64,80,96,112,128,144,160,176,192,208,224,240,255},每一维都取这17个值,即当遍历完所有可能性之后并将它们的结果记录在LUT中。后续推理过程不再需要模型,只需要查找这个LUT即可。
    作者先构建了这个4D的所有可能性数组。
 # 1D input
 base = torch.arange(0, 257, 2**SAMPLING_INTERVAL)   # 0-256
 base[-1] -= 1
 L = base.size(0)

 # 2D input
 first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1)  # 256*256   0 0 0...    |1 1 1...     |...|255 255 255...
 second = base.cuda().repeat(L)                             # 256*256   0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
 onebytwo = torch.stack([first, second], 1)  # [256*256, 2]

 # 3D input
 third = base.cuda().unsqueeze(1).repeat(1, L*L).reshape(-1) # 256*256*256   0 x65536|1 x65536|...|255 x65536
 onebytwo = onebytwo.repeat(L, 1)
 onebythree = torch.cat([third.unsqueeze(1), onebytwo], 1)    # [256*256*256, 3]

 # 4D input
 fourth = base.cuda().unsqueeze(1).repeat(1, L*L*L).reshape(-1) # 256*256*256*256   0 x16777216|1 x16777216|...|255 x16777216
 onebythree = onebythree.repeat(L, 1)
 onebyfourth = torch.cat([fourth.unsqueeze(1), onebythree], 1)    # [256*256*256*256, 4]

 # Rearange input: [N, 4] -> [N, C=1, H=2, W=2]
 input_tensor = onebyfourth.unsqueeze(1).unsqueeze(1).reshape(-1,1,2,2).float() / 255.0

当然我们将他们放在B上即可,这样我们就得到了一个输入tensor,它得维度是(17 * 17 * 17 * 17,2,2),这个送入模型可以计算出结果,把这个结果进行保存就得到了我们得超分结果LUT表了,因为论文这里使用得是4倍超分,因此这个LUT表得大小可以计算得到17 * 17 * 17 * 17 * 4 * 4B=1.274MB。PS:如果看不明白的,可以自己对着每一步进行debug,每一步的输出可以看出来就是再不断的扩充维度。

  1. 测试:刚才我们已经研究了两种插值方式,在这里就需要用到,这里因为作者分析了使用更少得点去进行插值会更快,因此就直接讲这种方式,同时因为双线性、三次、四次这类比较常见,就不进行代码得讲解。作者代码是一个4D LUT,使用得插值是4-simplex插值,与我们前面得讲的原理是一致得。
    (a)首先利用旋转,得到4张图像,原图,旋转90和180以及270得图像,这样可以将2x2得感受野扩充至3x3。
    (b)然后进行pad,扩充后图像得宽度和高度+1,这与训练一致。
    (c)使用插值得到HR结果,并进行反向得旋转,旋转回原图。
    (d)将4个结果进行平均得到最终输出。
    由于LUT每次只可以查找2x2得一个小块,所以我们在查找时与进行卷积操作一样,是需要滑窗得,通过这个滑窗得过程来不断拿到2x2得小块,小块拿到后进行查找即可,比较困难得点在于如何进行查找。如下图所示:
 if fa > fb:
     if fb > fc:
         if fc > fd:
             out[c,y,x] = (q-fa) * p0000[c,y,x] + (fa-fb) * p1000[c,y,x] + (fb-fc) * p1100[c,y,x] + (fc-fd) * p1110[c,y,x] + (fd) * p1111[c,y,x]
         elif fb > fd:
             out[c,y,x] = (q-fa) * p0000[c,y,x] + (fa-fb) * p1000[c,y,x] + (fb-fd) * p1100[c,y,x] + (fd-fc) * p1101[c,y,x] + (fc) * p1111[c,y,x]
         elif fa > fd:
             out[c,y,x] = (q-fa) * p0000[c,y,x] + (fa-fd) * p1000[c,y,x] + (fd-fb) * p1001[c,y,x] + (fb-fc) * p1101[c,y,x] + (fc) * p1111[c,y,x]
         else:
             out[c,y,x] = (q-fd) * p0000[c,y,x] + (fd-fa) * p0001[c,y,x] + (fa-fb) * p1001[c,y,x] + (fb-fc) * p1101[c,y,x] + (fc) * p1111[c,y,x]
     elif fa > fc:
         if fb > fd:
             out[c,y,x] = (q-fa) * p0000[c,y,x] + (fa-fc) * p1000[c,y,x] + (fc-fb) * p1010[c,y,x] + (fb-fd) * p1110[c,y,x] + (fd) * p1111[c,y,x]
         elif fc > fd:
             out[c,y,x] = (q-fa) * p0000[c,y,x] + (fa-fc) * p1000[c,y,x] + (fc-fd) * p1010[c,y,x] + (fd-fb) * p1011[c,y,x] + (fb) * p1111[c,y,x]
         elif fa > fd:
             out[c,y,x] = (q-fa) * p0000[c,y,x] + (fa-fd) * p1000[c,y,x] + (fd-fc) * p1001[c,y,x] + (fc-fb) * p1011[c,y,x] + (fb) * p1111[c,y,x]
         else:
             out[c,y,x] = (q-fd) * p0000[c,y,x] + (fd-fa) * p0001[c,y,x] + (fa-fc) * p1001[c,y,x] + (fc-fb) * p1011[c,y,x] + (fb) * p1111[c,y,x]
     else:
         if fb > fd:
             out[c,y,x] = (q-fc) * p0000[c,y,x] + (fc-fa) * p0010[c,y,x] + (fa-fb) * p1010[c,y,x] + (fb-fd) * p1110[c,y,x] + (fd) * p1111[c,y,x]
         elif fc > fd:
             out[c,y,x] = (q-fc) * p0000[c,y,x] + (fc-fa) * p0010[c,y,x] + (fa-fd) * p1010[c,y,x] + (fd-fb) * p1011[c,y,x] + (fb) * p1111[c,y,x]
         elif fa > fd:
             out[c,y,x] = (q-fc) * p0000[c,y,x] + (fc-fd) * p0010[c,y,x] + (fd-fa) * p0011[c,y,x] + (fa-fb) * p1011[c,y,x] + (fb) * p1111[c,y,x]
         else:
             out[c,y,x] = (q-fd) * p0000[c,y,x] + (fd-fc) * p0001[c,y,x] + (fc-fa) * p0011[c,y,x] + (fa-fb) * p1011[c,y,x] + (fb) * p1111[c,y,x]

 else:
     if fa > fc:
         if fc > fd:
             out[c,y,x] = (q-fb) * p0000[c,y,x] + (fb-fa) * p0100[c,y,x] + (fa-fc) * p1100[c,y,x] + (fc-fd) * p1110[c,y,x] + (fd) * p1111[c,y,x]
         elif fa > fd:
             out[c,y,x] = (q-fb) * p0000[c,y,x] + (fb-fa) * p0100[c,y,x] + (fa-fd) * p1100[c,y,x] + (fd-fc) * p1101[c,y,x] + (fc) * p1111[c,y,x]
         elif fb > fd:
             out[c,y,x] = (q-fb) * p0000[c,y,x] + (fb-fd) * p0100[c,y,x] + (fd-fa) * p0101[c,y,x] + (fa-fc) * p1101[c,y,x] + (fc) * p1111[c,y,x]
         else:
             out[c,y,x] = (q-fd) * p0000[c,y,x] + (fd-fb) * p0001[c,y,x] + (fb-fa) * p0101[c,y,x] + (fa-fc) * p1101[c,y,x] + (fc) * p1111[c,y,x]
     elif fb > fc:
         if fa > fd:
             out[c,y,x] = (q-fb) * p0000[c,y,x] + (fb-fc) * p0100[c,y,x] + (fc-fa) * p0110[c,y,x] + (fa-fd) * p1110[c,y,x] + (fd) * p1111[c,y,x]
         elif fc > fd:
             out[c,y,x] = (q-fb) * p0000[c,y,x] + (fb-fc) * p0100[c,y,x] + (fc-fd) * p0110[c,y,x] + (fd-fa) * p0111[c,y,x] + (fa) * p1111[c,y,x]
         elif fb > fd:
             out[c,y,x] = (q-fb) * p0000[c,y,x] + (fb-fd) * p0100[c,y,x] + (fd-fc) * p0101[c,y,x] + (fc-fa) * p0111[c,y,x] + (fa) * p1111[c,y,x]
         else:
             out[c,y,x] = (q-fd) * p0000[c,y,x] + (fd-fb) * p0001[c,y,x] + (fb-fc) * p0101[c,y,x] + (fc-fa) * p0111[c,y,x] + (fa) * p1111[c,y,x]
     else:
         if fa > fd:
             out[c,y,x] = (q-fc) * p0000[c,y,x] + (fc-fb) * p0010[c,y,x] + (fb-fa) * p0110[c,y,x] + (fa-fd) * p1110[c,y,x] + (fd) * p1111[c,y,x]
         elif fb > fd:
             out[c,y,x] = (q-fc) * p0000[c,y,x] + (fc-fb) * p0010[c,y,x] + (fb-fd) * p0110[c,y,x] + (fd-fa) * p0111[c,y,x] + (fa) * p1111[c,y,x]
         elif fc > fd:
             out[c,y,x] = (q-fc) * p0000[c,y,x] + (fc-fd) * p0010[c,y,x] + (fd-fb) * p0011[c,y,x] + (fb-fa) * p0111[c,y,x] + (fa) * p1111[c,y,x]
         else:
             out[c,y,x] = (q-fd) * p0000[c,y,x] + (fd-fc) * p0001[c,y,x] + (fc-fb) * p0011[c,y,x] + (fb-fa) * p0111[c,y,x] + (fa) * p1111[c,y,x]

首先根据我们前面所讲的,对角线的P0000和P1111是一定会被选择的,另外就只需要选择3个点即可,这三个点的计算需要根据待插值点的位置,总共有24种可能性,fa、fb、fc、fd分别是待插值点4维的LSB值,q是量化系数,这里相当于用整型数再算,最后再除以q进行归一化,这里主要是根据文章中给出的逻辑,直接翻译成代码即可。
在这里插入图片描述

五、总结

  1. 效果上相比较DNN的方法还是有一定的差距,原因本人认为一是感受野极其有限,加上旋转也只有3 * 3,另外就是没有通道之间的联系,每个通道是单独操作的。
  2. 在使用的普遍性和轻量级这部分,优势比较明显。因为其从总体来看浮点的计算量是比较小的,比双线性会略大一些,也不需要特定的NPU或GPU去计算。但可以发现虽然其不依赖于GPU或NPU这类神经网络友好硬件,但是查表的过程对cache有一定得需求,当后续如果要对这类方法扩展时,就可能会遇到查表耗时的问题。当然在满足查表速度得前提下,是可以说用极其小的代价换了一个更好的超分结果的。

感谢阅读,欢迎留言或私信,一起探讨和交流,如果对你有帮助的话,也希望可以给博主点一个关注,谢谢。

本文章已经生成可运行项目
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值