【LUT技术专题】SVDLUT代码讲解

本文是对SVDLUT技术的代码讲解,原文解读请看SVDLUT文章讲解

1、原文概要

SVDLUT最大的创新点在于其将3DLUT的查找过程使用SVD转换为了2DLUT的多次查找求和,减小了参数量和计算量。流程如图所示:
在这里插入图片描述
类比SABLUT方法上是相似的,都存在双边grid的插值和颜色的LUT插值操作,不过因为查找维度变小了,因此作者额外引入了 G r i d   w e i g h t s Grid \ weights Grid weights以及 L U T   w e i g h t s LUT \ weights LUT weights用于对多个2DLUT的插值结果进行加权求和,从而完成与3DLUT相似的功能。

2、代码结构

代码整体结构如下
在这里插入图片描述

主要关注模型部分的实现即可。

3 、核心代码模块

models.py 文件

整体如下所示:

class SVDLUT(nn.Module):
    def __init__(self, backbone_type='cnn', backbone_coef=8,
                 lut_n_vertices=17, lut_n_ranks=24,  
                 grid_n_vertices=17, grid_n_ranks=24, ch_per_grid=2,
                 lut_weight_ranks=8, grid_weight_ranks=8,
                 lut_n_singular=8, grid_n_singular=8):
        super(SVDLUT, self).__init__()
        
        self.backbone_type = backbone_type.lower()
        
        if backbone_type.lower() == 'resnet':
            self.backbone = resnet18_224()
            print('Resnet backbone apply')
            n_feats = 512
        else:
            self.backbone = Backbone(backbone_coef=backbone_coef)
            print('CNN backbone apply')
            n_feats = 32*backbone_coef

        self.gen_2d_lut = Gen_2D_SVD_LUT(n_vertices=lut_n_vertices, n_feats=n_feats, n_ranks=lut_n_ranks, n_singlar=lut_n_singular) 
        self.gen_2d_lut_weight_bias = Gen_2D_LUT_weight_bias(n_vertices=lut_n_vertices, n_feats=n_feats, n_ranks=lut_weight_ranks)
        self.gen_2d_bilateral = Gen_2D_bilateral_grids(n_vertices=grid_n_vertices, n_feats=n_feats, n_ranks=grid_n_ranks, ch_per_grid=ch_per_grid)

        self.gen_2d_grid_weight_bias =Gen_2D_bilateral_grids_weight_bias(n_vertices=grid_n_vertices, n_feats=n_feats, n_ranks=grid_weight_ranks, ch_per_grid=ch_per_grid)
        

        self.slicing_transform = bilinear_2Dslicing_lut_transform

        self.relu = nn.ReLU()
    
    def init_weights(self):
        def special_initilization(m):
            classname = m.__class__.__name__
            if 'Conv' in classname:
                nn.init.xavier_normal_(m.weight.data)
            elif 'InstanceNorm' in classname:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0.0)
        
        if self.backbone_type != 'resnet':
            self.backbone.apply(special_initilization)
            
        self.gen_2d_lut.init_weights()
        self.gen_2d_lut_weight_bias.init_weights()
        self.gen_2d_bilateral.init_weights()
        self.gen_2d_grid_weight_bias.init_weights()
       
    def forward(self, img):
        img_feature = self.backbone(img)
        
        g3d_lut, lut_weights = self.gen_2d_lut(img_feature)
        lut_param_weights, lut_param_bias = self.gen_2d_lut_weight_bias(img_feature)
        
        gbilateral, grid_weights = self.gen_2d_bilateral(img_feature)
        grid_param_weights, grid_param_bias = self.gen_2d_grid_weight_bias(img_feature)
        

        output = self.slicing_transform(gbilateral, img, grid_param_weights, grid_param_bias, g3d_lut, lut_param_weights, lut_param_bias)
        output = self.relu(output)
        return output, lut_weights, grid_weights, g3d_lut, gbilateral

整体跟我们在文章讲解中一样,先提取特征,然后生成2DLUT和grid以及它们对应的weights和bias用于加权,经过一个统一的slice插值模块得到最终结果,整个过程包含以下几类。

1. Backbone类

