本文是对CLUT技术的代码讲解,原文解读请看CLUT文章讲解。
1、原文概要
CLUT利用矩阵在保持3DLUT映射能力的前提下显著降低了参数量。整体流程如下所示。
整体还是基于3D-LUT的框架,只不过添加了一个压缩自适应的变换矩阵。作者使用的损失函数在3DLUT的基础上额外添加了一个余弦相似度的损失。
2、代码结构
代码整体结构如下:
核心代码是models.py与LUT.py文件。
3 、核心代码模块
model.py
文件
1. CLUTNet类
这里是网络的整体实现,其定义了backbone、classifier、CLUT。
class CLUTNet(nn.Module):
def __init__(self, nsw, dim=33, backbone='Backbone', *args, **kwargs):
super().__init__()
self.TrilinearInterpolation = TrilinearInterpolation()
self.pre = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.backbone = eval(backbone)()
last_channel = self.backbone.last_channel
self.classifier = nn.Sequential(
nn.Conv2d(last_channel, 128,1,1),
nn.Hardswish(inplace=True),
nn.Dropout(p=0.2, inplace=True),
nn.Conv2d(128, int(nsw[:2]),1,1),
)
nsw = nsw.split("+")
num, s, w = int(nsw[0]), int(nsw[1]), int(nsw[2])
self.CLUTs = CLUT(num, dim, s, w)
def fuse_basis_to_one(self, img, TVMN=None):
mid_results = self.backbone(self.pre(img))
weights = self.classifier(mid_results)[:,:,0,0] # n, num
D3LUT, tvmn_loss = self.CLUTs(weights, TVMN)
return D3LUT, tvmn_loss
def forward(self, img, img_org, TVMN=None):
D3LUT, tvmn_loss = self.fuse_basis_to_one(img, TVMN)
img_res = self.TrilinearInterpolation(D3LUT, img_org)
return {
"fakes": img_res + img_org,
"3DLUT": D3LUT,
"tvmn_loss": tvmn_loss,
}
前向中给出了计算过程,首先图像经过backbone计算中间结果,然后经过classifer得到CLUT的输入,最后给到CLUT变换得到实际使用的3DLUT。
2. CLUT类
定义了CLUT的计算过程,讲解中提到了有3个主要参数,num代表LUT的条数,s和w是压缩的参数。
class CLUT(nn.Module):
def __init__(self, num, dim=33, s="-1", w="-1", *args, **kwargs):
super(CLUT, self).__init__()
self.num = num
self.dim = dim
self.s,self.w = s,w = eval(str(s)), eval(str(w))
# +: compressed; -: uncompressed
if s == -1 and w == -1: # standard 3DLUT
self.mode = '--'
self.LUTs = nn.Parameter(torch.zeros(num,3,dim,dim,dim))
elif s != -1 and w == -1:
self.mode = '+-'
self.s_Layers = nn.Parameter(torch.rand(dim, s)/5-0.1)
self.LUTs = nn.Parameter(torch.zeros(s, num*3*dim*dim))
elif s == -1 and w != -1:
self.mode = '-+'
self.w_Layers = nn.Parameter(torch.rand(w, dim*dim)/5-0.1)
self.LUTs = nn.Parameter(torch.zeros(num*3*dim, w))
else: # full-version CLUT
self.mode = '++'
self.s_Layers = nn.Parameter(torch.rand(dim, s)/5-0.1)
self.w_Layers = nn.Parameter(torch.rand(w, dim*dim)/5-0.1)
self.LUTs = nn.Parameter(torch.zeros(s*num*3,w))
print("n=%d s=%d w=%d"%(num, s, w), self.mode)
def reconstruct_luts(self):
dim = self.dim
num = self.num
if self.mode == "--":
D3LUTs = self.LUTs
else:
if self.mode == "+-":
# d,s x s,num*3dd -> d,num*3dd -> d,num*3,dd -> num,3,d,dd -> num,-1
CUBEs = self.s_Layers.mm(self.LUTs).reshape(dim,num*3,dim*dim).permute(1,0,2).reshape(num,3,self.dim,self.dim,self.dim)
if self.mode == "-+":
# num*3d,w x w,dd -> num*3d,dd -> num,3ddd
CUBEs = self.LUTs.mm(self.w_Layers).reshape(num,3,self.dim,self.dim,self.dim)
if self.mode == "++":
# s*num*3, w x w, dd -> s*num*3,dd -> s,num*3*dd -> d,num*3*dd -> num,-1
CUBEs = self.s_Layers.mm(self.LUTs.mm(self.w_Layers).reshape(-1,num*3*dim*dim)).reshape(dim,num*3,dim**2).permute(1,0,2).reshape(num,3,self.dim,self.dim,self.dim)
D3LUTs = cube_to_lut(CUBEs)
return D3LUTs
def combine(self, weights, TVMN): # n,num
dim = self.dim
num = self.num
D3LUTs = self.reconstruct_luts()
if TVMN is None:
tvmn_loss = 0
else:
tvmn_loss = TVMN(D3LUTs)
D3LUT = weights.mm(D3LUTs.reshape(num,-1)).reshape(-1,3,dim,dim,dim)
return D3LUT, tvmn_loss
def forward(self, weights, TVMN=None):
lut, tvmn_loss = self.combine(weights, TVMN)
return lut, tvmn_loss
mode这里是调整压缩的模式,当然我们需要的是完全压缩的版本,即mode==“++”,可以看到首先会对w_layers与self.LUTs矩阵乘,后续在跟s_layers进行矩阵乘,这与讲解相对应。
utils/LUT.py
文件
1. cube_to_lut函数
此函数在CLUT类的前向完成处理最后会调用到。
def cube_to_lut(cube): # (n,)3,d,d,d
if len(cube.shape) == 5:
to_shape = [
[0,2,3,1],
[0,2,1,3],
]
else:
to_shape = [
[1,2,0],
[1,0,2],
]
if isinstance(cube, torch.Tensor):
lut = torch.empty_like(cube)
lut[...,0,:,:,:] = cube[...,0,:,:,:].permute(*to_shape[0])
lut[...,1,:,:,:] = cube[...,1,:,:,:].permute(*to_shape[1])
lut[...,2,:,:,:] = cube[...,2,:,:,:]
else:
lut = np.empty_like(cube)
lut[...,0,:,:,:] = cube[...,0,:,:,:].transpose(*to_shape[0])
lut[...,1,:,:,:] = cube[...,1,:,:,:].transpose(*to_shape[1])
lut[...,2,:,:,:] = cube[...,2,:,:,:]
return lut
通过CLUT类我们可以看到送入到该函数的输入的shape是(num,3,self.dim,self.dim,self.dim),因为shape的长度为5,to_shape是[0,2,3,1]以及[0,2,1,3],也就是说实际的lut是调换通道顺序的cube变量。
3、总结
代码实现核心的部分讲解完毕,跟以往最不同的部分就在于这个CLUT的计算矩阵,把这部分看明白就行。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。