【LUT技术专题】SPFLUT代码解读

目录

原文概要

1. 训练

2. 压缩并转表

3. 微调

4. 测试


本文是对SPFLUT技术的代码解读,原文解读请看SPFLUT。 

原文概要

SPFLUT方法重点在于对角线优先压缩策略,该方法总体流程分为4个部分,训练、转换(里面包含了压缩)、微调、测试。其代码的总体结构如下:

可以看到流程与MULUT基本一致,只不过在第二步转换之前还有一步对LUT进行压缩的过程,即2_compress_lut_from_net.py文件。另外第三步的微调中也有针对压缩后的LUT进行微调的代码。


1. 训练

这里我们可以从sr/model.py中,获取到SPF_LUT_net模型的代码实现如下:

class SPF_LUT_net(nn.Module):
    def __init__(self, nf=32, scale=4, modes=['s', 'd', 'y'], stages=2):
        super(SPF_LUT_net, self).__init__()
        self.upscale = scale
        self.modes = modes

        self.convblock1 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
        self.convblock2 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
        self.convblock3 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
        self.convblock4 = ConvBlock(1, 1, scale=None, output_quant=False, modes=modes, nf=nf)
        self.ChannelConv = MuLUTcUnit(in_c=4, out_c=4, mode='1x1', nf=nf)
        self.upblock = ConvBlock(4, 1, scale=scale, output_quant=False, modes=modes, nf=nf)


    def forward(self, x, phase='train'):
        B, C, H, W = x.size()
        x = x.reshape((B * C, 1, H, W))

        refine_list = []

        # block1
        x = self.convblock1(x)
        avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
        x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm

        refine_list.append(x[:, 0:1, :, :])
        x = x[:, 1:, :, :]

        # block2
        x = self.convblock2(x)
        avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
        x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm

        refine_list.append(x[:, 0:1, :, :])
        x = x[:, 1:, :, :]

        # block3
        x = self.convblock3(x)
        avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
        x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm

        refine_list.append(x[:, 0:1, :, :])
        x = x[:, 1:, :, :]

        # block4
        x = self.convblock4(x)
        avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
        x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
        refine_list.append(x)

        x = torch.cat(refine_list, dim=1)
        x = round_func(torch.tanh(self.ChannelConv(x)) * 127.0)
        x = round_func(torch.clamp(x + 127, 0, 255)) / 255.0

        x = self.upblock(x)
        avg_factor, bias, norm = len(self.modes), 0, 1
        x = round_func((x / avg_factor) + bias)

        if phase == 'train':
            x = x / 255.0
        x = x.reshape((B, C, self.upscale * H, self.upscale * W))

        return x

通过上述代码可以看出,SPFLUT模型主要由两个子模块构成,一个是ConvBlock,另一个是MuLUTcUnit,其中ConvBlock的实现如下:

class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, scale=None, output_quant=False, modes=['s', 'd', 'y'], nf=64):
        super(ConvBlock, self).__init__()
        self.in_c = in_c
        self.out_c = out_c
        self.modes = modes
        self.module_dict = dict()
        self.upscale = scale
        self.output_quant = output_quant

        scale_factor = 1 if scale is None else scale ** 2
        for c in range(in_c):
            for mode in modes:
                self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)] = MuLUTConv('{}x{}'.format(mode.upper(), 'N'),
                                                                                    nf=nf, out_c=out_c * scale_factor,
                                                                                    stride=1)
        self.module_dict = nn.ModuleDict(self.module_dict)
        if scale is None:
            self.pixel_shuffle = identity
        else:
            self.pixel_shuffle = nn.PixelShuffle(scale)

    def forward(self, x):
        modes = self.modes

        x_out = 0
        for c in range(self.in_c):
            x_c = x[:, c:c + 1, :, :]
            pred = 0
            for mode in modes:
                pad = mode_pad_dict[mode]
                sub_module = self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]
                for r in [0, 1, 2, 3]:
                    pred += round_func(torch.tanh(torch.rot90(self.pixel_shuffle(
                        sub_module(F.pad(torch.rot90(x_c, r, [2, 3]), (0, pad, 0, pad), mode='replicate'))),
                        (4 - r) % 4, [2, 3])) * 127)

            x_out += pred
        if self.output_quant:
            avg_factor = len(modes) * 4 * self.in_c
            x = round_func(torch.clamp(x_out / avg_factor, -1, 1) * 127) / 127
        else:
            x = x_out / self.in_c

        return x