这里用于提取图像特征,没有什么特别的。

class Backbone(nn.Module):
    def __init__(self, backbone_coef=8):
        super(Backbone, self).__init__()
        self.backbone_coef = backbone_coef
        self.model = nn.Sequential(
            nn.Upsample(size=(256,256),mode='bilinear'),
            nn.Conv2d(3, backbone_coef, 3, stride=2, padding=1), #8 x 128 x 128
            nn.LeakyReLU(0.2),
            nn.InstanceNorm2d(backbone_coef, affine=True),
            *discriminator_block(backbone_coef, 2*backbone_coef, normalization=True), #16 x 64 x 64
            *discriminator_block(2*backbone_coef, 4*backbone_coef, normalization=True), #32 x 32 x 32
            *discriminator_block(4*backbone_coef, 8*backbone_coef, normalization=True), #64 x 16 x 16
            *discriminator_block(8*backbone_coef, 8*backbone_coef),   #64 x 8 x 8
            #*discriminator_block(128, 128, normalization=True),
            nn.Dropout(p=0.5),
            nn.AvgPool2d(5, stride=2) #64 x 2 x 2
        )

    def forward(self, img_input):

        return self.model(img_input).view([-1,self.backbone_coef*32])
2. Gen_2D_SVD_LUT和Gen_2D_LUT_weight_bias类

用于生成3DLUT。

