目录
本文是对SPFLUT技术的代码解读,原文解读请看SPFLUT。
原文概要
SPFLUT方法重点在于对角线优先压缩策略,该方法总体流程分为4个部分,训练、转换(里面包含了压缩)、微调、测试。其代码的总体结构如下:
可以看到流程与MULUT基本一致,只不过在第二步转换之前还有一步对LUT进行压缩的过程,即2_compress_lut_from_net.py文件。另外第三步的微调中也有针对压缩后的LUT进行微调的代码。
1. 训练
这里我们可以从sr/model.py中,获取到SPF_LUT_net模型的代码实现如下:
class SPF_LUT_net(nn.Module):
def __init__(self, nf=32, scale=4, modes=['s', 'd', 'y'], stages=2):
super(SPF_LUT_net, self).__init__()
self.upscale = scale
self.modes = modes
self.convblock1 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
self.convblock2 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
self.convblock3 = ConvBlock(1, 2, scale=None, output_quant=False, modes=modes, nf=nf)
self.convblock4 = ConvBlock(1, 1, scale=None, output_quant=False, modes=modes, nf=nf)
self.ChannelConv = MuLUTcUnit(in_c=4, out_c=4, mode='1x1', nf=nf)
self.upblock = ConvBlock(4, 1, scale=scale, output_quant=False, modes=modes, nf=nf)
def forward(self, x, phase='train'):
B, C, H, W = x.size()
x = x.reshape((B * C, 1, H, W))
refine_list = []
# block1
x = self.convblock1(x)
avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
refine_list.append(x[:, 0:1, :, :])
x = x[:, 1:, :, :]
# block2
x = self.convblock2(x)
avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
refine_list.append(x[:, 0:1, :, :])
x = x[:, 1:, :, :]
# block3
x = self.convblock3(x)
avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
refine_list.append(x[:, 0:1, :, :])
x = x[:, 1:, :, :]
# block4
x = self.convblock4(x)
avg_factor, bias, norm = len(self.modes) * 4, 127, 255.0
x = round_func(torch.clamp((x / avg_factor) + bias, 0, 255)) / norm
refine_list.append(x)
x = torch.cat(refine_list, dim=1)
x = round_func(torch.tanh(self.ChannelConv(x)) * 127.0)
x = round_func(torch.clamp(x + 127, 0, 255)) / 255.0
x = self.upblock(x)
avg_factor, bias, norm = len(self.modes), 0, 1
x = round_func((x / avg_factor) + bias)
if phase == 'train':
x = x / 255.0
x = x.reshape((B, C, self.upscale * H, self.upscale * W))
return x
通过上述代码可以看出,SPFLUT模型主要由两个子模块构成,一个是ConvBlock,另一个是MuLUTcUnit,其中ConvBlock的实现如下:
class ConvBlock(nn.Module):
def __init__(self, in_c, out_c, scale=None, output_quant=False, modes=['s', 'd', 'y'], nf=64):
super(ConvBlock, self).__init__()
self.in_c = in_c
self.out_c = out_c
self.modes = modes
self.module_dict = dict()
self.upscale = scale
self.output_quant = output_quant
scale_factor = 1 if scale is None else scale ** 2
for c in range(in_c):
for mode in modes:
self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)] = MuLUTConv('{}x{}'.format(mode.upper(), 'N'),
nf=nf, out_c=out_c * scale_factor,
stride=1)
self.module_dict = nn.ModuleDict(self.module_dict)
if scale is None:
self.pixel_shuffle = identity
else:
self.pixel_shuffle = nn.PixelShuffle(scale)
def forward(self, x):
modes = self.modes
x_out = 0
for c in range(self.in_c):
x_c = x[:, c:c + 1, :, :]
pred = 0
for mode in modes:
pad = mode_pad_dict[mode]
sub_module = self.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]
for r in [0, 1, 2, 3]:
pred += round_func(torch.tanh(torch.rot90(self.pixel_shuffle(
sub_module(F.pad(torch.rot90(x_c, r, [2, 3]), (0, pad, 0, pad), mode='replicate'))),
(4 - r) % 4, [2, 3])) * 127)
x_out += pred
if self.output_quant:
avg_factor = len(modes) * 4 * self.in_c
x = round_func(torch.clamp(x_out / avg_factor, -1, 1) * 127) / 127
else:
x = x_out / self.in_c
return x
也是由MuLUTConv构成的,位于common/network.py中,而这个模块我们在MuLUT论文代码讲解中有提到,是一个由3种不同类型S、D、Y的kernel组成的一个RF=3x3的模块,这里还需要旋转和clamp等操作,防止每层的结果溢出。
而MuLUTcUnit即是通道上的MuLUT模块,位于common/network.py中,因为只在通道上操作,因此kernel_size上是1,主要建立起特征通道之间的关联。
整体结构是比较清晰的,尤其是对MuLUT的子模块熟悉的情况下,同样的,不清楚的读者可以初始化一个模型来逐步推理tensor的shape来熟悉。
2. 压缩并转表
这部分代码位于2_compress_lut_from_net.py中,整体流程如下:
def compress_SPFLUT(opt):
def save_SPFLUT_DFC(x, lut_path, module):
# Split input to not over GPU memory
B = x.size(0) // 100
outputs = []
# Extract input-output pairs
with torch.no_grad():
model_G.eval()
for b in range(100):
if b == 99:
batch_input = x[b * B:]
else:
batch_input = x[b * B:(b + 1) * B]
batch_output = module(batch_input)
results = torch.round(torch.tanh(batch_output) * 127).cpu().data.numpy().astype(np.int8)
outputs += [results]
results = np.concatenate(outputs, 0)
results = results.reshape(x.size(0), -1)
np.save(lut_path, results)
print("Resulting LUT size: ", results.shape, "Saved to", lut_path)
modes = [i for i in opt.modes]
stages = opt.stages
model = getattr(Model, 'SPF_LUT_net')
model_G = model(nf=opt.nf, scale=opt.scale, modes=modes, stages=stages).cuda()
lm = torch.load(os.path.join(opt.expDir, 'Model_{:06d}.pth'.format(opt.loadIter)))
model_G.load_state_dict(lm, strict=True)
input_tensor = get_input_tensor(opt)
for mode in modes:
if opt.cd == 'xyzt':
input_tensor_c1 = compress_lut_xyzt(opt, input_tensor)
elif opt.cd == 'xyz':
input_tensor_c1 = compress_lut_xyz(opt, input_tensor)
elif opt.cd == 'xy':
input_tensor_c1 = compress_lut(opt, input_tensor)
else:
raise ValueError
input_tensor_c2 = compress_lut_larger_interval(opt, input_tensor)
if mode != 's':
input_tensor_c1 = get_mode_input_tensor(input_tensor_c1, mode)
input_tensor_c2 = get_mode_input_tensor(input_tensor_c2, mode)
# conv1
module = model_G.convblock1.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 1, mode))
save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 1, mode))
save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
# conv2
module = model_G.convblock2.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 2, mode))
save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 2, mode))
save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
# conv3
module = model_G.convblock3.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 3, mode))
save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 3, mode))
save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
# conv4
module = model_G.convblock4.module_dict['DepthwiseBlock{}_{}'.format(0, mode)]
lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress1.npy'.format(opt.lutName, 4, mode))
save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
lut_path = os.path.join(opt.expDir, '{}_s{}c0_{}_compress2.npy'.format(opt.lutName, 4, mode))
save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
# conv6
for c in range(4):
module = model_G.upblock.module_dict['DepthwiseBlock{}_{}'.format(c, mode)]
lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress1.npy'.format(opt.lutName, 6,c, mode))
save_SPFLUT_DFC(input_tensor_c1, lut_path, module)
lut_path = os.path.join(opt.expDir, '{}_s{}c{}_{}_compress2.npy'.format(opt.lutName, 6,c, mode))
save_SPFLUT_DFC(input_tensor_c2, lut_path, module)
# conv5
input_tensor = input_tensor.reshape((-1,4,1,1))
module = model_G.ChannelConv
lut_path = os.path.join(opt.expDir, '{}_s{}_channel.npy'.format(opt.lutName, 5))
save_SPFLUT_DFC(input_tensor, lut_path, module)
这里需要关注的细节是对角线压缩相关的3个函数:compress_lut_xyzt、compress_lut_xyz、compress_lut,对应于4维、3维和2维的压缩过程,以及非对角线压缩相关的函数compress_lut_larger_interval,最后我们可以发现对于通道的卷积conv5,作者是没有进行压缩的,因为通道conv不满足对角线先验,故不能进行对角线优先的压缩。
针对于对角线相关的函数:以2维压缩为例,跟我们之前的讲解是一样的。
def compress_lut(opt, input_tensor):
base = torch.arange(0, 257, 2 ** opt.interval) # 0-256
base[-1] -= 1
L = base.size(0)
d = opt.dw
diag = 2 * d + 1
N = diag * L + (1 - diag ** 2) // 4
input_tensor = input_tensor.reshape(L * L, L, L, 1, 2, 2)
index_i = torch.zeros((N,)).type(torch.int64)
index_j = torch.zeros((N,)).type(torch.int64)
cnt = 0
ref2index = np.zeros((L, diag), dtype=np.int_) - 1
for i in range(L):
for j in range(L):
if abs(i - j) <= d:
index_i[cnt] = i
index_j[cnt] = j
ref2index[i, j - i] = cnt
cnt += 1
np.save(os.path.join(opt.expDir, 'ref2index_{}{}i{}.npy'.format(opt.cd, opt.dw, opt.si)),ref2index)
index_compress = index_i * L + index_j
compressed_input_tensor = input_tensor[index_compress, ...].reshape(-1, 1, 2, 2)
return compressed_input_tensor
作者是通过改变input_tensor来实现这个过程,我们需要取到2维tensor,满足对角线距离条件的所有位置,那这里opt.dw(变量d)对应于我们前面讲解中提到的,满足条件的将其放入ref2index中,并使得cnt加1,这样我们可以将对角线的位置进行保存。
至于L,是我们前面一直在用的与间隔interval相关的个数,一般等于17(4bit采样)。而N是我们前面推理算过的索引的总个数(大家可以带入diag来计算N,这样可以跟公式完全对应),至此2维的一个输入tensor就全部对应完毕,送入模型计算就可以了,这样子把对角线的位置进行了优先保存。
针对于非对角线的位置:看compress_lut_larger_interval函数,实现如下。
def compress_lut_larger_interval(opt, input_tensor):
base = torch.arange(0, 257, 2 ** opt.interval) # 0-256
base[-1] -= 1
L = base.size(0)
input_tensor = input_tensor.reshape(L, L, L, L, 1, 2, 2)
if opt.si==5:
k = 2
elif opt.si==6:
k = 4
elif opt.si==7:
k = 8
else:
raise ValueError
compressed_input_tensor = input_tensor[::k, ::k, ::k, ::k, ...].reshape(-1, 1, 2, 2)
return compressed_input_tensor
比较简单,即选用一个更大的比例,因为我们前面已经使用了4bit来做间隔,那么当opt.si为5时,我们需要对当前的input_tensor做2间隔的采样就可以,之后都是同理可得。
针对于通道:那我们已经讲到了通道是不可以进行压缩的,因此它的input_tensor是不变的,跟之前一样,实现如下,这个过程我们是比较熟悉的,(如果一直有看LUT系列的文章。还不了解的可以关注一下LUT专题哦):
def get_input_tensor(opt):
# 1D input
base = torch.arange(0, 257, 2 ** opt.interval) # 0-256
base[-1] -= 1
L = base.size(0)
# 2D input
# 256*256 0 0 0... |1 1 1... |...|255 255 255...
first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1)
# 256*256 0 1 2 .. 255|0 1 2 ... 255|...|0 1 2 ... 255
second = base.cuda().repeat(L)
onebytwo = torch.stack([first, second], 1) # [256*256, 2]
# 3D input
# 256*256*256 0 x65536|1 x65536|...|255 x65536
third = base.cuda().unsqueeze(1).repeat(1, L * L).reshape(-1)
onebytwo = onebytwo.repeat(L, 1)
onebythree = torch.cat(
[third.unsqueeze(1), onebytwo], 1) # [256*256*256, 3]
# 4D input
fourth = base.cuda().unsqueeze(1).repeat(1, L * L * L).reshape(
-1) # 256*256*256*256 0 x16777216|1 x16777216|...|255 x16777216
onebythree = onebythree.repeat(L, 1)
# [256*256*256*256, 4]
onebyfourth = torch.cat([fourth.unsqueeze(1), onebythree], 1)
# Rearange input: [N, 4] -> [N, C=1, H=2, W=2]
input_tensor = onebyfourth.unsqueeze(1).unsqueeze(
1).reshape(-1, 1, 2, 2).float() / 255.0
return input_tensor
3. 微调
微调的部分其实跟MuLUT对比,无明显变化,主要还是看作者如何构建SPF_LUT模型,位置在sr/model.py中,代码如下:
class SPF_LUT(nn.Module):
""" PyTorch version of MuLUT for LUT-aware fine-tuning. """
def __init__(self, lut_folder, stages, modes, lutName, upscale, interval, phase=None, **kwargs):
super(SPF_LUT, self).__init__()
self.interval = interval
self.upscale = upscale
self.modes = modes
self.stages = stages
L = 2 ** (8 - interval) + 1
for mode in modes:
# conv1
lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 1, mode))
# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(1, mode))
key = "s{}c0_{}".format(1, mode)
lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
# conv2
lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 2, mode))
# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(2, mode))
key = "s{}c0_{}".format(2, mode)
lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
# conv3
lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 3, mode))
# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(3, mode))
key = "s{}c0_{}".format(3, mode)
lut_arr = np.load(lut_path).reshape((-1, 2)).astype(np.float32) / 127.0
self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
# conv4
lut_path = os.path.join(lut_folder, '{}_s{}c0_{}.npy'.format(lutName, 4, mode))
# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c0_{}.npy'.format(4, mode))
key = "s{}c0_{}".format(4, mode)
lut_arr = np.load(lut_path).reshape((-1, 1)).astype(np.float32) / 127.0
self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
for c in range(4):
# conv6
lut_path = os.path.join(lut_folder, '{}_s{}c{}_{}.npy'.format(lutName, 6,c, mode))
# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}c{}_{}.npy'.format(6,c, mode))
key = "s{}c{}_{}".format(6,c, mode)
lut_arr = np.load(lut_path).reshape((-1, self.upscale * self.upscale)).astype(np.float32) / 127.0
self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
# conv5
lut_path = os.path.join(lut_folder, '{}_s{}_channel.npy'.format(lutName, 5))
# lut_path = os.path.join(lut_folder, 'LUT_x4_4bit_int8_s{}_channel.npy'.format(5))
key = "s{}_channel".format(5)
lut_arr = np.load(lut_path).reshape((-1, 4)).astype(np.float32) / 127.0
self.register_parameter(name="weight_" + key, param=torch.nn.Parameter(torch.Tensor(lut_arr)))
你会发现,其实跟MuLUT一样,将LUT给register成可训练的parameter,这样子去做一个微调。
4. 测试
测试的部分因为我们的LUT做了改变,修改为了对角线和非对角线,因此在最后的查表推理的部分需要做一些改变,以对角线做2维压缩为例,在sr/4_test_SPF_LUT_DFC.py中。
def InterpTorchBatch_compress_xy(weight, img_in, h, w, interval, rot, d, upscale=4, out_c=1, mode='s',ref2index=None):
q = 2 ** interval # 16
L = 2 ** (8 - interval) + 1 # 17
diag = 2 * d + 1
N = diag * L + (1 - diag ** 2) // 4
if mode == "s":
img_x = img_in[:, :, 0:0 + h, 0:0 + w]
img_y = img_in[:, :, 0:0 + h, 1:1 + w]
index_flag = (np.abs(img_x - img_y) <= d * q)
# Extract MSBs
img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
img_b1 = img_in[:, :, 0:0 + h, 1:1 + w] // q
img_c1 = img_in[:, :, 1:1 + h, 0:0 + w] // q
img_d1 = img_in[:, :, 1:1 + h, 1:1 + w] // q
# Extract LSBs
fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
fb = img_in[:, :, 0:0 + h, 1:1 + w] % q
fc = img_in[:, :, 1:1 + h, 0:0 + w] % q
fd = img_in[:, :, 1:1 + h, 1:1 + w] % q
elif mode == 'd':
img_x = img_in[:, :, 0:0 + h, 0:0 + w]
img_y = img_in[:, :, 0:0 + h, 2:2 + w]
index_flag = (np.abs(img_x - img_y) <= d * q)
img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
img_b1 = img_in[:, :, 0:0 + h, 2:2 + w] // q
img_c1 = img_in[:, :, 2:2 + h, 0:0 + w] // q
img_d1 = img_in[:, :, 2:2 + h, 2:2 + w] // q
fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
fb = img_in[:, :, 0:0 + h, 2:2 + w] % q
fc = img_in[:, :, 2:2 + h, 0:0 + w] % q
fd = img_in[:, :, 2:2 + h, 2:2 + w] % q
elif mode == 'y':
img_x = img_in[:, :, 0:0 + h, 0:0 + w]
img_y = img_in[:, :, 1:1 + h, 1:1 + w]
index_flag = (np.abs(img_x - img_y) <= d * q)
img_a1 = img_in[:, :, 0:0 + h, 0:0 + w] // q
img_b1 = img_in[:, :, 1:1 + h, 1:1 + w] // q
img_c1 = img_in[:, :, 1:1 + h, 2:2 + w] // q
img_d1 = img_in[:, :, 2:2 + h, 1:1 + w] // q
fa = img_in[:, :, 0:0 + h, 0:0 + w] % q
fb = img_in[:, :, 1:1 + h, 1:1 + w] % q
fc = img_in[:, :, 1:1 + h, 2:2 + w] % q
fd = img_in[:, :, 2:2 + h, 1:1 + w] % q
else:
# more sampling modes can be implemented similarly
raise ValueError("Mode {} not implemented.".format(mode))
img_a1 = img_a1[index_flag].flatten().astype(np.int_)
img_b1 = img_b1[index_flag].flatten().astype(np.int_)
img_c1 = img_c1[index_flag].flatten().astype(np.int_)
img_d1 = img_d1[index_flag].flatten().astype(np.int_)
fa = fa[index_flag].flatten()
fb = fb[index_flag].flatten()
fc = fc[index_flag].flatten()
fd = fd[index_flag].flatten()
img_a2 = img_a1 + 1
img_b2 = img_b1 + 1
img_c2 = img_c1 + 1
img_d2 = img_d1 + 1
k00 = ref2index[img_a1, img_b1 - img_a1]
k01 = ref2index[img_a1, img_b2 - img_a1]
k10 = ref2index[img_a2, img_b1 - img_a2]
k11 = ref2index[img_a2, img_b2 - img_a2]
p0000 = weight[k00,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
p0001 = weight[k00,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
p0010 = weight[k00,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
p0011 = weight[k00,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
p0100 = weight[k01,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
p0101 = weight[k01,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
p0110 = weight[k01,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
p0111 = weight[k01,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
p1000 = weight[k10,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
p1001 = weight[k10,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
p1010 = weight[k10,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
p1011 = weight[k10,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
p1100 = weight[k11,img_c1, img_d1].reshape((-1, out_c,upscale,upscale))
p1101 = weight[k11,img_c1, img_d2].reshape((-1, out_c,upscale,upscale))
p1110 = weight[k11,img_c2, img_d1].reshape((-1, out_c,upscale,upscale))
p1111 = weight[k11,img_c2, img_d2].reshape((-1, out_c,upscale,upscale))
# Output image holder
out = np.zeros((img_a1.shape[0],out_c, upscale, upscale))
sz = img_a1.shape[0]
out = out.reshape(sz, -1)
p0000 = p0000.reshape(sz, -1)
p0100 = p0100.reshape(sz, -1)
p1000 = p1000.reshape(sz, -1)
p1100 = p1100.reshape(sz, -1)
fa = fa.reshape(-1, 1)
p0001 = p0001.reshape(sz, -1)
p0101 = p0101.reshape(sz, -1)
p1001 = p1001.reshape(sz, -1)
p1101 = p1101.reshape(sz, -1)
fb = fb.reshape(-1, 1)
fc = fc.reshape(-1, 1)
p0010 = p0010.reshape(sz, -1)
p0110 = p0110.reshape(sz, -1)
p1010 = p1010.reshape(sz, -1)
p1110 = p1110.reshape(sz, -1)
fd = fd.reshape(-1, 1)
p0011 = p0011.reshape(sz, -1)
p0111 = p0111.reshape(sz, -1)
p1011 = p1011.reshape(sz, -1)
p1111 = p1111.reshape(sz, -1)
fab = fa > fb;
fac = fa > fc;
fad = fa > fd
fbc = fb > fc;
fbd = fb > fd;
fcd = fc > fd
i1 = i = np.logical_and.reduce((fab, fbc, fcd)).squeeze(1)
# print(p0000[i].shape,fa[i].shape,i.shape,out_c)
out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[
i] + (fd[i]) * p1111[i]
i2 = i = np.logical_and.reduce((~i1[:, None], fab, fbc, fbd)).squeeze(1)
out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fb[i]) * p1000[i] + (fb[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[
i] + (fc[i]) * p1111[i]
i3 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], fab, fbc, fad)).squeeze(1)
out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[
i] + (fc[i]) * p1111[i]
i4 = i = np.logical_and.reduce((~i1[:, None], ~i2[:, None], ~i3[:, None], fab, fbc)).squeeze(1)
out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fb[i]) * p1001[i] + (fb[i] - fc[i]) * p1101[
i] + (fc[i]) * p1111[i]
i5 = i = np.logical_and.reduce((~(fbc), fab, fac, fbd)).squeeze(1)
out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[
i] + (fd[i]) * p1111[i]
i6 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], fab, fac, fcd)).squeeze(1)
out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fc[i]) * p1000[i] + (fc[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[
i] + (fb[i]) * p1111[i]
i7 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], fab, fac, fad)).squeeze(1)
out[i] = (q - fa[i]) * p0000[i] + (fa[i] - fd[i]) * p1000[i] + (fd[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[
i] + (fb[i]) * p1111[i]
i8 = i = np.logical_and.reduce((~(fbc), ~i5[:, None], ~i6[:, None], ~i7[:, None], fab, fac)).squeeze(1)
out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fa[i]) * p0001[i] + (fa[i] - fc[i]) * p1001[i] + (fc[i] - fb[i]) * p1011[
i] + (fb[i]) * p1111[i]
i9 = i = np.logical_and.reduce((~(fbc), ~(fac), fab, fbd)).squeeze(1)
out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fb[i]) * p1010[i] + (fb[i] - fd[i]) * p1110[
i] + (fd[i]) * p1111[i]
# Fix the overflow bug in SR-LUT's implementation, should compare fd with fa first!
# i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], fab, fcd)).squeeze(1)
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fa[i]) * p0010[i] + (fa[i]-fd[i]) * p1010[i] + (fd[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
# i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:,None], ~i10[:,None], fab, fad)).squeeze(1)
# out[i] = (q-fc[i]) * p0000[i] + (fc[i]-fd[i]) * p0010[i] + (fd[i]-fa[i]) * p0011[i] + (fa[i]-fb[i]) * p1011[i] + (fb[i]) * p1111[i]
i10 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], fab, fad)).squeeze(1) # c > a > d > b
out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fa[i]) * p0010[i] + (fa[i] - fd[i]) * p1010[i] + (fd[i] - fb[i]) * p1011[
i] + (fb[i]) * p1111[i]
i11 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], fab, fcd)).squeeze(1) # c > d > a > b
out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[
i] + (fb[i]) * p1111[i]
i12 = i = np.logical_and.reduce((~(fbc), ~(fac), ~i9[:, None], ~i10[:, None], ~i11[:, None], fab)).squeeze(1)
out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fa[i]) * p0011[i] + (fa[i] - fb[i]) * p1011[
i] + (fb[i]) * p1111[i]
i13 = i = np.logical_and.reduce((~(fab), fac, fcd)).squeeze(1)
out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fc[i]) * p1100[i] + (fc[i] - fd[i]) * p1110[
i] + (fd[i]) * p1111[i]
i14 = i = np.logical_and.reduce((~(fab), ~i13[:, None], fac, fad)).squeeze(1)
out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fa[i]) * p0100[i] + (fa[i] - fd[i]) * p1100[i] + (fd[i] - fc[i]) * p1101[
i] + (fc[i]) * p1111[i]
i15 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], fac, fbd)).squeeze(1)
out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[
i] + (fc[i]) * p1111[i]
i16 = i = np.logical_and.reduce((~(fab), ~i13[:, None], ~i14[:, None], ~i15[:, None], fac)).squeeze(1)
out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fa[i]) * p0101[i] + (fa[i] - fc[i]) * p1101[
i] + (fc[i]) * p1111[i]
i17 = i = np.logical_and.reduce((~(fab), ~(fac), fbc, fad)).squeeze(1)
out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[
i] + (fd[i]) * p1111[i]
i18 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], fbc, fcd)).squeeze(1)
out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fc[i]) * p0100[i] + (fc[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[
i] + (fa[i]) * p1111[i]
i19 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], fbc, fbd)).squeeze(1)
out[i] = (q - fb[i]) * p0000[i] + (fb[i] - fd[i]) * p0100[i] + (fd[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[
i] + (fa[i]) * p1111[i]
i20 = i = np.logical_and.reduce((~(fab), ~(fac), ~i17[:, None], ~i18[:, None], ~i19[:, None], fbc)).squeeze(1)
out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fb[i]) * p0001[i] + (fb[i] - fc[i]) * p0101[i] + (fc[i] - fa[i]) * p0111[
i] + (fa[i]) * p1111[i]
i21 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), fad)).squeeze(1)
out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fa[i]) * p0110[i] + (fa[i] - fd[i]) * p1110[
i] + (fd[i]) * p1111[i]
i22 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], fbd)).squeeze(1)
out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fb[i]) * p0010[i] + (fb[i] - fd[i]) * p0110[i] + (fd[i] - fa[i]) * p0111[
i] + (fa[i]) * p1111[i]
i23 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], fcd)).squeeze(1)
out[i] = (q - fc[i]) * p0000[i] + (fc[i] - fd[i]) * p0010[i] + (fd[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[
i] + (fa[i]) * p1111[i]
i24 = i = np.logical_and.reduce((~(fab), ~(fac), ~(fbc), ~i21[:, None], ~i22[:, None], ~i23[:, None])).squeeze(1)
out[i] = (q - fd[i]) * p0000[i] + (fd[i] - fc[i]) * p0001[i] + (fc[i] - fb[i]) * p0011[i] + (fb[i] - fa[i]) * p0111[
i] + (fa[i]) * p1111[i]
out = out / q
return out,index_flag
可以看到查表之前,需要计算一个index_flag,index_flag的定义即是否满足对角线条件,如果满足对角线条件就是通过对角线LUT去查表,否则我们是采用非对角线的LUT去查表,具体的逻辑大家可以去捋一捋,博主认为实际运行也很少会使用python去跑。
以上针对于SPFLUT代码实现的部分讲解完毕,如果有不清楚的问题欢迎大家提出。