【LUT技术专题】SPLUT代码解读

目录

原文概要

1. 训练

2. 转表

3. 测试


本文是对SPLUT技术的代码解读,原文解读请看SPLUT

原文概要

SPLUT通过增大网络感受野,来提升超分效果,实现SRLUT的改进,主要两个创新点:

  • 提出了一个无需旋转集成、多LUT串联和并联的方法来提升感受野。
  • 提出了一个并行网络做补偿,来改善LUT的量化误差问题。

其网络结构图如下:

Query Blocks:

代码结构如下:


我们这里以后缀为M的SPLUT为例进行讲解。


1. 训练

代码位于Train_SPLUT_M.py文件中,某一个分支结构实现如下,MSB和LSB都会经过下面的结构,处理后相加就得到最终结果,跟前面讲到的结构是对应的:

class SRNet(torch.nn.Module):
    def __init__(self, upscale=4):
        super(SRNet, self).__init__()

        self.upscale = upscale
        self.lvls = 16
        self.quant= BASEQ(self.lvls ,8.0)
        self.out_channal= OUT_NUM

        #cwh
        self.lut122=nn.Sequential(
            nn.Conv2d(1, 64, [2,2], stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 8, 1, stride=1, padding=0, dilation=1)
        )

        self.lut221=nn.Sequential(
            nn.Conv2d(2,  64, [2,1], stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 8, 1, stride=1, padding=0, dilation=1)
        )

        self.lut212=nn.Sequential(
            nn.Conv2d(2,  64, [1,2], stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 8, 1, stride=1, padding=0, dilation=1)
        )

        self.lut221_c12=nn.Sequential(
            nn.Conv2d(2,  64, [2,1], stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 16, 1, stride=1, padding=0, dilation=1)
        )

        self.lut212_c34=nn.Sequential(
            nn.Conv2d(2,  64, [1,2], stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 64, 1, stride=1, padding=0, dilation=1),
            nn.GELU(),
            nn.Conv2d(64, 16, 1, stride=1, padding=0, dilation=1)
        )

        self.pixel_shuffle = nn.PixelShuffle(upscale)

        # Init weights
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.lower().find('conv') != -1:
                nn.init.kaiming_normal(m.weight)
                nn.init.constant(m.bias, 0)
            elif classname.find('bn') != -1:
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.fill_(0)

    def forward(self, x_in):
        B, C, H, W = x_in.size()
        x_in = x_in.reshape(B*C, 1, H, W)
        x = self.lut122(x_in)+x_in[:,:,:H-1,:W-1]
        x_temp=x_in
        x221 = self.lut221(self.quant(F.pad(x[:,0:2,:,:],(0,0,0,1), mode='reflect')+F.pad(x[:,2:4,:,:],(0,0,1,0), mode='reflect')))
        x212 = self.lut212(self.quant(F.pad(x[:,4:6,:,:],(0,1,0,0), mode='reflect')+F.pad(x[:,6: ,:,:],(1,0,0,0), mode='reflect')))
        x=(x221+x212)/2.0+x

        x2 = self.lut221_c12(self.quant(F.pad(x[:,0:2,:,:],(0,0,0,1), mode='reflect')+F.pad(x[:,2:4,:,:],(0,0,1,0), mode='reflect')))
        x3 = self.lut212_c34(self.quant(F.pad(x[:,4:6,:,:],(0,1,0,0), mode='reflect')+F.pad(x[:,6: ,:,:],(1,0,0,0), mode='reflect')))
        x=(x2+x3)/2.0+x_temp[:,:,:H-1,:W-1]
        x = self.pixel_shuffle(x)
        x = x.reshape(B, C, self.upscale*(H-1), self.upscale*(W-1))
        return x

可以看到self.lut122代表了空间的lut,即Spatial Lookup

接下来计算x221,先忽略quant算子,可以看到F.pad(x[:,0:2,:,:],(0,0,0,1), mode='reflect')+F.pad(x[:,2:4,:,:],(0,0,1,0), mode='reflect'),这个过程代表了Query Blocks里的horizontal aggregate模块。再进行LUT221,即LUT WC的操作,这样子完成Query Blocks的一个支路。x212即是对应的vertical部分,这两部分加起来之后除以2即可。

