快速矩阵乘法的研究
最近的工作主要在于深度学习框架的性能优化。深度学习框架在工程的优化(内存池、SIMD、汇编、GPU、DSP等等)做到接近极限之后,突破点便集中于算法。
深度学习的性能瓶颈主要在于卷积,卷积的运算方法主要是通过 Im2Col / Winograd / FFT 转化为矩阵乘,完成矩阵乘法之后,再转化为目标结果。
深度学习框架的输入是算法工程产出的网络模型,而目前网络模型都渐渐地转变为 mobilenet 那样 1x1 convolution + depthwise 的形式,在精度几乎无损的情况下,既减少了计算量,又减少了模型体积。而这类网络模型,都以 1x1 卷积为主要耗时点。
对 1x1 卷积而言,其本身就是一个矩阵乘法,FFT / Winograd 等卷积算法已经失去价值,因此研读了一些矩阵乘法相关的论文,整理如下。
传统矩阵乘算法
定义
在 1968 年之前,矩阵乘算法只有按定义实现的传统算法,:
设:
A=(a11a12...a21a22............an1an2...)B=(b11b12...b21b22............bn1bn2...)A=\begin{pmatrix}
a_{11} &a_{12} &... \\
a_{21} &a_{22} &... \\
... & ... & ... \\
a_{n1} & a_{n2} & ... \\
\end{pmatrix}
B=\begin{pmatrix}
b_{11} &b_{12} &... \\
b_{21} &b_{22} &... \\
... & ... & ... \\
b_{n1} & b_{n2} & ... \\
\end{pmatrix}
A=⎝⎜⎜⎛a11a21...an1a12a22...an2............⎠⎟⎟⎞B=⎝⎜⎜⎛b11b21...bn1b12b22...bn2............⎠⎟⎟⎞
AB 为其乘积,则:
[AB]pq=∑i=1napibiq[AB]_{pq} = \sum_{i=1}^{n}a_{pi}b_{iq}[AB]pq=i=1∑napibiq
很明显,它是一个 n3n^3n3复杂度的算法,需要 n3n^3n3 次乘法和 n3−n2n^3-n^2n3−n2次加法。
矩阵乘表示
设 C=ABC = ABC=AB,A 为 e∗le*le∗l的矩阵,B 为 l∗hl*hl∗h的矩阵,则称这个矩阵乘是一个 [e,l,h][e, l, h][e,l,h] 的矩阵乘。
快速矩阵乘法的初步探索
Winograd 算法
请注意,这个不是我们通常所说的卷积优化算法,只是同一个人(Winograd大神)在 1968 年提出一种减少乘法数的矩阵乘算法。
其思路是通过两次 n2n^2n2 的乘法预处理,将规模大的矩阵乘法减少一半,但相应的加法增加一半。为了说明简单,这里假定nnn为偶数。
θp=∑j=1⌊n/2⌋(ap,2j−1ap,2j)γq=∑j=1⌊n/2⌋(b2j−1,qb2j,q)[AB]pq=∑j=1⌊n/2⌋(ap,2j−1+b2j,q)(ap,2j+b2j−1,q)−θp−γq\theta_p = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(a_{p, 2j-1} a_{p, 2j}) \\\gamma_q = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(b_{2j-1, q}b_{2j, q}) \\
[AB]_{pq} = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(a_{p, 2j-1}+b_{2j, q})(a_{p, 2j}+b_{2j-1, q}) - \theta_p - \gamma_q
θp=j=1∑⌊n/2⌋(ap,2j−1ap,2j)γq=j=1∑⌊n/2⌋(b2j−1,qb2j,q)[AB]pq=j=1∑⌊n/2⌋(ap,2j−1+b2j,q)(ap,2j+b2j−1,q)−θp−γq
这个算法没有降低矩阵乘法的阶(还是n3n^3n3),只是以廉价计算(加法)替代昂贵运算(乘法),需要根据具体的硬件去判断是否可应用。ARM 架构的 CPU,对量化矩阵乘有帮助,但对浮点矩阵乘没有用。
Strassen 矩阵乘算法
Strassen 矩阵乘的思路是通过加减变换,将一个 [2,2,2][2, 2, 2][2,2,2]的矩阵乘法所用的乘法数由8降到7,并且递归使用,降低矩阵乘法的阶数:n3n^3n3变成n2.81n^{2.81}n2.81
A=(a11a12a21a22)B=(b11b12b21b22)AB=(c11c12c21c22) A=\begin{pmatrix}
a_{11} &a_{12} \\
a_{21} &a_{22} \\
\end{pmatrix} B=\begin{pmatrix}
b_{11} &b_{12} \\
b_{21} &b_{22} \\
\end{pmatrix} AB=\begin{pmatrix}
c_{11} &c_{12} \\
c_{21} &c_{22} \\
\end{pmatrix}A=(a11a21a12a22)B=(b11b21b12b22)AB=(c11c21c12c22)
v1=(a11+a22)(b11+b22)v2=(a21+a22)(b11)v3=(a11)(b12−b22)v4=(a22)(b21−b11)v5=(a11+a12)(b22)v6=(a21−a11)(b11+b12)v7=(a12−a22)(b21+b22)v_1 = (a_{11}+a_{22})(b_{11}+b_{22})\\ v_2 = (a_{21}+a_{22})(b_{11})\\v_3 = (a_{11})(b_{12}-b_{22})\\v_4 = (a_{22})(b_{21}-b_{11})\\v_5 = (a_{11}+a_{12})(b_{22})\\v_6 = (a_{21}-a_{11})(b_{11}+b_{12})\\v_7 = (a_{12}-a_{22})(b_{21}+b_{22})v1=(a11+a22)(b11+b22)v2=(a21+a22)(b11)v3=(a11)(b12−b22)v4=(a22)(b21−b11)v5=(a11+a12)(b22)v6=(a21−a11)(b11+b12)v7=(a12−a22)(b21+b22)
c11=v1+v4−v5+v7c21=v2+v4c12=v3+v5c22=v1+v3−v2+v6c_{11} = v_1+v_4-v_5+v_7\\c_{21} = v_2+v_4\\c_{12} = v_3+v_5\\c_{22} = v_1+v_3-v_2+v_6c11=v1+v4−v5+v7c21=v2+v4c12=v3+v5c22=v1+v3−v2+v6
请注意,其中每个元素(a11,b12,c22a_{11}, b_{12}, c_{22}a11,b12,c22等等)不限于实数,可以是一个矩阵。因为矩阵乘法满足分配率与结合率。这样算法就有了脱离硬件的普适价值,因为矩阵加减的复杂度(n2n^2n2)远低于矩阵乘(n3n^3n3)
Winograd 在 Strassen 的基础上对它的算法进行了改进,减少了加减数(18->15),这个也成为最常用的 Strassen 矩阵乘法应用。
三线性表示
为了方便矩阵乘算法的研究,人们提出一种表示矩阵乘算法的形式,叫“Trilinear-form”,即三线性形式。
我们先以 Strassen 算法为例,它的三线性形式是:
∑i=12∑j=12∑k=12aijbjkcik=(a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22)\sum_{i=1}^2\sum_{j=1}^2\sum_{k=1}^2 a_{ij}b_{jk}c_{ik} = (a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22})∑i=12∑j=12∑k=12aijbjkcik=(a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22)
怎么看这个公式呢,它其实是按 Trace(ABC)=ABTrace(ABC) = ABTrace(ABC)=AB 的原理去表示的。两个矩阵的乘积,等效于三个矩阵乘积的迹。在上面公式中,如果我们要算出 c11c_{11}c11 的解法,就将 c11c_{11}c11 设成 1,其他的 c 值,c12,c21,c22c_{12}, c_{21}, c_{22}c12,c21,c22 全设成 0 ,然后将对应的项相加即可。
这个算式总共有7项,这个 7 我们称之为 Rank (阶)
APA——矩阵乘算法的突破
APA,即 Any Precision Algorithm,是把矩阵乘法阶数继续往下降的重要思想,基本思路是先给出近似的矩阵乘法表达式,然后在多阶张量积之后转换为准确的矩阵乘法。
张量积
我们来看 Strassen 矩阵乘法的表达式:
λ=(a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22)\lambda = (a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22})λ=(a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22)
对其平方:
λ2=((a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22))2\lambda^2 = ((a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22}))^2λ2=((a11)(b12−b22)(c12+c22)+(a11+a12)(b22)(−c11+c12)+(a21+a22)(b11)(c21−c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12−a22)(b21+b22)(c11)+(a11−a21)(b11+b12)(−c22))2
这是个多项式乘法,不难知λ2\lambda^2λ2 有 72=497^2=4972=49 项,我们来看其中一项:
((a11)(b12−b22)(c12+c22))((a11+a12)(b22)(−c11+c12))=(a11a11+a11a12)(b12b22−b22b22)(−c12c11+c12c12−c22c11+c22c12)((a_{11})(b_{12}-b_{22})(c_{12}+c_{22}))((a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}))=(a_{11}a_{11}+a_{11}a_{12})(b_{12}b_{22}-b_{22}b_{22})(-c_{12}c_{11}+c_{12}c_{12}-c_{22}c_{11}+c_{22}c_{12})((a11)(b12−b22)(c12+c22))((a11+a12)(b22)(−c11+c12))=(a11a11+a11a12)(b12b22−b22b22)(−c12c11+c12c12−c22c11+c22c12)
(依然是将a, b, c 分别组合在一起)
a,b,ca, b, ca,b,c间的相乘,如a11a12a_{11}a_{12}a11a12,我们将其替代为直和:a1112a_{1112}a1112,其含义可以这么理解,在a11a_{11}a11的区域(左上角)中,再划分为四块,取其a12a_{12}a12的区域(右上角)。
不难证明,我们通过这个多项式平方后得到的三线性形式,等效于一个 [4,4,4][4, 4, 4][4,4,4] 的矩阵乘法。
类似地,我们可以对矩阵乘法的三线性形式进行立方,n次方,以及两个不同的三线性形式乘积,这一系列操作可由“张量积”概括。
APA
Any Precision Algorithm(APA),即任意精度算法,通过在算式中引入一个可配置的实数λ\lambdaλ,得到更好的简化效果。
下面的式子近似用21项表示了一个[3,3,3][3, 3, 3][3,3,3]的矩阵乘法
F1(λ)=(a11+λ2a12)(λ2b11+b21)c11+(a21+λ2a22)(λ2b12+b22)c22+(a31+λ2a32)(λ2b13+b23)c33−a11(b21+b31)(c11+c12+c13)−a21(b22+b32)(c21+c22+c23)−a31(b23+b33)(c31+c32+c33)+(a11+λ2a22)(b21−λb12)c12+(a21+λ2a12)(b22−λb11)c21+(a11+λ2a32)(b21−λb13)c13+(a31+λ2a12)(b23−λb11)c31+(a21+λ2a32)(b22−λb13)c23+(a31+λ2a22)(b23−λb12)c32+(a11+λ2a23)(b31+λb12)(c12+λc21)+(a21+λ2a13)(b32+λb11)(c21+λc12)+(a11+λ2a33)(b31+λb13)(c13+λc31)+(a31+λ2a13)(b33+λb12)(c31+λc13)+(a21+λ2a33)(b32+λb13)(c23+λc32)+(a31+λ2a23)(b33+λb12)(c32+λc23)+(a11+λ2a13)b31(c11−λc31−λc21)+(a21+λ2a23)b32(c22−λc32−λc12)+(a31+λ2a33)b33(c33−λc13−λc23)=λ2(Trace(ABC)+λG(λ))F_1(\lambda) = (a_{11}+\lambda^2a_{12})(\lambda^2b_{11}+b_{21})c_{11}\\+(a_{21}+\lambda^2a_{22})(\lambda^2b_{12}+b_{22})c_{22}+(a_{31}+\lambda^2a_{32})(\lambda^2b_{13}+b_{23})c_{33}-a_{11}(b_{21}+b_{31})(c_{11}+c_{12}+c_{13})-a_{21}(b_{22}+b_{32})(c_{21}+c_{22}+c_{23})-a_{31}(b_{23}+b_{33})(c_{31}+c_{32}+c_{33})+(a_{11}+\lambda^2a_{22})(b_{21}-\lambda b_{12})c_{12}+(a_{21}+\lambda^2a_{12})(b_{22}-\lambda b_{11})c_{21}+(a_{11}+\lambda^2a_{32})(b_{21}-\lambda b_{13})c_{13}+(a_{31}+\lambda^2a_{12})(b_{23}-\lambda b_{11})c_{31}+(a_{21}+\lambda^2a_{32})(b_{22}-\lambda b_{13})c_{23}+(a_{31}+\lambda^2a_{22})(b_{23}-\lambda b_{12})c_{32}+(a_{11}+\lambda^2a_{23})(b_{31}+\lambda b_{12})(c_{12}+\lambda c_{21})+(a_{21}+\lambda^2a_{13})(b_{32}+\lambda b_{11})(c_{21}+\lambda c_{12})+(a_{11}+\lambda^2a_{33})(b_{31}+\lambda b_{13})(c_{13}+\lambda c_{31})+(a_{31}+\lambda^2a_{13})(b_{33}+\lambda b_{12})(c_{31}+\lambda c_{13})+(a_{21}+\lambda^2a_{33})(b_{32}+\lambda b_{13})(c_{23}+\lambda c_{32})+(a_{31}+\lambda^2a_{23})(b_{33}+\lambda b_{12})(c_{32}+\lambda c_{23})+(a_{11}+\lambda^2a_{13})b_{31}(c_{11}-\lambda c_{31}-\lambda c_{21})+(a_{21}+\lambda^2a_{23})b_{32}(c_{22}-\lambda c_{32}-\lambda c_{12})+(a_{31}+\lambda^2a_{33})b_{33}(c_{33}-\lambda c_{13}-\lambda c_{23}) = \lambda^2 (Trace(ABC)+\lambda G(\lambda))F1(λ)=(a11+λ2a12)(λ2b11+b21)c11+(a21+λ2a22)(λ2b12+b22)c22+(a31+λ2a32)(λ2b13+b23)c33−a11(b21+b31)(c11+c12+c13)−a21(b22+b32)(c21+c22+c23)−a31(b23+b33)(c31+c32+c33)+(a11+λ2a22)(b21−λb12)c12+(a21+λ2a12)(b22−λb11)c21+(a11+λ2a32)(b21−λb13)c13+(a31+λ2a12)(b23−λb11)c31+(a21+λ2a32)(b22−λb13)c23+(a31+λ2a22)(b23−λb12)c32+(a11+λ2a23)(b31+λb12)(c12+λc21)+(a21+λ2a13)(b32+λb11)(c21+λc12)+(a11+λ2a33)(b31+λb13)(c13+λc31)+(a31+λ2a13)(b33+λb12)(c31+λc13)+(a21+λ2a33)(b32+λb13)(c23+λc32)+(a31+λ2a23)(b33+λb12)(c32+λc23)+(a11+λ2a13)b31(c11−λc31−λc21)+(a21+λ2a23)b32(c22−λc32−λc12)+(a31+λ2a33)b33(c33−λc13−λc23)=λ2(Trace(ABC)+λG(λ))
当 λ\lambdaλ趋于无穷小时,其误差也趋于无穷小,因此我们可以设定任意的精度去使用它,这就是 APA 的由来。
对于 APA 算法,多项式的个数我们称之为 Border Rank,上述算式表示了一个[3,3,3][3, 3, 3][3,3,3]的矩阵乘法,在λ3\lambda ^3λ3的基础上分出误差,我们称之为一个降解:[3,3,3]⊴321[3, 3, 3] \unlhd_3 21[3,3,3]⊴321
现在我们来看怎么把上面的 APA 算法变成准确算法。
直观的做法就是把λ2\lambda^2λ2项取出来,如:(a11+λ2a12)(λ2b11+b21)c11(a_{11}+\lambda^2a_{12})(\lambda^2b_{11}+b_{21})c_{11}(a11+λ2a12)(λ2b11+b21)c11,取出 λ2a11b11c11+λ2a12b21c11\lambda^2a_{11}b_{11}c_{11}+\lambda^2a_{12}b_{21}c_{11}λ2a11b11c11+λ2a12b21c11,代价就是增加了多项式,不难证明,我们最多会增加到 2(2+1)/2=32(2+1)/2=32(2+1)/2=3倍的多项式个数。
无疑,这样做肯定亏了,3∗21=63>3∗3∗3=273*21=63 > 3*3*3=273∗21=63>3∗3∗3=27,我们需要施个魔法,就是张量积。
对上面APA 算法进行n次张量积之后,我们可以得到3n3^n3n大小的矩阵乘算法的降解:[3n,3n,3n]⊴2n+121n[3^n, 3^n, 3^n] \unlhd_{2n+1} 21^n[3n,3n,3n]⊴2n+121n
这时候我们再来取,就不一样了,其阶数变成了:
n(2n+1)21nn(2n+1)21^nn(2n+1)21n
很明显,当 n 足够大时,n(2n+1)n(2n+1)n(2n+1) 和指数项相比可忽略,这样我们就得到了更好的准确算法,其阶数为:
3ln(21)/ln(27)≈2.773ln(21)/ln(27)\approx2.773ln(21)/ln(27)≈2.77
下篇内容:
1、组合矩阵乘
2、渐近和定理
3、Strassen构造
4、Coppersmith–Winograd 算法