也是由MuLUTConv构成的,位于common/network.py中,而这个模块我们在MuLUT论文代码讲解中有提到,是一个由3种不同类型S、D、Y的kernel组成的一个RF=3x3的模块,这里还需要旋转和clamp等操作,防止每层的结果溢出。

而MuLUTcUnit即是通道上的MuLUT模块,位于common/network.py中,因为只在通道上操作,因此kernel_size上是1,主要建立起特征通道之间的关联。

整体结构是比较清晰的,尤其是对MuLUT的子模块熟悉的情况下,同样的,不清楚的读者可以初始化一个模型来逐步推理tensor的shape来熟悉。


2. 压缩并转表

这部分代码位于2_compress_lut_from_net.py中,整体流程如下:

def compress_SPFLUT(opt):
    def save_SPFLUT_DFC(x, lut_path, module):
        # Split input to not over GPU memory
        B = x.size(0) // 100
        outputs = []

        # Extract input-output pairs
        with torch.no_grad():
            model_G.eval()
            for b in range(100):
                if b == 99:
                    batch_input = x[b * B:]
                else:
                    batch_input = x[b * B:(b + 1) * B]

                batch_output = module(batch_input)

                results = torch.round(torch.tanh(batch_output) * 127).cpu().data.numpy().astype(np.int8)
                outputs += [results]

        results = np.concatenate(outputs, 0)
        results = results.reshape(x.size(0), -1)
        np.save(lut_path, results)
        print("Resulting LUT size: ", results.shape, "Saved to", lut_path)


    modes = [i for i in opt.modes]
    stages = opt.stages

    model = getattr(Model, 'SPF_LUT_net')

    model_G = model(nf=opt.nf, scale=opt.scale, modes=modes, stages=stages).cuda()

    lm = torch.load(os.path.join(opt.expDir, 'Model_{:06d}.pth'.format(opt.loadIter)))
    model_G.load_state_dict(lm, strict=True)

    input_tensor = get_input_tensor(opt)
    for mode in modes:
        if opt.cd == 'xyzt':
            input_tensor_c1 = compress_lut_xyzt(opt, input_tensor)
        elif opt.cd == 'xyz':
            input_tensor_c1 = compress_lut_xyz(opt, input_tensor)
        elif opt.cd == 'xy':
            input_tensor_c1 = compress_lut(opt, input_tensor)
        else:
            raise ValueError
        input_tensor_c2 = compress_lut_larger_interval(opt, input_tensor)

        if mode != 's':
            input_tensor_c1 = get_mode_input_tensor(input_tensor_c1, mode)
            input_tensor_c2 = get_mode_input_tensor(input_tensor_c2, mode)

        # conv1
        module = model_G.convblock1.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
        lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 1, mode))
        save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
        lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 1, mode))
        save_SPFLUT_DFC(input_tensor_c2, lut_path, module)

        # conv2
        module = model_G.convblock2.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
        lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 2, mode))
        save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
        lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 2, mode))
        save_SPFLUT_DFC(input_tensor_c2, lut_path, module)

        # conv3
        module = model_G.convblock3.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
        lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 3, mode))
        save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
        lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 3, mode))
        save_SPFLUT_DFC(input_tensor_c2, lut_path, module)

        # conv4
        module = model_G.convblock4.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
        lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 4, mode))
        save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
        lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 4, mode))
        save_SPFLUT_DFC(input_tensor_c2, lut_path, module)

        # conv6
        for c in range(4):
            module = model_G.upblock.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]
            lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress1.npy'.format(opt.lutName, 6,c, mode))
            save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
            lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress2.npy'.format(opt.lutName, 6,c, mode))
            save_SPFLUT_DFC(input_tensor_c2, lut_path, module)

    # conv5
    input_tensor = input_tensor.reshape((-1,4,1,1))
    module = model_G.ChannelConv
    lut_path = os.path.join(opt.expDir, '{}_s{}_channel.npy'.format(opt.lutName, 5))
    save_SPFLUT_DFC(input_tensor, lut_path, module)