接下来是重复的计算x2和x3,最后经过一个pixel_shuffle完成超分即可。

然后是量化节点的实现,代码如下:

class _baseq(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, steps):
        y_step_ind=torch.floor(x / steps)
        y = y_step_ind * steps
        return y

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

class BASEQ(nn.Module):
    def __init__(self, lvls,activation_range):
        super(BASEQ, self).__init__()
        self.lvls = lvls
        self.activation_range = activation_range
        self.steps = 2 * activation_range / self.lvls

    def forward(self, x):
        x=(((-x - self.activation_range).abs() - (x - self.activation_range).abs()))/2.0
        x[x > self.activation_range-0.1*self.steps] =self.activation_range-0.1*self.steps
        return _baseq.apply(x, self.steps)

量化节点是这个BASEQ类,作者给的lvls和activation_range的默认值分别是16和8,因此steps就是1,前向第一步是等价一个torch.clip操作,大于activation range是activation range,小于-activation range是-activation range,在-activation range和activation range之间是x(这里可以将不同情况下的输入带入代码中计算,可以得到这个结论),第二步博主这里认为是由于带符号数正数和负数是不对称的,例如一个4bit的输入,它的值域范围是[-8, 7],因此这里对正数做了进一步的clip,来减小这个误差。最后应用_baseq的量化算子,因为steps是1,这里Funcition的forward完成的是一个向下取整的floor操作,backward完成的就是STE,直接将梯度传递下去,保证整个网络的可导性。


2. 转表

在Transfer_SPLUT_M.py中,这里跟SRLUT基本一致,根据LUT类型,先构造4bit的输入,如下:

L = 2 ** img_bits
 base_step_ind=torch.arange(0, L, 1)
 base=base_step_ind/16.0
 index_4D=torch.meshgrid(base,base,base,base)
 onebyfourth=torch.cat([index_4D[0].flatten().unsqueeze(1),index_4D[1].flatten().unsqueeze(1),index_4D[2].flatten().unsqueeze(1),index_4D[3].flatten().unsqueeze(1)],1)
 
 base_B=base_step_ind/16.0
 index_4D_B=torch.meshgrid(base_B,base_B,base_B,base_B)
 onebyfourth_B=torch.cat([index_4D_B[0].flatten().unsqueeze(1),index_4D_B[1].flatten().unsqueeze(1),index_4D_B[2].flatten().unsqueeze(1),index_4D_B[3].flatten().unsqueeze(1)],1)

if LUT_NUM==1:
    input_tensor   = onebyfourth.unsqueeze(1).unsqueeze(1).reshape(-1,1,2,2)
    input_tensor_B = onebyfourth_B.unsqueeze(1).unsqueeze(1).reshape(-1,1,2,2)
elif KERNAL==221:
    input_tensor = onebyfourth.unsqueeze(1).unsqueeze(1).reshape(-1,2,2,1)
    input_tensor_B = onebyfourth.unsqueeze(1).unsqueeze(1).reshape(-1,2,2,1)
elif KERNAL==212:
    input_tensor = onebyfourth.unsqueeze(1).unsqueeze(1).reshape(-1,2,1,2)
    input_tensor_B = onebyfourth.unsqueeze(1).unsqueeze(1).reshape(-1,2,1,2)


再送入之前训练好的层中输出结果并将其保存即可。


3. 测试

