本文是对HVI技术的代码解读,原文解读请看HVI。
原文概要
本文通过提出HVI色彩空间和CIDNet模型,为低光照图像增强领域带来了新的突破。
- HVI色彩空间通过极化HS图和可学习的强度压缩函数,有效解决了传统色彩空间在低光照图像增强中的不足。
- CIDNet网络的双分支结构和交叉注意力机制进一步提升了图像增强的效果。
- 实验结果表明,HVI色彩空间和CIDNet模型在多个数据集上均取得了优异的性能,证明了其在低光照图像增强中的有效性和实用性。
1、颜色空间转换公式
H
=
cos
(
π
h
3
)
,
V
=
sin
(
π
h
3
)
{H} = \cos\left(\frac{\pi \mathbf{h}}{3}\right), \quad \mathbf{V} = \sin\left(\frac{\pi \mathbf{h}}{3}\right)
H=cos(3πh),V=sin(3πh)
C
k
(
x
)
=
sin
(
π
I
m
a
x
(
x
)
2
)
+
ε
k
,
{C}_k(x) = \sqrt[k]{\sin\left(\frac{\pi \mathbf{I}_{max}(x)}{2}\right) + \varepsilon},
Ck(x)=ksin(2πImax(x))+ε,
H
^
=
C
k
⊙
S
⊙
H
,
V
^
=
C
k
⊙
S
⊙
V
,
\begin{align*} \hat{\mathbf{H}} &= \mathbf{C}_k \odot \mathbf{S} \odot H,\\ \hat{\mathbf{V}} &= \mathbf{C}_k \odot \mathbf{S} \odot V, \end{align*}
H^V^=Ck⊙S⊙H,=Ck⊙S⊙V,
I
m
a
x
(
x
)
=
max
c
∈
{
R
,
G
,
B
}
(
I
c
(
x
)
)
{I}_{max}(x)=\max_{\mathbf{c}\in\{R,G,B\}}(\mathbf{I}_{\mathbf{c}}(x))
Imax(x)=c∈{R,G,B}max(Ic(x))
RGB转HVI其公式如上所述,h为HSV公式中的h;
I
m
a
x
(
x
)
{I}_{max}(x)
Imax(x)同理为HSV公式中的,其实也就是HVI中的I;S也是HSV中的s;其中 k ∈ Q 是可训练参数,使用小
ε
=
1
×
1
0
−
8
\varepsilon = 1\times10^{-8}
ε=1×10−8 以避免梯度爆炸。HVI颜色空间其实需要先转化为HSV颜色空间,然后在此基础上做一些变换操作。具体我们可以看代码。
HVI转RGB公式如下,首先HVI转HSV:
H
=
arctan
(
v
^
h
^
)
m
o
d
1
S
=
α
S
h
^
2
+
v
^
2
V
=
α
I
I
^
I
\begin{align} \mathbf{H} &= \arctan\left(\frac{\hat{v}}{\hat{h}}\right) \bmod 1 \\ \mathbf{S} &= \alpha_S \sqrt{\hat{h}^2 + \hat{v}^2} \\ \mathbf{V} &= \alpha_I \hat{\mathbf{I}}_{\mathbf{I}} \end{align}
HSV=arctan(h^v^)mod1=αSh^2+v^2=αII^I
然后HSV转RGB:
h
i
≡
⌊
h
60
⌋
(
m
o
d
6
)
f
=
h
60
−
h
i
p
=
v
×
(
1
−
s
)
q
=
v
×
(
1
−
f
×
s
)
t
=
v
×
(
1
−
(
1
−
f
)
×
s
)
\begin{align} h_i &\equiv \left\lfloor \frac{h}{60} \right\rfloor \pmod{6} \\ f &= \frac{h}{60} - h_i \\ p &= v\times(1 - s) \\ q &= v\times(1 - f\times s) \\ t &= v\times(1 - (1 - f)\times s) \\ \end{align}
hifpqt≡⌊60h⌋(mod6)=60h−hi=v×(1−s)=v×(1−f×s)=v×(1−(1−f)×s)
(
r
,
g
,
b
)
=
{
(
v
,
t
,
p
)
,
if
h
i
=
0
(
q
,
v
,
p
)
,
if
h
i
=
1
(
p
,
v
,
t
)
,
if
h
i
=
2
(
p
,
q
,
v
)
,
if
h
i
=
3
(
t
,
p
,
v
)
,
if
h
i
=
4
(
v
,
p
,
q
)
,
if
h
i
=
5
(r,g,b)= \begin{cases} (v,t,p), & \text{if } h_i = 0 \\ (q,v,p), & \text{if } h_i = 1 \\ (p,v,t), & \text{if } h_i = 2 \\ (p,q,v), & \text{if } h_i = 3 \\ (t,p,v), & \text{if } h_i = 4 \\ (v,p,q), & \text{if } h_i = 5 \end{cases}
(r,g,b)=⎩
⎨
⎧(v,t,p),(q,v,p),(p,v,t),(p,q,v),(t,p,v),(v,p,q),if hi=0if hi=1if hi=2if hi=3if hi=4if hi=5
import torch
import torch.nn as nn
pi = 3.141592653589793
class RGB_HVI(nn.Module):
def __init__(self):
super(RGB_HVI, self).__init__()
self.density_k = torch.nn.Parameter(torch.full([1],0.2)) # k is reciprocal to the paper mentioned
self.gated = False
self.gated2= False
self.alpha = 1.0
self.alpha_s = 1.3
self.this_k = 0
def HVIT(self, img):
eps = 1e-8
device = img.device
dtypes = img.dtype
hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
value = img.max(1)[0].to(dtypes)
img_min = img.min(1)[0].to(dtypes)
hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value]
hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value]
hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6
hue[img.min(1)[0]==value] = 0.0
hue = hue/6.0
saturation = (value - img_min ) / (value + eps )
saturation[value==0] = 0
hue = hue.unsqueeze(1)
saturation = saturation.unsqueeze(1)
value = value.unsqueeze(1)
k = self.density_k
self.this_k = k.item()
color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
ch = (2.0 * pi * hue).cos()
cv = (2.0 * pi * hue).sin()
H = color_sensitive * saturation * ch
V = color_sensitive * saturation * cv
I = value
xyz = torch.cat([H, V, I],dim=1)
return xyz
def PHVIT(self, img):
eps = 1e-8
H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:]
# clip
H = torch.clamp(H,-1,1)
V = torch.clamp(V,-1,1)
I = torch.clamp(I,0,1)
v = I
k = self.this_k
color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
H = (H) / (color_sensitive + eps)
V = (V) / (color_sensitive + eps)
H = torch.clamp(H,-1,1)
V = torch.clamp(V,-1,1)
h = torch.atan2(V + eps,H + eps) / (2*pi)
h = h%1
s = torch.sqrt(H**2 + V**2 + eps)
if self.gated:
s = s * self.alpha_s
s = torch.clamp(s,0,1)
v = torch.clamp(v,0,1)
r = torch.zeros_like(h)
g = torch.zeros_like(h)
b = torch.zeros_like(h)
hi = torch.floor(h * 6.0)
f = h * 6.0 - hi
p = v * (1. - s)
q = v * (1. - (f * s))
t = v * (1. - ((1. - f) * s))
hi0 = hi==0
hi1 = hi==1
hi2 = hi==2
hi3 = hi==3
hi4 = hi==4
hi5 = hi==5
r[hi0] = v[hi0]
g[hi0] = t[hi0]
b[hi0] = p[hi0]
r[hi1] = q[hi1]
g[hi1] = v[hi1]
b[hi1] = p[hi1]
r[hi2] = p[hi2]
g[hi2] = v[hi2]
b[hi2] = t[hi2]
r[hi3] = p[hi3]
g[hi3] = q[hi3]
b[hi3] = v[hi3]
r[hi4] = t[hi4]
g[hi4] = p[hi4]
b[hi4] = v[hi4]
r[hi5] = v[hi5]
g[hi5] = p[hi5]
b[hi5] = q[hi5]
r = r.unsqueeze(1)
g = g.unsqueeze(1)
b = b.unsqueeze(1)
rgb = torch.cat([r, g, b], dim=1)
if self.gated2:
rgb = rgb * self.alpha
return rgb
RGB到HVI转换(HVIT方法)
- HVIT 方法实现了从 RGB 颜色空间到 HVI 颜色空间的转换,它首先计算出图像的亮度(Value)、色调(Hue)和饱和度(Saturation)三个分量,其中亮度是 RGB 各通道中的最大值,色调根据各通道与亮度的关系确定,饱和度则由亮度与最小值的差值计算得到。
- 然后,基于亮度值计算颜色敏感性因子,该因子与一个可学习的参数相关,用于调节颜色的敏感程度,接着利用颜色敏感性因子、饱和度以及色调的余弦和正弦值,计算出 HVI 颜色空间中的 H、V、I 三个通道的值。
- 最后,将这三个通道的值拼接成一个新的张量并输出,从而完成从 RGB 到 HVI 的颜色空间转换。
HVI到RGB转换(PHVIT方法)
- PHVIT 方法负责将 HVI 颜色空间转换回 RGB 颜色空间,输入为一个 HVI 格式的图像张量,首先分离出 H、V、I 三个通道,并对 H 和 V 通道的值进行裁剪,限制在一定范围,对 I 通道的值也进行裁剪,确保其在 [0, 1] 范围内。
- 接着,利用 I 通道的值计算颜色敏感性因子,并基于此因子对 H 和 V 通道的值进行调整,再次裁剪后计算出色调值和饱和度,饱和度的计算还涉及是否启用门控机制来对其进行缩放,同时对饱和度和亮度值进行裁剪,避免超出合理范围。
- 最后,依据计算的色调值确定所处区间,根据不同区间的公式计算出对应的 R、G、B 值,将其拼接成一个新的张量,若启用门控机制则对 RGB 值进行缩放,最终输出转换后的 RGB 图像。
2、网络结构
网络结构图如下所示:
代码如下:
class CIDNet(nn.Module, PyTorchModelHubMixin):
def __init__(self,
channels=[36, 36, 72, 144],
heads=[1, 2, 4, 8],
norm=False
):
super(CIDNet, self).__init__()
[ch1, ch2, ch3, ch4] = channels
[head1, head2, head3, head4] = heads
# HV_ways
self.HVE_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False)
)
self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm)
self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm)
self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm)
self.HVD_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False)
)
# I_ways
self.IE_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False),
)
self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm)
self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm)
self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm)
self.ID_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False),
)
self.HV_LCA1 = HV_LCA(ch2, head2)
self.HV_LCA2 = HV_LCA(ch3, head3)
self.HV_LCA3 = HV_LCA(ch4, head4)
self.HV_LCA4 = HV_LCA(ch4, head4)
self.HV_LCA5 = HV_LCA(ch3, head3)
self.HV_LCA6 = HV_LCA(ch2, head2)
self.I_LCA1 = I_LCA(ch2, head2)
self.I_LCA2 = I_LCA(ch3, head3)
self.I_LCA3 = I_LCA(ch4, head4)
self.I_LCA4 = I_LCA(ch4, head4)
self.I_LCA5 = I_LCA(ch3, head3)
self.I_LCA6 = I_LCA(ch2, head2)
self.trans = RGB_HVI()
def forward(self, x):
dtypes = x.dtype
hvi = self.trans.HVIT(x)
i = hvi[:,2,:,:].unsqueeze(1).to(dtypes)
# low
i_enc0 = self.IE_block0(i)
i_enc1 = self.IE_block1(i_enc0)
hv_0 = self.HVE_block0(hvi)
hv_1 = self.HVE_block1(hv_0)
i_jump0 = i_enc0
hv_jump0 = hv_0
i_enc2 = self.I_LCA1(i_enc1, hv_1)
hv_2 = self.HV_LCA1(hv_1, i_enc1)
v_jump1 = i_enc2
hv_jump1 = hv_2
i_enc2 = self.IE_block2(i_enc2)
hv_2 = self.HVE_block2(hv_2)
i_enc3 = self.I_LCA2(i_enc2, hv_2)
hv_3 = self.HV_LCA2(hv_2, i_enc2)
v_jump2 = i_enc3
hv_jump2 = hv_3
i_enc3 = self.IE_block3(i_enc2)
hv_3 = self.HVE_block3(hv_2)
i_enc4 = self.I_LCA3(i_enc3, hv_3)
hv_4 = self.HV_LCA3(hv_3, i_enc3)
i_dec4 = self.I_LCA4(i_enc4,hv_4)
hv_4 = self.HV_LCA4(hv_4, i_enc4)
hv_3 = self.HVD_block3(hv_4, hv_jump2)
i_dec3 = self.ID_block3(i_dec4, v_jump2)
i_dec2 = self.I_LCA5(i_dec3, hv_3)
hv_2 = self.HV_LCA5(hv_3, i_dec3)
hv_2 = self.HVD_block2(hv_2, hv_jump1)
i_dec2 = self.ID_block2(i_dec3, v_jump1)
i_dec1 = self.I_LCA6(i_dec2, hv_2)
hv_1 = self.HV_LCA6(hv_2, i_dec2)
i_dec1 = self.ID_block1(i_dec1, i_jump0)
i_dec0 = self.ID_block0(i_dec1)
hv_1 = self.HVD_block1(hv_1, hv_jump0)
hv_0 = self.HVD_block0(hv_1)
output_hvi = torch.cat([hv_0, i_dec0], dim=1) + hvi
output_rgb = self.trans.PHVIT(output_hvi)
return output_rgb
def HVIT(self,x):
hvi = self.trans.HVIT(x)
return hvi
其中最重要的LCA注意力模块结构和代码如下:
# Cross Attention Block
class CAB(nn.Module):
def __init__(self, dim, num_heads, bias):
super(CAB, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x, y):
b, c, h, w = x.shape
q = self.q_dwconv(self.q(x))
kv = self.kv_dwconv(self.kv(y))
k, v = kv.chunk(2, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = nn.functional.softmax(attn,dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
# Intensity Enhancement Layer
class IEL(nn.Module):
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
super(IEL, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.dwconv1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
self.dwconv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
self.Tanh = nn.Tanh()
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x1 = self.Tanh(self.dwconv1(x1)) + x1
x2 = self.Tanh(self.dwconv2(x2)) + x2
x = x1 * x2
x = self.project_out(x)
return x
# Lightweight Cross Attention
class HV_LCA(nn.Module):
def __init__(self, dim,num_heads, bias=False):
super(HV_LCA, self).__init__()
self.gdfn = IEL(dim) # IEL and CDL have same structure
self.norm = LayerNorm(dim)
self.ffn = CAB(dim, num_heads, bias)
def forward(self, x, y):
x = x + self.ffn(self.norm(x),self.norm(y))
x = self.gdfn(self.norm(x))
return x
class I_LCA(nn.Module):
def __init__(self, dim,num_heads, bias=False):
super(I_LCA, self).__init__()
self.norm = LayerNorm(dim)
self.gdfn = IEL(dim)
self.ffn = CAB(dim, num_heads, bias=bias)
def forward(self, x, y):
x = x + self.ffn(self.norm(x),self.norm(y))
x = x + self.gdfn(self.norm(x))
return x
网络结构中用到的上下采样模块如下:
class NormDownsample(nn.Module):
def __init__(self,in_ch,out_ch,scale=0.5,use_norm=False):
super(NormDownsample, self).__init__()
self.use_norm=use_norm
if self.use_norm:
self.norm=LayerNorm(out_ch)
self.prelu = nn.PReLU()
self.down = nn.Sequential(
nn.Conv2d(in_ch, out_ch,kernel_size=3,stride=1, padding=1, bias=False),
nn.UpsamplingBilinear2d(scale_factor=scale))
def forward(self, x):
x = self.down(x)
x = self.prelu(x)
if self.use_norm:
x = self.norm(x)
return x
else:
return x
class NormUpsample(nn.Module):
def __init__(self, in_ch,out_ch,scale=2,use_norm=False):
super(NormUpsample, self).__init__()
self.use_norm=use_norm
if self.use_norm:
self.norm=LayerNorm(out_ch)
self.prelu = nn.PReLU()
self.up_scale = nn.Sequential(
nn.Conv2d(in_ch,out_ch,kernel_size=3,stride=1, padding=1, bias=False),
nn.UpsamplingBilinear2d(scale_factor=scale))
self.up = nn.Conv2d(out_ch*2,out_ch,kernel_size=1,stride=1, padding=0, bias=False)
def forward(self, x,y):
x = self.up_scale(x)
x = torch.cat([x, y],dim=1)
x = self.up(x)
x = self.prelu(x)
if self.use_norm:
return self.norm(x)
else:
return x
从网络结构和代码可以看出,该网络结构采用了类似unet型的结构,并且在中间插入了自注意力捕捉HS和I之间的信息。
3、小结
HVI代码主要是HVI和RGB颜色空间的互相转换以及CIDNet网络结构。
4、实测
我们用自己手机拍摄的夜间图像,用作者给出的Hugging Face Demo测试了一下,效果如下所示,可以看出暗区存在较为明细的分层现象,这个可能和以下因素有关:暗区过暗(极端暗区)、训练数据(数据中缺少这种极端场景)。