这里需要关注的细节是对角线压缩相关的3个函数:compress_lut_xyzt、compress_lut_xyz、compress_lut,对应于4维、3维和2维的压缩过程,以及非对角线压缩相关的函数compress_lut_larger_interval,最后我们可以发现对于通道的卷积conv5,作者是没有进行压缩的,因为通道conv不满足对角线先验,故不能进行对角线优先的压缩

针对于对角线相关的函数:以2维压缩为例,跟我们之前的讲解是一样的。

def compress_lut(opt, input_tensor):
    base = torch.arange(0, 257, 2 ** opt.interval)  # 0-256
    base[-1] -= 1
    L = base.size(0)
    d = opt.dw
    diag = 2 * d + 1
    N = diag * L + (1 - diag ** 2) // 4

    input_tensor = input_tensor.reshape(L * L, L, L, 1, 2, 2)
    index_i = torch.zeros((N,)).type(torch.int64)
    index_j = torch.zeros((N,)).type(torch.int64)
    cnt = 0
    ref2index = np.zeros((L, diag), dtype=np.int_) - 1
    for i in range(L):
        for j in range(L):
            if abs(i - j) <= d:
                index_i[cnt] = i
                index_j[cnt] = j
                ref2index[i, j - i] = cnt
                cnt += 1
    np.save(os.path.join(opt.expDir, 'ref2index_{}{}i{}.npy'.format(opt.cd, opt.dw, opt.si)),ref2index)
    index_compress = index_i * L + index_j
    compressed_input_tensor = input_tensor[index_compress, ...].reshape(-1, 1, 2, 2)
    return compressed_input_tensor

作者是通过改变input_tensor来实现这个过程,我们需要取到2维tensor,满足对角线距离条件的所有位置,那这里opt.dw(变量d)对应于我们前面讲解中提到的\lambda,满足条件的将其放入ref2index中,并使得cnt加1,这样我们可以将对角线的位置进行保存。

至于L,是我们前面一直在用的与间隔interval相关的个数,一般等于17(4bit采样)。而N是我们前面推理算过的索引的总个数K(大家可以带入diag来计算N,这样可以跟公式完全对应),至此2维的一个输入tensor就全部对应完毕,送入模型计算就可以了,这样子把对角线的位置进行了优先保存。

针对于非对角线的位置:看compress_lut_larger_interval函数,实现如下。


def compress_lut_larger_interval(opt, input_tensor):
    base = torch.arange(0, 257, 2 ** opt.interval)  # 0-256
    base[-1] -= 1
    L = base.size(0)
    input_tensor = input_tensor.reshape(L, L, L, L, 1, 2, 2)

    if opt.si==5:
        k = 2
    elif opt.si==6:
        k = 4
    elif opt.si==7:
        k = 8
    else:
        raise ValueError

    compressed_input_tensor = input_tensor[::k, ::k, ::k, ::k, ...].reshape(-1, 1, 2, 2)
    return compressed_input_tensor

比较简单,即选用一个更大的比例,因为我们前面已经使用了4bit来做间隔,那么当opt.si为5时,我们需要对当前的input_tensor做2间隔的采样就可以,之后都是同理可得。

针对于通道:那我们已经讲到了通道是不可以进行压缩的,因此它的input_tensor是不变的,跟之前一样,实现如下,这个过程我们是比较熟悉的,(如果一直有看LUT系列的文章。还不了解的可以关注一下LUT专题哦):

