本文是对SVDLUT技术的代码讲解,原文解读请看SVDLUT文章讲解。
1、原文概要
SVDLUT最大的创新点在于其将3DLUT的查找过程使用SVD转换为了2DLUT的多次查找求和,减小了参数量和计算量。流程如图所示:

类比SABLUT方法上是相似的,都存在双边grid的插值和颜色的LUT插值操作,不过因为查找维度变小了,因此作者额外引入了
G
r
i
d
w
e
i
g
h
t
s
Grid \ weights
Grid weights以及
L
U
T
w
e
i
g
h
t
s
LUT \ weights
LUT weights用于对多个2DLUT的插值结果进行加权求和,从而完成与3DLUT相似的功能。
2、代码结构
代码整体结构如下:

主要关注模型部分的实现即可。
3 、核心代码模块
models.py 文件
整体如下所示:
class SVDLUT(nn.Module):
def __init__(self, backbone_type='cnn', backbone_coef=8,
lut_n_vertices=17, lut_n_ranks=24,
grid_n_vertices=17, grid_n_ranks=24, ch_per_grid=2,
lut_weight_ranks=8, grid_weight_ranks=8,
lut_n_singular=8, grid_n_singular=8):
super(SVDLUT, self).__init__()
self.backbone_type = backbone_type.lower()
if backbone_type.lower() == 'resnet':
self.backbone = resnet18_224()
print('Resnet backbone apply')
n_feats = 512
else:
self.backbone = Backbone(backbone_coef=backbone_coef)
print('CNN backbone apply')
n_feats = 32*backbone_coef
self.gen_2d_lut = Gen_2D_SVD_LUT(n_vertices=lut_n_vertices, n_feats=n_feats, n_ranks=lut_n_ranks, n_singlar=lut_n_singular)
self.gen_2d_lut_weight_bias = Gen_2D_LUT_weight_bias(n_vertices=lut_n_vertices, n_feats=n_feats, n_ranks=lut_weight_ranks)
self.gen_2d_bilateral = Gen_2D_bilateral_grids(n_vertices=grid_n_vertices, n_feats=n_feats, n_ranks=grid_n_ranks, ch_per_grid=ch_per_grid)
self.gen_2d_grid_weight_bias =Gen_2D_bilateral_grids_weight_bias(n_vertices=grid_n_vertices, n_feats=n_feats, n_ranks=grid_weight_ranks, ch_per_grid=ch_per_grid)
self.slicing_transform = bilinear_2Dslicing_lut_transform
self.relu = nn.ReLU()
def init_weights(self):
def special_initilization(m):
classname = m.__class__.__name__
if 'Conv' in classname:
nn.init.xavier_normal_(m.weight.data)
elif 'InstanceNorm' in classname:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
if self.backbone_type != 'resnet':
self.backbone.apply(special_initilization)
self.gen_2d_lut.init_weights()
self.gen_2d_lut_weight_bias.init_weights()
self.gen_2d_bilateral.init_weights()
self.gen_2d_grid_weight_bias.init_weights()
def forward(self, img):
img_feature = self.backbone(img)
g3d_lut, lut_weights = self.gen_2d_lut(img_feature)
lut_param_weights, lut_param_bias = self.gen_2d_lut_weight_bias(img_feature)
gbilateral, grid_weights = self.gen_2d_bilateral(img_feature)
grid_param_weights, grid_param_bias = self.gen_2d_grid_weight_bias(img_feature)
output = self.slicing_transform(gbilateral, img, grid_param_weights, grid_param_bias, g3d_lut, lut_param_weights, lut_param_bias)
output = self.relu(output)
return output, lut_weights, grid_weights, g3d_lut, gbilateral
整体跟我们在文章讲解中一样,先提取特征,然后生成2DLUT和grid以及它们对应的weights和bias用于加权,经过一个统一的slice插值模块得到最终结果,整个过程包含以下几类。
1. Backbone类
这里用于提取图像特征,没有什么特别的。
class Backbone(nn.Module):
def __init__(self, backbone_coef=8):
super(Backbone, self).__init__()
self.backbone_coef = backbone_coef
self.model = nn.Sequential(
nn.Upsample(size=(256,256),mode='bilinear'),
nn.Conv2d(3, backbone_coef, 3, stride=2, padding=1), #8 x 128 x 128
nn.LeakyReLU(0.2),
nn.InstanceNorm2d(backbone_coef, affine=True),
*discriminator_block(backbone_coef, 2*backbone_coef, normalization=True), #16 x 64 x 64
*discriminator_block(2*backbone_coef, 4*backbone_coef, normalization=True), #32 x 32 x 32
*discriminator_block(4*backbone_coef, 8*backbone_coef, normalization=True), #64 x 16 x 16
*discriminator_block(8*backbone_coef, 8*backbone_coef), #64 x 8 x 8
#*discriminator_block(128, 128, normalization=True),
nn.Dropout(p=0.5),
nn.AvgPool2d(5, stride=2) #64 x 2 x 2
)
def forward(self, img_input):
return self.model(img_input).view([-1,self.backbone_coef*32])
2. Gen_2D_SVD_LUT和Gen_2D_LUT_weight_bias类
用于生成3DLUT。
class Gen_2D_SVD_LUT(nn.Module):
def __init__(self, n_colors=3, ch_per_lut = 3, n_lut_dim=2, n_vertices=17, n_feats=256, n_ranks=24, n_singlar=8):
super(Gen_2D_SVD_LUT, self).__init__()
# h0
self.weights_generator = nn.Linear(n_feats, n_ranks)
self.n_svd = n_vertices * n_singlar + n_singlar + n_singlar * n_vertices
# h1
self.basis_luts_bank = nn.Linear(
n_ranks, n_colors * ch_per_lut * self.n_svd)
self.n_colors = n_colors
self.n_vertices = n_vertices
self.n_feats = n_feats
self.n_ranks = n_ranks
self.ch_per_lut = ch_per_lut
self.n_singlar = n_singlar
def init_weights(self):
r"""Init weights for models.
For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
[3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).
"""
nn.init.ones_(self.weights_generator.bias)
nn.init.zeros_(self.basis_luts_bank.bias)
cols, rows = torch.stack(torch.meshgrid(*[torch.arange(self.n_vertices) for _ in range(2)]),dim=0).div(self.n_vertices - 1).flip(0)
zero2d = torch.zeros(self.n_vertices, self.n_vertices)
d = torch.stack([cols,cols,zero2d,
rows,zero2d,cols,
zero2d,rows,rows], dim=0)
u,s,v =torch.svd(d)
u = u[:,:,:self.n_singlar].contiguous().view([3*self.ch_per_lut,-1])
s = s[:,:self.n_singlar]
v = v[:,:,:self.n_singlar].mT.contiguous().view([3*self.ch_per_lut,-1])
d= torch.cat([u,s,v], dim=1)
identity_lut = torch.stack([d,*[torch.zeros(3 * self.ch_per_lut, self.n_svd) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)
self.basis_luts_bank.weight.data.copy_(identity_lut.t())
def forward(self, img_feature):
weights = self.weights_generator(img_feature)
lut_svd = self.basis_luts_bank(weights)
lut_svd = lut_svd.view([-1, self.n_svd])
lut_u = lut_svd[:,:self.n_vertices * self.n_singlar]
lut_s = lut_svd[:,self.n_vertices * self.n_singlar:self.n_vertices * self.n_singlar + self.n_singlar]
lut_v = lut_svd[:,self.n_vertices * self.n_singlar + self.n_singlar:]
lut_u = lut_u.view([-1, self.n_vertices, self.n_singlar])
lut_s = torch.diag_embed(lut_s)
lut_v = lut_v.view([-1, self.n_singlar, self.n_vertices])
luts = torch.bmm(torch.bmm(lut_u,lut_s), lut_v)
luts = luts.view([-1,self.n_colors, self.ch_per_lut, self.n_vertices,self.n_vertices])
return luts, weights
利用svd还原多个2DLUT,每个2DLUT包含了n_vertices个顶点,缺少的一个维度转换为ch_per_lut,否则无法得到多个2DLUT的组合。
class Gen_2D_LUT_weight_bias(nn.Module):
def __init__(self, n_colors=3, ch_per_lut = 3, n_vertices=17, n_feats=256, n_ranks=24):
super(Gen_2D_LUT_weight_bias, self).__init__()
# h0
self.weights_generator = nn.Linear(n_feats, n_ranks)
# h1
self.basis_luts_bank = nn.Linear(
n_ranks, n_colors * (ch_per_lut + 1))
self.n_colors = n_colors
self.n_vertices = n_vertices
self.n_feats = n_feats
self.n_ranks = n_ranks
self.ch_per_lut = ch_per_lut
def init_weights(self):
nn.init.ones_(self.weights_generator.bias)
nn.init.zeros_(self.basis_luts_bank.bias)
d = torch.tensor([[0.5,0.5,0,0],
[0.5,0,0.5,0],
[0,0.5,0.5,0]])
identity_lut = torch.stack([d,
*[torch.zeros(self.n_colors, self.ch_per_lut + 1) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)
self.basis_luts_bank.weight.data.copy_(identity_lut.t())
def forward(self, img_feature):
weights = self.weights_generator(img_feature)
weights_bias = self.basis_luts_bank(weights)
weights_bias = weights_bias.view([-1,self.n_colors, self.ch_per_lut + 1])
lut_param_weights = weights_bias[:, :, :self.n_colors]
lut_param_bias = weights_bias[:, :, self.n_colors:]
return lut_param_weights, lut_param_bias
生成加权系数,跟LUT数目有关,参数量非常小。
3. Gen_2D_bilateral_grids和Gen_2D_bilateral_grids_weight_bias类
用于生成双边网格。
class Gen_2D_bilateral_grids(nn.Module):
def __init__(self, n_grid_dim=2, n_vertices=17, n_feats=256, n_ranks=24, ch_per_grid=2):
super(Gen_2D_bilateral_grids, self).__init__()
# h0
self.weights_generator = nn.Linear(n_feats, n_ranks)
# h1
self.basis_grids_bank = nn.Linear(
n_ranks, ch_per_grid * 3 * 3 * (n_vertices ** n_grid_dim))
self.n_grid_dim = n_grid_dim
self.n_vertices = n_vertices
self.n_feats = n_feats
self.n_ranks = n_ranks
self.ch_per_grid = ch_per_grid
self.n_grids = ch_per_grid * 3
def init_weights(self):
r"""Init weights for models.
For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
[3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).
"""
nn.init.ones_(self.weights_generator.bias)
nn.init.zeros_(self.basis_grids_bank.bias)
cols, rows = torch.stack(torch.meshgrid(*[torch.arange(self.n_vertices) for _ in range(2)]),dim=0).div(self.n_vertices - 1).flip(0)
zero2d = torch.zeros(self.n_vertices, self.n_vertices)
d = torch.stack([*[zero2d,rows,rows,
zero2d,rows,rows,
zero2d,rows,rows] * self.ch_per_grid], dim=0)
identity_grid = torch.stack([d,*[torch.zeros(self.n_grids * 3,self.n_vertices, self.n_vertices) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)
self.basis_grids_bank.weight.data.copy_(identity_grid.t())
def forward(self, img_feature):
weights = self.weights_generator(img_feature)
grids = self.basis_grids_bank(weights)
grids = grids.view([-1,self.n_grids,3,self.n_vertices,self.n_vertices])
return grids, weights
class Gen_2D_bilateral_grids_weight_bias(nn.Module):
def __init__(self, n_colors=3, ch_per_grid=2, n_vertices=17, n_feats=256, n_ranks=24):
super(Gen_2D_bilateral_grids_weight_bias, self).__init__()
# h0
self.weights_generator = nn.Linear(n_feats, n_ranks)
# h1
self.basis_luts_bank = nn.Linear(
n_ranks, ch_per_grid * (3 * n_colors + n_colors))
self.n_colors = n_colors
self.n_vertices = n_vertices
self.n_feats = n_feats
self.n_ranks = n_ranks
self.ch_per_grid = ch_per_grid
def init_weights(self):
r"""Init weights for models.
For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
[3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).
"""
nn.init.ones_(self.weights_generator.bias)
nn.init.zeros_(self.basis_luts_bank.bias)
d = torch.tensor([*[[0,1,1,0],
[0,1,1,0],
[0,1,1,0]] * self.ch_per_grid]).div(self.ch_per_grid * 2)
identity_lut = torch.stack([d,
*[torch.zeros(3*self.ch_per_grid, self.n_colors + 1) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)
self.basis_luts_bank.weight.data.copy_(identity_lut.t())
def forward(self, img_feature):
weights = self.weights_generator(img_feature)
weights_bias = self.basis_luts_bank(weights)
weights_bias = weights_bias.view([-1,self.ch_per_grid, 3 *self.n_colors + self.n_colors])
grid_param_weights = weights_bias[:, :, : 3 * self.n_colors]
grid_param_bias = weights_bias[:, :, 3 * self.n_colors:]
return grid_param_weights, grid_param_bias
grid的生成不适用于svd,因为性能会下降,因此作者这里是直接用linear给出的。
4. bilinear_2Dslicing_lut_transform函数
在kernel_code/bilateral_slicing/src/trilinear2D_slice_LUTTransform_cpu.cpp中可以看到源码。
void TriLinearCPU2DSliceAndLUTTransformForward(const int nthreads,
const scalar_t *grid,
const scalar_t *image,
const scalar_t *grid_weights,
const scalar_t *grid_bias,
const scalar_t *lut,
const scalar_t *lut_weights,
const scalar_t *lut_bias,
scalar_t *output,
const int grid_dim,
const int grid_shift,
const scalar_t grid_binsize,
const int lut_dim,
const int lut_shift,
const scalar_t lut_binsize,
const int width,
const int height,
const int num_channels,
const int grid_per_ch)
{
for (int index = 0; index < nthreads; index++)
{
const int x_ = index % width;
const int y_ = index / width;
const scalar_t x = x_ / (width - 1);
const scalar_t y = y_ / (height - 1);
const scalar_t r = image[index];
const scalar_t g = image[index + width * height];
const scalar_t b = image[index + width * height * 2];
const int32_t x_id = clamp((int32_t)floor(x * (grid_dim - 1)), 0, grid_dim - 2);
const int32_t y_id = clamp((int32_t)floor(y * (grid_dim - 1)), 0, grid_dim - 2);
int32_t r_id = clamp((int32_t)floor(r * (grid_dim - 1)), 0, grid_dim - 2);
int32_t g_id = clamp((int32_t)floor(g * (grid_dim - 1)), 0, grid_dim - 2);
int32_t b_id = clamp((int32_t)floor(b * (grid_dim - 1)), 0, grid_dim - 2);
const scalar_t x_d = (x - grid_binsize * x_id) / grid_binsize;
const scalar_t y_d = (y - grid_binsize * y_id) / grid_binsize;
scalar_t r_d = (r - grid_binsize * r_id) / grid_binsize;
scalar_t g_d = (g - grid_binsize * g_id) / grid_binsize;
scalar_t b_d = (b - grid_binsize * b_id) / grid_binsize;
const int id00_xy = (x_id) + (y_id)*grid_dim;
const int id10_xy = (x_id + 1) + (y_id)*grid_dim;
const int id01_xy = (x_id) + (y_id + 1) * grid_dim;
const int id11_xy = (x_id + 1) + (y_id + 1) * grid_dim;
const int id00_xr = (x_id) + (r_id)*grid_dim;
const int id10_xr = (x_id + 1) + (r_id)*grid_dim;
const int id01_xr = (x_id) + (r_id + 1) * grid_dim;
const int id11_xr = (x_id + 1) + (r_id + 1) * grid_dim;
const int id00_yr = (y_id) + (r_id)*grid_dim;
const int id10_yr = (y_id + 1) + (r_id)*grid_dim;
const int id01_yr = (y_id) + (r_id + 1) * grid_dim;
const int id11_yr = (y_id + 1) + (r_id + 1) * grid_dim;
const int id00_xg = (x_id) + (g_id)*grid_dim;
const int id10_xg = (x_id + 1) + (g_id)*grid_dim;
const int id01_xg = (x_id) + (g_id + 1) * grid_dim;
const int id11_xg = (x_id + 1) + (g_id + 1) * grid_dim;
const int id00_yg = (y_id) + (g_id)*grid_dim;
const int id10_yg = (y_id + 1) + (g_id)*grid_dim;
const int id01_yg = (y_id) + (g_id + 1) * grid_dim;
const int id11_yg = (y_id + 1) + (g_id + 1) * grid_dim;
const int id00_xb = (x_id) + (b_id)*grid_dim;
const int id10_xb = (x_id + 1) + (b_id)*grid_dim;
const int id01_xb = (x_id) + (b_id + 1) * grid_dim;
const int id11_xb = (x_id + 1) + (b_id + 1) * grid_dim;
const int id00_yb = (y_id) + (b_id)*grid_dim;
const int id10_yb = (y_id + 1) + (b_id)*grid_dim;
const int id01_yb = (y_id) + (b_id + 1) * grid_dim;
const int id11_yb = (y_id + 1) + (b_id + 1) * grid_dim;
const scalar_t w00_xy = (1 - x_d) * (1 - y_d);
const scalar_t w10_xy = (x_d) * (1 - y_d);
const scalar_t w01_xy = (1 - x_d) * (y_d);
const scalar_t w11_xy = (x_d) * (y_d);
const scalar_t w00_xr = (1 - x_d) * (1 - r_d);
const scalar_t w10_xr = (x_d) * (1 - r_d);
const scalar_t w01_xr = (1 - x_d) * (r_d);
const scalar_t w11_xr = (x_d) * (r_d);
const scalar_t w00_yr = (1 - y_d) * (1 - r_d);
const scalar_t w10_yr = (y_d) * (1 - r_d);
const scalar_t w01_yr = (1 - y_d) * (r_d);
const scalar_t w11_yr = (y_d) * (r_d);
const scalar_t w00_xg = (1 - x_d) * (1 - g_d);
const scalar_t w10_xg = (x_d) * (1 - g_d);
const scalar_t w01_xg = (1 - x_d) * (g_d);
const scalar_t w11_xg = (x_d) * (g_d);
const scalar_t w00_yg = (1 - y_d) * (1 - g_d);
const scalar_t w10_yg = (y_d) * (1 - g_d);
const scalar_t w01_yg = (1 - y_d) * (g_d);
const scalar_t w11_yg = (y_d) * (g_d);
const scalar_t w00_xb = (1 - x_d) * (1 - b_d);
const scalar_t w10_xb = (x_d) * (1 - b_d);
const scalar_t w01_xb = (1 - x_d) * (b_d);
const scalar_t w11_xb = (x_d) * (b_d);
const scalar_t w00_yb = (1 - y_d) * (1 - b_d);
const scalar_t w10_yb = (y_d) * (1 - b_d);
const scalar_t w01_yb = (1 - y_d) * (b_d);
const scalar_t w11_yb = (y_d) * (b_d);
scalar_t int_img[3] = {
0,
};
for (int i = 0; i < grid_per_ch; ++i)
{
int_img[0] = int_img[0] + grid_weights[3 * (i + grid_per_ch * 0)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)]) +
grid_weights[3 * (i + grid_per_ch * 0) + 1] * (w00_xr * grid[id00_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
w10_xr * grid[id10_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
w01_xr * grid[id01_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
w11_xr * grid[id11_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)]) +
grid_weights[3 * (i + grid_per_ch * 0) + 2] * (w00_yr * grid[id00_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
w10_yr * grid[id10_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
w01_yr * grid[id01_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
w11_yr * grid[id11_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)]) +
grid_bias[(i + grid_per_ch * 0)];
int_img[1] = int_img[1] + grid_weights[3 * (i + grid_per_ch * 1)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)]) +
grid_weights[3 * (i + grid_per_ch * 1) + 1] * (w00_xg * grid[id00_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +
w10_xg * grid[id10_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +
w01_xg * grid[id01_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +
w11_xg * grid[id11_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)]) +
grid_weights[3 * (i + grid_per_ch * 1) + 2] * (w00_yg * grid[id00_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +
w10_yg * grid[id10_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +
w01_yg * grid[id01_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +
w11_yg * grid[id11_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)]) +
grid_bias[(i + grid_per_ch * 1)];
int_img[2] = int_img[2] + grid_weights[3 * (i + grid_per_ch * 2)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)]) +
grid_weights[3 * (i + grid_per_ch * 2) + 1] * (w00_xb * grid[id00_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +
w10_xb * grid[id10_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +
w01_xb * grid[id01_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +
w11_xb * grid[id11_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)]) +
grid_weights[3 * (i + grid_per_ch * 2) + 2] * (w00_yb * grid[id00_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +
w10_yb * grid[id10_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +
w01_yb * grid[id01_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +
w11_yb * grid[id11_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)]) +
grid_bias[(i + grid_per_ch * 2)];
}
r_id = clamp((int32_t)floor(r * (lut_dim - 1)), 0, lut_dim - 2);
g_id = clamp((int32_t)floor(g * (lut_dim - 1)), 0, lut_dim - 2);
b_id = clamp((int32_t)floor(b * (lut_dim - 1)), 0, lut_dim - 2);
r_d = (r - lut_binsize * r_id) / lut_binsize;
g_d = (g - lut_binsize * g_id) / lut_binsize;
b_d = (b - lut_binsize * b_id) / lut_binsize;
const int id00_rg = r_id + g_id * lut_dim;
const int id10_rg = r_id + 1 + g_id * lut_dim;
const int id01_rg = r_id + (g_id + 1) * lut_dim;
const int id11_rg = r_id + 1 + (g_id + 1) * lut_dim;
const int id00_rb = r_id + b_id * lut_dim;
const int id10_rb = r_id + 1 + b_id * lut_dim;
const int id01_rb = r_id + (b_id + 1) * lut_dim;
const int id11_rb = r_id + 1 + (b_id + 1) * lut_dim;
const int id00_gb = g_id + b_id * lut_dim;
const int id10_gb = g_id + 1 + b_id * lut_dim;
const int id01_gb = g_id + (b_id + 1) * lut_dim;
const int id11_gb = g_id + 1 + (b_id + 1) * lut_dim;
const scalar_t w00_rg = (1 - r_d) * (1 - g_d);
const scalar_t w10_rg = (r_d) * (1 - g_d);
const scalar_t w01_rg = (1 - r_d) * (g_d);
const scalar_t w11_rg = (r_d) * (g_d);
const scalar_t w00_rb = (1 - r_d) * (1 - b_d);
const scalar_t w10_rb = (r_d) * (1 - b_d);
const scalar_t w01_rb = (1 - r_d) * (b_d);
const scalar_t w11_rb = (r_d) * (b_d);
const scalar_t w00_gb = (1 - g_d) * (1 - b_d);
const scalar_t w10_gb = (g_d) * (1 - b_d);
const scalar_t w01_gb = (1 - g_d) * (b_d);
const scalar_t w11_gb = (g_d) * (b_d);
for (int i = 0; i < num_channels; ++i)
{
scalar_t output_rg = w00_rg * lut[id00_rg + lut_shift * 3 * i] + w10_rg * lut[id10_rg + lut_shift * 3 * i] +
w01_rg * lut[id01_rg + lut_shift * 3 * i] + w11_rg * lut[id11_rg + lut_shift * 3 * i];
scalar_t output_rb = w00_rb * lut[id00_rb + lut_shift * (3 * i + 1)] + w10_rb * lut[id10_rb + lut_shift * (3 * i + 1)] +
w01_rb * lut[id01_rb + lut_shift * (3 * i + 1)] + w11_rb * lut[id11_rb + lut_shift * (3 * i + 1)];
scalar_t output_gb = w00_gb * lut[id00_gb + lut_shift * (3 * i + 2)] + w10_gb * lut[id10_gb + lut_shift * (3 * i + 2)] +
w01_gb * lut[id01_gb + lut_shift * (3 * i + 2)] + w11_gb * lut[id11_gb + lut_shift * (3 * i + 2)];
output[index + width * height * i] = int_img[i] + lut_weights[3 * i] * output_rg + lut_weights[3 * i + 1] * output_rb + lut_weights[3 * i + 2] * output_gb + lut_bias[i];
}
}
}
由于它合并了slice和最后lut转换的部分,所以代码比较长,我们先看第一段关于空间信息的融合部分的输出,每一个grid通道的和(以r为例)都需要原图、xy、xr、yr、bias组合而成,这跟讲解中公式对应,对应代码为:
int_img[0] = int_img[0] + grid_weights[3 * (i + grid_per_ch * 0)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)]) +
grid_weights[3 * (i + grid_per_ch * 0) + 1] * (w00_xr * grid[id00_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
w10_xr * grid[id10_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
w01_xr * grid[id01_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +
w11_xr * grid[id11_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)]) +
grid_weights[3 * (i + grid_per_ch * 0) + 2] * (w00_yr * grid[id00_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
w10_yr * grid[id10_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
w01_yr * grid[id01_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +
w11_yr * grid[id11_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)]) +
grid_bias[(i + grid_per_ch * 0)];
后续完成2DLUT的颜色插值时,我们看输出,同理(以r为例),都需要原图(这里是上面的grid slice输出结果)、rg、rb、gb、bias这几个组合而成,同样的与公式对应,对应代码为:
scalar_t output_rg = w00_rg * lut[id00_rg + lut_shift * 3 * i] + w10_rg * lut[id10_rg + lut_shift * 3 * i] +
w01_rg * lut[id01_rg + lut_shift * 3 * i] + w11_rg * lut[id11_rg + lut_shift * 3 * i];
scalar_t output_rb = w00_rb * lut[id00_rb + lut_shift * (3 * i + 1)] + w10_rb * lut[id10_rb + lut_shift * (3 * i + 1)] +
w01_rb * lut[id01_rb + lut_shift * (3 * i + 1)] + w11_rb * lut[id11_rb + lut_shift * (3 * i + 1)];
scalar_t output_gb = w00_gb * lut[id00_gb + lut_shift * (3 * i + 2)] + w10_gb * lut[id10_gb + lut_shift * (3 * i + 2)] +
w01_gb * lut[id01_gb + lut_shift * (3 * i + 2)] + w11_gb * lut[id11_gb + lut_shift * (3 * i + 2)];
output[index + width * height * i] = int_img[i] + lut_weights[3 * i] * output_rg + lut_weights[3 * i + 1] * output_rb + lut_weights[3 * i + 2] * output_gb + lut_bias[i];
3、总结
代码实现核心的部分讲解完毕,SVDLUT利用3D转2D和SVD分解的思路进一步优化了SABLUT,比较好的兼顾性能、存储与推理效率,特别适合资源受限的边缘设备部署。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。
1242

被折叠的 条评论
为什么被折叠?