class Gen_2D_SVD_LUT(nn.Module):
    def __init__(self, n_colors=3, ch_per_lut = 3, n_lut_dim=2, n_vertices=17, n_feats=256, n_ranks=24, n_singlar=8):
        super(Gen_2D_SVD_LUT, self).__init__()
        
        # h0
        self.weights_generator = nn.Linear(n_feats, n_ranks)
        
        self.n_svd = n_vertices * n_singlar + n_singlar + n_singlar * n_vertices
        # h1
        self.basis_luts_bank = nn.Linear(
            n_ranks, n_colors * ch_per_lut * self.n_svd)

        self.n_colors = n_colors
        self.n_vertices = n_vertices
        self.n_feats = n_feats
        self.n_ranks = n_ranks
        self.ch_per_lut = ch_per_lut
        self.n_singlar = n_singlar
    
    def init_weights(self):
        r"""Init weights for models.

        For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
            [3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).

        """
        nn.init.ones_(self.weights_generator.bias)
        nn.init.zeros_(self.basis_luts_bank.bias)
        cols, rows = torch.stack(torch.meshgrid(*[torch.arange(self.n_vertices) for _ in range(2)]),dim=0).div(self.n_vertices - 1).flip(0)
        zero2d = torch.zeros(self.n_vertices, self.n_vertices)
        d = torch.stack([cols,cols,zero2d, 
                         rows,zero2d,cols,
                         zero2d,rows,rows], dim=0)
        
        u,s,v =torch.svd(d)
        
        u = u[:,:,:self.n_singlar].contiguous().view([3*self.ch_per_lut,-1])
        s = s[:,:self.n_singlar]
        v = v[:,:,:self.n_singlar].mT.contiguous().view([3*self.ch_per_lut,-1])
        
        d= torch.cat([u,s,v], dim=1)
        
        identity_lut = torch.stack([d,*[torch.zeros(3 * self.ch_per_lut, self.n_svd) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)
        
        self.basis_luts_bank.weight.data.copy_(identity_lut.t())
       
    def forward(self, img_feature):
        weights = self.weights_generator(img_feature)
        lut_svd = self.basis_luts_bank(weights)

        lut_svd = lut_svd.view([-1, self.n_svd])
        
        lut_u = lut_svd[:,:self.n_vertices * self.n_singlar]
        lut_s = lut_svd[:,self.n_vertices * self.n_singlar:self.n_vertices * self.n_singlar + self.n_singlar]
        lut_v = lut_svd[:,self.n_vertices * self.n_singlar + self.n_singlar:]
        
        lut_u = lut_u.view([-1, self.n_vertices, self.n_singlar])
        lut_s = torch.diag_embed(lut_s)
        lut_v = lut_v.view([-1, self.n_singlar, self.n_vertices])
        
        luts = torch.bmm(torch.bmm(lut_u,lut_s), lut_v)
        
        luts = luts.view([-1,self.n_colors, self.ch_per_lut, self.n_vertices,self.n_vertices])
        return luts, weights

利用svd还原多个2DLUT,每个2DLUT包含了n_vertices个顶点,缺少的一个维度转换为ch_per_lut,否则无法得到多个2DLUT的组合。

class Gen_2D_LUT_weight_bias(nn.Module):
    def __init__(self, n_colors=3, ch_per_lut = 3, n_vertices=17, n_feats=256, n_ranks=24):
        super(Gen_2D_LUT_weight_bias, self).__init__()
        
        # h0
        self.weights_generator = nn.Linear(n_feats, n_ranks)
        # h1
        self.basis_luts_bank = nn.Linear(
            n_ranks, n_colors * (ch_per_lut + 1))

        self.n_colors = n_colors
        self.n_vertices = n_vertices
        self.n_feats = n_feats
        self.n_ranks = n_ranks
        self.ch_per_lut = ch_per_lut
  
    
    def init_weights(self):
        nn.init.ones_(self.weights_generator.bias)
        nn.init.zeros_(self.basis_luts_bank.bias)
        
        d = torch.tensor([[0.5,0.5,0,0],
                         [0.5,0,0.5,0],
                         [0,0.5,0.5,0]])
        
        
        identity_lut = torch.stack([d,
            *[torch.zeros(self.n_colors, self.ch_per_lut + 1) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)
        self.basis_luts_bank.weight.data.copy_(identity_lut.t())
       
    def forward(self, img_feature):
        weights = self.weights_generator(img_feature)
        weights_bias = self.basis_luts_bank(weights)

        weights_bias = weights_bias.view([-1,self.n_colors, self.ch_per_lut + 1])
        lut_param_weights = weights_bias[:, :, :self.n_colors]
        lut_param_bias = weights_bias[:, :, self.n_colors:]
        return lut_param_weights, lut_param_bias

生成加权系数,跟LUT数目有关,参数量非常小。

3. Gen_2D_bilateral_grids和Gen_2D_bilateral_grids_weight_bias类

用于生成双边网格。

class Gen_2D_bilateral_grids(nn.Module):
    def __init__(self, n_grid_dim=2, n_vertices=17, n_feats=256, n_ranks=24, ch_per_grid=2):
        super(Gen_2D_bilateral_grids, self).__init__()
        
        # h0
        self.weights_generator = nn.Linear(n_feats, n_ranks)
        # h1
        self.basis_grids_bank = nn.Linear(
            n_ranks, ch_per_grid * 3 * 3 * (n_vertices ** n_grid_dim))

        self.n_grid_dim = n_grid_dim
        self.n_vertices = n_vertices
        self.n_feats = n_feats
        self.n_ranks = n_ranks
        self.ch_per_grid = ch_per_grid
        self.n_grids = ch_per_grid * 3
  
    
    def init_weights(self):
        r"""Init weights for models.

        For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
            [3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).

        """
        nn.init.ones_(self.weights_generator.bias)
        nn.init.zeros_(self.basis_grids_bank.bias)
        cols, rows = torch.stack(torch.meshgrid(*[torch.arange(self.n_vertices) for _ in range(2)]),dim=0).div(self.n_vertices - 1).flip(0)
        zero2d = torch.zeros(self.n_vertices, self.n_vertices)
        d = torch.stack([*[zero2d,rows,rows, 
                         zero2d,rows,rows,
                         zero2d,rows,rows] * self.ch_per_grid], dim=0)
        identity_grid = torch.stack([d,*[torch.zeros(self.n_grids * 3,self.n_vertices, self.n_vertices) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)
        self.basis_grids_bank.weight.data.copy_(identity_grid.t())
        
    def forward(self, img_feature):
        weights = self.weights_generator(img_feature)
        grids = self.basis_grids_bank(weights)
        grids = grids.view([-1,self.n_grids,3,self.n_vertices,self.n_vertices])
        return grids, weights

class Gen_2D_bilateral_grids_weight_bias(nn.Module):
    def __init__(self, n_colors=3,  ch_per_grid=2, n_vertices=17, n_feats=256, n_ranks=24):
        super(Gen_2D_bilateral_grids_weight_bias, self).__init__()
        
        # h0
        self.weights_generator = nn.Linear(n_feats, n_ranks)
        # h1
        self.basis_luts_bank = nn.Linear(
            n_ranks, ch_per_grid * (3 * n_colors  + n_colors))

        self.n_colors = n_colors
        self.n_vertices = n_vertices
        self.n_feats = n_feats
        self.n_ranks = n_ranks
        self.ch_per_grid = ch_per_grid

  
    
    def init_weights(self):
        r"""Init weights for models.

        For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
            [3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).

        """
        nn.init.ones_(self.weights_generator.bias)
        nn.init.zeros_(self.basis_luts_bank.bias)
        d = torch.tensor([*[[0,1,1,0],
                         [0,1,1,0],
                         [0,1,1,0]] * self.ch_per_grid]).div(self.ch_per_grid * 2)
       
        
        identity_lut = torch.stack([d,
            *[torch.zeros(3*self.ch_per_grid, self.n_colors + 1) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)
        self.basis_luts_bank.weight.data.copy_(identity_lut.t())
       
    def forward(self, img_feature):
        weights = self.weights_generator(img_feature)
        weights_bias = self.basis_luts_bank(weights)
        
        weights_bias = weights_bias.view([-1,self.ch_per_grid, 3 *self.n_colors + self.n_colors])
        
        grid_param_weights = weights_bias[:, :, : 3 * self.n_colors]
        grid_param_bias = weights_bias[:, :, 3 * self.n_colors:]
        return grid_param_weights, grid_param_bias      

grid的生成不适用于svd,因为性能会下降,因此作者这里是直接用linear给出的。

4. bilinear_2Dslicing_lut_transform函数

在kernel_code/bilateral_slicing/src/trilinear2D_slice_LUTTransform_cpu.cpp中可以看到源码。

void TriLinearCPU2DSliceAndLUTTransformForward(const int nthreads,
                                               const scalar_t *grid,
                                               const scalar_t *image,
                                               const scalar_t *grid_weights,
                                               const scalar_t *grid_bias,
                                               const scalar_t *lut,
                                               const scalar_t *lut_weights,
                                               const scalar_t *lut_bias,
                                               scalar_t *output,
                                               const int grid_dim,
                                               const int grid_shift,
                                               const scalar_t grid_binsize,
                                               const int lut_dim,
                                               const int lut_shift,
                                               const scalar_t lut_binsize,
                                               const int width,
                                               const int height,
                                               const int num_channels,
                                               const int grid_per_ch)
{
    for (int index = 0; index < nthreads; index++)
    {

        const int x_ = index % width;
        const int y_ = index / width;

        const scalar_t x = x_ / (width - 1);
        const scalar_t y = y_ / (height - 1);

        const scalar_t r = image[index];
        const scalar_t g = image[index + width * height];
        const scalar_t b = image[index + width * height * 2];

        const int32_t x_id = clamp((int32_t)floor(x * (grid_dim - 1)), 0, grid_dim - 2);
        const int32_t y_id = clamp((int32_t)floor(y * (grid_dim - 1)), 0, grid_dim - 2);

        int32_t r_id = clamp((int32_t)floor(r * (grid_dim - 1)), 0, grid_dim - 2);
        int32_t g_id = clamp((int32_t)floor(g * (grid_dim - 1)), 0, grid_dim - 2);
        int32_t b_id = clamp((int32_t)floor(b * (grid_dim - 1)), 0, grid_dim - 2);

        const scalar_t x_d = (x - grid_binsize * x_id) / grid_binsize;
        const scalar_t y_d = (y - grid_binsize * y_id) / grid_binsize;

        scalar_t r_d = (r - grid_binsize * r_id) / grid_binsize;
        scalar_t g_d = (g - grid_binsize * g_id) / grid_binsize;
        scalar_t b_d = (b - grid_binsize * b_id) / grid_binsize;

        const int id00_xy = (x_id) + (y_id)*grid_dim;
        const int id10_xy = (x_id + 1) + (y_id)*grid_dim;
        const int id01_xy = (x_id) + (y_id + 1) * grid_dim;
        const int id11_xy = (x_id + 1) + (y_id + 1) * grid_dim;

        const int id00_xr = (x_id) + (r_id)*grid_dim;
        const int id10_xr = (x_id + 1) + (r_id)*grid_dim;
        const int id01_xr = (x_id) + (r_id + 1) * grid_dim;
        const int id11_xr = (x_id + 1) + (r_id + 1) * grid_dim;

        const int id00_yr = (y_id) + (r_id)*grid_dim;
        const int id10_yr = (y_id + 1) + (r_id)*grid_dim;
        const int id01_yr = (y_id) + (r_id + 1) * grid_dim;
        const int id11_yr = (y_id + 1) + (r_id + 1) * grid_dim;

        const int id00_xg = (x_id) + (g_id)*grid_dim;
        const int id10_xg = (x_id + 1) + (g_id)*grid_dim;
        const int id01_xg = (x_id) + (g_id + 1) * grid_dim;
        const int id11_xg = (x_id + 1) + (g_id + 1) * grid_dim;

        const int id00_yg = (y_id) + (g_id)*grid_dim;
        const int id10_yg = (y_id + 1) + (g_id)*grid_dim;
        const int id01_yg = (y_id) + (g_id + 1) * grid_dim;
        const int id11_yg = (y_id + 1) + (g_id + 1) * grid_dim;

        const int id00_xb = (x_id) + (b_id)*grid_dim;
        const int id10_xb = (x_id + 1) + (b_id)*grid_dim;
        const int id01_xb = (x_id) + (b_id + 1) * grid_dim;
        const int id11_xb = (x_id + 1) + (b_id + 1) * grid_dim;

        const int id00_yb = (y_id) + (b_id)*grid_dim;
        const int id10_yb = (y_id + 1) + (b_id)*grid_dim;
        const int id01_yb = (y_id) + (b_id + 1) * grid_dim;
        const int id11_yb = (y_id + 1) + (b_id + 1) * grid_dim;

        const scalar_t w00_xy = (1 - x_d) * (1 - y_d);
        const scalar_t w10_xy = (x_d) * (1 - y_d);
        const scalar_t w01_xy = (1 - x_d) * (y_d);
        const scalar_t w11_xy = (x_d) * (y_d);

        const scalar_t w00_xr = (1 - x_d) * (1 - r_d);
        const scalar_t w10_xr = (x_d) * (1 - r_d);
        const scalar_t w01_xr = (1 - x_d) * (r_d);
        const scalar_t w11_xr = (x_d) * (r_d);

        const scalar_t w00_yr = (1 - y_d) * (1 - r_d);
        const scalar_t w10_yr = (y_d) * (1 - r_d);
        const scalar_t w01_yr = (1 - y_d) * (r_d);
        const scalar_t w11_yr = (y_d) * (r_d);

        const scalar_t w00_xg = (1 - x_d) * (1 - g_d);
        const scalar_t w10_xg = (x_d) * (1 - g_d);
        const scalar_t w01_xg = (1 - x_d) * (g_d);
        const scalar_t w11_xg = (x_d) * (g_d);

        const scalar_t w00_yg = (1 - y_d) * (1 - g_d);
        const scalar_t w10_yg = (y_d) * (1 - g_d);
        const scalar_t w01_yg = (1 - y_d) * (g_d);
        const scalar_t w11_yg = (y_d) * (g_d);

        const scalar_t w00_xb = (1 - x_d) * (1 - b_d);
        const scalar_t w10_xb = (x_d) * (1 - b_d);
        const scalar_t w01_xb = (1 - x_d) * (b_d);
        const scalar_t w11_xb = (x_d) * (b_d);

        const scalar_t w00_yb = (1 - y_d) * (1 - b_d);
        const scalar_t w10_yb = (y_d) * (1 - b_d);
        const scalar_t w01_yb = (1 - y_d) * (b_d);
        const scalar_t w11_yb = (y_d) * (b_d);

        scalar_t int_img[3] = {
            0,
        };

        for (int i = 0; i < grid_per_ch; ++i)
        {
            int_img[0] = int_img[0] + grid_weights[3 * (i + grid_per_ch * 0)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)]) +

                         grid_weights[3 * (i + grid_per_ch * 0) + 1] * (w00_xr * grid[id00_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
                                                                        w10_xr * grid[id10_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
                                                                        w01_xr * grid[id01_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
                                                                        w11_xr * grid[id11_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)]) +

                         grid_weights[3 * (i + grid_per_ch * 0) + 2] * (w00_yr * grid[id00_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
                                                                        w10_yr * grid[id10_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
                                                                        w01_yr * grid[id01_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
                                                                        w11_yr * grid[id11_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)]) +
                         grid_bias[(i + grid_per_ch * 0)];

            int_img[1] = int_img[1] + grid_weights[3 * (i + grid_per_ch * 1)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)]) +

                         grid_weights[3 * (i + grid_per_ch * 1) + 1] * (w00_xg * grid[id00_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +
                                                                        w10_xg * grid[id10_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +
                                                                        w01_xg * grid[id01_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +
                                                                        w11_xg * grid[id11_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)]) +

                         grid_weights[3 * (i + grid_per_ch * 1) + 2] * (w00_yg * grid[id00_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +
                                                                        w10_yg * grid[id10_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +
                                                                        w01_yg * grid[id01_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +
                                                                        w11_yg * grid[id11_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)]) +
                         grid_bias[(i + grid_per_ch * 1)];

            int_img[2] = int_img[2] + grid_weights[3 * (i + grid_per_ch * 2)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)]) +

                         grid_weights[3 * (i + grid_per_ch * 2) + 1] * (w00_xb * grid[id00_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +
                                                                        w10_xb * grid[id10_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +
                                                                        w01_xb * grid[id01_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +
                                                                        w11_xb * grid[id11_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)]) +

                         grid_weights[3 * (i + grid_per_ch * 2) + 2] * (w00_yb * grid[id00_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +
                                                                        w10_yb * grid[id10_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +
                                                                        w01_yb * grid[id01_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +
                                                                        w11_yb * grid[id11_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)]) +
                         grid_bias[(i + grid_per_ch * 2)];
        }

        r_id = clamp((int32_t)floor(r * (lut_dim - 1)), 0, lut_dim - 2);
        g_id = clamp((int32_t)floor(g * (lut_dim - 1)), 0, lut_dim - 2);
        b_id = clamp((int32_t)floor(b * (lut_dim - 1)), 0, lut_dim - 2);

        r_d = (r - lut_binsize * r_id) / lut_binsize;
        g_d = (g - lut_binsize * g_id) / lut_binsize;
        b_d = (b - lut_binsize * b_id) / lut_binsize;

        const int id00_rg = r_id + g_id * lut_dim;
        const int id10_rg = r_id + 1 + g_id * lut_dim;
        const int id01_rg = r_id + (g_id + 1) * lut_dim;
        const int id11_rg = r_id + 1 + (g_id + 1) * lut_dim;

        const int id00_rb = r_id + b_id * lut_dim;
        const int id10_rb = r_id + 1 + b_id * lut_dim;
        const int id01_rb = r_id + (b_id + 1) * lut_dim;
        const int id11_rb = r_id + 1 + (b_id + 1) * lut_dim;

        const int id00_gb = g_id + b_id * lut_dim;
        const int id10_gb = g_id + 1 + b_id * lut_dim;
        const int id01_gb = g_id + (b_id + 1) * lut_dim;
        const int id11_gb = g_id + 1 + (b_id + 1) * lut_dim;

        const scalar_t w00_rg = (1 - r_d) * (1 - g_d);
        const scalar_t w10_rg = (r_d) * (1 - g_d);
        const scalar_t w01_rg = (1 - r_d) * (g_d);
        const scalar_t w11_rg = (r_d) * (g_d);

        const scalar_t w00_rb = (1 - r_d) * (1 - b_d);
        const scalar_t w10_rb = (r_d) * (1 - b_d);
        const scalar_t w01_rb = (1 - r_d) * (b_d);
        const scalar_t w11_rb = (r_d) * (b_d);

        const scalar_t w00_gb = (1 - g_d) * (1 - b_d);
        const scalar_t w10_gb = (g_d) * (1 - b_d);
        const scalar_t w01_gb = (1 - g_d) * (b_d);
        const scalar_t w11_gb = (g_d) * (b_d);

        for (int i = 0; i < num_channels; ++i)
        {
            scalar_t output_rg = w00_rg * lut[id00_rg + lut_shift * 3 * i] + w10_rg * lut[id10_rg + lut_shift * 3 * i] +
                                 w01_rg * lut[id01_rg + lut_shift * 3 * i] + w11_rg * lut[id11_rg + lut_shift * 3 * i];

            scalar_t output_rb = w00_rb * lut[id00_rb + lut_shift * (3 * i + 1)] + w10_rb * lut[id10_rb + lut_shift * (3 * i + 1)] +
                                 w01_rb * lut[id01_rb + lut_shift * (3 * i + 1)] + w11_rb * lut[id11_rb + lut_shift * (3 * i + 1)];

            scalar_t output_gb = w00_gb * lut[id00_gb + lut_shift * (3 * i + 2)] + w10_gb * lut[id10_gb + lut_shift * (3 * i + 2)] +
                                 w01_gb * lut[id01_gb + lut_shift * (3 * i + 2)] + w11_gb * lut[id11_gb + lut_shift * (3 * i + 2)];

            output[index + width * height * i] = int_img[i] + lut_weights[3 * i] * output_rg + lut_weights[3 * i + 1] * output_rb + lut_weights[3 * i + 2] * output_gb + lut_bias[i];
        }
    }
}

由于它合并了slice和最后lut转换的部分,所以代码比较长,我们先看第一段关于空间信息的融合部分的输出,每一个grid通道的和(以r为例)都需要原图、xy、xr、yr、bias组合而成,这跟讲解中公式对应,对应代码为:

int_img[0] = int_img[0] + grid_weights[3 * (i + grid_per_ch * 0)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)]) +

grid_weights[3 * (i + grid_per_ch * 0) + 1] * (w00_xr * grid[id00_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
                                               w10_xr * grid[id10_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
                                               w01_xr * grid[id01_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
                                               w11_xr * grid[id11_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)]) +

grid_weights[3 * (i + grid_per_ch * 0) + 2] * (w00_yr * grid[id00_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
                                               w10_yr * grid[id10_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
                                               w01_yr * grid[id01_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
                                               w11_yr * grid[id11_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)]) +
grid_bias[(i + grid_per_ch * 0)];

后续完成2DLUT的颜色插值时,我们看输出,同理(以r为例),都需要原图(这里是上面的grid slice输出结果)、rg、rb、gb、bias这几个组合而成,同样的与公式对应,对应代码为:

scalar_t output_rg = w00_rg * lut[id00_rg + lut_shift * 3 * i] + w10_rg * lut[id10_rg + lut_shift * 3 * i] +
                     w01_rg * lut[id01_rg + lut_shift * 3 * i] + w11_rg * lut[id11_rg + lut_shift * 3 * i];

scalar_t output_rb = w00_rb * lut[id00_rb + lut_shift * (3 * i + 1)] + w10_rb * lut[id10_rb + lut_shift * (3 * i + 1)] +
                     w01_rb * lut[id01_rb + lut_shift * (3 * i + 1)] + w11_rb * lut[id11_rb + lut_shift * (3 * i + 1)];

scalar_t output_gb = w00_gb * lut[id00_gb + lut_shift * (3 * i + 2)] + w10_gb * lut[id10_gb + lut_shift * (3 * i + 2)] +
                     w01_gb * lut[id01_gb + lut_shift * (3 * i + 2)] + w11_gb * lut[id11_gb + lut_shift * (3 * i + 2)];

output[index + width * height * i] = int_img[i] + lut_weights[3 * i] * output_rg + lut_weights[3 * i + 1] * output_rb + lut_weights[3 * i + 2] * output_gb + lut_bias[i];

3、总结

代码实现核心的部分讲解完毕,SVDLUT利用3D转2D和SVD分解的思路进一步优化了SABLUT,比较好的兼顾性能、存储与推理效率,特别适合资源受限的边缘设备部署。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值