def get_input_tensor(opt):
    # 1D input
    base = torch.arange(0, 257, 2 ** opt.interval)  # 0-256
    base[-1] -= 1
    L = base.size(0)

    # 2D input
    # 256*256   0 0 0...    |1 1 1...     |...|255 255 255...
    first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1)
    # 256*256   0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
    second = base.cuda().repeat(L)
    onebytwo = torch.stack([first, second], 1)  # [256*256, 2]

    # 3D input
    # 256*256*256   0 x65536|1 x65536|...|255 x65536
    third = base.cuda().unsqueeze(1).repeat(1, L * L).reshape(-1)
    onebytwo = onebytwo.repeat(L, 1)
    onebythree = torch.cat(
        [third.unsqueeze(1), onebytwo], 1)  # [256*256*256, 3]

    # 4D input
    fourth = base.cuda().unsqueeze(1).repeat(1, L * L * L).reshape(
        -1)  # 256*256*256*256   0 x16777216|1 x16777216|...|255 x16777216
    onebythree = onebythree.repeat(L, 1)
    # [256*256*256*256, 4]
    onebyfourth = torch.cat([fourth.unsqueeze(1), onebythree], 1)

    # Rearange input: [N, 4] -> [N, C=1, H=2, W=2]
    input_tensor = onebyfourth.unsqueeze(1).unsqueeze(
        1).reshape(-1, 1, 2, 2).float() / 255.0
    return input_tensor

3. 微调

微调的部分其实跟MuLUT对比,无明显变化,主要还是看作者如何构建SPF_LUT模型,位置在sr/model.py中,代码如下:

class SPF_LUT(nn.Module):
    """ PyTorch version of MuLUT for LUT-aware fine-tuning. """

    def __init__(self, lut_folder, stages, modes, lutName, upscale, interval, phase=None, **kwargs):
        super(SPF_LUT, self).__init__()
        self.interval = interval
        self.upscale = upscale
        self.modes = modes
        self.stages = stages

        L = 2 ** (8 - interval) + 1


        for mode in modes:
            # conv1
            lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 1, mode))
            # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(1, mode))
            key = "s{}c0_{}".format(1, mode)
            lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
            self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))

            # conv2
            lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 2, mode))
            # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(2, mode))
            key = "s{}c0_{}".format(2, mode)
            lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
            self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))

            # conv3
            lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 3, mode))
            # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(3, mode))
            key = "s{}c0_{}".format(3, mode)
            lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
            self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))

            # conv4
            lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 4, mode))
            # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(4, mode))
            key = "s{}c0_{}".format(4, mode)
            lut_arr = np.load(lut_path).reshape((-1, 1)).astype(np.float32) / 127.0
            self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))

            for c in range(4):
                # conv6
                lut_path = os.path.join(lut_folder, '{}_s{}c{}_{}.npy'.format(lutName, 6,c, mode))
                # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c{}_{}.npy'.format(6,c, mode))
                key = "s{}c{}_{}".format(6,c, mode)
                lut_arr = np.load(lut_path).reshape((-1, self.upscale * self.upscale)).astype(np.float32) / 127.0
                self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))

        # conv5
        lut_path = os.path.join(lut_folder, '{}_s{}_channel.npy'.format(lutName, 5))
        # lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}_channel.npy'.format(5))
        key = "s{}_channel".format(5)
        lut_arr = np.load(lut_path).reshape((-1, 4)).astype(np.float32) / 127.0
        self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))

你会发现,其实跟MuLUT一样,将LUT给register成可训练的parameter,这样子去做一个微调。


4. 测试

测试的部分因为我们的LUT做了改变,修改为了对角线和非对角线,因此在最后的查表推理的部分需要做一些改变,以对角线做2维压缩为例,在sr/4_test_SPF_LUT_DFC.py中。