在Inference_SPLUT_M.py中,这里只是将网络的部分用LUT表去实现了,而且前面也讲到了,SPLUT是4bit的,不用量化,所以它的测试中不存在插值,所以与网络推理的代码就非常相像了。如下:

 img_in = np.pad(img_lr, ((0,1), (0,1), (0,0)), mode='reflect').transpose((2,0,1))
    img_in_A255 = img_in// L
    img_in_B255 = img_in % L
    img_in_A = img_in_A255/L
    img_in_B = img_in_B255/L

    # A
    x_layer1=LUT1_122(LUTA1_122,img_in_A255)+img_in_A[:,0:h,0:w].reshape((3,1,h,w))
    x2_in1=np.pad(x_layer1[:,0:2,:,:],((0,0),(0,0),(0,1),(0,0)), mode='reflect')+np.pad(x_layer1[:,2:4,:,:],((0,0),(0,0),(1,0),(0,0)), mode='reflect')
    x2_in2=np.pad(x_layer1[:,4:6,:,:],((0,0),(0,0),(0,0),(0,1)), mode='reflect')+np.pad(x_layer1[:,6: ,:,:],((0,0),(0,0),(0,0),(1,0)), mode='reflect')
    x_layer2=(LUT23(LUTA2_221, x2_in1,8.0 ,'k221',2)+LUT23(LUTA2_212, x2_in2,8.0 ,'k212',2))/2.0+x_layer1
    x3_in1=np.pad(x_layer2[:,0:2,:,:],((0,0),(0,0),(0,1),(0,0)), mode='reflect')+np.pad(x_layer2[:,2:4,:,:],((0,0),(0,0),(1,0),(0,0)), mode='reflect')
    x3_in2=np.pad(x_layer2[:,4:6,:,:],((0,0),(0,0),(0,0),(0,1)), mode='reflect')+np.pad(x_layer2[:,6: ,:,:],((0,0),(0,0),(0,0),(1,0)), mode='reflect')
    img_out=(LUT23(LUTA3_221, x3_in1,8.0 ,'k221',3)+LUT23(LUTA3_212, x3_in2,8.0 ,'k212',3))/2.0+img_in_A[:,0:h,0:w].reshape((3,1,h,w))
    img_out=img_out.reshape((3,UPSCALE,UPSCALE,h,w))
    img_out=np.transpose(img_out,(0,3,1,4,2)).reshape((3,UPSCALE*h,UPSCALE*w))    
    img_out_A = img_out.transpose((1,2,0))
    # B
    x_layer1=LUT1_122(LUTB1_122,img_in_B255)+img_in_B[:,0:h,0:w].reshape((3,1,h,w))
    x2_in1=np.pad(x_layer1[:,0:2,:,:],((0,0),(0,0),(0,1),(0,0)), mode='reflect')+np.pad(x_layer1[:,2:4,:,:],((0,0),(0,0),(1,0),(0,0)), mode='reflect')
    x2_in2=np.pad(x_layer1[:,4:6,:,:],((0,0),(0,0),(0,0),(0,1)), mode='reflect')+np.pad(x_layer1[:,6: ,:,:],((0,0),(0,0),(0,0),(1,0)), mode='reflect')
    x_layer2=(LUT23(LUTB2_221, x2_in1,8.0 ,'k221',2)+LUT23(LUTB2_212, x2_in2,8.0 ,'k212',2))/2.0+x_layer1
    x3_in1=np.pad(x_layer2[:,0:2,:,:],((0,0),(0,0),(0,1),(0,0)), mode='reflect')+np.pad(x_layer2[:,2:4,:,:],((0,0),(0,0),(1,0),(0,0)), mode='reflect')
    x3_in2=np.pad(x_layer2[:,4:6,:,:],((0,0),(0,0),(0,0),(0,1)), mode='reflect')+np.pad(x_layer2[:,6: ,:,:],((0,0),(0,0),(0,0),(1,0)), mode='reflect')
    img_out=(LUT23(LUTB3_221, x3_in1,8.0 ,'k221',3)+LUT23(LUTB3_212, x3_in2,8.0 ,'k212',3))/2.0+img_in_B[:,0:h,0:w].reshape((3,1,h,w))
    img_out=img_out.reshape((3,UPSCALE,UPSCALE,h,w))
    img_out=np.transpose(img_out,(0,3,1,4,2)).reshape((3,UPSCALE*h,UPSCALE*w))    
    img_out_B = img_out.transpose((1,2,0))

    img_out=img_out_A+img_out_B
    img_out = np.round(np.clip(img_out, 0, 1) * 255).astype(np.uint8)

可见,开始先获取MSB和LSB,然后送入LUT表中推理,过程跟我们前面的网络的推理过程是非常相像的,数据只不过是numpy.array和torch.Tensor的区别。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值