【WebGPU学习杂记】WebAssembly中的relaxed_madd指令到底做了什么?

relaxed_madd 这条指令到底做了什么


核心:relaxed_madd 是一个分量级别 (Component-wise) 的操作

首先,最重要的一点是:v128.relaxed_madd<f32>(a, b, c) 不是矩阵乘法。它是一个在三个向量 a, b, c 之间进行的、逐个分量的、并行的融合乘加操作。

这三个向量 a, b, c 都是 v128 类型,我们可以把它们看作包含四个 32位浮点数 (f32) 的数组。

1. 定义我们的输入向量

让我们用数学形式来表示这三个输入向量 a, b, c

  • 向量 a⃗=(a0a1a2a3)\vec{a} = \begin{pmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \end{pmatrix}a=a0a1a2a3
  • 向量 b⃗=(b0b1b2b3)\vec{b} = \begin{pmatrix} b_0 \\ b_1 \\ b_2 \\ b_3 \end{pmatrix}b=b0b1b2b3
  • 向量 c⃗=(c0c1c2c3)\vec{c} = \begin{pmatrix} c_0 \\ c_1 \\ c_2 \\ c_3 \end{pmatrix}c=c0c1c2c3

这里的 a0,a1,…,c3a_0, a_1, \dots, c_3a0,a1,,c3 都是普通的 32 位浮点数。

2. relaxed_madd 的数学公式

v128.relaxed_madd<f32>(a, b, c) 这条指令执行的计算,其结果是一个新的向量,我们称之为 $ \vec{r} $ (result)。这个计算的数学公式是:

r⃗=(a⃗⊙b⃗)+c⃗\vec{r} = (\vec{a} \odot \vec{b}) + \vec{c}r=(ab)+c

这里的 ⊙\odot 符号代表 哈达玛积 (Hadamard Product),也就是分量相乘 (component-wise multiplication)

3. 展开公式,看清细节

把上面的公式展开到每一个分量上,就能看清它到底发生了什么。结果向量 r⃗\vec{r}r 的四个分量 r0,r1,r2,r3r_0, r_1, r_2, r_3r0,r1,r2,r3 是这样并行计算出来的:

r⃗=(r0r1r2r3)=(a0⋅b0+c0a1⋅b1+c1a2⋅b2+c2a3⋅b3+c3)\vec{r} = \begin{pmatrix} r_0 \\ r_1 \\ r_2 \\ r_3 \end{pmatrix} = \begin{pmatrix} a_0 \cdot b_0 + c_0 \\ a_1 \cdot b_1 + c_1 \\ a_2 \cdot b_2 + c_2 \\ a_3 \cdot b_3 + c_3 \end{pmatrix}r=r0r1r2r3=a0b0+c0a1b1+c1a2b2+c2a3b3+c3

这就是 relaxed_madd 的全部真相:它在一条指令里,同时并行地完成了这四个独立的融合乘加运算


它在你的矩阵乘法代码中是如何被应用的?

现在我们把你代码中的一行拿出来,用这个公式来解释:

// 代码行
res0 = v128.relaxed_madd<f32>(sA0, rB0, res0_prev); 
// (我把之前的 res0 重命名为 res0_prev 以便区分)

这里的输入是什么?

  • sA0: 这是一个由矩阵 A 的元素 A[0][0] 广播 (splat) 而来的向量。
    sA0⃗=(A[0][0]A[0][0]A[0][0]A[0][0])\vec{sA0} = \begin{pmatrix} A[0][0] \\ A[0][0] \\ A[0][0] \\ A[0][0] \end{pmatrix}sA0=A[0][0]A[0][0]A[0][0]A[0][0]

  • rB0: 这是矩阵 B 的第一行向量。
    rB0⃗=(B[0][0]B[0][1]B[0][2]B[0][3])\vec{rB0} = \begin{pmatrix} B[0][0] \\ B[0][1] \\ B[0][2] \\ B[0][3] \end{pmatrix}rB0=B[0][0]B[0][1]B[0][2]B[0][3]

  • res0_prev: 这是上一步计算的结果(累加值)。

那么,v128.relaxed_madd<f32>(sA0, rB0, res0_prev) 这一步到底计算了什么?我们套用上面的公式:

res0⃗=(sA0⃗⊙rB0⃗)+res0prev⃗\vec{res0} = (\vec{sA0} \odot \vec{rB0}) + \vec{res0_{prev}}res0=(sA0rB0)+res0prev

展开来看就是:

res0⃗=(A[0][0]⋅B[0][0]+res0prev,0A[0][0]⋅B[0][1]+res0prev,1A[0][0]⋅B[0][2]+res0prev,2A[0][0]⋅B[0][3]+res0prev,3)\vec{res0} = \begin{pmatrix} A[0][0] \cdot B[0][0] + res0_{prev,0} \\ A[0][0] \cdot B[0][1] + res0_{prev,1} \\ A[0][0] \cdot B[0][2] + res0_{prev,2} \\ A[0][0] \cdot B[0][3] + res0_{prev,3} \end{pmatrix}res0=A[0][0]B[0][0]+res0prev,0A[0][0]B[0][1]+res0prev,1A[0][0]B[0][2]+res0prev,2A[0][0]B[0][3]+res0prev,3

这完美地对应了我们矩阵乘法思想:用 A 的一个标量元素,去数乘 B 的一整个行向量,然后加到累加器上。

当你把四次 relaxed_madd 调用链接起来后,最终的结果 res0 的第一个分量就是:
r0=(A[0][0]⋅B[0][0])+(A[0][1]⋅B[1][0])+(A[0][2]⋅B[2][0])+(A[0][3]⋅B[3][0])r_0 = (A[0][0] \cdot B[0][0]) + (A[0][1] \cdot B[1][0]) + (A[0][2] \cdot B[2][0]) + (A[0][3] \cdot B[3][0])r0=(A[0][0]B[0][0])+(A[0][1]B[1][0])+(A[0][2]B[2][0])+(A[0][3]B[3][0])
这正好是结果矩阵 C 的 C[0][0] 元素!其他分量同理。

FMA 的“融合”体现在哪里?

“融合” (Fused) 的意思是,在计算 ai⋅bi+cia_i \cdot b_i + c_iaibi+ci 时:

  1. 计算 ai⋅bia_i \cdot b_iaibi 的乘积,得到一个内部的、高精度的中间结果(比如 80 位浮点数)。
  2. 不进行舍入,直接用这个高精度结果与 cic_ici 相加。
  3. 对最终的和只进行一次舍入,得到 32 位的浮点数结果。

相比之下,非融合的 mul + add 会进行两次舍入,可能会损失精度。更重要的是,FMA 是一条硬件指令,吞吐量更高。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值