relaxed_madd
这条指令到底做了什么
核心:relaxed_madd
是一个分量级别 (Component-wise) 的操作
首先,最重要的一点是:v128.relaxed_madd<f32>(a, b, c)
不是矩阵乘法。它是一个在三个向量 a
, b
, c
之间进行的、逐个分量的、并行的融合乘加操作。
这三个向量 a
, b
, c
都是 v128
类型,我们可以把它们看作包含四个 32位浮点数 (f32) 的数组。
1. 定义我们的输入向量
让我们用数学形式来表示这三个输入向量 a
, b
, c
:
- 向量 a⃗=(a0a1a2a3)\vec{a} = \begin{pmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \end{pmatrix}a=a0a1a2a3
- 向量 b⃗=(b0b1b2b3)\vec{b} = \begin{pmatrix} b_0 \\ b_1 \\ b_2 \\ b_3 \end{pmatrix}b=b0b1b2b3
- 向量 c⃗=(c0c1c2c3)\vec{c} = \begin{pmatrix} c_0 \\ c_1 \\ c_2 \\ c_3 \end{pmatrix}c=c0c1c2c3
这里的 a0,a1,…,c3a_0, a_1, \dots, c_3a0,a1,…,c3 都是普通的 32 位浮点数。
2. relaxed_madd
的数学公式
v128.relaxed_madd<f32>(a, b, c)
这条指令执行的计算,其结果是一个新的向量,我们称之为 $ \vec{r} $ (result)。这个计算的数学公式是:
r⃗=(a⃗⊙b⃗)+c⃗\vec{r} = (\vec{a} \odot \vec{b}) + \vec{c}r=(a⊙b)+c
这里的 ⊙\odot⊙ 符号代表 哈达玛积 (Hadamard Product),也就是分量相乘 (component-wise multiplication)。
3. 展开公式,看清细节
把上面的公式展开到每一个分量上,就能看清它到底发生了什么。结果向量 r⃗\vec{r}r 的四个分量 r0,r1,r2,r3r_0, r_1, r_2, r_3r0,r1,r2,r3 是这样并行计算出来的:
r⃗=(r0r1r2r3)=(a0⋅b0+c0a1⋅b1+c1a2⋅b2+c2a3⋅b3+c3)\vec{r} = \begin{pmatrix} r_0 \\ r_1 \\ r_2 \\ r_3 \end{pmatrix} = \begin{pmatrix} a_0 \cdot b_0 + c_0 \\ a_1 \cdot b_1 + c_1 \\ a_2 \cdot b_2 + c_2 \\ a_3 \cdot b_3 + c_3 \end{pmatrix}r=r0r1r2r3=a0⋅b0+c0a1⋅b1+c1a2⋅b2+c2a3⋅b3+c3
这就是 relaxed_madd
的全部真相:它在一条指令里,同时并行地完成了这四个独立的融合乘加运算。
它在你的矩阵乘法代码中是如何被应用的?
现在我们把你代码中的一行拿出来,用这个公式来解释:
// 代码行
res0 = v128.relaxed_madd<f32>(sA0, rB0, res0_prev);
// (我把之前的 res0 重命名为 res0_prev 以便区分)
这里的输入是什么?
-
sA0
: 这是一个由矩阵 A 的元素A[0][0]
广播 (splat) 而来的向量。
sA0⃗=(A[0][0]A[0][0]A[0][0]A[0][0])\vec{sA0} = \begin{pmatrix} A[0][0] \\ A[0][0] \\ A[0][0] \\ A[0][0] \end{pmatrix}sA0=A[0][0]A[0][0]A[0][0]A[0][0] -
rB0
: 这是矩阵 B 的第一行向量。
rB0⃗=(B[0][0]B[0][1]B[0][2]B[0][3])\vec{rB0} = \begin{pmatrix} B[0][0] \\ B[0][1] \\ B[0][2] \\ B[0][3] \end{pmatrix}rB0=B[0][0]B[0][1]B[0][2]B[0][3] -
res0_prev
: 这是上一步计算的结果(累加值)。
那么,v128.relaxed_madd<f32>(sA0, rB0, res0_prev)
这一步到底计算了什么?我们套用上面的公式:
res0⃗=(sA0⃗⊙rB0⃗)+res0prev⃗\vec{res0} = (\vec{sA0} \odot \vec{rB0}) + \vec{res0_{prev}}res0=(sA0⊙rB0)+res0prev
展开来看就是:
res0⃗=(A[0][0]⋅B[0][0]+res0prev,0A[0][0]⋅B[0][1]+res0prev,1A[0][0]⋅B[0][2]+res0prev,2A[0][0]⋅B[0][3]+res0prev,3)\vec{res0} = \begin{pmatrix} A[0][0] \cdot B[0][0] + res0_{prev,0} \\ A[0][0] \cdot B[0][1] + res0_{prev,1} \\ A[0][0] \cdot B[0][2] + res0_{prev,2} \\ A[0][0] \cdot B[0][3] + res0_{prev,3} \end{pmatrix}res0=A[0][0]⋅B[0][0]+res0prev,0A[0][0]⋅B[0][1]+res0prev,1A[0][0]⋅B[0][2]+res0prev,2A[0][0]⋅B[0][3]+res0prev,3
这完美地对应了我们矩阵乘法思想:用 A 的一个标量元素,去数乘 B 的一整个行向量,然后加到累加器上。
当你把四次 relaxed_madd
调用链接起来后,最终的结果 res0
的第一个分量就是:
r0=(A[0][0]⋅B[0][0])+(A[0][1]⋅B[1][0])+(A[0][2]⋅B[2][0])+(A[0][3]⋅B[3][0])r_0 = (A[0][0] \cdot B[0][0]) + (A[0][1] \cdot B[1][0]) + (A[0][2] \cdot B[2][0]) + (A[0][3] \cdot B[3][0])r0=(A[0][0]⋅B[0][0])+(A[0][1]⋅B[1][0])+(A[0][2]⋅B[2][0])+(A[0][3]⋅B[3][0])
这正好是结果矩阵 C 的 C[0][0]
元素!其他分量同理。
FMA 的“融合”体现在哪里?
“融合” (Fused) 的意思是,在计算 ai⋅bi+cia_i \cdot b_i + c_iai⋅bi+ci 时:
- 计算 ai⋅bia_i \cdot b_iai⋅bi 的乘积,得到一个内部的、高精度的中间结果(比如 80 位浮点数)。
- 不进行舍入,直接用这个高精度结果与 cic_ici 相加。
- 对最终的和只进行一次舍入,得到 32 位的浮点数结果。
相比之下,非融合的 mul
+ add
会进行两次舍入,可能会损失精度。更重要的是,FMA 是一条硬件指令,吞吐量更高。