图像中有一类任务是对RGB通道进行变换,例如RGB转换CIE XYZ,或者颜色风格转移。这类变换可以用矩阵乘法表示
图1. 未经gamma处理的sRGB到XYZ的转换矩阵,采用BT709标准,利用D65白点归一化
逐个像素处理可以直接按上述乘法进行,实际要对单张图/多张图批量处理,输入图像维度是H,W,C或B,C,H,W。需要对处理过程进行一定优化。
1. 采用reshape+矩阵乘法
先并合并HW维度,然后利用转置将通道维度转换到第一个维度,相当于把所有像素按顺序填入A=(3,HW)的矩阵中。然后,直接进行B = MA的矩阵乘法即可对所有像素进行颜色变换。
def reshape_dot(image,matrix):
init_shape = image.shape
image = image.reshape(-1,3).T
image_out = matrix @ image
image_out = image_out.T.reshape(init_shape)
return image_out
2. 采用爱因斯坦积
张量矩阵运算可以用爱因斯坦积化简。爱因斯坦积的简单理解就是:不重复的字母分开罗列结果,重复字母进行乘加消除。例如对于一般矩阵乘法的表达式
y
i
k
=
A
i
j
B
j
k
y_{ik} = A_{ij}B_{jk}
yik=AijBjk,有
i
×
k
i\times k
i×k个输出,给定i,k, 在等式右边j出现了重复,所以就需要对所有
A
i
j
,
B
j
k
A_{ij},B_{jk}
Aij,Bjk进行相乘相加来消掉j:
∑
k
A
i
j
B
j
k
\sum_kA_{ij}B_{jk}
∑kAijBjk。在这个例子中,我们完成了矩阵乘法的表达,记为'ij,jk->ik'
回到我们的问题,我们要做的矩阵乘法是在通道维度上进行,通道维度字母应该重复。同时,为了满足矩阵乘法的样式('ij,jk->ik'
)我们最好将通道维度放到第一个维度。最终,可以写出爱因斯坦积表达式
B
v
j
u
=
M
v
i
A
i
j
k
B_{vju} = M_{vi}A_{ijk}
Bvju=MviAijk,该表达式简记为'vi,iju->vju'
. vju和ijk代表了转换后和之前图像的CHW维度。
def einsum_dot(image:np.ndarray,matrix):
image = image.transpose(2,0,1) # CHW
return np.einsum('vi,iju->vju',matrix,image).transpose(1,2,0)
对于pytorch中的批量图片输入,我们仍然只关注通道维度,其他维度保持不变即可。 B b v j u = M v i A b i j k B_{bvju} = M_{vi}A_{bijk} Bbvju=MviAbijk. 其中bvju和bijk代表BCHW维度。
def einsum_dot_tensor(batched_image,matrix): # input BCHW
return torch.einsum('vi,biju->bvju',matrix,batched_image)
下面是一个转换效果,采用如下矩阵,交换RG维度,应该能观察到红绿反转
[
0
1
0
1
0
0
0
0
1
]
\begin{bmatrix} 0&1&0\\ 1&0&0\\ 0&0&1\\ \end{bmatrix}
010100001
图2. 原图像 图片源自wikipedia
图3. 三种方法转换后图像(reshape,numpy einsum,torch einsum)
完整代码
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
def reshape_dot(image,matrix):
init_shape = image.shape
image = image.reshape(-1,3).T
image_out = matrix @ image
image_out = image_out.T.reshape(init_shape)
return image_out
def einsum_dot(image:np.ndarray,matrix):
image = image.transpose(2,0,1) # CHW
return np.einsum('vi,iju->vju',matrix,image).transpose(1,2,0)
def einsum_dot_tensor(batched_image,matrix): # input BCHW
return torch.einsum('vi,biju->bvju',matrix,batched_image)
if __name__ == '__main__':
transform_matrix = np.array([[0.,1.,0.],
[1.,0.,0.],
[0.,0.,1.]])
transform_matrix_tensor = torch.from_numpy(transform_matrix)
image = Image.open('图片地址').convert('RGB')
image_np = np.array(image)/255.
image_batched = torch.from_numpy(image_np).permute(2,0,1).unsqueeze(0)
out_0 = reshape_dot(image_np,transform_matrix)
out_1 = einsum_dot(image_np,transform_matrix)
out_2 = einsum_dot_tensor(image_batched,transform_matrix_tensor).squeeze(0).permute(1,2,0).numpy()
plt.imshow(np.hstack([out_0,out_1,out_2]))
plt.show()