def InterpTorchBatch_compress_xy(weight, img_in, h, w, interval, rot, d, upscale=4, out_c=1, mode='s',ref2index=None):
    q = 2 ** interval  # 16
    L = 2 ** (8 - interval) + 1  # 17

    diag = 2 * d + 1
    N = diag * L + (1 - diag ** 2) // 4

    if mode == "s":
        img_x = img_in[:, :, 0:0 + h, 0:0 + w]
        img_y = img_in[:, :, 0:0 + h, 1:1 + w]
        index_flag = (np.abs(img_x - img_y) <= d * q)

        # Extract MSBs
        img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
        img_b1 = img_in[:, :, 0:0 + h, 1:1 + w] // q
        img_c1 = img_in[:, :, 1:1 + h, 0:0 + w] // q
        img_d1 = img_in[:, :, 1:1 + h, 1:1 + w] // q

        # Extract LSBs
        fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
        fb = img_in[:, :, 0:0 + h, 1:1 + w] % q
        fc = img_in[:, :, 1:1 + h, 0:0 + w] % q
        fd = img_in[:, :, 1:1 + h, 1:1 + w] % q

    elif mode == 'd':
        img_x = img_in[:, :, 0:0 + h, 0:0 + w]
        img_y = img_in[:, :, 0:0 + h, 2:2 + w]
        index_flag = (np.abs(img_x - img_y) <= d * q)

        img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
        img_b1 = img_in[:, :, 0:0 + h, 2:2 + w] // q
        img_c1 = img_in[:, :, 2:2 + h, 0:0 + w] // q
        img_d1 = img_in[:, :, 2:2 + h, 2:2 + w] // q

        fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
        fb = img_in[:, :, 0:0 + h, 2:2 + w] % q
        fc = img_in[:, :, 2:2 + h, 0:0 + w] % q
        fd = img_in[:, :, 2:2 + h, 2:2 + w] % q

    elif mode == 'y':
        img_x = img_in[:, :, 0:0 + h, 0:0 + w]
        img_y = img_in[:, :, 1:1 + h, 1:1 + w]
        index_flag = (np.abs(img_x - img_y) <= d * q)

        img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
        img_b1 = img_in[:, :, 1:1 + h, 1:1 + w] // q
        img_c1 = img_in[:, :, 1:1 + h, 2:2 + w] // q
        img_d1 = img_in[:, :, 2:2 + h, 1:1 + w] // q

        fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
        fb = img_in[:, :, 1:1 + h, 1:1 + w] % q
        fc = img_in[:, :, 1:1 + h, 2:2 + w] % q
        fd = img_in[:, :, 2:2 + h, 1:1 + w] % q
    else:
        # more sampling modes can be implemented similarly
        raise ValueError("Mode {} not implemented.".format(mode))

    img_a1 = img_a1[index_flag].flatten().astype(np.int_)
    img_b1 = img_b1[index_flag].flatten().astype(np.int_)
    img_c1 = img_c1[index_flag].flatten().astype(np.int_)
    img_d1 = img_d1[index_flag].flatten().astype(np.int_)

    fa = fa[index_flag].flatten()
    fb = fb[index_flag].flatten()
    fc = fc[index_flag].flatten()
    fd = fd[index_flag].flatten()

    img_a2 = img_a1 + 1
    img_b2 = img_b1 + 1
    img_c2 = img_c1 + 1
    img_d2 = img_d1 + 1

    k00 = ref2index[img_a1, img_b1 - img_a1]
    k01 = ref2index[img_a1, img_b2 - img_a1]
    k10 = ref2index[img_a2, img_b1 - img_a2]
    k11 = ref2index[img_a2, img_b2 - img_a2]

    p0000 = weight[k00,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
    p0001 = weight[k00,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
    p0010 = weight[k00,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
    p0011 = weight[k00,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
    p0100 = weight[k01,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
    p0101 = weight[k01,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
    p0110 = weight[k01,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
    p0111 = weight[k01,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))

    p1000 = weight[k10,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
    p1001 = weight[k10,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
    p1010 = weight[k10,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
    p1011 = weight[k10,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
    p1100 = weight[k11,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
    p1101 = weight[k11,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
    p1110 = weight[k11,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
    p1111 = weight[k11,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))

    # Output image holder
    out = np.zeros((img_a1.shape[0],out_c, upscale, upscale))
    sz = img_a1.shape[0]
    out = out.reshape(sz, -1)

    p0000 = p0000.reshape(sz, -1)
    p0100 = p0100.reshape(sz, -1)
    p1000 = p1000.reshape(sz, -1)
    p1100 = p1100.reshape(sz, -1)
    fa = fa.reshape(-1, 1)

    p0001 = p0001.reshape(sz, -1)
    p0101 = p0101.reshape(sz, -1)
    p1001 = p1001.reshape(sz, -1)
    p1101 = p1101.reshape(sz, -1)
    fb = fb.reshape(-1, 1)
    fc = fc.reshape(-1, 1)

    p0010 = p0010.reshape(sz, -1)
    p0110 = p0110.reshape(sz, -1)
    p1010 = p1010.reshape(sz, -1)
    p1110 = p1110.reshape(sz, -1)
    fd = fd.reshape(-1, 1)

    p0011 = p0011.reshape(sz, -1)
    p0111 = p0111.reshape(sz, -1)
    p1011 = p1011.reshape(sz, -1)
    p1111 = p1111.reshape(sz, -1)

    fab = fa > fb;
    fac = fa > fc;
    fad = fa > fd

    fbc = fb > fc;
    fbd = fb > fd;
    fcd = fc > fd

    i1 = i = np.logical_and.reduce((fab, fbc, fcd)).squeeze(1)
    # print(p0000[i].shape,fa[i].shape,i.shape,out_c)
    out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[
        i] + (fd[i]) * p1111[i]
    i2 = i = np.logical_and.reduce((~i1[:, None], fab, fbc, fbd)).squeeze(1)
    out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[
        i] + (fc[i]) * p1111[i]
    i3 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], fab, fbc, fad)).squeeze(1)
    out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[
        i] + (fc[i]) * p1111[i]
    i4 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], ~i3[:, None], fab, fbc)).squeeze(1)

    out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[
        i] + (fc[i]) * p1111[i]

    i5 = i = np.logical_and.reduce((~(fbc), fab, fac, fbd)).squeeze(1)
    out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[
        i] + (fd[i]) * p1111[i]
    i6 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], fab, fac, fcd)).squeeze(1)
    out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[
        i] + (fb[i]) * p1111[i]
    i7 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], fab, fac, fad)).squeeze(1)
    out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[
        i] + (fb[i]) * p1111[i]
    i8 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac)).squeeze(1)
    out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[
        i] + (fb[i]) * p1111[i]

    i9 = i = np.logical_and.reduce((~(fbc), ~(fac), fab, fbd)).squeeze(1)
    out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[
        i] + (fd[i]) * p1111[i]
    # Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!
    # i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], fab, fcd)).squeeze(1)
    # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
    # i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad)).squeeze(1)
    # out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
    i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], fab, fad)).squeeze(1)  # c > a > d > b
    out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[
        i] + (fb[i]) * p1111[i]
    i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd)).squeeze(1)  # c > d > a > b
    out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[
        i] + (fb[i]) * p1111[i]
    i12 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab)).squeeze(1)
    out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[
        i] + (fb[i]) * p1111[i]

    i13 = i = np.logical_and.reduce((~(fab), fac, fcd)).squeeze(1)
    out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[
        i] + (fd[i]) * p1111[i]
    i14 = i = np.logical_and.reduce((~(fab), ~i13[:, None], fac, fad)).squeeze(1)
    out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[
        i] + (fc[i]) * p1111[i]
    i15 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], fac, fbd)).squeeze(1)
    out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[
        i] + (fc[i]) * p1111[i]
    i16 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], ~i15[:, None], fac)).squeeze(1)
    out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[
        i] + (fc[i]) * p1111[i]

    i17 = i = np.logical_and.reduce((~(fab), ~(fac), fbc, fad)).squeeze(1)
    out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[
        i] + (fd[i]) * p1111[i]
    i18 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], fbc, fcd)).squeeze(1)
    out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[
        i] + (fa[i]) * p1111[i]
    i19 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], fbc, fbd)).squeeze(1)
    out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[
        i] + (fa[i]) * p1111[i]
    i20 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], ~i19[:, None], fbc)).squeeze(1)
    out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[
        i] + (fa[i]) * p1111[i]

    i21 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), fad)).squeeze(1)
    out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[
        i] + (fd[i]) * p1111[i]
    i22 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], fbd)).squeeze(1)
    out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[
        i] + (fa[i]) * p1111[i]
    i23 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd)).squeeze(1)
    out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[
        i] + (fa[i]) * p1111[i]
    i24 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None])).squeeze(1)
    out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[
        i] + (fa[i]) * p1111[i]

    out = out / q
    return out,index_flag

可以看到查表之前,需要计算一个index_flag,index_flag的定义即是否满足对角线条件,如果满足对角线条件就是通过对角线LUT去查表,否则我们是采用非对角线的LUT去查表,具体的逻辑大家可以去捋一捋,博主认为实际运行也很少会使用python去跑。


以上针对于SPFLUT代码实现的部分讲解完毕,如果有不清楚的问题欢迎大家提出。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值