einsum用于矩阵乘法
直接上例子吧
比如
'bhqd, bhkd -> bhqk'
虽然是4维,但是前两维是不变的,先不看,只看后2维,qd, kd -> qk
这是两个矩阵相乘,两个矩阵的shape分别为A=qxd, B=kxd, 得到的结果形状是C =qxk
根据矩阵乘法,我们知道(qxd) x (dxk)结果的形状为qxk,
也就是说上面相当于是AxBT=C
验证一下
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy.shape',energy.shape)

该博客介绍了如何利用PyTorch中的einsum函数进行矩阵乘法操作。通过实例解析了'einsum'的字符串表示法,展示了它如何等价于传统的矩阵乘法和张量相乘。einsum简化了多维数组间的复杂运算,适用于(197x96)x(197x96)和(197x197)x(197x96)的矩阵计算,验证了运算结果的正确性。
最低0.47元/天 解锁文章
1920

被折叠的 条评论
为什么被折叠?



